unbee tetrisd commited on
Commit
2aec976
0 Parent(s):

Duplicate from tetrisd/Diffusion-Attentive-Attribution-Maps

Browse files

Co-authored-by: Raphael Tang <tetrisd@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +31 -0
  2. README.md +14 -0
  3. app.py +170 -0
  4. diffusers/__init__.py +60 -0
  5. diffusers/__pycache__/__init__.cpython-310.pyc +0 -0
  6. diffusers/__pycache__/configuration_utils.cpython-310.pyc +0 -0
  7. diffusers/__pycache__/dependency_versions_check.cpython-310.pyc +0 -0
  8. diffusers/__pycache__/dependency_versions_table.cpython-310.pyc +0 -0
  9. diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc +0 -0
  10. diffusers/__pycache__/hub_utils.cpython-310.pyc +0 -0
  11. diffusers/__pycache__/modeling_utils.cpython-310.pyc +0 -0
  12. diffusers/__pycache__/onnx_utils.cpython-310.pyc +0 -0
  13. diffusers/__pycache__/optimization.cpython-310.pyc +0 -0
  14. diffusers/__pycache__/pipeline_utils.cpython-310.pyc +0 -0
  15. diffusers/__pycache__/testing_utils.cpython-310.pyc +0 -0
  16. diffusers/__pycache__/training_utils.cpython-310.pyc +0 -0
  17. diffusers/commands/__init__.py +27 -0
  18. diffusers/commands/__pycache__/__init__.cpython-310.pyc +0 -0
  19. diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc +0 -0
  20. diffusers/commands/__pycache__/env.cpython-310.pyc +0 -0
  21. diffusers/commands/diffusers_cli.py +41 -0
  22. diffusers/commands/env.py +70 -0
  23. diffusers/configuration_utils.py +403 -0
  24. diffusers/dependency_versions_check.py +47 -0
  25. diffusers/dependency_versions_table.py +26 -0
  26. diffusers/dynamic_modules_utils.py +335 -0
  27. diffusers/hub_utils.py +197 -0
  28. diffusers/modeling_utils.py +542 -0
  29. diffusers/models/__init__.py +17 -0
  30. diffusers/models/__pycache__/__init__.cpython-310.pyc +0 -0
  31. diffusers/models/__pycache__/attention.cpython-310.pyc +0 -0
  32. diffusers/models/__pycache__/embeddings.cpython-310.pyc +0 -0
  33. diffusers/models/__pycache__/resnet.cpython-310.pyc +0 -0
  34. diffusers/models/__pycache__/unet_2d.cpython-310.pyc +0 -0
  35. diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc +0 -0
  36. diffusers/models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
  37. diffusers/models/__pycache__/vae.cpython-310.pyc +0 -0
  38. diffusers/models/attention.py +409 -0
  39. diffusers/models/embeddings.py +115 -0
  40. diffusers/models/resnet.py +483 -0
  41. diffusers/models/unet_2d.py +246 -0
  42. diffusers/models/unet_2d_condition.py +272 -0
  43. diffusers/models/unet_blocks.py +1484 -0
  44. diffusers/models/vae.py +585 -0
  45. diffusers/onnx_utils.py +189 -0
  46. diffusers/optimization.py +275 -0
  47. diffusers/pipeline_utils.py +417 -0
  48. diffusers/pipelines/__init__.py +19 -0
  49. diffusers/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
  50. diffusers/pipelines/ddim/__init__.py +2 -0
.gitattributes ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Stable Diffusion Attentive Attribution Maps
3
+ emoji: 👀
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 3.4.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: tetrisd/Diffusion-Attentive-Attribution-Maps
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi, HfFolder
2
+ import os
3
+
4
+ api = HfApi()
5
+ api.set_access_token(os.environ['HF_SECRET'])
6
+ folder = HfFolder()
7
+ folder.save_token(os.environ['HF_SECRET'])
8
+
9
+ from threading import Lock
10
+ import math
11
+ import os
12
+ import random
13
+
14
+ from diffusers import StableDiffusionPipeline
15
+ from diffusers.models.attention import get_global_heat_map, clear_heat_maps
16
+ from matplotlib import pyplot as plt
17
+ import gradio as gr
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import spacy
21
+
22
+
23
+ if not os.environ.get('NO_DOWNLOAD_SPACY'):
24
+ spacy.cli.download('en_core_web_sm')
25
+
26
+
27
+ model_id = "runwayml/stable-diffusion-v1-5"
28
+ device = "cuda"
29
+
30
+ gen = torch.Generator(device='cuda')
31
+ gen.manual_seed(12758672)
32
+ orig_state = gen.get_state()
33
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(device)
34
+ lock = Lock()
35
+ nlp = spacy.load('en_core_web_sm')
36
+
37
+
38
+ def expand_m(m, n: int = 1, o=512, mode='bicubic'):
39
+ m = m.unsqueeze(0).unsqueeze(0) / n
40
+ m = F.interpolate(m.float().detach(), size=(o, o), mode='bicubic', align_corners=False)
41
+ m = (m - m.min()) / (m.max() - m.min() + 1e-8)
42
+ m = m.cpu().detach()
43
+
44
+ return m
45
+
46
+
47
+ @torch.no_grad()
48
+ def predict(prompt, inf_steps, threshold):
49
+ global lock
50
+ with torch.cuda.amp.autocast(), lock:
51
+ try:
52
+ plt.close('all')
53
+ except:
54
+ pass
55
+
56
+ gen.set_state(orig_state.clone())
57
+ clear_heat_maps()
58
+
59
+ out = pipe(prompt, guidance_scale=7.5, height=512, width=512, do_intermediates=False, generator=gen, num_inference_steps=int(inf_steps))
60
+ heat_maps = get_global_heat_map()
61
+
62
+ with torch.cuda.amp.autocast(dtype=torch.float32):
63
+ m = 0
64
+ n = 0
65
+ w = ''
66
+ w_idx = 0
67
+
68
+ fig, ax = plt.subplots()
69
+ ax.imshow(out.images[0].cpu().float().detach().permute(1, 2, 0).numpy())
70
+ ax.set_xticks([])
71
+ ax.set_yticks([])
72
+
73
+ fig1, axs1 = plt.subplots(math.ceil(len(out.words) / 4), 4)#, figsize=(20, 20))
74
+ fig2, axs2 = plt.subplots(math.ceil(len(out.words) / 4), 4) # , figsize=(20, 20))
75
+
76
+ for idx in range(len(out.words) + 1):
77
+ if idx == 0:
78
+ continue
79
+
80
+ word = out.words[idx - 1]
81
+ m += heat_maps[idx]
82
+ n += 1
83
+ w += word
84
+
85
+ if '</w>' not in word:
86
+ continue
87
+ else:
88
+ mplot = expand_m(m, n)
89
+ spotlit_im = out.images[0].cpu().float().detach()
90
+ w = w.replace('</w>', '')
91
+ spotlit_im2 = torch.cat((spotlit_im, (1 - mplot.squeeze(0)).pow(1)), dim=0)
92
+
93
+ if len(out.words) <= 4:
94
+ a1 = axs1[w_idx % 4]
95
+ a2 = axs2[w_idx % 4]
96
+ else:
97
+ a1 = axs1[w_idx // 4, w_idx % 4]
98
+ a2 = axs2[w_idx // 4, w_idx % 4]
99
+
100
+ a1.set_xticks([])
101
+ a1.set_yticks([])
102
+ a1.imshow(mplot.squeeze().numpy(), cmap='jet')
103
+ a1.imshow(spotlit_im2.permute(1, 2, 0).numpy())
104
+ a1.set_title(w)
105
+
106
+ mask = torch.ones_like(mplot)
107
+ mask[mplot < threshold * mplot.max()] = 0
108
+ im2 = spotlit_im * mask.squeeze(0)
109
+ a2.set_xticks([])
110
+ a2.set_yticks([])
111
+ a2.imshow(im2.permute(1, 2, 0).numpy())
112
+ a2.set_title(w)
113
+ m = 0
114
+ n = 0
115
+ w_idx += 1
116
+ w = ''
117
+
118
+ for idx in range(w_idx, len(axs1.flatten())):
119
+ fig1.delaxes(axs1.flatten()[idx])
120
+ fig2.delaxes(axs2.flatten()[idx])
121
+
122
+ return fig, fig1, fig2
123
+
124
+
125
+ def set_prompt(prompt):
126
+ return prompt
127
+
128
+
129
+ with gr.Blocks() as demo:
130
+ md = '''# DAAM: Attention Maps for Interpreting Stable Diffusion
131
+ Check out the paper: [What the DAAM: Interpreting Stable Diffusion Using Cross Attention](http://arxiv.org/abs/2210.04885).
132
+ See our (much cleaner) [DAAM codebase](https://github.com/castorini/daam) on GitHub.
133
+
134
+ **Update**: We got a community grant! I'll continue running and updating the space, with a major release planned in December.
135
+ '''
136
+ gr.Markdown(md)
137
+
138
+ with gr.Row():
139
+ with gr.Column():
140
+ dropdown = gr.Dropdown([
141
+ 'An angry, bald man doing research',
142
+ 'Doing research at Comcast Applied AI labs',
143
+ 'Professor Jimmy Lin from the University of Waterloo',
144
+ 'Yann Lecun teaching machine learning on a chalkboard',
145
+ 'A cat eating cake for her birthday',
146
+ 'Steak and dollars on a plate',
147
+ 'A fox, a dog, and a wolf in a field'
148
+ ], label='Examples', value='An angry, bald man doing research')
149
+
150
+ text = gr.Textbox(label='Prompt', value='An angry, bald man doing research')
151
+ slider1 = gr.Slider(15, 35, value=25, interactive=True, step=1, label='Inference steps')
152
+ slider2 = gr.Slider(0, 1.0, value=0.4, interactive=True, step=0.05, label='Threshold (tau)')
153
+ submit_btn = gr.Button('Submit')
154
+
155
+ with gr.Tab('Original Image'):
156
+ p0 = gr.Plot()
157
+
158
+ with gr.Tab('Soft DAAM Maps'):
159
+ p1 = gr.Plot()
160
+
161
+ with gr.Tab('Hard DAAM Maps'):
162
+ p2 = gr.Plot()
163
+
164
+ submit_btn.click(fn=predict, inputs=[text, slider1, slider2], outputs=[p0, p1, p2])
165
+ dropdown.change(set_prompt, dropdown, text)
166
+ dropdown.update()
167
+
168
+
169
+ demo.launch()
170
+
diffusers/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ is_inflect_available,
3
+ is_onnx_available,
4
+ is_scipy_available,
5
+ is_transformers_available,
6
+ is_unidecode_available,
7
+ )
8
+
9
+
10
+ __version__ = "0.3.0"
11
+
12
+ from .configuration_utils import ConfigMixin
13
+ from .modeling_utils import ModelMixin
14
+ from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
15
+ from .onnx_utils import OnnxRuntimeModel
16
+ from .optimization import (
17
+ get_constant_schedule,
18
+ get_constant_schedule_with_warmup,
19
+ get_cosine_schedule_with_warmup,
20
+ get_cosine_with_hard_restarts_schedule_with_warmup,
21
+ get_linear_schedule_with_warmup,
22
+ get_polynomial_decay_schedule_with_warmup,
23
+ get_scheduler,
24
+ )
25
+ from .pipeline_utils import DiffusionPipeline
26
+ from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
27
+ from .schedulers import (
28
+ DDIMScheduler,
29
+ DDPMScheduler,
30
+ KarrasVeScheduler,
31
+ PNDMScheduler,
32
+ SchedulerMixin,
33
+ ScoreSdeVeScheduler,
34
+ )
35
+ from .utils import logging
36
+
37
+
38
+ if is_scipy_available():
39
+ from .schedulers import LMSDiscreteScheduler
40
+ else:
41
+ from .utils.dummy_scipy_objects import * # noqa F403
42
+
43
+ from .training_utils import EMAModel
44
+
45
+
46
+ if is_transformers_available():
47
+ from .pipelines import (
48
+ LDMTextToImagePipeline,
49
+ StableDiffusionImg2ImgPipeline,
50
+ StableDiffusionInpaintPipeline,
51
+ StableDiffusionPipeline,
52
+ )
53
+ else:
54
+ from .utils.dummy_transformers_objects import * # noqa F403
55
+
56
+
57
+ if is_transformers_available() and is_onnx_available():
58
+ from .pipelines import StableDiffusionOnnxPipeline
59
+ else:
60
+ from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
diffusers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.85 kB). View file
 
diffusers/__pycache__/configuration_utils.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
diffusers/__pycache__/dependency_versions_check.cpython-310.pyc ADDED
Binary file (967 Bytes). View file
 
diffusers/__pycache__/dependency_versions_table.cpython-310.pyc ADDED
Binary file (819 Bytes). View file
 
diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc ADDED
Binary file (11.6 kB). View file
 
diffusers/__pycache__/hub_utils.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
diffusers/__pycache__/modeling_utils.cpython-310.pyc ADDED
Binary file (18.7 kB). View file
 
diffusers/__pycache__/onnx_utils.cpython-310.pyc ADDED
Binary file (6.3 kB). View file
 
diffusers/__pycache__/optimization.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
diffusers/__pycache__/pipeline_utils.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
diffusers/__pycache__/testing_utils.cpython-310.pyc ADDED
Binary file (1.66 kB). View file
 
diffusers/__pycache__/training_utils.cpython-310.pyc ADDED
Binary file (3.64 kB). View file
 
diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
diffusers/commands/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (817 Bytes). View file
 
diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc ADDED
Binary file (778 Bytes). View file
 
diffusers/commands/__pycache__/env.cpython-310.pyc ADDED
Binary file (2.17 kB). View file
 
diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+
20
+
21
+ def main():
22
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
+
25
+ # Register commands
26
+ EnvironmentCommand.register_subcommand(commands_parser)
27
+
28
+ # Let's go
29
+ args = parser.parse_args()
30
+
31
+ if not hasattr(args, "func"):
32
+ parser.print_help()
33
+ exit(1)
34
+
35
+ # Run
36
+ service = args.func(args)
37
+ service.run()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
diffusers/commands/env.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_torch_available, is_transformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available:
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ info = {
53
+ "`diffusers` version": version,
54
+ "Platform": platform.platform(),
55
+ "Python version": platform.python_version(),
56
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
57
+ "Huggingface_hub version": hub_version,
58
+ "Transformers version": transformers_version,
59
+ "Using GPU in script?": "<fill in>",
60
+ "Using distributed or parallel set-up in script?": "<fill in>",
61
+ }
62
+
63
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
64
+ print(self.format_dict(info))
65
+
66
+ return info
67
+
68
+ @staticmethod
69
+ def format_dict(d):
70
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
diffusers/configuration_utils.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixinuration base class and utilities."""
17
+ import functools
18
+ import inspect
19
+ import json
20
+ import os
21
+ import re
22
+ from collections import OrderedDict
23
+ from typing import Any, Dict, Tuple, Union
24
+
25
+ from huggingface_hub import hf_hub_download
26
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
27
+ from requests import HTTPError
28
+
29
+ from . import __version__
30
+ from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
36
+
37
+
38
+ class ConfigMixin:
39
+ r"""
40
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
41
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
42
+ - [`~ConfigMixin.from_config`]
43
+ - [`~ConfigMixin.save_config`]
44
+
45
+ Class attributes:
46
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
47
+ [`~ConfigMixin.save_config`] (should be overriden by parent class).
48
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
49
+ overriden by parent class).
50
+ """
51
+ config_name = None
52
+ ignore_for_config = []
53
+
54
+ def register_to_config(self, **kwargs):
55
+ if self.config_name is None:
56
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
57
+ kwargs["_class_name"] = self.__class__.__name__
58
+ kwargs["_diffusers_version"] = __version__
59
+
60
+ for key, value in kwargs.items():
61
+ try:
62
+ setattr(self, key, value)
63
+ except AttributeError as err:
64
+ logger.error(f"Can't set {key} with value {value} for {self}")
65
+ raise err
66
+
67
+ if not hasattr(self, "_internal_dict"):
68
+ internal_dict = kwargs
69
+ else:
70
+ previous_dict = dict(self._internal_dict)
71
+ internal_dict = {**self._internal_dict, **kwargs}
72
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
73
+
74
+ self._internal_dict = FrozenDict(internal_dict)
75
+
76
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
77
+ """
78
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
79
+ [`~ConfigMixin.from_config`] class method.
80
+
81
+ Args:
82
+ save_directory (`str` or `os.PathLike`):
83
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
84
+ """
85
+ if os.path.isfile(save_directory):
86
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
87
+
88
+ os.makedirs(save_directory, exist_ok=True)
89
+
90
+ # If we save using the predefined names, we can load using `from_config`
91
+ output_config_file = os.path.join(save_directory, self.config_name)
92
+
93
+ self.to_json_file(output_config_file)
94
+ logger.info(f"ConfigMixinuration saved in {output_config_file}")
95
+
96
+ @classmethod
97
+ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
98
+ r"""
99
+ Instantiate a Python class from a pre-defined JSON-file.
100
+
101
+ Parameters:
102
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
103
+ Can be either:
104
+
105
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
106
+ organization name, like `google/ddpm-celebahq-256`.
107
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
108
+ `./my_model_directory/`.
109
+
110
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
111
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
112
+ standard cache should not be used.
113
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
114
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
115
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
116
+ checkpoint with 3 labels).
117
+ force_download (`bool`, *optional*, defaults to `False`):
118
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
119
+ cached versions if they exist.
120
+ resume_download (`bool`, *optional*, defaults to `False`):
121
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
122
+ file exists.
123
+ proxies (`Dict[str, str]`, *optional*):
124
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
125
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
126
+ output_loading_info(`bool`, *optional*, defaults to `False`):
127
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
128
+ local_files_only(`bool`, *optional*, defaults to `False`):
129
+ Whether or not to only look at local files (i.e., do not try to download the model).
130
+ use_auth_token (`str` or *bool*, *optional*):
131
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
132
+ when running `transformers-cli login` (stored in `~/.huggingface`).
133
+ revision (`str`, *optional*, defaults to `"main"`):
134
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
135
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
136
+ identifier allowed by git.
137
+ mirror (`str`, *optional*):
138
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
139
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
140
+ Please refer to the mirror site for more information.
141
+
142
+ <Tip>
143
+
144
+ Passing `use_auth_token=True`` is required when you want to use a private model.
145
+
146
+ </Tip>
147
+
148
+ <Tip>
149
+
150
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
151
+ use this method in a firewalled environment.
152
+
153
+ </Tip>
154
+
155
+ """
156
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
157
+
158
+ init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
159
+
160
+ model = cls(**init_dict)
161
+
162
+ if return_unused_kwargs:
163
+ return model, unused_kwargs
164
+ else:
165
+ return model
166
+
167
+ @classmethod
168
+ def get_config_dict(
169
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
170
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
171
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
172
+ force_download = kwargs.pop("force_download", False)
173
+ resume_download = kwargs.pop("resume_download", False)
174
+ proxies = kwargs.pop("proxies", None)
175
+ use_auth_token = kwargs.pop("use_auth_token", None)
176
+ local_files_only = kwargs.pop("local_files_only", False)
177
+ revision = kwargs.pop("revision", None)
178
+ subfolder = kwargs.pop("subfolder", None)
179
+
180
+ user_agent = {"file_type": "config"}
181
+
182
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
183
+
184
+ if cls.config_name is None:
185
+ raise ValueError(
186
+ "`self.config_name` is not defined. Note that one should not load a config from "
187
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
188
+ )
189
+
190
+ if os.path.isfile(pretrained_model_name_or_path):
191
+ config_file = pretrained_model_name_or_path
192
+ elif os.path.isdir(pretrained_model_name_or_path):
193
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
194
+ # Load from a PyTorch checkpoint
195
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
196
+ elif subfolder is not None and os.path.isfile(
197
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
198
+ ):
199
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
200
+ else:
201
+ raise EnvironmentError(
202
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
203
+ )
204
+ else:
205
+ try:
206
+ # Load from URL or cache if already cached
207
+ config_file = hf_hub_download(
208
+ pretrained_model_name_or_path,
209
+ filename=cls.config_name,
210
+ cache_dir=cache_dir,
211
+ force_download=force_download,
212
+ proxies=proxies,
213
+ resume_download=resume_download,
214
+ local_files_only=local_files_only,
215
+ use_auth_token=use_auth_token,
216
+ user_agent=user_agent,
217
+ subfolder=subfolder,
218
+ revision=revision,
219
+ )
220
+
221
+ except RepositoryNotFoundError:
222
+ raise EnvironmentError(
223
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
224
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
225
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
226
+ " login` and pass `use_auth_token=True`."
227
+ )
228
+ except RevisionNotFoundError:
229
+ raise EnvironmentError(
230
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
231
+ " this model name. Check the model page at"
232
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
233
+ )
234
+ except EntryNotFoundError:
235
+ raise EnvironmentError(
236
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
237
+ )
238
+ except HTTPError as err:
239
+ raise EnvironmentError(
240
+ "There was a specific connection error when trying to load"
241
+ f" {pretrained_model_name_or_path}:\n{err}"
242
+ )
243
+ except ValueError:
244
+ raise EnvironmentError(
245
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
246
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
247
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
248
+ " run the library in offline mode at"
249
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
250
+ )
251
+ except EnvironmentError:
252
+ raise EnvironmentError(
253
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
254
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
255
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
256
+ f"containing a {cls.config_name} file"
257
+ )
258
+
259
+ try:
260
+ # Load config dict
261
+ config_dict = cls._dict_from_json_file(config_file)
262
+ except (json.JSONDecodeError, UnicodeDecodeError):
263
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
264
+
265
+ return config_dict
266
+
267
+ @classmethod
268
+ def extract_init_dict(cls, config_dict, **kwargs):
269
+ expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
270
+ expected_keys.remove("self")
271
+ # remove general kwargs if present in dict
272
+ if "kwargs" in expected_keys:
273
+ expected_keys.remove("kwargs")
274
+ # remove keys to be ignored
275
+ if len(cls.ignore_for_config) > 0:
276
+ expected_keys = expected_keys - set(cls.ignore_for_config)
277
+ init_dict = {}
278
+ for key in expected_keys:
279
+ if key in kwargs:
280
+ # overwrite key
281
+ init_dict[key] = kwargs.pop(key)
282
+ elif key in config_dict:
283
+ # use value from config dict
284
+ init_dict[key] = config_dict.pop(key)
285
+
286
+ unused_kwargs = config_dict.update(kwargs)
287
+
288
+ passed_keys = set(init_dict.keys())
289
+ if len(expected_keys - passed_keys) > 0:
290
+ logger.warning(
291
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
292
+ )
293
+
294
+ return init_dict, unused_kwargs
295
+
296
+ @classmethod
297
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
298
+ with open(json_file, "r", encoding="utf-8") as reader:
299
+ text = reader.read()
300
+ return json.loads(text)
301
+
302
+ def __repr__(self):
303
+ return f"{self.__class__.__name__} {self.to_json_string()}"
304
+
305
+ @property
306
+ def config(self) -> Dict[str, Any]:
307
+ return self._internal_dict
308
+
309
+ def to_json_string(self) -> str:
310
+ """
311
+ Serializes this instance to a JSON string.
312
+
313
+ Returns:
314
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
315
+ """
316
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
317
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
318
+
319
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
320
+ """
321
+ Save this instance to a JSON file.
322
+
323
+ Args:
324
+ json_file_path (`str` or `os.PathLike`):
325
+ Path to the JSON file in which this configuration instance's parameters will be saved.
326
+ """
327
+ with open(json_file_path, "w", encoding="utf-8") as writer:
328
+ writer.write(self.to_json_string())
329
+
330
+
331
+ class FrozenDict(OrderedDict):
332
+ def __init__(self, *args, **kwargs):
333
+ super().__init__(*args, **kwargs)
334
+
335
+ for key, value in self.items():
336
+ setattr(self, key, value)
337
+
338
+ self.__frozen = True
339
+
340
+ def __delitem__(self, *args, **kwargs):
341
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
342
+
343
+ def setdefault(self, *args, **kwargs):
344
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
345
+
346
+ def pop(self, *args, **kwargs):
347
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
348
+
349
+ def update(self, *args, **kwargs):
350
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
351
+
352
+ def __setattr__(self, name, value):
353
+ if hasattr(self, "__frozen") and self.__frozen:
354
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
355
+ super().__setattr__(name, value)
356
+
357
+ def __setitem__(self, name, value):
358
+ if hasattr(self, "__frozen") and self.__frozen:
359
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
360
+ super().__setitem__(name, value)
361
+
362
+
363
+ def register_to_config(init):
364
+ r"""
365
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
366
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
367
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
368
+
369
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
370
+ """
371
+
372
+ @functools.wraps(init)
373
+ def inner_init(self, *args, **kwargs):
374
+ # Ignore private kwargs in the init.
375
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
376
+ init(self, *args, **init_kwargs)
377
+ if not isinstance(self, ConfigMixin):
378
+ raise RuntimeError(
379
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
380
+ "not inherit from `ConfigMixin`."
381
+ )
382
+
383
+ ignore = getattr(self, "ignore_for_config", [])
384
+ # Get positional arguments aligned with kwargs
385
+ new_kwargs = {}
386
+ signature = inspect.signature(init)
387
+ parameters = {
388
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
389
+ }
390
+ for arg, name in zip(args, parameters.keys()):
391
+ new_kwargs[name] = arg
392
+
393
+ # Then add all kwargs
394
+ new_kwargs.update(
395
+ {
396
+ k: init_kwargs.get(k, default)
397
+ for k, default in parameters.items()
398
+ if k not in ignore and k not in new_kwargs
399
+ }
400
+ )
401
+ getattr(self, "register_to_config")(**new_kwargs)
402
+
403
+ return inner_init
diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
+ if sys.version_info < (3, 7):
28
+ pkgs_to_check_at_runtime.append("dataclasses")
29
+ if sys.version_info < (3, 8):
30
+ pkgs_to_check_at_runtime.append("importlib_metadata")
31
+
32
+ for pkg in pkgs_to_check_at_runtime:
33
+ if pkg in deps:
34
+ if pkg == "tokenizers":
35
+ # must be loaded here, or else tqdm check may fail
36
+ from .utils import is_tokenizers_available
37
+
38
+ if not is_tokenizers_available():
39
+ continue # not required, check version only if installed
40
+
41
+ require_version_core(deps[pkg])
42
+ else:
43
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
+
45
+
46
+ def dep_version_check(pkg, hint=None):
47
+ require_version(deps[pkg], hint)
diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "black": "black==22.3",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flake8": "flake8>=3.8.3",
11
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
+ "huggingface-hub": "huggingface-hub>=0.8.1",
13
+ "importlib_metadata": "importlib_metadata",
14
+ "isort": "isort>=5.5.4",
15
+ "modelcards": "modelcards==0.1.4",
16
+ "numpy": "numpy",
17
+ "pytest": "pytest",
18
+ "pytest-timeout": "pytest-timeout",
19
+ "pytest-xdist": "pytest-xdist",
20
+ "scipy": "scipy",
21
+ "regex": "regex!=2019.12.17",
22
+ "requests": "requests",
23
+ "tensorboard": "tensorboard",
24
+ "torch": "torch>=1.4",
25
+ "transformers": "transformers>=4.21.0",
26
+ }
diffusers/dynamic_modules_utils.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities to dynamically load objects from the Hub."""
16
+
17
+ import importlib
18
+ import os
19
+ import re
20
+ import shutil
21
+ import sys
22
+ from pathlib import Path
23
+ from typing import Dict, Optional, Union
24
+
25
+ from huggingface_hub import cached_download
26
+
27
+ from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ def init_hf_modules():
34
+ """
35
+ Creates the cache directory for modules with an init, and adds it to the Python path.
36
+ """
37
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
38
+ if HF_MODULES_CACHE in sys.path:
39
+ return
40
+
41
+ sys.path.append(HF_MODULES_CACHE)
42
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
43
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
44
+ if not init_path.exists():
45
+ init_path.touch()
46
+
47
+
48
+ def create_dynamic_module(name: Union[str, os.PathLike]):
49
+ """
50
+ Creates a dynamic module in the cache directory for modules.
51
+ """
52
+ init_hf_modules()
53
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
54
+ # If the parent module does not exist yet, recursively create it.
55
+ if not dynamic_module_path.parent.exists():
56
+ create_dynamic_module(dynamic_module_path.parent)
57
+ os.makedirs(dynamic_module_path, exist_ok=True)
58
+ init_path = dynamic_module_path / "__init__.py"
59
+ if not init_path.exists():
60
+ init_path.touch()
61
+
62
+
63
+ def get_relative_imports(module_file):
64
+ """
65
+ Get the list of modules that are relatively imported in a module file.
66
+
67
+ Args:
68
+ module_file (`str` or `os.PathLike`): The module file to inspect.
69
+ """
70
+ with open(module_file, "r", encoding="utf-8") as f:
71
+ content = f.read()
72
+
73
+ # Imports of the form `import .xxx`
74
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
75
+ # Imports of the form `from .xxx import yyy`
76
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
77
+ # Unique-ify
78
+ return list(set(relative_imports))
79
+
80
+
81
+ def get_relative_import_files(module_file):
82
+ """
83
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
84
+ imports (if a imports b and b imports c, it will return module files for b and c).
85
+
86
+ Args:
87
+ module_file (`str` or `os.PathLike`): The module file to inspect.
88
+ """
89
+ no_change = False
90
+ files_to_check = [module_file]
91
+ all_relative_imports = []
92
+
93
+ # Let's recurse through all relative imports
94
+ while not no_change:
95
+ new_imports = []
96
+ for f in files_to_check:
97
+ new_imports.extend(get_relative_imports(f))
98
+
99
+ module_path = Path(module_file).parent
100
+ new_import_files = [str(module_path / m) for m in new_imports]
101
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
102
+ files_to_check = [f"{f}.py" for f in new_import_files]
103
+
104
+ no_change = len(new_import_files) == 0
105
+ all_relative_imports.extend(files_to_check)
106
+
107
+ return all_relative_imports
108
+
109
+
110
+ def check_imports(filename):
111
+ """
112
+ Check if the current Python environment contains all the libraries that are imported in a file.
113
+ """
114
+ with open(filename, "r", encoding="utf-8") as f:
115
+ content = f.read()
116
+
117
+ # Imports of the form `import xxx`
118
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
119
+ # Imports of the form `from xxx import yyy`
120
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
121
+ # Only keep the top-level module
122
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
123
+
124
+ # Unique-ify and test we got them all
125
+ imports = list(set(imports))
126
+ missing_packages = []
127
+ for imp in imports:
128
+ try:
129
+ importlib.import_module(imp)
130
+ except ImportError:
131
+ missing_packages.append(imp)
132
+
133
+ if len(missing_packages) > 0:
134
+ raise ImportError(
135
+ "This modeling file requires the following packages that were not found in your environment: "
136
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
137
+ )
138
+
139
+ return get_relative_imports(filename)
140
+
141
+
142
+ def get_class_in_module(class_name, module_path):
143
+ """
144
+ Import a module on the cache directory for modules and extract a class from it.
145
+ """
146
+ module_path = module_path.replace(os.path.sep, ".")
147
+ module = importlib.import_module(module_path)
148
+ return getattr(module, class_name)
149
+
150
+
151
+ def get_cached_module_file(
152
+ pretrained_model_name_or_path: Union[str, os.PathLike],
153
+ module_file: str,
154
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
155
+ force_download: bool = False,
156
+ resume_download: bool = False,
157
+ proxies: Optional[Dict[str, str]] = None,
158
+ use_auth_token: Optional[Union[bool, str]] = None,
159
+ revision: Optional[str] = None,
160
+ local_files_only: bool = False,
161
+ ):
162
+ """
163
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
164
+ Transformers module.
165
+
166
+ Args:
167
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
168
+ This can be either:
169
+
170
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
171
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
172
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
173
+ - a path to a *directory* containing a configuration file saved using the
174
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
175
+
176
+ module_file (`str`):
177
+ The name of the module file containing the class to look for.
178
+ cache_dir (`str` or `os.PathLike`, *optional*):
179
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
180
+ cache should not be used.
181
+ force_download (`bool`, *optional*, defaults to `False`):
182
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
183
+ exist.
184
+ resume_download (`bool`, *optional*, defaults to `False`):
185
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
186
+ proxies (`Dict[str, str]`, *optional*):
187
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
188
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
189
+ use_auth_token (`str` or *bool*, *optional*):
190
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
191
+ when running `transformers-cli login` (stored in `~/.huggingface`).
192
+ revision (`str`, *optional*, defaults to `"main"`):
193
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
194
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
195
+ identifier allowed by git.
196
+ local_files_only (`bool`, *optional*, defaults to `False`):
197
+ If `True`, will only try to load the tokenizer configuration from local files.
198
+
199
+ <Tip>
200
+
201
+ Passing `use_auth_token=True` is required when you want to use a private model.
202
+
203
+ </Tip>
204
+
205
+ Returns:
206
+ `str`: The path to the module inside the cache.
207
+ """
208
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
209
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
210
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
211
+ submodule = "local"
212
+
213
+ if os.path.isfile(module_file_or_url):
214
+ resolved_module_file = module_file_or_url
215
+ else:
216
+ try:
217
+ # Load from URL or cache if already cached
218
+ resolved_module_file = cached_download(
219
+ module_file_or_url,
220
+ cache_dir=cache_dir,
221
+ force_download=force_download,
222
+ proxies=proxies,
223
+ resume_download=resume_download,
224
+ local_files_only=local_files_only,
225
+ use_auth_token=use_auth_token,
226
+ )
227
+
228
+ except EnvironmentError:
229
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
230
+ raise
231
+
232
+ # Check we have all the requirements in our environment
233
+ modules_needed = check_imports(resolved_module_file)
234
+
235
+ # Now we move the module inside our cached dynamic modules.
236
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
237
+ create_dynamic_module(full_submodule)
238
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
239
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
240
+ # that hash, to only copy when there is a modification but it seems overkill for now).
241
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
242
+ shutil.copy(resolved_module_file, submodule_path / module_file)
243
+ for module_needed in modules_needed:
244
+ module_needed = f"{module_needed}.py"
245
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
246
+ return os.path.join(full_submodule, module_file)
247
+
248
+
249
+ def get_class_from_dynamic_module(
250
+ pretrained_model_name_or_path: Union[str, os.PathLike],
251
+ module_file: str,
252
+ class_name: str,
253
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
254
+ force_download: bool = False,
255
+ resume_download: bool = False,
256
+ proxies: Optional[Dict[str, str]] = None,
257
+ use_auth_token: Optional[Union[bool, str]] = None,
258
+ revision: Optional[str] = None,
259
+ local_files_only: bool = False,
260
+ **kwargs,
261
+ ):
262
+ """
263
+ Extracts a class from a module file, present in the local folder or repository of a model.
264
+
265
+ <Tip warning={true}>
266
+
267
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
268
+ therefore only be called on trusted repos.
269
+
270
+ </Tip>
271
+
272
+ Args:
273
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
274
+ This can be either:
275
+
276
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
277
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
278
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
279
+ - a path to a *directory* containing a configuration file saved using the
280
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
281
+
282
+ module_file (`str`):
283
+ The name of the module file containing the class to look for.
284
+ class_name (`str`):
285
+ The name of the class to import in the module.
286
+ cache_dir (`str` or `os.PathLike`, *optional*):
287
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
288
+ cache should not be used.
289
+ force_download (`bool`, *optional*, defaults to `False`):
290
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
291
+ exist.
292
+ resume_download (`bool`, *optional*, defaults to `False`):
293
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
294
+ proxies (`Dict[str, str]`, *optional*):
295
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
296
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
297
+ use_auth_token (`str` or `bool`, *optional*):
298
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
299
+ when running `transformers-cli login` (stored in `~/.huggingface`).
300
+ revision (`str`, *optional*, defaults to `"main"`):
301
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
302
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
303
+ identifier allowed by git.
304
+ local_files_only (`bool`, *optional*, defaults to `False`):
305
+ If `True`, will only try to load the tokenizer configuration from local files.
306
+
307
+ <Tip>
308
+
309
+ Passing `use_auth_token=True` is required when you want to use a private model.
310
+
311
+ </Tip>
312
+
313
+ Returns:
314
+ `type`: The class, dynamically imported from the module.
315
+
316
+ Examples:
317
+
318
+ ```python
319
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
320
+ # module.
321
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
322
+ ```"""
323
+ # And lastly we get the class inside our newly created module
324
+ final_module = get_cached_module_file(
325
+ pretrained_model_name_or_path,
326
+ module_file,
327
+ cache_dir=cache_dir,
328
+ force_download=force_download,
329
+ resume_download=resume_download,
330
+ proxies=proxies,
331
+ use_auth_token=use_auth_token,
332
+ revision=revision,
333
+ local_files_only=local_files_only,
334
+ )
335
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
diffusers/hub_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import shutil
19
+ from pathlib import Path
20
+ from typing import Optional
21
+
22
+ from huggingface_hub import HfFolder, Repository, whoami
23
+
24
+ from .pipeline_utils import DiffusionPipeline
25
+ from .utils import is_modelcards_available, logging
26
+
27
+
28
+ if is_modelcards_available():
29
+ from modelcards import CardData, ModelCard
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
36
+
37
+
38
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
39
+ if token is None:
40
+ token = HfFolder.get_token()
41
+ if organization is None:
42
+ username = whoami(token)["name"]
43
+ return f"{username}/{model_id}"
44
+ else:
45
+ return f"{organization}/{model_id}"
46
+
47
+
48
+ def init_git_repo(args, at_init: bool = False):
49
+ """
50
+ Args:
51
+ Initializes a git repo in `args.hub_model_id`.
52
+ at_init (`bool`, *optional*, defaults to `False`):
53
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
54
+ and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
55
+ """
56
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
57
+ return
58
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
59
+ use_auth_token = True if hub_token is None else hub_token
60
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
61
+ repo_name = Path(args.output_dir).absolute().name
62
+ else:
63
+ repo_name = args.hub_model_id
64
+ if "/" not in repo_name:
65
+ repo_name = get_full_repo_name(repo_name, token=hub_token)
66
+
67
+ try:
68
+ repo = Repository(
69
+ args.output_dir,
70
+ clone_from=repo_name,
71
+ use_auth_token=use_auth_token,
72
+ private=args.hub_private_repo,
73
+ )
74
+ except EnvironmentError:
75
+ if args.overwrite_output_dir and at_init:
76
+ # Try again after wiping output_dir
77
+ shutil.rmtree(args.output_dir)
78
+ repo = Repository(
79
+ args.output_dir,
80
+ clone_from=repo_name,
81
+ use_auth_token=use_auth_token,
82
+ )
83
+ else:
84
+ raise
85
+
86
+ repo.git_pull()
87
+
88
+ # By default, ignore the checkpoint folders
89
+ if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
90
+ with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
91
+ writer.writelines(["checkpoint-*/"])
92
+
93
+ return repo
94
+
95
+
96
+ def push_to_hub(
97
+ args,
98
+ pipeline: DiffusionPipeline,
99
+ repo: Repository,
100
+ commit_message: Optional[str] = "End of training",
101
+ blocking: bool = True,
102
+ **kwargs,
103
+ ) -> str:
104
+ """
105
+ Parameters:
106
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
107
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
108
+ Message to commit while pushing.
109
+ blocking (`bool`, *optional*, defaults to `True`):
110
+ Whether the function should return only when the `git push` has finished.
111
+ kwargs:
112
+ Additional keyword arguments passed along to [`create_model_card`].
113
+ Returns:
114
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
115
+ commit and an object to track the progress of the commit if `blocking=True`
116
+ """
117
+
118
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
119
+ model_name = Path(args.output_dir).name
120
+ else:
121
+ model_name = args.hub_model_id.split("/")[-1]
122
+
123
+ output_dir = args.output_dir
124
+ os.makedirs(output_dir, exist_ok=True)
125
+ logger.info(f"Saving pipeline checkpoint to {output_dir}")
126
+ pipeline.save_pretrained(output_dir)
127
+
128
+ # Only push from one node.
129
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
130
+ return
131
+
132
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
133
+ if (
134
+ blocking
135
+ and len(repo.command_queue) > 0
136
+ and repo.command_queue[-1] is not None
137
+ and not repo.command_queue[-1].is_done
138
+ ):
139
+ repo.command_queue[-1]._process.kill()
140
+
141
+ git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
142
+ # push separately the model card to be independent from the rest of the model
143
+ create_model_card(args, model_name=model_name)
144
+ try:
145
+ repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
146
+ except EnvironmentError as exc:
147
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
148
+
149
+ return git_head_commit_url
150
+
151
+
152
+ def create_model_card(args, model_name):
153
+ if not is_modelcards_available:
154
+ raise ValueError(
155
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
156
+ " install the package with `pip install modelcards`."
157
+ )
158
+
159
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
160
+ return
161
+
162
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
163
+ repo_name = get_full_repo_name(model_name, token=hub_token)
164
+
165
+ model_card = ModelCard.from_template(
166
+ card_data=CardData( # Card metadata object that will be converted to YAML block
167
+ language="en",
168
+ license="apache-2.0",
169
+ library_name="diffusers",
170
+ tags=[],
171
+ datasets=args.dataset_name,
172
+ metrics=[],
173
+ ),
174
+ template_path=MODEL_CARD_TEMPLATE_PATH,
175
+ model_name=model_name,
176
+ repo_name=repo_name,
177
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
178
+ learning_rate=args.learning_rate,
179
+ train_batch_size=args.train_batch_size,
180
+ eval_batch_size=args.eval_batch_size,
181
+ gradient_accumulation_steps=args.gradient_accumulation_steps
182
+ if hasattr(args, "gradient_accumulation_steps")
183
+ else None,
184
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
185
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
186
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
187
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
188
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
189
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
190
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
191
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
192
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
193
+ mixed_precision=args.mixed_precision,
194
+ )
195
+
196
+ card_path = os.path.join(args.output_dir, "README.md")
197
+ model_card.save(card_path)
diffusers/modeling_utils.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ from typing import Callable, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from torch import Tensor, device
22
+
23
+ from huggingface_hub import hf_hub_download
24
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
25
+ from requests import HTTPError
26
+
27
+ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
28
+
29
+
30
+ WEIGHTS_NAME = "diffusion_pytorch_model.bin"
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ def get_parameter_device(parameter: torch.nn.Module):
37
+ try:
38
+ return next(parameter.parameters()).device
39
+ except StopIteration:
40
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
41
+
42
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
43
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
44
+ return tuples
45
+
46
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
47
+ first_tuple = next(gen)
48
+ return first_tuple[1].device
49
+
50
+
51
+ def get_parameter_dtype(parameter: torch.nn.Module):
52
+ try:
53
+ return next(parameter.parameters()).dtype
54
+ except StopIteration:
55
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
56
+
57
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
58
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
59
+ return tuples
60
+
61
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
62
+ first_tuple = next(gen)
63
+ return first_tuple[1].dtype
64
+
65
+
66
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
67
+ """
68
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
69
+ """
70
+ try:
71
+ return torch.load(checkpoint_file, map_location="cpu")
72
+ except Exception as e:
73
+ try:
74
+ with open(checkpoint_file) as f:
75
+ if f.read().startswith("version"):
76
+ raise OSError(
77
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
78
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
79
+ "you cloned."
80
+ )
81
+ else:
82
+ raise ValueError(
83
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
84
+ "model. Make sure you have saved the model properly."
85
+ ) from e
86
+ except (UnicodeDecodeError, ValueError):
87
+ raise OSError(
88
+ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
89
+ f"at '{checkpoint_file}'. "
90
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
91
+ )
92
+
93
+
94
+ def _load_state_dict_into_model(model_to_load, state_dict):
95
+ # Convert old format to new format if needed from a PyTorch state_dict
96
+ # copy state_dict so _load_from_state_dict can modify it
97
+ state_dict = state_dict.copy()
98
+ error_msgs = []
99
+
100
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
101
+ # so we need to apply the function recursively.
102
+ def load(module: torch.nn.Module, prefix=""):
103
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
104
+ module._load_from_state_dict(*args)
105
+
106
+ for name, child in module._modules.items():
107
+ if child is not None:
108
+ load(child, prefix + name + ".")
109
+
110
+ load(model_to_load)
111
+
112
+ return error_msgs
113
+
114
+
115
+ class ModelMixin(torch.nn.Module):
116
+ r"""
117
+ Base class for all models.
118
+
119
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
120
+ and saving models.
121
+
122
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
123
+ [`~modeling_utils.ModelMixin.save_pretrained`].
124
+ """
125
+ config_name = CONFIG_NAME
126
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+
131
+ def save_pretrained(
132
+ self,
133
+ save_directory: Union[str, os.PathLike],
134
+ is_main_process: bool = True,
135
+ save_function: Callable = torch.save,
136
+ ):
137
+ """
138
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
139
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
140
+
141
+ Arguments:
142
+ save_directory (`str` or `os.PathLike`):
143
+ Directory to which to save. Will be created if it doesn't exist.
144
+ is_main_process (`bool`, *optional*, defaults to `True`):
145
+ Whether the process calling this is the main process or not. Useful when in distributed training like
146
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
147
+ the main process to avoid race conditions.
148
+ save_function (`Callable`):
149
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
150
+ need to replace `torch.save` by another method.
151
+ """
152
+ if os.path.isfile(save_directory):
153
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
154
+ return
155
+
156
+ os.makedirs(save_directory, exist_ok=True)
157
+
158
+ model_to_save = self
159
+
160
+ # Attach architecture to the config
161
+ # Save the config
162
+ if is_main_process:
163
+ model_to_save.save_config(save_directory)
164
+
165
+ # Save the model
166
+ state_dict = model_to_save.state_dict()
167
+
168
+ # Clean the folder from a previous save
169
+ for filename in os.listdir(save_directory):
170
+ full_filename = os.path.join(save_directory, filename)
171
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
172
+ # in distributed settings to avoid race conditions.
173
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
174
+ os.remove(full_filename)
175
+
176
+ # Save the model
177
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
178
+
179
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
180
+
181
+ @classmethod
182
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
183
+ r"""
184
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
185
+
186
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
187
+ the model, you should first set it back in training mode with `model.train()`.
188
+
189
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
190
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
191
+ task.
192
+
193
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
194
+ weights are discarded.
195
+
196
+ Parameters:
197
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
198
+ Can be either:
199
+
200
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
201
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
202
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
203
+ `./my_model_directory/`.
204
+
205
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
206
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
207
+ standard cache should not be used.
208
+ torch_dtype (`str` or `torch.dtype`, *optional*):
209
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
210
+ will be automatically derived from the model's weights.
211
+ force_download (`bool`, *optional*, defaults to `False`):
212
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
213
+ cached versions if they exist.
214
+ resume_download (`bool`, *optional*, defaults to `False`):
215
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
216
+ file exists.
217
+ proxies (`Dict[str, str]`, *optional*):
218
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
219
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
220
+ output_loading_info(`bool`, *optional*, defaults to `False`):
221
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
222
+ local_files_only(`bool`, *optional*, defaults to `False`):
223
+ Whether or not to only look at local files (i.e., do not try to download the model).
224
+ use_auth_token (`str` or *bool*, *optional*):
225
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
226
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
227
+ revision (`str`, *optional*, defaults to `"main"`):
228
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
229
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
230
+ identifier allowed by git.
231
+ mirror (`str`, *optional*):
232
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
233
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
234
+ Please refer to the mirror site for more information.
235
+
236
+ <Tip>
237
+
238
+ Passing `use_auth_token=True`` is required when you want to use a private model.
239
+
240
+ </Tip>
241
+
242
+ <Tip>
243
+
244
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
245
+ this method in a firewalled environment.
246
+
247
+ </Tip>
248
+
249
+ """
250
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
251
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
252
+ force_download = kwargs.pop("force_download", False)
253
+ resume_download = kwargs.pop("resume_download", False)
254
+ proxies = kwargs.pop("proxies", None)
255
+ output_loading_info = kwargs.pop("output_loading_info", False)
256
+ local_files_only = kwargs.pop("local_files_only", False)
257
+ use_auth_token = kwargs.pop("use_auth_token", None)
258
+ revision = kwargs.pop("revision", None)
259
+ from_auto_class = kwargs.pop("_from_auto", False)
260
+ torch_dtype = kwargs.pop("torch_dtype", None)
261
+ subfolder = kwargs.pop("subfolder", None)
262
+
263
+ user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
264
+
265
+ # Load config if we don't provide a configuration
266
+ config_path = pretrained_model_name_or_path
267
+ model, unused_kwargs = cls.from_config(
268
+ config_path,
269
+ cache_dir=cache_dir,
270
+ return_unused_kwargs=True,
271
+ force_download=force_download,
272
+ resume_download=resume_download,
273
+ proxies=proxies,
274
+ local_files_only=local_files_only,
275
+ use_auth_token=use_auth_token,
276
+ revision=revision,
277
+ subfolder=subfolder,
278
+ **kwargs,
279
+ )
280
+
281
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
282
+ raise ValueError(
283
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
284
+ )
285
+ elif torch_dtype is not None:
286
+ model = model.to(torch_dtype)
287
+
288
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
289
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
290
+ # Load model
291
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
292
+ if os.path.isdir(pretrained_model_name_or_path):
293
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
294
+ # Load from a PyTorch checkpoint
295
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
296
+ elif subfolder is not None and os.path.isfile(
297
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
298
+ ):
299
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
300
+ else:
301
+ raise EnvironmentError(
302
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
303
+ )
304
+ else:
305
+ try:
306
+ # Load from URL or cache if already cached
307
+ model_file = hf_hub_download(
308
+ pretrained_model_name_or_path,
309
+ filename=WEIGHTS_NAME,
310
+ cache_dir=cache_dir,
311
+ force_download=force_download,
312
+ proxies=proxies,
313
+ resume_download=resume_download,
314
+ local_files_only=local_files_only,
315
+ use_auth_token=use_auth_token,
316
+ user_agent=user_agent,
317
+ subfolder=subfolder,
318
+ revision=revision,
319
+ )
320
+
321
+ except RepositoryNotFoundError:
322
+ raise EnvironmentError(
323
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
324
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
325
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
326
+ "login` and pass `use_auth_token=True`."
327
+ )
328
+ except RevisionNotFoundError:
329
+ raise EnvironmentError(
330
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
331
+ "this model name. Check the model page at "
332
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
333
+ )
334
+ except EntryNotFoundError:
335
+ raise EnvironmentError(
336
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
337
+ )
338
+ except HTTPError as err:
339
+ raise EnvironmentError(
340
+ "There was a specific connection error when trying to load"
341
+ f" {pretrained_model_name_or_path}:\n{err}"
342
+ )
343
+ except ValueError:
344
+ raise EnvironmentError(
345
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
346
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
347
+ f" directory containing a file named {WEIGHTS_NAME} or"
348
+ " \nCheckout your internet connection or see how to run the library in"
349
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
350
+ )
351
+ except EnvironmentError:
352
+ raise EnvironmentError(
353
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
354
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
355
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
356
+ f"containing a file named {WEIGHTS_NAME}"
357
+ )
358
+
359
+ # restore default dtype
360
+ state_dict = load_state_dict(model_file)
361
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
362
+ model,
363
+ state_dict,
364
+ model_file,
365
+ pretrained_model_name_or_path,
366
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
367
+ )
368
+
369
+ # Set model in evaluation mode to deactivate DropOut modules by default
370
+ model.eval()
371
+
372
+ if output_loading_info:
373
+ loading_info = {
374
+ "missing_keys": missing_keys,
375
+ "unexpected_keys": unexpected_keys,
376
+ "mismatched_keys": mismatched_keys,
377
+ "error_msgs": error_msgs,
378
+ }
379
+ return model, loading_info
380
+
381
+ return model
382
+
383
+ @classmethod
384
+ def _load_pretrained_model(
385
+ cls,
386
+ model,
387
+ state_dict,
388
+ resolved_archive_file,
389
+ pretrained_model_name_or_path,
390
+ ignore_mismatched_sizes=False,
391
+ ):
392
+ # Retrieve missing & unexpected_keys
393
+ model_state_dict = model.state_dict()
394
+ loaded_keys = [k for k in state_dict.keys()]
395
+
396
+ expected_keys = list(model_state_dict.keys())
397
+
398
+ original_loaded_keys = loaded_keys
399
+
400
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
401
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
402
+
403
+ # Make sure we are able to load base models as well as derived models (with heads)
404
+ model_to_load = model
405
+
406
+ def _find_mismatched_keys(
407
+ state_dict,
408
+ model_state_dict,
409
+ loaded_keys,
410
+ ignore_mismatched_sizes,
411
+ ):
412
+ mismatched_keys = []
413
+ if ignore_mismatched_sizes:
414
+ for checkpoint_key in loaded_keys:
415
+ model_key = checkpoint_key
416
+
417
+ if (
418
+ model_key in model_state_dict
419
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
420
+ ):
421
+ mismatched_keys.append(
422
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
423
+ )
424
+ del state_dict[checkpoint_key]
425
+ return mismatched_keys
426
+
427
+ if state_dict is not None:
428
+ # Whole checkpoint
429
+ mismatched_keys = _find_mismatched_keys(
430
+ state_dict,
431
+ model_state_dict,
432
+ original_loaded_keys,
433
+ ignore_mismatched_sizes,
434
+ )
435
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
436
+
437
+ if len(error_msgs) > 0:
438
+ error_msg = "\n\t".join(error_msgs)
439
+ if "size mismatch" in error_msg:
440
+ error_msg += (
441
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
442
+ )
443
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
444
+
445
+ if len(unexpected_keys) > 0:
446
+ logger.warning(
447
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
448
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
449
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
450
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
451
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
452
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
453
+ " identical (initializing a BertForSequenceClassification model from a"
454
+ " BertForSequenceClassification model)."
455
+ )
456
+ else:
457
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
458
+ if len(missing_keys) > 0:
459
+ logger.warning(
460
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
461
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
462
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
463
+ )
464
+ elif len(mismatched_keys) == 0:
465
+ logger.info(
466
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
467
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
468
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
469
+ " without further training."
470
+ )
471
+ if len(mismatched_keys) > 0:
472
+ mismatched_warning = "\n".join(
473
+ [
474
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
475
+ for key, shape1, shape2 in mismatched_keys
476
+ ]
477
+ )
478
+ logger.warning(
479
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
480
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
481
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
482
+ " able to use it for predictions and inference."
483
+ )
484
+
485
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
486
+
487
+ @property
488
+ def device(self) -> device:
489
+ """
490
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
491
+ device).
492
+ """
493
+ return get_parameter_device(self)
494
+
495
+ @property
496
+ def dtype(self) -> torch.dtype:
497
+ """
498
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
499
+ """
500
+ return get_parameter_dtype(self)
501
+
502
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
503
+ """
504
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
505
+
506
+ Args:
507
+ only_trainable (`bool`, *optional*, defaults to `False`):
508
+ Whether or not to return only the number of trainable parameters
509
+
510
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
511
+ Whether or not to return only the number of non-embeddings parameters
512
+
513
+ Returns:
514
+ `int`: The number of parameters.
515
+ """
516
+
517
+ if exclude_embeddings:
518
+ embedding_param_names = [
519
+ f"{name}.weight"
520
+ for name, module_type in self.named_modules()
521
+ if isinstance(module_type, torch.nn.Embedding)
522
+ ]
523
+ non_embedding_parameters = [
524
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
525
+ ]
526
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
527
+ else:
528
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
529
+
530
+
531
+ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
532
+ """
533
+ Recursively unwraps a model from potential containers (as used in distributed training).
534
+
535
+ Args:
536
+ model (`torch.nn.Module`): The model to unwrap.
537
+ """
538
+ # since there could be multiple levels of wrapping, unwrap recursively
539
+ if hasattr(model, "module"):
540
+ return unwrap_model(model.module)
541
+ else:
542
+ return model
diffusers/models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .unet_2d import UNet2DModel
16
+ from .unet_2d_condition import UNet2DConditionModel
17
+ from .vae import AutoencoderKL, VQModel
diffusers/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (313 Bytes). View file
 
diffusers/models/__pycache__/attention.cpython-310.pyc ADDED
Binary file (14.3 kB). View file
 
diffusers/models/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (3.72 kB). View file
 
diffusers/models/__pycache__/resnet.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
diffusers/models/__pycache__/unet_2d.cpython-310.pyc ADDED
Binary file (7.94 kB). View file
 
diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
diffusers/models/__pycache__/unet_blocks.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
diffusers/models/__pycache__/vae.cpython-310.pyc ADDED
Binary file (16.5 kB). View file
 
diffusers/models/attention.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import defaultdict
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class AttentionBlock(nn.Module):
11
+ """
12
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
13
+ to the N-d case.
14
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
15
+ Uses three q, k, v linear layers to compute attention.
16
+
17
+ Parameters:
18
+ channels (:obj:`int`): The number of channels in the input and output.
19
+ num_head_channels (:obj:`int`, *optional*):
20
+ The number of channels in each head. If None, then `num_heads` = 1.
21
+ num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
22
+ rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
23
+ eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ channels: int,
29
+ num_head_channels: Optional[int] = None,
30
+ num_groups: int = 32,
31
+ rescale_output_factor: float = 1.0,
32
+ eps: float = 1e-5,
33
+ ):
34
+ super().__init__()
35
+ self.channels = channels
36
+
37
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
38
+ self.num_head_size = num_head_channels
39
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
40
+
41
+ # define q,k,v as linear layers
42
+ self.query = nn.Linear(channels, channels)
43
+ self.key = nn.Linear(channels, channels)
44
+ self.value = nn.Linear(channels, channels)
45
+
46
+ self.rescale_output_factor = rescale_output_factor
47
+ self.proj_attn = nn.Linear(channels, channels, 1)
48
+
49
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
50
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
51
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
52
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
53
+ return new_projection
54
+
55
+ def forward(self, hidden_states):
56
+ residual = hidden_states
57
+ batch, channel, height, width = hidden_states.shape
58
+
59
+ # norm
60
+ hidden_states = self.group_norm(hidden_states)
61
+
62
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
63
+
64
+ # proj to q, k, v
65
+ query_proj = self.query(hidden_states)
66
+ key_proj = self.key(hidden_states)
67
+ value_proj = self.value(hidden_states)
68
+
69
+ # transpose
70
+ query_states = self.transpose_for_scores(query_proj)
71
+ key_states = self.transpose_for_scores(key_proj)
72
+ value_states = self.transpose_for_scores(value_proj)
73
+
74
+ # get scores
75
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
76
+
77
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
78
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
79
+
80
+ # compute attention output
81
+ hidden_states = torch.matmul(attention_probs, value_states)
82
+
83
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
84
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
85
+ hidden_states = hidden_states.view(new_hidden_states_shape)
86
+
87
+ # compute next hidden_states
88
+ hidden_states = self.proj_attn(hidden_states)
89
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
90
+
91
+ # res connect and rescale
92
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
93
+ return hidden_states
94
+
95
+
96
+ class SpatialTransformer(nn.Module):
97
+ """
98
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
99
+ standard transformer action. Finally, reshape to image.
100
+
101
+ Parameters:
102
+ in_channels (:obj:`int`): The number of channels in the input and output.
103
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
104
+ d_head (:obj:`int`): The number of channels in each head.
105
+ depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
106
+ dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
107
+ context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
108
+ """
109
+
110
+ def __init__(
111
+ self,
112
+ in_channels: int,
113
+ n_heads: int,
114
+ d_head: int,
115
+ depth: int = 1,
116
+ dropout: float = 0.0,
117
+ context_dim: Optional[int] = None,
118
+ ):
119
+ super().__init__()
120
+ self.n_heads = n_heads
121
+ self.d_head = d_head
122
+ self.in_channels = in_channels
123
+ inner_dim = n_heads * d_head
124
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
125
+
126
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
127
+
128
+ self.transformer_blocks = nn.ModuleList(
129
+ [
130
+ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
131
+ for d in range(depth)
132
+ ]
133
+ )
134
+
135
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
136
+
137
+ def _set_attention_slice(self, slice_size):
138
+ for block in self.transformer_blocks:
139
+ block._set_attention_slice(slice_size)
140
+
141
+ def forward(self, x, context=None):
142
+ # note: if no context is given, cross-attention defaults to self-attention
143
+ b, c, h, w = x.shape
144
+ x_in = x
145
+ x = self.norm(x)
146
+ x = self.proj_in(x)
147
+ x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
148
+ for block in self.transformer_blocks:
149
+ x = block(x, context=context)
150
+ x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
151
+ x = self.proj_out(x)
152
+ return x + x_in
153
+
154
+
155
+ class BasicTransformerBlock(nn.Module):
156
+ r"""
157
+ A basic Transformer block.
158
+
159
+ Parameters:
160
+ dim (:obj:`int`): The number of channels in the input and output.
161
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
162
+ d_head (:obj:`int`): The number of channels in each head.
163
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
164
+ context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
165
+ gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
166
+ checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
167
+ """
168
+
169
+ def __init__(
170
+ self,
171
+ dim: int,
172
+ n_heads: int,
173
+ d_head: int,
174
+ dropout=0.0,
175
+ context_dim: Optional[int] = None,
176
+ gated_ff: bool = True,
177
+ checkpoint: bool = True,
178
+ ):
179
+ super().__init__()
180
+ self.attn1 = CrossAttention(
181
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
182
+ ) # is a self-attention
183
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
184
+ self.attn2 = CrossAttention(
185
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
186
+ ) # is self-attn if context is none
187
+ self.norm1 = nn.LayerNorm(dim)
188
+ self.norm2 = nn.LayerNorm(dim)
189
+ self.norm3 = nn.LayerNorm(dim)
190
+ self.checkpoint = checkpoint
191
+
192
+ def _set_attention_slice(self, slice_size):
193
+ self.attn1._slice_size = slice_size
194
+ self.attn2._slice_size = slice_size
195
+
196
+ def forward(self, x, context=None):
197
+ x = x.contiguous() if x.device.type == "mps" else x
198
+ x = self.attn1(self.norm1(x)) + x
199
+ x = self.attn2(self.norm2(x), context=context) + x
200
+ x = self.ff(self.norm3(x)) + x
201
+ return x
202
+
203
+
204
+ heat_maps = defaultdict(list)
205
+ all_heat_maps = []
206
+
207
+
208
+ def clear_heat_maps():
209
+ global heat_maps, all_heat_maps
210
+ heat_maps = defaultdict(list)
211
+ all_heat_maps = []
212
+
213
+
214
+ def next_heat_map():
215
+ global heat_maps, all_heat_maps
216
+ all_heat_maps.append(heat_maps)
217
+ heat_maps = defaultdict(list)
218
+
219
+
220
+ def get_global_heat_map(last_n: int = None, idx: int = None, factors=None):
221
+ global heat_maps, all_heat_maps
222
+
223
+ if idx is not None:
224
+ heat_maps2 = [all_heat_maps[idx]]
225
+ else:
226
+ heat_maps2 = all_heat_maps[-last_n:] if last_n is not None else all_heat_maps
227
+
228
+ if factors is None:
229
+ factors = {1, 2, 4, 8, 16, 32}
230
+
231
+ all_merges = []
232
+
233
+ for heat_map_map in heat_maps2:
234
+ merge_list = []
235
+
236
+ for k, v in heat_map_map.items():
237
+ if k in factors:
238
+ merge_list.append(torch.stack(v, 0).mean(0))
239
+
240
+ all_merges.append(merge_list)
241
+
242
+ maps = torch.stack([torch.stack(x, 0) for x in all_merges], dim=0)
243
+ return maps.sum(0).cuda().sum(2).sum(0)
244
+
245
+
246
+ class CrossAttention(nn.Module):
247
+ r"""
248
+ A cross attention layer.
249
+
250
+ Parameters:
251
+ query_dim (:obj:`int`): The number of channels in the query.
252
+ context_dim (:obj:`int`, *optional*):
253
+ The number of channels in the context. If not given, defaults to `query_dim`.
254
+ heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
255
+ dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
256
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
257
+ """
258
+
259
+ def __init__(
260
+ self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
261
+ ):
262
+ super().__init__()
263
+ inner_dim = dim_head * heads
264
+ context_dim = context_dim if context_dim is not None else query_dim
265
+
266
+ self.scale = dim_head**-0.5
267
+ self.heads = heads
268
+ # for slice_size > 0 the attention score computation
269
+ # is split across the batch axis to save memory
270
+ # You can set slice_size with `set_attention_slice`
271
+ self._slice_size = None
272
+
273
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
274
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
275
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
276
+
277
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
278
+
279
+ def reshape_heads_to_batch_dim(self, tensor):
280
+ batch_size, seq_len, dim = tensor.shape
281
+ head_size = self.heads
282
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
283
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
284
+ return tensor
285
+
286
+ def reshape_batch_dim_to_heads(self, tensor):
287
+ batch_size, seq_len, dim = tensor.shape
288
+ head_size = self.heads
289
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
290
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
291
+ return tensor
292
+
293
+ def forward(self, x, context=None, mask=None):
294
+ batch_size, sequence_length, dim = x.shape
295
+
296
+ use_context = context is not None
297
+
298
+ q = self.to_q(x)
299
+ context = context if context is not None else x
300
+ k = self.to_k(context)
301
+ v = self.to_v(context)
302
+
303
+ q = self.reshape_heads_to_batch_dim(q)
304
+ k = self.reshape_heads_to_batch_dim(k)
305
+ v = self.reshape_heads_to_batch_dim(v)
306
+
307
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
308
+
309
+ # attention, what we cannot get enough of
310
+ hidden_states = self._attention(q, k, v, sequence_length, dim, use_context=use_context)
311
+
312
+ return self.to_out(hidden_states)
313
+
314
+ @torch.no_grad()
315
+ def _up_sample_attn(self, x, factor, method: str = 'bicubic'):
316
+ weight = torch.full((factor, factor), 1 / factor**2, device=x.device)
317
+ weight = weight.view(1, 1, factor, factor)
318
+
319
+ h = w = int(math.sqrt(x.size(1)))
320
+ maps = []
321
+ x = x.permute(2, 0, 1)
322
+
323
+ with torch.cuda.amp.autocast(dtype=torch.float32):
324
+ for map_ in x:
325
+ map_ = map_.unsqueeze(1).view(map_.size(0), 1, h, w)
326
+ if method == 'bicubic':
327
+ map_ = F.interpolate(map_, size=(64, 64), mode="bicubic", align_corners=False)
328
+ maps.append(map_.squeeze(1))
329
+ else:
330
+ maps.append(F.conv_transpose2d(map_, weight, stride=factor).squeeze(1).cpu())
331
+
332
+ maps = torch.stack(maps, 0).sum(1, keepdim=True).cpu()
333
+ return maps
334
+
335
+ def _attention(self, query, key, value, sequence_length, dim, use_context: bool = True):
336
+ batch_size_attention = query.shape[0]
337
+ hidden_states = torch.zeros(
338
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
339
+ )
340
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
341
+ for i in range(hidden_states.shape[0] // slice_size):
342
+ start_idx = i * slice_size
343
+ end_idx = (i + 1) * slice_size
344
+ attn_slice = (
345
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
346
+ )
347
+ factor = int(math.sqrt(4096 // attn_slice.shape[1]))
348
+ attn_slice = attn_slice.softmax(-1)
349
+
350
+ if use_context and attn_slice.shape[-1] == 77:
351
+ if factor >= 1:
352
+ factor //= 1
353
+ maps = self._up_sample_attn(attn_slice, factor)
354
+ global heat_maps
355
+ heat_maps[factor].append(maps)
356
+ # print(attn_slice.size(), query.size(), key.size(), value.size())
357
+
358
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
359
+
360
+ hidden_states[start_idx:end_idx] = attn_slice
361
+
362
+ # reshape hidden_states
363
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
364
+ return hidden_states
365
+
366
+
367
+ class FeedForward(nn.Module):
368
+ r"""
369
+ A feed-forward layer.
370
+
371
+ Parameters:
372
+ dim (:obj:`int`): The number of channels in the input.
373
+ dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
374
+ mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
375
+ glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
376
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
377
+ """
378
+
379
+ def __init__(
380
+ self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
381
+ ):
382
+ super().__init__()
383
+ inner_dim = int(dim * mult)
384
+ dim_out = dim_out if dim_out is not None else dim
385
+ project_in = GEGLU(dim, inner_dim)
386
+
387
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
388
+
389
+ def forward(self, x):
390
+ return self.net(x)
391
+
392
+
393
+ # feedforward
394
+ class GEGLU(nn.Module):
395
+ r"""
396
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
397
+
398
+ Parameters:
399
+ dim_in (:obj:`int`): The number of channels in the input.
400
+ dim_out (:obj:`int`): The number of channels in the output.
401
+ """
402
+
403
+ def __init__(self, dim_in: int, dim_out: int):
404
+ super().__init__()
405
+ self.proj = nn.Linear(dim_in, dim_out * 2)
406
+
407
+ def forward(self, x):
408
+ x, gate = self.proj(x).chunk(2, dim=-1)
409
+ return x * F.gelu(gate)
diffusers/models/embeddings.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import numpy as np
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ def get_timestep_embedding(
22
+ timesteps: torch.Tensor,
23
+ embedding_dim: int,
24
+ flip_sin_to_cos: bool = False,
25
+ downscale_freq_shift: float = 1,
26
+ scale: float = 1,
27
+ max_period: int = 10000,
28
+ ):
29
+ """
30
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
31
+
32
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
33
+ These may be fractional.
34
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
35
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
36
+ """
37
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
38
+
39
+ half_dim = embedding_dim // 2
40
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
41
+ exponent = exponent / (half_dim - downscale_freq_shift)
42
+
43
+ emb = torch.exp(exponent).to(device=timesteps.device)
44
+ emb = timesteps[:, None].float() * emb[None, :]
45
+
46
+ # scale embeddings
47
+ emb = scale * emb
48
+
49
+ # concat sine and cosine embeddings
50
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
51
+
52
+ # flip sine and cosine embeddings
53
+ if flip_sin_to_cos:
54
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
55
+
56
+ # zero pad
57
+ if embedding_dim % 2 == 1:
58
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
59
+ return emb
60
+
61
+
62
+ class TimestepEmbedding(nn.Module):
63
+ def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
64
+ super().__init__()
65
+
66
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
67
+ self.act = None
68
+ if act_fn == "silu":
69
+ self.act = nn.SiLU()
70
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
71
+
72
+ def forward(self, sample):
73
+ sample = self.linear_1(sample)
74
+
75
+ if self.act is not None:
76
+ sample = self.act(sample)
77
+
78
+ sample = self.linear_2(sample)
79
+ return sample
80
+
81
+
82
+ class Timesteps(nn.Module):
83
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
84
+ super().__init__()
85
+ self.num_channels = num_channels
86
+ self.flip_sin_to_cos = flip_sin_to_cos
87
+ self.downscale_freq_shift = downscale_freq_shift
88
+
89
+ def forward(self, timesteps):
90
+ t_emb = get_timestep_embedding(
91
+ timesteps,
92
+ self.num_channels,
93
+ flip_sin_to_cos=self.flip_sin_to_cos,
94
+ downscale_freq_shift=self.downscale_freq_shift,
95
+ )
96
+ return t_emb
97
+
98
+
99
+ class GaussianFourierProjection(nn.Module):
100
+ """Gaussian Fourier embeddings for noise levels."""
101
+
102
+ def __init__(self, embedding_size: int = 256, scale: float = 1.0):
103
+ super().__init__()
104
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
105
+
106
+ # to delete later
107
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
108
+
109
+ self.weight = self.W
110
+
111
+ def forward(self, x):
112
+ x = torch.log(x)
113
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
114
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
115
+ return out
diffusers/models/resnet.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Upsample2D(nn.Module):
10
+ """
11
+ An upsampling layer with an optional convolution.
12
+
13
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
14
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
15
+ upsampling occurs in the inner-two dimensions.
16
+ """
17
+
18
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.out_channels = out_channels or channels
22
+ self.use_conv = use_conv
23
+ self.use_conv_transpose = use_conv_transpose
24
+ self.name = name
25
+
26
+ conv = None
27
+ if use_conv_transpose:
28
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
29
+ elif use_conv:
30
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
31
+
32
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
33
+ if name == "conv":
34
+ self.conv = conv
35
+ else:
36
+ self.Conv2d_0 = conv
37
+
38
+ def forward(self, x):
39
+ assert x.shape[1] == self.channels
40
+ if self.use_conv_transpose:
41
+ return self.conv(x)
42
+
43
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
44
+
45
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
46
+ if self.use_conv:
47
+ if self.name == "conv":
48
+ x = self.conv(x)
49
+ else:
50
+ x = self.Conv2d_0(x)
51
+
52
+ return x
53
+
54
+
55
+ class Downsample2D(nn.Module):
56
+ """
57
+ A downsampling layer with an optional convolution.
58
+
59
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
60
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
61
+ downsampling occurs in the inner-two dimensions.
62
+ """
63
+
64
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
65
+ super().__init__()
66
+ self.channels = channels
67
+ self.out_channels = out_channels or channels
68
+ self.use_conv = use_conv
69
+ self.padding = padding
70
+ stride = 2
71
+ self.name = name
72
+
73
+ if use_conv:
74
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
75
+ else:
76
+ assert self.channels == self.out_channels
77
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
78
+
79
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
80
+ if name == "conv":
81
+ self.Conv2d_0 = conv
82
+ self.conv = conv
83
+ elif name == "Conv2d_0":
84
+ self.conv = conv
85
+ else:
86
+ self.conv = conv
87
+
88
+ def forward(self, x):
89
+ assert x.shape[1] == self.channels
90
+ if self.use_conv and self.padding == 0:
91
+ pad = (0, 1, 0, 1)
92
+ x = F.pad(x, pad, mode="constant", value=0)
93
+
94
+ assert x.shape[1] == self.channels
95
+ x = self.conv(x)
96
+
97
+ return x
98
+
99
+
100
+ class FirUpsample2D(nn.Module):
101
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
102
+ super().__init__()
103
+ out_channels = out_channels if out_channels else channels
104
+ if use_conv:
105
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
106
+ self.use_conv = use_conv
107
+ self.fir_kernel = fir_kernel
108
+ self.out_channels = out_channels
109
+
110
+ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
111
+ """Fused `upsample_2d()` followed by `Conv2d()`.
112
+
113
+ Args:
114
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
115
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
116
+ order.
117
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
118
+ C]`.
119
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
120
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
121
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
122
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
123
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
124
+
125
+ Returns:
126
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
127
+ `x`.
128
+ """
129
+
130
+ assert isinstance(factor, int) and factor >= 1
131
+
132
+ # Setup filter kernel.
133
+ if kernel is None:
134
+ kernel = [1] * factor
135
+
136
+ # setup kernel
137
+ kernel = np.asarray(kernel, dtype=np.float32)
138
+ if kernel.ndim == 1:
139
+ kernel = np.outer(kernel, kernel)
140
+ kernel /= np.sum(kernel)
141
+
142
+ kernel = kernel * (gain * (factor**2))
143
+
144
+ if self.use_conv:
145
+ convH = weight.shape[2]
146
+ convW = weight.shape[3]
147
+ inC = weight.shape[1]
148
+
149
+ p = (kernel.shape[0] - factor) - (convW - 1)
150
+
151
+ stride = (factor, factor)
152
+ # Determine data dimensions.
153
+ stride = [1, 1, factor, factor]
154
+ output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
155
+ output_padding = (
156
+ output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
157
+ output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
158
+ )
159
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
160
+ inC = weight.shape[1]
161
+ num_groups = x.shape[1] // inC
162
+
163
+ # Transpose weights.
164
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
165
+ weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
166
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
167
+
168
+ x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
169
+
170
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
171
+ else:
172
+ p = kernel.shape[0] - factor
173
+ x = upfirdn2d_native(
174
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
175
+ )
176
+
177
+ return x
178
+
179
+ def forward(self, x):
180
+ if self.use_conv:
181
+ height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
182
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
183
+ else:
184
+ height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
185
+
186
+ return height
187
+
188
+
189
+ class FirDownsample2D(nn.Module):
190
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
191
+ super().__init__()
192
+ out_channels = out_channels if out_channels else channels
193
+ if use_conv:
194
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
195
+ self.fir_kernel = fir_kernel
196
+ self.use_conv = use_conv
197
+ self.out_channels = out_channels
198
+
199
+ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
200
+ """Fused `Conv2d()` followed by `downsample_2d()`.
201
+
202
+ Args:
203
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
204
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
205
+ order.
206
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
207
+ filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
208
+ numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
209
+ factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
210
+ Scaling factor for signal magnitude (default: 1.0).
211
+
212
+ Returns:
213
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
214
+ datatype as `x`.
215
+ """
216
+
217
+ assert isinstance(factor, int) and factor >= 1
218
+ if kernel is None:
219
+ kernel = [1] * factor
220
+
221
+ # setup kernel
222
+ kernel = np.asarray(kernel, dtype=np.float32)
223
+ if kernel.ndim == 1:
224
+ kernel = np.outer(kernel, kernel)
225
+ kernel /= np.sum(kernel)
226
+
227
+ kernel = kernel * gain
228
+
229
+ if self.use_conv:
230
+ _, _, convH, convW = weight.shape
231
+ p = (kernel.shape[0] - factor) + (convW - 1)
232
+ s = [factor, factor]
233
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
234
+ x = F.conv2d(x, weight, stride=s, padding=0)
235
+ else:
236
+ p = kernel.shape[0] - factor
237
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
238
+
239
+ return x
240
+
241
+ def forward(self, x):
242
+ if self.use_conv:
243
+ x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
244
+ x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
245
+ else:
246
+ x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
247
+
248
+ return x
249
+
250
+
251
+ class ResnetBlock2D(nn.Module):
252
+ def __init__(
253
+ self,
254
+ *,
255
+ in_channels,
256
+ out_channels=None,
257
+ conv_shortcut=False,
258
+ dropout=0.0,
259
+ temb_channels=512,
260
+ groups=32,
261
+ groups_out=None,
262
+ pre_norm=True,
263
+ eps=1e-6,
264
+ non_linearity="swish",
265
+ time_embedding_norm="default",
266
+ kernel=None,
267
+ output_scale_factor=1.0,
268
+ use_nin_shortcut=None,
269
+ up=False,
270
+ down=False,
271
+ ):
272
+ super().__init__()
273
+ self.pre_norm = pre_norm
274
+ self.pre_norm = True
275
+ self.in_channels = in_channels
276
+ out_channels = in_channels if out_channels is None else out_channels
277
+ self.out_channels = out_channels
278
+ self.use_conv_shortcut = conv_shortcut
279
+ self.time_embedding_norm = time_embedding_norm
280
+ self.up = up
281
+ self.down = down
282
+ self.output_scale_factor = output_scale_factor
283
+
284
+ if groups_out is None:
285
+ groups_out = groups
286
+
287
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
288
+
289
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
290
+
291
+ if temb_channels is not None:
292
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
293
+ else:
294
+ self.time_emb_proj = None
295
+
296
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
297
+ self.dropout = torch.nn.Dropout(dropout)
298
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
299
+
300
+ if non_linearity == "swish":
301
+ self.nonlinearity = lambda x: F.silu(x)
302
+ elif non_linearity == "mish":
303
+ self.nonlinearity = Mish()
304
+ elif non_linearity == "silu":
305
+ self.nonlinearity = nn.SiLU()
306
+
307
+ self.upsample = self.downsample = None
308
+ if self.up:
309
+ if kernel == "fir":
310
+ fir_kernel = (1, 3, 3, 1)
311
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
312
+ elif kernel == "sde_vp":
313
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
314
+ else:
315
+ self.upsample = Upsample2D(in_channels, use_conv=False)
316
+ elif self.down:
317
+ if kernel == "fir":
318
+ fir_kernel = (1, 3, 3, 1)
319
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
320
+ elif kernel == "sde_vp":
321
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
322
+ else:
323
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
324
+
325
+ self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
326
+
327
+ self.conv_shortcut = None
328
+ if self.use_nin_shortcut:
329
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
330
+
331
+ def forward(self, x, temb):
332
+ hidden_states = x
333
+
334
+ # make sure hidden states is in float32
335
+ # when running in half-precision
336
+ hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
337
+ hidden_states = self.nonlinearity(hidden_states)
338
+
339
+ if self.upsample is not None:
340
+ x = self.upsample(x)
341
+ hidden_states = self.upsample(hidden_states)
342
+ elif self.downsample is not None:
343
+ x = self.downsample(x)
344
+ hidden_states = self.downsample(hidden_states)
345
+
346
+ hidden_states = self.conv1(hidden_states)
347
+
348
+ if temb is not None:
349
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
350
+ hidden_states = hidden_states + temb
351
+
352
+ # make sure hidden states is in float32
353
+ # when running in half-precision
354
+ hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
355
+ hidden_states = self.nonlinearity(hidden_states)
356
+
357
+ hidden_states = self.dropout(hidden_states)
358
+ hidden_states = self.conv2(hidden_states)
359
+
360
+ if self.conv_shortcut is not None:
361
+ x = self.conv_shortcut(x)
362
+
363
+ out = (x + hidden_states) / self.output_scale_factor
364
+
365
+ return out
366
+
367
+
368
+ class Mish(torch.nn.Module):
369
+ def forward(self, x):
370
+ return x * torch.tanh(torch.nn.functional.softplus(x))
371
+
372
+
373
+ def upsample_2d(x, kernel=None, factor=2, gain=1):
374
+ r"""Upsample2D a batch of 2D images with the given filter.
375
+
376
+ Args:
377
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
378
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
379
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
380
+ multiple of the upsampling factor.
381
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
382
+ C]`.
383
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
384
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
385
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
386
+
387
+ Returns:
388
+ Tensor of the shape `[N, C, H * factor, W * factor]`
389
+ """
390
+ assert isinstance(factor, int) and factor >= 1
391
+ if kernel is None:
392
+ kernel = [1] * factor
393
+
394
+ kernel = np.asarray(kernel, dtype=np.float32)
395
+ if kernel.ndim == 1:
396
+ kernel = np.outer(kernel, kernel)
397
+ kernel /= np.sum(kernel)
398
+
399
+ kernel = kernel * (gain * (factor**2))
400
+ p = kernel.shape[0] - factor
401
+ return upfirdn2d_native(
402
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
403
+ )
404
+
405
+
406
+ def downsample_2d(x, kernel=None, factor=2, gain=1):
407
+ r"""Downsample2D a batch of 2D images with the given filter.
408
+
409
+ Args:
410
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
411
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
412
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
413
+ shape is a multiple of the downsampling factor.
414
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
415
+ C]`.
416
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
417
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
418
+ factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
419
+
420
+ Returns:
421
+ Tensor of the shape `[N, C, H // factor, W // factor]`
422
+ """
423
+
424
+ assert isinstance(factor, int) and factor >= 1
425
+ if kernel is None:
426
+ kernel = [1] * factor
427
+
428
+ kernel = np.asarray(kernel, dtype=np.float32)
429
+ if kernel.ndim == 1:
430
+ kernel = np.outer(kernel, kernel)
431
+ kernel /= np.sum(kernel)
432
+
433
+ kernel = kernel * gain
434
+ p = kernel.shape[0] - factor
435
+ return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
436
+
437
+
438
+ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
439
+ up_x = up_y = up
440
+ down_x = down_y = down
441
+ pad_x0 = pad_y0 = pad[0]
442
+ pad_x1 = pad_y1 = pad[1]
443
+
444
+ _, channel, in_h, in_w = input.shape
445
+ input = input.reshape(-1, in_h, in_w, 1)
446
+
447
+ _, in_h, in_w, minor = input.shape
448
+ kernel_h, kernel_w = kernel.shape
449
+
450
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
451
+
452
+ # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
453
+ if input.device.type == "mps":
454
+ out = out.to("cpu")
455
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
456
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
457
+
458
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
459
+ out = out.to(input.device) # Move back to mps if necessary
460
+ out = out[
461
+ :,
462
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
463
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
464
+ :,
465
+ ]
466
+
467
+ out = out.permute(0, 3, 1, 2)
468
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
469
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
470
+ out = F.conv2d(out, w)
471
+ out = out.reshape(
472
+ -1,
473
+ minor,
474
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
475
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
476
+ )
477
+ out = out.permute(0, 2, 3, 1)
478
+ out = out[:, ::down_y, ::down_x, :]
479
+
480
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
481
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
482
+
483
+ return out.view(-1, channel, out_h, out_w)
diffusers/models/unet_2d.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..configuration_utils import ConfigMixin, register_to_config
8
+ from ..modeling_utils import ModelMixin
9
+ from ..utils import BaseOutput
10
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
11
+ from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
12
+
13
+
14
+ @dataclass
15
+ class UNet2DOutput(BaseOutput):
16
+ """
17
+ Args:
18
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
19
+ Hidden states output. Output of last layer of model.
20
+ """
21
+
22
+ sample: torch.FloatTensor
23
+
24
+
25
+ class UNet2DModel(ModelMixin, ConfigMixin):
26
+ r"""
27
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
28
+
29
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
30
+ implements for all the model (such as downloading or saving, etc.)
31
+
32
+ Parameters:
33
+ sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
34
+ Input sample size.
35
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
36
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
37
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
38
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
39
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
40
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
41
+ obj:`False`): Whether to flip sin to cos for fourier time embedding.
42
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
43
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
44
+ types.
45
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
46
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
47
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
48
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
49
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
50
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
51
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
52
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
53
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
54
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
55
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
56
+ """
57
+
58
+ @register_to_config
59
+ def __init__(
60
+ self,
61
+ sample_size: Optional[int] = None,
62
+ in_channels: int = 3,
63
+ out_channels: int = 3,
64
+ center_input_sample: bool = False,
65
+ time_embedding_type: str = "positional",
66
+ freq_shift: int = 0,
67
+ flip_sin_to_cos: bool = True,
68
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
69
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
70
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
71
+ layers_per_block: int = 2,
72
+ mid_block_scale_factor: float = 1,
73
+ downsample_padding: int = 1,
74
+ act_fn: str = "silu",
75
+ attention_head_dim: int = 8,
76
+ norm_num_groups: int = 32,
77
+ norm_eps: float = 1e-5,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.sample_size = sample_size
82
+ time_embed_dim = block_out_channels[0] * 4
83
+
84
+ # input
85
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86
+
87
+ # time
88
+ if time_embedding_type == "fourier":
89
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
90
+ timestep_input_dim = 2 * block_out_channels[0]
91
+ elif time_embedding_type == "positional":
92
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
+ timestep_input_dim = block_out_channels[0]
94
+
95
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
+
97
+ self.down_blocks = nn.ModuleList([])
98
+ self.mid_block = None
99
+ self.up_blocks = nn.ModuleList([])
100
+
101
+ # down
102
+ output_channel = block_out_channels[0]
103
+ for i, down_block_type in enumerate(down_block_types):
104
+ input_channel = output_channel
105
+ output_channel = block_out_channels[i]
106
+ is_final_block = i == len(block_out_channels) - 1
107
+
108
+ down_block = get_down_block(
109
+ down_block_type,
110
+ num_layers=layers_per_block,
111
+ in_channels=input_channel,
112
+ out_channels=output_channel,
113
+ temb_channels=time_embed_dim,
114
+ add_downsample=not is_final_block,
115
+ resnet_eps=norm_eps,
116
+ resnet_act_fn=act_fn,
117
+ attn_num_head_channels=attention_head_dim,
118
+ downsample_padding=downsample_padding,
119
+ )
120
+ self.down_blocks.append(down_block)
121
+
122
+ # mid
123
+ self.mid_block = UNetMidBlock2D(
124
+ in_channels=block_out_channels[-1],
125
+ temb_channels=time_embed_dim,
126
+ resnet_eps=norm_eps,
127
+ resnet_act_fn=act_fn,
128
+ output_scale_factor=mid_block_scale_factor,
129
+ resnet_time_scale_shift="default",
130
+ attn_num_head_channels=attention_head_dim,
131
+ resnet_groups=norm_num_groups,
132
+ )
133
+
134
+ # up
135
+ reversed_block_out_channels = list(reversed(block_out_channels))
136
+ output_channel = reversed_block_out_channels[0]
137
+ for i, up_block_type in enumerate(up_block_types):
138
+ prev_output_channel = output_channel
139
+ output_channel = reversed_block_out_channels[i]
140
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
141
+
142
+ is_final_block = i == len(block_out_channels) - 1
143
+
144
+ up_block = get_up_block(
145
+ up_block_type,
146
+ num_layers=layers_per_block + 1,
147
+ in_channels=input_channel,
148
+ out_channels=output_channel,
149
+ prev_output_channel=prev_output_channel,
150
+ temb_channels=time_embed_dim,
151
+ add_upsample=not is_final_block,
152
+ resnet_eps=norm_eps,
153
+ resnet_act_fn=act_fn,
154
+ attn_num_head_channels=attention_head_dim,
155
+ )
156
+ self.up_blocks.append(up_block)
157
+ prev_output_channel = output_channel
158
+
159
+ # out
160
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
161
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
162
+ self.conv_act = nn.SiLU()
163
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
164
+
165
+ def forward(
166
+ self,
167
+ sample: torch.FloatTensor,
168
+ timestep: Union[torch.Tensor, float, int],
169
+ return_dict: bool = True,
170
+ ) -> Union[UNet2DOutput, Tuple]:
171
+ """r
172
+ Args:
173
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
174
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
175
+ return_dict (`bool`, *optional*, defaults to `True`):
176
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
177
+
178
+ Returns:
179
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
180
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
181
+ """
182
+ # 0. center input if necessary
183
+ if self.config.center_input_sample:
184
+ sample = 2 * sample - 1.0
185
+
186
+ # 1. time
187
+ timesteps = timestep
188
+ if not torch.is_tensor(timesteps):
189
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
190
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
191
+ timesteps = timesteps[None].to(sample.device)
192
+
193
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
194
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
195
+
196
+ t_emb = self.time_proj(timesteps)
197
+ emb = self.time_embedding(t_emb)
198
+
199
+ # 2. pre-process
200
+ skip_sample = sample
201
+ sample = self.conv_in(sample)
202
+
203
+ # 3. down
204
+ down_block_res_samples = (sample,)
205
+ for downsample_block in self.down_blocks:
206
+ if hasattr(downsample_block, "skip_conv"):
207
+ sample, res_samples, skip_sample = downsample_block(
208
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
209
+ )
210
+ else:
211
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
212
+
213
+ down_block_res_samples += res_samples
214
+
215
+ # 4. mid
216
+ sample = self.mid_block(sample, emb)
217
+
218
+ # 5. up
219
+ skip_sample = None
220
+ for upsample_block in self.up_blocks:
221
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
222
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
223
+
224
+ if hasattr(upsample_block, "skip_conv"):
225
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
226
+ else:
227
+ sample = upsample_block(sample, res_samples, emb)
228
+
229
+ # 6. post-process
230
+ # make sure hidden states is in float32
231
+ # when running in half-precision
232
+ sample = self.conv_norm_out(sample.float()).type(sample.dtype)
233
+ sample = self.conv_act(sample)
234
+ sample = self.conv_out(sample)
235
+
236
+ if skip_sample is not None:
237
+ sample += skip_sample
238
+
239
+ if self.config.time_embedding_type == "fourier":
240
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
241
+ sample = sample / timesteps
242
+
243
+ if not return_dict:
244
+ return (sample,)
245
+
246
+ return UNet2DOutput(sample=sample)
diffusers/models/unet_2d_condition.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..configuration_utils import ConfigMixin, register_to_config
8
+ from ..modeling_utils import ModelMixin
9
+ from ..utils import BaseOutput
10
+ from .embeddings import TimestepEmbedding, Timesteps
11
+ from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
12
+
13
+
14
+ @dataclass
15
+ class UNet2DConditionOutput(BaseOutput):
16
+ """
17
+ Args:
18
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
19
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
20
+ """
21
+
22
+ sample: torch.FloatTensor
23
+
24
+
25
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
26
+ r"""
27
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
28
+ and returns sample shaped output.
29
+
30
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
31
+ implements for all the model (such as downloading or saving, etc.)
32
+
33
+ Parameters:
34
+ sample_size (`int`, *optional*): The size of the input sample.
35
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
36
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
37
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
38
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
39
+ Whether to flip the sin to cos in the time embedding.
40
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
41
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
42
+ The tuple of downsample blocks to use.
43
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
44
+ The tuple of upsample blocks to use.
45
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
46
+ The tuple of output channels for each block.
47
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
48
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
49
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
50
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
51
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
52
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
53
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
54
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
55
+ """
56
+
57
+ @register_to_config
58
+ def __init__(
59
+ self,
60
+ sample_size: Optional[int] = None,
61
+ in_channels: int = 4,
62
+ out_channels: int = 4,
63
+ center_input_sample: bool = False,
64
+ flip_sin_to_cos: bool = True,
65
+ freq_shift: int = 0,
66
+ down_block_types: Tuple[str] = (
67
+ "CrossAttnDownBlock2D",
68
+ "CrossAttnDownBlock2D",
69
+ "CrossAttnDownBlock2D",
70
+ "DownBlock2D",
71
+ ),
72
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
73
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
74
+ layers_per_block: int = 2,
75
+ downsample_padding: int = 1,
76
+ mid_block_scale_factor: float = 1,
77
+ act_fn: str = "silu",
78
+ norm_num_groups: int = 32,
79
+ norm_eps: float = 1e-5,
80
+ cross_attention_dim: int = 1280,
81
+ attention_head_dim: int = 8,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.sample_size = sample_size
86
+ time_embed_dim = block_out_channels[0] * 4
87
+
88
+ # input
89
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
90
+
91
+ # time
92
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
+ timestep_input_dim = block_out_channels[0]
94
+
95
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
+
97
+ self.down_blocks = nn.ModuleList([])
98
+ self.mid_block = None
99
+ self.up_blocks = nn.ModuleList([])
100
+
101
+ # down
102
+ output_channel = block_out_channels[0]
103
+ for i, down_block_type in enumerate(down_block_types):
104
+ input_channel = output_channel
105
+ output_channel = block_out_channels[i]
106
+ is_final_block = i == len(block_out_channels) - 1
107
+
108
+ down_block = get_down_block(
109
+ down_block_type,
110
+ num_layers=layers_per_block,
111
+ in_channels=input_channel,
112
+ out_channels=output_channel,
113
+ temb_channels=time_embed_dim,
114
+ add_downsample=not is_final_block,
115
+ resnet_eps=norm_eps,
116
+ resnet_act_fn=act_fn,
117
+ cross_attention_dim=cross_attention_dim,
118
+ attn_num_head_channels=attention_head_dim,
119
+ downsample_padding=downsample_padding,
120
+ )
121
+ self.down_blocks.append(down_block)
122
+
123
+ # mid
124
+ self.mid_block = UNetMidBlock2DCrossAttn(
125
+ in_channels=block_out_channels[-1],
126
+ temb_channels=time_embed_dim,
127
+ resnet_eps=norm_eps,
128
+ resnet_act_fn=act_fn,
129
+ output_scale_factor=mid_block_scale_factor,
130
+ resnet_time_scale_shift="default",
131
+ cross_attention_dim=cross_attention_dim,
132
+ attn_num_head_channels=attention_head_dim,
133
+ resnet_groups=norm_num_groups,
134
+ )
135
+
136
+ # up
137
+ reversed_block_out_channels = list(reversed(block_out_channels))
138
+ output_channel = reversed_block_out_channels[0]
139
+ for i, up_block_type in enumerate(up_block_types):
140
+ prev_output_channel = output_channel
141
+ output_channel = reversed_block_out_channels[i]
142
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
143
+
144
+ is_final_block = i == len(block_out_channels) - 1
145
+
146
+ up_block = get_up_block(
147
+ up_block_type,
148
+ num_layers=layers_per_block + 1,
149
+ in_channels=input_channel,
150
+ out_channels=output_channel,
151
+ prev_output_channel=prev_output_channel,
152
+ temb_channels=time_embed_dim,
153
+ add_upsample=not is_final_block,
154
+ resnet_eps=norm_eps,
155
+ resnet_act_fn=act_fn,
156
+ cross_attention_dim=cross_attention_dim,
157
+ attn_num_head_channels=attention_head_dim,
158
+ )
159
+ self.up_blocks.append(up_block)
160
+ prev_output_channel = output_channel
161
+
162
+ # out
163
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
164
+ self.conv_act = nn.SiLU()
165
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
166
+
167
+ def set_attention_slice(self, slice_size):
168
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
169
+ raise ValueError(
170
+ f"Make sure slice_size {slice_size} is a divisor of "
171
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
172
+ )
173
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
174
+ raise ValueError(
175
+ f"Chunk_size {slice_size} has to be smaller or equal to "
176
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
177
+ )
178
+
179
+ for block in self.down_blocks:
180
+ if hasattr(block, "attentions") and block.attentions is not None:
181
+ block.set_attention_slice(slice_size)
182
+
183
+ self.mid_block.set_attention_slice(slice_size)
184
+
185
+ for block in self.up_blocks:
186
+ if hasattr(block, "attentions") and block.attentions is not None:
187
+ block.set_attention_slice(slice_size)
188
+
189
+ def forward(
190
+ self,
191
+ sample: torch.FloatTensor,
192
+ timestep: Union[torch.Tensor, float, int],
193
+ encoder_hidden_states: torch.Tensor,
194
+ return_dict: bool = True,
195
+ ) -> Union[UNet2DConditionOutput, Tuple]:
196
+ """r
197
+ Args:
198
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
199
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
200
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
201
+ return_dict (`bool`, *optional*, defaults to `True`):
202
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
203
+
204
+ Returns:
205
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
206
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
207
+ returning a tuple, the first element is the sample tensor.
208
+ """
209
+ # 0. center input if necessary
210
+ if self.config.center_input_sample:
211
+ sample = 2 * sample - 1.0
212
+
213
+ # 1. time
214
+ timesteps = timestep
215
+ if not torch.is_tensor(timesteps):
216
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
217
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
218
+ timesteps = timesteps.to(dtype=torch.float32)
219
+ timesteps = timesteps[None].to(device=sample.device)
220
+
221
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
222
+ timesteps = timesteps.expand(sample.shape[0])
223
+
224
+ t_emb = self.time_proj(timesteps)
225
+ emb = self.time_embedding(t_emb)
226
+
227
+ # 2. pre-process
228
+ sample = self.conv_in(sample)
229
+
230
+ # 3. down
231
+ down_block_res_samples = (sample,)
232
+ for downsample_block in self.down_blocks:
233
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
234
+ sample, res_samples = downsample_block(
235
+ hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
236
+ )
237
+ else:
238
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
239
+
240
+ down_block_res_samples += res_samples
241
+
242
+ # 4. mid
243
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
244
+
245
+ # 5. up
246
+ for upsample_block in self.up_blocks:
247
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
248
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
249
+
250
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
251
+ sample = upsample_block(
252
+ hidden_states=sample,
253
+ temb=emb,
254
+ res_hidden_states_tuple=res_samples,
255
+ encoder_hidden_states=encoder_hidden_states,
256
+ )
257
+ else:
258
+ sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
259
+
260
+ # 6. post-process
261
+ # make sure hidden states is in float32
262
+ # when running in half-precision
263
+ sample = self.conv_norm_out(sample.float()).type(sample.dtype)
264
+ sample = self.conv_act(sample)
265
+ sample = self.conv_out(sample)
266
+
267
+ return sample
268
+
269
+ if not return_dict:
270
+ return (sample,)
271
+
272
+ return UNet2DConditionOutput(sample=sample)
diffusers/models/unet_blocks.py ADDED
@@ -0,0 +1,1484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+
14
+ import numpy as np
15
+
16
+ # limitations under the License.
17
+ import torch
18
+ from torch import nn
19
+
20
+ from .attention import AttentionBlock, SpatialTransformer
21
+ from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
22
+
23
+
24
+ def get_down_block(
25
+ down_block_type,
26
+ num_layers,
27
+ in_channels,
28
+ out_channels,
29
+ temb_channels,
30
+ add_downsample,
31
+ resnet_eps,
32
+ resnet_act_fn,
33
+ attn_num_head_channels,
34
+ cross_attention_dim=None,
35
+ downsample_padding=None,
36
+ ):
37
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
38
+ print(down_block_type)
39
+ if down_block_type == "DownBlock2D":
40
+ return DownBlock2D(
41
+ num_layers=num_layers,
42
+ in_channels=in_channels,
43
+ out_channels=out_channels,
44
+ temb_channels=temb_channels,
45
+ add_downsample=add_downsample,
46
+ resnet_eps=resnet_eps,
47
+ resnet_act_fn=resnet_act_fn,
48
+ downsample_padding=downsample_padding,
49
+ )
50
+ elif down_block_type == "AttnDownBlock2D":
51
+ return AttnDownBlock2D(
52
+ num_layers=num_layers,
53
+ in_channels=in_channels,
54
+ out_channels=out_channels,
55
+ temb_channels=temb_channels,
56
+ add_downsample=add_downsample,
57
+ resnet_eps=resnet_eps,
58
+ resnet_act_fn=resnet_act_fn,
59
+ downsample_padding=downsample_padding,
60
+ attn_num_head_channels=attn_num_head_channels,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock2D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
65
+ return CrossAttnDownBlock2D(
66
+ num_layers=num_layers,
67
+ in_channels=in_channels,
68
+ out_channels=out_channels,
69
+ temb_channels=temb_channels,
70
+ add_downsample=add_downsample,
71
+ resnet_eps=resnet_eps,
72
+ resnet_act_fn=resnet_act_fn,
73
+ downsample_padding=downsample_padding,
74
+ cross_attention_dim=cross_attention_dim,
75
+ attn_num_head_channels=attn_num_head_channels,
76
+ )
77
+ elif down_block_type == "SkipDownBlock2D":
78
+ return SkipDownBlock2D(
79
+ num_layers=num_layers,
80
+ in_channels=in_channels,
81
+ out_channels=out_channels,
82
+ temb_channels=temb_channels,
83
+ add_downsample=add_downsample,
84
+ resnet_eps=resnet_eps,
85
+ resnet_act_fn=resnet_act_fn,
86
+ downsample_padding=downsample_padding,
87
+ )
88
+ elif down_block_type == "AttnSkipDownBlock2D":
89
+ return AttnSkipDownBlock2D(
90
+ num_layers=num_layers,
91
+ in_channels=in_channels,
92
+ out_channels=out_channels,
93
+ temb_channels=temb_channels,
94
+ add_downsample=add_downsample,
95
+ resnet_eps=resnet_eps,
96
+ resnet_act_fn=resnet_act_fn,
97
+ downsample_padding=downsample_padding,
98
+ attn_num_head_channels=attn_num_head_channels,
99
+ )
100
+ elif down_block_type == "DownEncoderBlock2D":
101
+ return DownEncoderBlock2D(
102
+ num_layers=num_layers,
103
+ in_channels=in_channels,
104
+ out_channels=out_channels,
105
+ add_downsample=add_downsample,
106
+ resnet_eps=resnet_eps,
107
+ resnet_act_fn=resnet_act_fn,
108
+ downsample_padding=downsample_padding,
109
+ )
110
+
111
+
112
+ def get_up_block(
113
+ up_block_type,
114
+ num_layers,
115
+ in_channels,
116
+ out_channels,
117
+ prev_output_channel,
118
+ temb_channels,
119
+ add_upsample,
120
+ resnet_eps,
121
+ resnet_act_fn,
122
+ attn_num_head_channels,
123
+ cross_attention_dim=None,
124
+ ):
125
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
126
+ print(up_block_type)
127
+ if up_block_type == "UpBlock2D":
128
+ return UpBlock2D(
129
+ num_layers=num_layers,
130
+ in_channels=in_channels,
131
+ out_channels=out_channels,
132
+ prev_output_channel=prev_output_channel,
133
+ temb_channels=temb_channels,
134
+ add_upsample=add_upsample,
135
+ resnet_eps=resnet_eps,
136
+ resnet_act_fn=resnet_act_fn,
137
+ )
138
+ elif up_block_type == "CrossAttnUpBlock2D":
139
+ if cross_attention_dim is None:
140
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
141
+ return CrossAttnUpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ add_upsample=add_upsample,
148
+ resnet_eps=resnet_eps,
149
+ resnet_act_fn=resnet_act_fn,
150
+ cross_attention_dim=cross_attention_dim,
151
+ attn_num_head_channels=attn_num_head_channels,
152
+ )
153
+ elif up_block_type == "AttnUpBlock2D":
154
+ return AttnUpBlock2D(
155
+ num_layers=num_layers,
156
+ in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ prev_output_channel=prev_output_channel,
159
+ temb_channels=temb_channels,
160
+ add_upsample=add_upsample,
161
+ resnet_eps=resnet_eps,
162
+ resnet_act_fn=resnet_act_fn,
163
+ attn_num_head_channels=attn_num_head_channels,
164
+ )
165
+ elif up_block_type == "SkipUpBlock2D":
166
+ return SkipUpBlock2D(
167
+ num_layers=num_layers,
168
+ in_channels=in_channels,
169
+ out_channels=out_channels,
170
+ prev_output_channel=prev_output_channel,
171
+ temb_channels=temb_channels,
172
+ add_upsample=add_upsample,
173
+ resnet_eps=resnet_eps,
174
+ resnet_act_fn=resnet_act_fn,
175
+ )
176
+ elif up_block_type == "AttnSkipUpBlock2D":
177
+ return AttnSkipUpBlock2D(
178
+ num_layers=num_layers,
179
+ in_channels=in_channels,
180
+ out_channels=out_channels,
181
+ prev_output_channel=prev_output_channel,
182
+ temb_channels=temb_channels,
183
+ add_upsample=add_upsample,
184
+ resnet_eps=resnet_eps,
185
+ resnet_act_fn=resnet_act_fn,
186
+ attn_num_head_channels=attn_num_head_channels,
187
+ )
188
+ elif up_block_type == "UpDecoderBlock2D":
189
+ return UpDecoderBlock2D(
190
+ num_layers=num_layers,
191
+ in_channels=in_channels,
192
+ out_channels=out_channels,
193
+ add_upsample=add_upsample,
194
+ resnet_eps=resnet_eps,
195
+ resnet_act_fn=resnet_act_fn,
196
+ )
197
+ raise ValueError(f"{up_block_type} does not exist.")
198
+
199
+
200
+ class UNetMidBlock2D(nn.Module):
201
+ def __init__(
202
+ self,
203
+ in_channels: int,
204
+ temb_channels: int,
205
+ dropout: float = 0.0,
206
+ num_layers: int = 1,
207
+ resnet_eps: float = 1e-6,
208
+ resnet_time_scale_shift: str = "default",
209
+ resnet_act_fn: str = "swish",
210
+ resnet_groups: int = 32,
211
+ resnet_pre_norm: bool = True,
212
+ attn_num_head_channels=1,
213
+ attention_type="default",
214
+ output_scale_factor=1.0,
215
+ **kwargs,
216
+ ):
217
+ super().__init__()
218
+
219
+ self.attention_type = attention_type
220
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
221
+
222
+ # there is always at least one resnet
223
+ resnets = [
224
+ ResnetBlock2D(
225
+ in_channels=in_channels,
226
+ out_channels=in_channels,
227
+ temb_channels=temb_channels,
228
+ eps=resnet_eps,
229
+ groups=resnet_groups,
230
+ dropout=dropout,
231
+ time_embedding_norm=resnet_time_scale_shift,
232
+ non_linearity=resnet_act_fn,
233
+ output_scale_factor=output_scale_factor,
234
+ pre_norm=resnet_pre_norm,
235
+ )
236
+ ]
237
+ attentions = []
238
+
239
+ for _ in range(num_layers):
240
+ attentions.append(
241
+ AttentionBlock(
242
+ in_channels,
243
+ num_head_channels=attn_num_head_channels,
244
+ rescale_output_factor=output_scale_factor,
245
+ eps=resnet_eps,
246
+ num_groups=resnet_groups,
247
+ )
248
+ )
249
+ resnets.append(
250
+ ResnetBlock2D(
251
+ in_channels=in_channels,
252
+ out_channels=in_channels,
253
+ temb_channels=temb_channels,
254
+ eps=resnet_eps,
255
+ groups=resnet_groups,
256
+ dropout=dropout,
257
+ time_embedding_norm=resnet_time_scale_shift,
258
+ non_linearity=resnet_act_fn,
259
+ output_scale_factor=output_scale_factor,
260
+ pre_norm=resnet_pre_norm,
261
+ )
262
+ )
263
+
264
+ self.attentions = nn.ModuleList(attentions)
265
+ self.resnets = nn.ModuleList(resnets)
266
+
267
+ def forward(self, hidden_states, temb=None, encoder_states=None):
268
+ hidden_states = self.resnets[0](hidden_states, temb)
269
+ print(self.attention_type)
270
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
271
+ if self.attention_type == "default":
272
+ hidden_states = attn(hidden_states)
273
+ else:
274
+ hidden_states = attn(hidden_states, encoder_states)
275
+ hidden_states = resnet(hidden_states, temb)
276
+
277
+ return hidden_states
278
+
279
+
280
+ class UNetMidBlock2DCrossAttn(nn.Module):
281
+ def __init__(
282
+ self,
283
+ in_channels: int,
284
+ temb_channels: int,
285
+ dropout: float = 0.0,
286
+ num_layers: int = 1,
287
+ resnet_eps: float = 1e-6,
288
+ resnet_time_scale_shift: str = "default",
289
+ resnet_act_fn: str = "swish",
290
+ resnet_groups: int = 32,
291
+ resnet_pre_norm: bool = True,
292
+ attn_num_head_channels=1,
293
+ attention_type="default",
294
+ output_scale_factor=1.0,
295
+ cross_attention_dim=1280,
296
+ **kwargs,
297
+ ):
298
+ super().__init__()
299
+
300
+ self.attention_type = attention_type
301
+ self.attn_num_head_channels = attn_num_head_channels
302
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
303
+
304
+ # there is always at least one resnet
305
+ resnets = [
306
+ ResnetBlock2D(
307
+ in_channels=in_channels,
308
+ out_channels=in_channels,
309
+ temb_channels=temb_channels,
310
+ eps=resnet_eps,
311
+ groups=resnet_groups,
312
+ dropout=dropout,
313
+ time_embedding_norm=resnet_time_scale_shift,
314
+ non_linearity=resnet_act_fn,
315
+ output_scale_factor=output_scale_factor,
316
+ pre_norm=resnet_pre_norm,
317
+ )
318
+ ]
319
+ attentions = []
320
+
321
+ for _ in range(num_layers):
322
+ attentions.append(
323
+ SpatialTransformer(
324
+ in_channels,
325
+ attn_num_head_channels,
326
+ in_channels // attn_num_head_channels,
327
+ depth=1,
328
+ context_dim=cross_attention_dim,
329
+ )
330
+ )
331
+ resnets.append(
332
+ ResnetBlock2D(
333
+ in_channels=in_channels,
334
+ out_channels=in_channels,
335
+ temb_channels=temb_channels,
336
+ eps=resnet_eps,
337
+ groups=resnet_groups,
338
+ dropout=dropout,
339
+ time_embedding_norm=resnet_time_scale_shift,
340
+ non_linearity=resnet_act_fn,
341
+ output_scale_factor=output_scale_factor,
342
+ pre_norm=resnet_pre_norm,
343
+ )
344
+ )
345
+
346
+ self.attentions = nn.ModuleList(attentions)
347
+ self.resnets = nn.ModuleList(resnets)
348
+
349
+ def set_attention_slice(self, slice_size):
350
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
351
+ raise ValueError(
352
+ f"Make sure slice_size {slice_size} is a divisor of "
353
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
354
+ )
355
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
356
+ raise ValueError(
357
+ f"Chunk_size {slice_size} has to be smaller or equal to "
358
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
359
+ )
360
+
361
+ for attn in self.attentions:
362
+ attn._set_attention_slice(slice_size)
363
+
364
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
365
+ hidden_states = self.resnets[0](hidden_states, temb)
366
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
367
+ hidden_states = attn(hidden_states, encoder_hidden_states)
368
+ hidden_states = resnet(hidden_states, temb)
369
+
370
+ return hidden_states
371
+
372
+
373
+ class AttnDownBlock2D(nn.Module):
374
+ def __init__(
375
+ self,
376
+ in_channels: int,
377
+ out_channels: int,
378
+ temb_channels: int,
379
+ dropout: float = 0.0,
380
+ num_layers: int = 1,
381
+ resnet_eps: float = 1e-6,
382
+ resnet_time_scale_shift: str = "default",
383
+ resnet_act_fn: str = "swish",
384
+ resnet_groups: int = 32,
385
+ resnet_pre_norm: bool = True,
386
+ attn_num_head_channels=1,
387
+ attention_type="default",
388
+ output_scale_factor=1.0,
389
+ downsample_padding=1,
390
+ add_downsample=True,
391
+ ):
392
+ super().__init__()
393
+ resnets = []
394
+ attentions = []
395
+
396
+ self.attention_type = attention_type
397
+
398
+ for i in range(num_layers):
399
+ in_channels = in_channels if i == 0 else out_channels
400
+ resnets.append(
401
+ ResnetBlock2D(
402
+ in_channels=in_channels,
403
+ out_channels=out_channels,
404
+ temb_channels=temb_channels,
405
+ eps=resnet_eps,
406
+ groups=resnet_groups,
407
+ dropout=dropout,
408
+ time_embedding_norm=resnet_time_scale_shift,
409
+ non_linearity=resnet_act_fn,
410
+ output_scale_factor=output_scale_factor,
411
+ pre_norm=resnet_pre_norm,
412
+ )
413
+ )
414
+ attentions.append(
415
+ AttentionBlock(
416
+ out_channels,
417
+ num_head_channels=attn_num_head_channels,
418
+ rescale_output_factor=output_scale_factor,
419
+ eps=resnet_eps,
420
+ )
421
+ )
422
+
423
+ self.attentions = nn.ModuleList(attentions)
424
+ self.resnets = nn.ModuleList(resnets)
425
+
426
+ if add_downsample:
427
+ self.downsamplers = nn.ModuleList(
428
+ [
429
+ Downsample2D(
430
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
431
+ )
432
+ ]
433
+ )
434
+ else:
435
+ self.downsamplers = None
436
+
437
+ def forward(self, hidden_states, temb=None):
438
+ output_states = ()
439
+
440
+ for resnet, attn in zip(self.resnets, self.attentions):
441
+ hidden_states = resnet(hidden_states, temb)
442
+ hidden_states = attn(hidden_states)
443
+ output_states += (hidden_states,)
444
+
445
+ if self.downsamplers is not None:
446
+ for downsampler in self.downsamplers:
447
+ hidden_states = downsampler(hidden_states)
448
+
449
+ output_states += (hidden_states,)
450
+
451
+ return hidden_states, output_states
452
+
453
+
454
+ class CrossAttnDownBlock2D(nn.Module):
455
+ def __init__(
456
+ self,
457
+ in_channels: int,
458
+ out_channels: int,
459
+ temb_channels: int,
460
+ dropout: float = 0.0,
461
+ num_layers: int = 1,
462
+ resnet_eps: float = 1e-6,
463
+ resnet_time_scale_shift: str = "default",
464
+ resnet_act_fn: str = "swish",
465
+ resnet_groups: int = 32,
466
+ resnet_pre_norm: bool = True,
467
+ attn_num_head_channels=1,
468
+ cross_attention_dim=1280,
469
+ attention_type="default",
470
+ output_scale_factor=1.0,
471
+ downsample_padding=1,
472
+ add_downsample=True,
473
+ ):
474
+ super().__init__()
475
+ resnets = []
476
+ attentions = []
477
+
478
+ self.attention_type = attention_type
479
+ self.attn_num_head_channels = attn_num_head_channels
480
+
481
+ for i in range(num_layers):
482
+ in_channels = in_channels if i == 0 else out_channels
483
+ resnets.append(
484
+ ResnetBlock2D(
485
+ in_channels=in_channels,
486
+ out_channels=out_channels,
487
+ temb_channels=temb_channels,
488
+ eps=resnet_eps,
489
+ groups=resnet_groups,
490
+ dropout=dropout,
491
+ time_embedding_norm=resnet_time_scale_shift,
492
+ non_linearity=resnet_act_fn,
493
+ output_scale_factor=output_scale_factor,
494
+ pre_norm=resnet_pre_norm,
495
+ )
496
+ )
497
+ attentions.append(
498
+ SpatialTransformer(
499
+ out_channels,
500
+ attn_num_head_channels,
501
+ out_channels // attn_num_head_channels,
502
+ depth=1,
503
+ context_dim=cross_attention_dim,
504
+ )
505
+ )
506
+ self.attentions = nn.ModuleList(attentions)
507
+ self.resnets = nn.ModuleList(resnets)
508
+
509
+ if add_downsample:
510
+ self.downsamplers = nn.ModuleList(
511
+ [
512
+ Downsample2D(
513
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
514
+ )
515
+ ]
516
+ )
517
+ else:
518
+ self.downsamplers = None
519
+
520
+ def set_attention_slice(self, slice_size):
521
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
522
+ raise ValueError(
523
+ f"Make sure slice_size {slice_size} is a divisor of "
524
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
525
+ )
526
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
527
+ raise ValueError(
528
+ f"Chunk_size {slice_size} has to be smaller or equal to "
529
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
530
+ )
531
+
532
+ for attn in self.attentions:
533
+ attn._set_attention_slice(slice_size)
534
+
535
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
536
+ output_states = ()
537
+
538
+ for resnet, attn in zip(self.resnets, self.attentions):
539
+ hidden_states = resnet(hidden_states, temb)
540
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
541
+ output_states += (hidden_states,)
542
+
543
+ if self.downsamplers is not None:
544
+ for downsampler in self.downsamplers:
545
+ hidden_states = downsampler(hidden_states)
546
+
547
+ output_states += (hidden_states,)
548
+
549
+ return hidden_states, output_states
550
+
551
+
552
+ class DownBlock2D(nn.Module):
553
+ def __init__(
554
+ self,
555
+ in_channels: int,
556
+ out_channels: int,
557
+ temb_channels: int,
558
+ dropout: float = 0.0,
559
+ num_layers: int = 1,
560
+ resnet_eps: float = 1e-6,
561
+ resnet_time_scale_shift: str = "default",
562
+ resnet_act_fn: str = "swish",
563
+ resnet_groups: int = 32,
564
+ resnet_pre_norm: bool = True,
565
+ output_scale_factor=1.0,
566
+ add_downsample=True,
567
+ downsample_padding=1,
568
+ ):
569
+ super().__init__()
570
+ resnets = []
571
+
572
+ for i in range(num_layers):
573
+ in_channels = in_channels if i == 0 else out_channels
574
+ resnets.append(
575
+ ResnetBlock2D(
576
+ in_channels=in_channels,
577
+ out_channels=out_channels,
578
+ temb_channels=temb_channels,
579
+ eps=resnet_eps,
580
+ groups=resnet_groups,
581
+ dropout=dropout,
582
+ time_embedding_norm=resnet_time_scale_shift,
583
+ non_linearity=resnet_act_fn,
584
+ output_scale_factor=output_scale_factor,
585
+ pre_norm=resnet_pre_norm,
586
+ )
587
+ )
588
+
589
+ self.resnets = nn.ModuleList(resnets)
590
+
591
+ if add_downsample:
592
+ self.downsamplers = nn.ModuleList(
593
+ [
594
+ Downsample2D(
595
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
596
+ )
597
+ ]
598
+ )
599
+ else:
600
+ self.downsamplers = None
601
+
602
+ def forward(self, hidden_states, temb=None):
603
+ output_states = ()
604
+
605
+ for resnet in self.resnets:
606
+ hidden_states = resnet(hidden_states, temb)
607
+ output_states += (hidden_states,)
608
+
609
+ if self.downsamplers is not None:
610
+ for downsampler in self.downsamplers:
611
+ hidden_states = downsampler(hidden_states)
612
+
613
+ output_states += (hidden_states,)
614
+
615
+ return hidden_states, output_states
616
+
617
+
618
+ class DownEncoderBlock2D(nn.Module):
619
+ def __init__(
620
+ self,
621
+ in_channels: int,
622
+ out_channels: int,
623
+ dropout: float = 0.0,
624
+ num_layers: int = 1,
625
+ resnet_eps: float = 1e-6,
626
+ resnet_time_scale_shift: str = "default",
627
+ resnet_act_fn: str = "swish",
628
+ resnet_groups: int = 32,
629
+ resnet_pre_norm: bool = True,
630
+ output_scale_factor=1.0,
631
+ add_downsample=True,
632
+ downsample_padding=1,
633
+ ):
634
+ super().__init__()
635
+ resnets = []
636
+
637
+ for i in range(num_layers):
638
+ in_channels = in_channels if i == 0 else out_channels
639
+ resnets.append(
640
+ ResnetBlock2D(
641
+ in_channels=in_channels,
642
+ out_channels=out_channels,
643
+ temb_channels=None,
644
+ eps=resnet_eps,
645
+ groups=resnet_groups,
646
+ dropout=dropout,
647
+ time_embedding_norm=resnet_time_scale_shift,
648
+ non_linearity=resnet_act_fn,
649
+ output_scale_factor=output_scale_factor,
650
+ pre_norm=resnet_pre_norm,
651
+ )
652
+ )
653
+
654
+ self.resnets = nn.ModuleList(resnets)
655
+
656
+ if add_downsample:
657
+ self.downsamplers = nn.ModuleList(
658
+ [
659
+ Downsample2D(
660
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
661
+ )
662
+ ]
663
+ )
664
+ else:
665
+ self.downsamplers = None
666
+
667
+ def forward(self, hidden_states):
668
+ for resnet in self.resnets:
669
+ hidden_states = resnet(hidden_states, temb=None)
670
+
671
+ if self.downsamplers is not None:
672
+ for downsampler in self.downsamplers:
673
+ hidden_states = downsampler(hidden_states)
674
+
675
+ return hidden_states
676
+
677
+
678
+ class AttnDownEncoderBlock2D(nn.Module):
679
+ def __init__(
680
+ self,
681
+ in_channels: int,
682
+ out_channels: int,
683
+ dropout: float = 0.0,
684
+ num_layers: int = 1,
685
+ resnet_eps: float = 1e-6,
686
+ resnet_time_scale_shift: str = "default",
687
+ resnet_act_fn: str = "swish",
688
+ resnet_groups: int = 32,
689
+ resnet_pre_norm: bool = True,
690
+ attn_num_head_channels=1,
691
+ output_scale_factor=1.0,
692
+ add_downsample=True,
693
+ downsample_padding=1,
694
+ ):
695
+ super().__init__()
696
+ resnets = []
697
+ attentions = []
698
+
699
+ for i in range(num_layers):
700
+ in_channels = in_channels if i == 0 else out_channels
701
+ resnets.append(
702
+ ResnetBlock2D(
703
+ in_channels=in_channels,
704
+ out_channels=out_channels,
705
+ temb_channels=None,
706
+ eps=resnet_eps,
707
+ groups=resnet_groups,
708
+ dropout=dropout,
709
+ time_embedding_norm=resnet_time_scale_shift,
710
+ non_linearity=resnet_act_fn,
711
+ output_scale_factor=output_scale_factor,
712
+ pre_norm=resnet_pre_norm,
713
+ )
714
+ )
715
+ attentions.append(
716
+ AttentionBlock(
717
+ out_channels,
718
+ num_head_channels=attn_num_head_channels,
719
+ rescale_output_factor=output_scale_factor,
720
+ eps=resnet_eps,
721
+ num_groups=resnet_groups,
722
+ )
723
+ )
724
+
725
+ self.attentions = nn.ModuleList(attentions)
726
+ self.resnets = nn.ModuleList(resnets)
727
+
728
+ if add_downsample:
729
+ self.downsamplers = nn.ModuleList(
730
+ [
731
+ Downsample2D(
732
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
733
+ )
734
+ ]
735
+ )
736
+ else:
737
+ self.downsamplers = None
738
+
739
+ def forward(self, hidden_states):
740
+ for resnet, attn in zip(self.resnets, self.attentions):
741
+ hidden_states = resnet(hidden_states, temb=None)
742
+ hidden_states = attn(hidden_states)
743
+
744
+ if self.downsamplers is not None:
745
+ for downsampler in self.downsamplers:
746
+ hidden_states = downsampler(hidden_states)
747
+
748
+ return hidden_states
749
+
750
+
751
+ class AttnSkipDownBlock2D(nn.Module):
752
+ def __init__(
753
+ self,
754
+ in_channels: int,
755
+ out_channels: int,
756
+ temb_channels: int,
757
+ dropout: float = 0.0,
758
+ num_layers: int = 1,
759
+ resnet_eps: float = 1e-6,
760
+ resnet_time_scale_shift: str = "default",
761
+ resnet_act_fn: str = "swish",
762
+ resnet_pre_norm: bool = True,
763
+ attn_num_head_channels=1,
764
+ attention_type="default",
765
+ output_scale_factor=np.sqrt(2.0),
766
+ downsample_padding=1,
767
+ add_downsample=True,
768
+ ):
769
+ super().__init__()
770
+ self.attentions = nn.ModuleList([])
771
+ self.resnets = nn.ModuleList([])
772
+
773
+ self.attention_type = attention_type
774
+
775
+ for i in range(num_layers):
776
+ in_channels = in_channels if i == 0 else out_channels
777
+ self.resnets.append(
778
+ ResnetBlock2D(
779
+ in_channels=in_channels,
780
+ out_channels=out_channels,
781
+ temb_channels=temb_channels,
782
+ eps=resnet_eps,
783
+ groups=min(in_channels // 4, 32),
784
+ groups_out=min(out_channels // 4, 32),
785
+ dropout=dropout,
786
+ time_embedding_norm=resnet_time_scale_shift,
787
+ non_linearity=resnet_act_fn,
788
+ output_scale_factor=output_scale_factor,
789
+ pre_norm=resnet_pre_norm,
790
+ )
791
+ )
792
+ self.attentions.append(
793
+ AttentionBlock(
794
+ out_channels,
795
+ num_head_channels=attn_num_head_channels,
796
+ rescale_output_factor=output_scale_factor,
797
+ eps=resnet_eps,
798
+ )
799
+ )
800
+
801
+ if add_downsample:
802
+ self.resnet_down = ResnetBlock2D(
803
+ in_channels=out_channels,
804
+ out_channels=out_channels,
805
+ temb_channels=temb_channels,
806
+ eps=resnet_eps,
807
+ groups=min(out_channels // 4, 32),
808
+ dropout=dropout,
809
+ time_embedding_norm=resnet_time_scale_shift,
810
+ non_linearity=resnet_act_fn,
811
+ output_scale_factor=output_scale_factor,
812
+ pre_norm=resnet_pre_norm,
813
+ use_nin_shortcut=True,
814
+ down=True,
815
+ kernel="fir",
816
+ )
817
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
818
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
819
+ else:
820
+ self.resnet_down = None
821
+ self.downsamplers = None
822
+ self.skip_conv = None
823
+
824
+ def forward(self, hidden_states, temb=None, skip_sample=None):
825
+ output_states = ()
826
+
827
+ for resnet, attn in zip(self.resnets, self.attentions):
828
+ hidden_states = resnet(hidden_states, temb)
829
+ hidden_states = attn(hidden_states)
830
+ output_states += (hidden_states,)
831
+
832
+ if self.downsamplers is not None:
833
+ hidden_states = self.resnet_down(hidden_states, temb)
834
+ for downsampler in self.downsamplers:
835
+ skip_sample = downsampler(skip_sample)
836
+
837
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
838
+
839
+ output_states += (hidden_states,)
840
+
841
+ return hidden_states, output_states, skip_sample
842
+
843
+
844
+ class SkipDownBlock2D(nn.Module):
845
+ def __init__(
846
+ self,
847
+ in_channels: int,
848
+ out_channels: int,
849
+ temb_channels: int,
850
+ dropout: float = 0.0,
851
+ num_layers: int = 1,
852
+ resnet_eps: float = 1e-6,
853
+ resnet_time_scale_shift: str = "default",
854
+ resnet_act_fn: str = "swish",
855
+ resnet_pre_norm: bool = True,
856
+ output_scale_factor=np.sqrt(2.0),
857
+ add_downsample=True,
858
+ downsample_padding=1,
859
+ ):
860
+ super().__init__()
861
+ self.resnets = nn.ModuleList([])
862
+
863
+ for i in range(num_layers):
864
+ in_channels = in_channels if i == 0 else out_channels
865
+ self.resnets.append(
866
+ ResnetBlock2D(
867
+ in_channels=in_channels,
868
+ out_channels=out_channels,
869
+ temb_channels=temb_channels,
870
+ eps=resnet_eps,
871
+ groups=min(in_channels // 4, 32),
872
+ groups_out=min(out_channels // 4, 32),
873
+ dropout=dropout,
874
+ time_embedding_norm=resnet_time_scale_shift,
875
+ non_linearity=resnet_act_fn,
876
+ output_scale_factor=output_scale_factor,
877
+ pre_norm=resnet_pre_norm,
878
+ )
879
+ )
880
+
881
+ if add_downsample:
882
+ self.resnet_down = ResnetBlock2D(
883
+ in_channels=out_channels,
884
+ out_channels=out_channels,
885
+ temb_channels=temb_channels,
886
+ eps=resnet_eps,
887
+ groups=min(out_channels // 4, 32),
888
+ dropout=dropout,
889
+ time_embedding_norm=resnet_time_scale_shift,
890
+ non_linearity=resnet_act_fn,
891
+ output_scale_factor=output_scale_factor,
892
+ pre_norm=resnet_pre_norm,
893
+ use_nin_shortcut=True,
894
+ down=True,
895
+ kernel="fir",
896
+ )
897
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
898
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
899
+ else:
900
+ self.resnet_down = None
901
+ self.downsamplers = None
902
+ self.skip_conv = None
903
+
904
+ def forward(self, hidden_states, temb=None, skip_sample=None):
905
+ output_states = ()
906
+
907
+ for resnet in self.resnets:
908
+ hidden_states = resnet(hidden_states, temb)
909
+ output_states += (hidden_states,)
910
+
911
+ if self.downsamplers is not None:
912
+ hidden_states = self.resnet_down(hidden_states, temb)
913
+ for downsampler in self.downsamplers:
914
+ skip_sample = downsampler(skip_sample)
915
+
916
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
917
+
918
+ output_states += (hidden_states,)
919
+
920
+ return hidden_states, output_states, skip_sample
921
+
922
+
923
+ class AttnUpBlock2D(nn.Module):
924
+ def __init__(
925
+ self,
926
+ in_channels: int,
927
+ prev_output_channel: int,
928
+ out_channels: int,
929
+ temb_channels: int,
930
+ dropout: float = 0.0,
931
+ num_layers: int = 1,
932
+ resnet_eps: float = 1e-6,
933
+ resnet_time_scale_shift: str = "default",
934
+ resnet_act_fn: str = "swish",
935
+ resnet_groups: int = 32,
936
+ resnet_pre_norm: bool = True,
937
+ attention_type="default",
938
+ attn_num_head_channels=1,
939
+ output_scale_factor=1.0,
940
+ add_upsample=True,
941
+ ):
942
+ super().__init__()
943
+ resnets = []
944
+ attentions = []
945
+
946
+ self.attention_type = attention_type
947
+
948
+ for i in range(num_layers):
949
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
950
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
951
+
952
+ resnets.append(
953
+ ResnetBlock2D(
954
+ in_channels=resnet_in_channels + res_skip_channels,
955
+ out_channels=out_channels,
956
+ temb_channels=temb_channels,
957
+ eps=resnet_eps,
958
+ groups=resnet_groups,
959
+ dropout=dropout,
960
+ time_embedding_norm=resnet_time_scale_shift,
961
+ non_linearity=resnet_act_fn,
962
+ output_scale_factor=output_scale_factor,
963
+ pre_norm=resnet_pre_norm,
964
+ )
965
+ )
966
+ attentions.append(
967
+ AttentionBlock(
968
+ out_channels,
969
+ num_head_channels=attn_num_head_channels,
970
+ rescale_output_factor=output_scale_factor,
971
+ eps=resnet_eps,
972
+ )
973
+ )
974
+
975
+ self.attentions = nn.ModuleList(attentions)
976
+ self.resnets = nn.ModuleList(resnets)
977
+
978
+ if add_upsample:
979
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
980
+ else:
981
+ self.upsamplers = None
982
+
983
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
984
+ for resnet, attn in zip(self.resnets, self.attentions):
985
+
986
+ # pop res hidden states
987
+ res_hidden_states = res_hidden_states_tuple[-1]
988
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
989
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
990
+
991
+ hidden_states = resnet(hidden_states, temb)
992
+ hidden_states = attn(hidden_states)
993
+
994
+ if self.upsamplers is not None:
995
+ for upsampler in self.upsamplers:
996
+ hidden_states = upsampler(hidden_states)
997
+
998
+ return hidden_states
999
+
1000
+
1001
+ class CrossAttnUpBlock2D(nn.Module):
1002
+ def __init__(
1003
+ self,
1004
+ in_channels: int,
1005
+ out_channels: int,
1006
+ prev_output_channel: int,
1007
+ temb_channels: int,
1008
+ dropout: float = 0.0,
1009
+ num_layers: int = 1,
1010
+ resnet_eps: float = 1e-6,
1011
+ resnet_time_scale_shift: str = "default",
1012
+ resnet_act_fn: str = "swish",
1013
+ resnet_groups: int = 32,
1014
+ resnet_pre_norm: bool = True,
1015
+ attn_num_head_channels=1,
1016
+ cross_attention_dim=1280,
1017
+ attention_type="default",
1018
+ output_scale_factor=1.0,
1019
+ downsample_padding=1,
1020
+ add_upsample=True,
1021
+ ):
1022
+ super().__init__()
1023
+ resnets = []
1024
+ attentions = []
1025
+
1026
+ self.attention_type = attention_type
1027
+ self.attn_num_head_channels = attn_num_head_channels
1028
+
1029
+ for i in range(num_layers):
1030
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1031
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1032
+
1033
+ resnets.append(
1034
+ ResnetBlock2D(
1035
+ in_channels=resnet_in_channels + res_skip_channels,
1036
+ out_channels=out_channels,
1037
+ temb_channels=temb_channels,
1038
+ eps=resnet_eps,
1039
+ groups=resnet_groups,
1040
+ dropout=dropout,
1041
+ time_embedding_norm=resnet_time_scale_shift,
1042
+ non_linearity=resnet_act_fn,
1043
+ output_scale_factor=output_scale_factor,
1044
+ pre_norm=resnet_pre_norm,
1045
+ )
1046
+ )
1047
+ attentions.append(
1048
+ SpatialTransformer(
1049
+ out_channels,
1050
+ attn_num_head_channels,
1051
+ out_channels // attn_num_head_channels,
1052
+ depth=1,
1053
+ context_dim=cross_attention_dim,
1054
+ )
1055
+ )
1056
+ self.attentions = nn.ModuleList(attentions)
1057
+ self.resnets = nn.ModuleList(resnets)
1058
+
1059
+ if add_upsample:
1060
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1061
+ else:
1062
+ self.upsamplers = None
1063
+
1064
+ def set_attention_slice(self, slice_size):
1065
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1066
+ raise ValueError(
1067
+ f"Make sure slice_size {slice_size} is a divisor of "
1068
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1069
+ )
1070
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
1071
+ raise ValueError(
1072
+ f"Chunk_size {slice_size} has to be smaller or equal to "
1073
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1074
+ )
1075
+
1076
+ for attn in self.attentions:
1077
+ attn._set_attention_slice(slice_size)
1078
+
1079
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
1080
+ for resnet, attn in zip(self.resnets, self.attentions):
1081
+
1082
+ # pop res hidden states
1083
+ res_hidden_states = res_hidden_states_tuple[-1]
1084
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1085
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1086
+
1087
+ hidden_states = resnet(hidden_states, temb)
1088
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
1089
+
1090
+ if self.upsamplers is not None:
1091
+ for upsampler in self.upsamplers:
1092
+ hidden_states = upsampler(hidden_states)
1093
+
1094
+ return hidden_states
1095
+
1096
+
1097
+ class UpBlock2D(nn.Module):
1098
+ def __init__(
1099
+ self,
1100
+ in_channels: int,
1101
+ prev_output_channel: int,
1102
+ out_channels: int,
1103
+ temb_channels: int,
1104
+ dropout: float = 0.0,
1105
+ num_layers: int = 1,
1106
+ resnet_eps: float = 1e-6,
1107
+ resnet_time_scale_shift: str = "default",
1108
+ resnet_act_fn: str = "swish",
1109
+ resnet_groups: int = 32,
1110
+ resnet_pre_norm: bool = True,
1111
+ output_scale_factor=1.0,
1112
+ add_upsample=True,
1113
+ ):
1114
+ super().__init__()
1115
+ resnets = []
1116
+
1117
+ for i in range(num_layers):
1118
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1119
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1120
+
1121
+ resnets.append(
1122
+ ResnetBlock2D(
1123
+ in_channels=resnet_in_channels + res_skip_channels,
1124
+ out_channels=out_channels,
1125
+ temb_channels=temb_channels,
1126
+ eps=resnet_eps,
1127
+ groups=resnet_groups,
1128
+ dropout=dropout,
1129
+ time_embedding_norm=resnet_time_scale_shift,
1130
+ non_linearity=resnet_act_fn,
1131
+ output_scale_factor=output_scale_factor,
1132
+ pre_norm=resnet_pre_norm,
1133
+ )
1134
+ )
1135
+
1136
+ self.resnets = nn.ModuleList(resnets)
1137
+
1138
+ if add_upsample:
1139
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1140
+ else:
1141
+ self.upsamplers = None
1142
+
1143
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1144
+ for resnet in self.resnets:
1145
+
1146
+ # pop res hidden states
1147
+ res_hidden_states = res_hidden_states_tuple[-1]
1148
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1149
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1150
+
1151
+ hidden_states = resnet(hidden_states, temb)
1152
+
1153
+ if self.upsamplers is not None:
1154
+ for upsampler in self.upsamplers:
1155
+ hidden_states = upsampler(hidden_states)
1156
+
1157
+ return hidden_states
1158
+
1159
+
1160
+ class UpDecoderBlock2D(nn.Module):
1161
+ def __init__(
1162
+ self,
1163
+ in_channels: int,
1164
+ out_channels: int,
1165
+ dropout: float = 0.0,
1166
+ num_layers: int = 1,
1167
+ resnet_eps: float = 1e-6,
1168
+ resnet_time_scale_shift: str = "default",
1169
+ resnet_act_fn: str = "swish",
1170
+ resnet_groups: int = 32,
1171
+ resnet_pre_norm: bool = True,
1172
+ output_scale_factor=1.0,
1173
+ add_upsample=True,
1174
+ ):
1175
+ super().__init__()
1176
+ resnets = []
1177
+
1178
+ for i in range(num_layers):
1179
+ input_channels = in_channels if i == 0 else out_channels
1180
+
1181
+ resnets.append(
1182
+ ResnetBlock2D(
1183
+ in_channels=input_channels,
1184
+ out_channels=out_channels,
1185
+ temb_channels=None,
1186
+ eps=resnet_eps,
1187
+ groups=resnet_groups,
1188
+ dropout=dropout,
1189
+ time_embedding_norm=resnet_time_scale_shift,
1190
+ non_linearity=resnet_act_fn,
1191
+ output_scale_factor=output_scale_factor,
1192
+ pre_norm=resnet_pre_norm,
1193
+ )
1194
+ )
1195
+
1196
+ self.resnets = nn.ModuleList(resnets)
1197
+
1198
+ if add_upsample:
1199
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1200
+ else:
1201
+ self.upsamplers = None
1202
+
1203
+ def forward(self, hidden_states):
1204
+ for resnet in self.resnets:
1205
+ hidden_states = resnet(hidden_states, temb=None)
1206
+
1207
+ if self.upsamplers is not None:
1208
+ for upsampler in self.upsamplers:
1209
+ hidden_states = upsampler(hidden_states)
1210
+
1211
+ return hidden_states
1212
+
1213
+
1214
+ class AttnUpDecoderBlock2D(nn.Module):
1215
+ def __init__(
1216
+ self,
1217
+ in_channels: int,
1218
+ out_channels: int,
1219
+ dropout: float = 0.0,
1220
+ num_layers: int = 1,
1221
+ resnet_eps: float = 1e-6,
1222
+ resnet_time_scale_shift: str = "default",
1223
+ resnet_act_fn: str = "swish",
1224
+ resnet_groups: int = 32,
1225
+ resnet_pre_norm: bool = True,
1226
+ attn_num_head_channels=1,
1227
+ output_scale_factor=1.0,
1228
+ add_upsample=True,
1229
+ ):
1230
+ super().__init__()
1231
+ resnets = []
1232
+ attentions = []
1233
+
1234
+ for i in range(num_layers):
1235
+ input_channels = in_channels if i == 0 else out_channels
1236
+
1237
+ resnets.append(
1238
+ ResnetBlock2D(
1239
+ in_channels=input_channels,
1240
+ out_channels=out_channels,
1241
+ temb_channels=None,
1242
+ eps=resnet_eps,
1243
+ groups=resnet_groups,
1244
+ dropout=dropout,
1245
+ time_embedding_norm=resnet_time_scale_shift,
1246
+ non_linearity=resnet_act_fn,
1247
+ output_scale_factor=output_scale_factor,
1248
+ pre_norm=resnet_pre_norm,
1249
+ )
1250
+ )
1251
+ attentions.append(
1252
+ AttentionBlock(
1253
+ out_channels,
1254
+ num_head_channels=attn_num_head_channels,
1255
+ rescale_output_factor=output_scale_factor,
1256
+ eps=resnet_eps,
1257
+ num_groups=resnet_groups,
1258
+ )
1259
+ )
1260
+
1261
+ self.attentions = nn.ModuleList(attentions)
1262
+ self.resnets = nn.ModuleList(resnets)
1263
+
1264
+ if add_upsample:
1265
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1266
+ else:
1267
+ self.upsamplers = None
1268
+
1269
+ def forward(self, hidden_states):
1270
+ for resnet, attn in zip(self.resnets, self.attentions):
1271
+ hidden_states = resnet(hidden_states, temb=None)
1272
+ hidden_states = attn(hidden_states)
1273
+
1274
+ if self.upsamplers is not None:
1275
+ for upsampler in self.upsamplers:
1276
+ hidden_states = upsampler(hidden_states)
1277
+
1278
+ return hidden_states
1279
+
1280
+
1281
+ class AttnSkipUpBlock2D(nn.Module):
1282
+ def __init__(
1283
+ self,
1284
+ in_channels: int,
1285
+ prev_output_channel: int,
1286
+ out_channels: int,
1287
+ temb_channels: int,
1288
+ dropout: float = 0.0,
1289
+ num_layers: int = 1,
1290
+ resnet_eps: float = 1e-6,
1291
+ resnet_time_scale_shift: str = "default",
1292
+ resnet_act_fn: str = "swish",
1293
+ resnet_pre_norm: bool = True,
1294
+ attn_num_head_channels=1,
1295
+ attention_type="default",
1296
+ output_scale_factor=np.sqrt(2.0),
1297
+ upsample_padding=1,
1298
+ add_upsample=True,
1299
+ ):
1300
+ super().__init__()
1301
+ self.attentions = nn.ModuleList([])
1302
+ self.resnets = nn.ModuleList([])
1303
+
1304
+ self.attention_type = attention_type
1305
+
1306
+ for i in range(num_layers):
1307
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1308
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1309
+
1310
+ self.resnets.append(
1311
+ ResnetBlock2D(
1312
+ in_channels=resnet_in_channels + res_skip_channels,
1313
+ out_channels=out_channels,
1314
+ temb_channels=temb_channels,
1315
+ eps=resnet_eps,
1316
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1317
+ groups_out=min(out_channels // 4, 32),
1318
+ dropout=dropout,
1319
+ time_embedding_norm=resnet_time_scale_shift,
1320
+ non_linearity=resnet_act_fn,
1321
+ output_scale_factor=output_scale_factor,
1322
+ pre_norm=resnet_pre_norm,
1323
+ )
1324
+ )
1325
+
1326
+ self.attentions.append(
1327
+ AttentionBlock(
1328
+ out_channels,
1329
+ num_head_channels=attn_num_head_channels,
1330
+ rescale_output_factor=output_scale_factor,
1331
+ eps=resnet_eps,
1332
+ )
1333
+ )
1334
+
1335
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1336
+ if add_upsample:
1337
+ self.resnet_up = ResnetBlock2D(
1338
+ in_channels=out_channels,
1339
+ out_channels=out_channels,
1340
+ temb_channels=temb_channels,
1341
+ eps=resnet_eps,
1342
+ groups=min(out_channels // 4, 32),
1343
+ groups_out=min(out_channels // 4, 32),
1344
+ dropout=dropout,
1345
+ time_embedding_norm=resnet_time_scale_shift,
1346
+ non_linearity=resnet_act_fn,
1347
+ output_scale_factor=output_scale_factor,
1348
+ pre_norm=resnet_pre_norm,
1349
+ use_nin_shortcut=True,
1350
+ up=True,
1351
+ kernel="fir",
1352
+ )
1353
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1354
+ self.skip_norm = torch.nn.GroupNorm(
1355
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1356
+ )
1357
+ self.act = nn.SiLU()
1358
+ else:
1359
+ self.resnet_up = None
1360
+ self.skip_conv = None
1361
+ self.skip_norm = None
1362
+ self.act = None
1363
+
1364
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1365
+ for resnet in self.resnets:
1366
+ # pop res hidden states
1367
+ res_hidden_states = res_hidden_states_tuple[-1]
1368
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1369
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1370
+
1371
+ hidden_states = resnet(hidden_states, temb)
1372
+
1373
+ hidden_states = self.attentions[0](hidden_states)
1374
+
1375
+ if skip_sample is not None:
1376
+ skip_sample = self.upsampler(skip_sample)
1377
+ else:
1378
+ skip_sample = 0
1379
+
1380
+ if self.resnet_up is not None:
1381
+ skip_sample_states = self.skip_norm(hidden_states)
1382
+ skip_sample_states = self.act(skip_sample_states)
1383
+ skip_sample_states = self.skip_conv(skip_sample_states)
1384
+
1385
+ skip_sample = skip_sample + skip_sample_states
1386
+
1387
+ hidden_states = self.resnet_up(hidden_states, temb)
1388
+
1389
+ return hidden_states, skip_sample
1390
+
1391
+
1392
+ class SkipUpBlock2D(nn.Module):
1393
+ def __init__(
1394
+ self,
1395
+ in_channels: int,
1396
+ prev_output_channel: int,
1397
+ out_channels: int,
1398
+ temb_channels: int,
1399
+ dropout: float = 0.0,
1400
+ num_layers: int = 1,
1401
+ resnet_eps: float = 1e-6,
1402
+ resnet_time_scale_shift: str = "default",
1403
+ resnet_act_fn: str = "swish",
1404
+ resnet_pre_norm: bool = True,
1405
+ output_scale_factor=np.sqrt(2.0),
1406
+ add_upsample=True,
1407
+ upsample_padding=1,
1408
+ ):
1409
+ super().__init__()
1410
+ self.resnets = nn.ModuleList([])
1411
+
1412
+ for i in range(num_layers):
1413
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1414
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1415
+
1416
+ self.resnets.append(
1417
+ ResnetBlock2D(
1418
+ in_channels=resnet_in_channels + res_skip_channels,
1419
+ out_channels=out_channels,
1420
+ temb_channels=temb_channels,
1421
+ eps=resnet_eps,
1422
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1423
+ groups_out=min(out_channels // 4, 32),
1424
+ dropout=dropout,
1425
+ time_embedding_norm=resnet_time_scale_shift,
1426
+ non_linearity=resnet_act_fn,
1427
+ output_scale_factor=output_scale_factor,
1428
+ pre_norm=resnet_pre_norm,
1429
+ )
1430
+ )
1431
+
1432
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1433
+ if add_upsample:
1434
+ self.resnet_up = ResnetBlock2D(
1435
+ in_channels=out_channels,
1436
+ out_channels=out_channels,
1437
+ temb_channels=temb_channels,
1438
+ eps=resnet_eps,
1439
+ groups=min(out_channels // 4, 32),
1440
+ groups_out=min(out_channels // 4, 32),
1441
+ dropout=dropout,
1442
+ time_embedding_norm=resnet_time_scale_shift,
1443
+ non_linearity=resnet_act_fn,
1444
+ output_scale_factor=output_scale_factor,
1445
+ pre_norm=resnet_pre_norm,
1446
+ use_nin_shortcut=True,
1447
+ up=True,
1448
+ kernel="fir",
1449
+ )
1450
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1451
+ self.skip_norm = torch.nn.GroupNorm(
1452
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1453
+ )
1454
+ self.act = nn.SiLU()
1455
+ else:
1456
+ self.resnet_up = None
1457
+ self.skip_conv = None
1458
+ self.skip_norm = None
1459
+ self.act = None
1460
+
1461
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1462
+ for resnet in self.resnets:
1463
+ # pop res hidden states
1464
+ res_hidden_states = res_hidden_states_tuple[-1]
1465
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1466
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1467
+
1468
+ hidden_states = resnet(hidden_states, temb)
1469
+
1470
+ if skip_sample is not None:
1471
+ skip_sample = self.upsampler(skip_sample)
1472
+ else:
1473
+ skip_sample = 0
1474
+
1475
+ if self.resnet_up is not None:
1476
+ skip_sample_states = self.skip_norm(hidden_states)
1477
+ skip_sample_states = self.act(skip_sample_states)
1478
+ skip_sample_states = self.skip_conv(skip_sample_states)
1479
+
1480
+ skip_sample = skip_sample + skip_sample_states
1481
+
1482
+ hidden_states = self.resnet_up(hidden_states, temb)
1483
+
1484
+ return hidden_states, skip_sample
diffusers/models/vae.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ..configuration_utils import ConfigMixin, register_to_config
9
+ from ..modeling_utils import ModelMixin
10
+ from ..utils import BaseOutput
11
+ from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
12
+
13
+
14
+ @dataclass
15
+ class DecoderOutput(BaseOutput):
16
+ """
17
+ Output of decoding method.
18
+
19
+ Args:
20
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
21
+ Decoded output sample of the model. Output of the last layer of the model.
22
+ """
23
+
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ @dataclass
28
+ class VQEncoderOutput(BaseOutput):
29
+ """
30
+ Output of VQModel encoding method.
31
+
32
+ Args:
33
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
+ Encoded output sample of the model. Output of the last layer of the model.
35
+ """
36
+
37
+ latents: torch.FloatTensor
38
+
39
+
40
+ @dataclass
41
+ class AutoencoderKLOutput(BaseOutput):
42
+ """
43
+ Output of AutoencoderKL encoding method.
44
+
45
+ Args:
46
+ latent_dist (`DiagonalGaussianDistribution`):
47
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
48
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
49
+ """
50
+
51
+ latent_dist: "DiagonalGaussianDistribution"
52
+
53
+
54
+ class Encoder(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_channels=3,
58
+ out_channels=3,
59
+ down_block_types=("DownEncoderBlock2D",),
60
+ block_out_channels=(64,),
61
+ layers_per_block=2,
62
+ act_fn="silu",
63
+ double_z=True,
64
+ ):
65
+ super().__init__()
66
+ self.layers_per_block = layers_per_block
67
+
68
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
69
+
70
+ self.mid_block = None
71
+ self.down_blocks = nn.ModuleList([])
72
+
73
+ # down
74
+ output_channel = block_out_channels[0]
75
+ for i, down_block_type in enumerate(down_block_types):
76
+ input_channel = output_channel
77
+ output_channel = block_out_channels[i]
78
+ is_final_block = i == len(block_out_channels) - 1
79
+
80
+ down_block = get_down_block(
81
+ down_block_type,
82
+ num_layers=self.layers_per_block,
83
+ in_channels=input_channel,
84
+ out_channels=output_channel,
85
+ add_downsample=not is_final_block,
86
+ resnet_eps=1e-6,
87
+ downsample_padding=0,
88
+ resnet_act_fn=act_fn,
89
+ attn_num_head_channels=None,
90
+ temb_channels=None,
91
+ )
92
+ self.down_blocks.append(down_block)
93
+
94
+ # mid
95
+ self.mid_block = UNetMidBlock2D(
96
+ in_channels=block_out_channels[-1],
97
+ resnet_eps=1e-6,
98
+ resnet_act_fn=act_fn,
99
+ output_scale_factor=1,
100
+ resnet_time_scale_shift="default",
101
+ attn_num_head_channels=None,
102
+ resnet_groups=32,
103
+ temb_channels=None,
104
+ )
105
+
106
+ # out
107
+ num_groups_out = 32
108
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
109
+ self.conv_act = nn.SiLU()
110
+
111
+ conv_out_channels = 2 * out_channels if double_z else out_channels
112
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
113
+
114
+ def forward(self, x):
115
+ sample = x
116
+ sample = self.conv_in(sample)
117
+
118
+ # down
119
+ for down_block in self.down_blocks:
120
+ sample = down_block(sample)
121
+
122
+ # middle
123
+ sample = self.mid_block(sample)
124
+
125
+ # post-process
126
+ sample = self.conv_norm_out(sample)
127
+ sample = self.conv_act(sample)
128
+ sample = self.conv_out(sample)
129
+
130
+ return sample
131
+
132
+
133
+ class Decoder(nn.Module):
134
+ def __init__(
135
+ self,
136
+ in_channels=3,
137
+ out_channels=3,
138
+ up_block_types=("UpDecoderBlock2D",),
139
+ block_out_channels=(64,),
140
+ layers_per_block=2,
141
+ act_fn="silu",
142
+ ):
143
+ super().__init__()
144
+ self.layers_per_block = layers_per_block
145
+
146
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
147
+
148
+ self.mid_block = None
149
+ self.up_blocks = nn.ModuleList([])
150
+
151
+ # mid
152
+ self.mid_block = UNetMidBlock2D(
153
+ in_channels=block_out_channels[-1],
154
+ resnet_eps=1e-6,
155
+ resnet_act_fn=act_fn,
156
+ output_scale_factor=1,
157
+ resnet_time_scale_shift="default",
158
+ attn_num_head_channels=None,
159
+ resnet_groups=32,
160
+ temb_channels=None,
161
+ )
162
+
163
+ # up
164
+ reversed_block_out_channels = list(reversed(block_out_channels))
165
+ output_channel = reversed_block_out_channels[0]
166
+ for i, up_block_type in enumerate(up_block_types):
167
+ prev_output_channel = output_channel
168
+ output_channel = reversed_block_out_channels[i]
169
+
170
+ is_final_block = i == len(block_out_channels) - 1
171
+
172
+ up_block = get_up_block(
173
+ up_block_type,
174
+ num_layers=self.layers_per_block + 1,
175
+ in_channels=prev_output_channel,
176
+ out_channels=output_channel,
177
+ prev_output_channel=None,
178
+ add_upsample=not is_final_block,
179
+ resnet_eps=1e-6,
180
+ resnet_act_fn=act_fn,
181
+ attn_num_head_channels=None,
182
+ temb_channels=None,
183
+ )
184
+ self.up_blocks.append(up_block)
185
+ prev_output_channel = output_channel
186
+
187
+ # out
188
+ num_groups_out = 32
189
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
190
+ self.conv_act = nn.SiLU()
191
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
192
+
193
+ def forward(self, z):
194
+ sample = z
195
+ sample = self.conv_in(sample)
196
+
197
+ # middle
198
+ sample = self.mid_block(sample)
199
+
200
+ # up
201
+ for up_block in self.up_blocks:
202
+ sample = up_block(sample)
203
+
204
+ # post-process
205
+ sample = self.conv_norm_out(sample)
206
+ sample = self.conv_act(sample)
207
+ sample = self.conv_out(sample)
208
+
209
+ return sample
210
+
211
+
212
+ class VectorQuantizer(nn.Module):
213
+ """
214
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
215
+ multiplications and allows for post-hoc remapping of indices.
216
+ """
217
+
218
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
219
+ # backwards compatibility we use the buggy version by default, but you can
220
+ # specify legacy=False to fix it.
221
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
222
+ super().__init__()
223
+ self.n_e = n_e
224
+ self.e_dim = e_dim
225
+ self.beta = beta
226
+ self.legacy = legacy
227
+
228
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
229
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
230
+
231
+ self.remap = remap
232
+ if self.remap is not None:
233
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
234
+ self.re_embed = self.used.shape[0]
235
+ self.unknown_index = unknown_index # "random" or "extra" or integer
236
+ if self.unknown_index == "extra":
237
+ self.unknown_index = self.re_embed
238
+ self.re_embed = self.re_embed + 1
239
+ print(
240
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
+ f"Using {self.unknown_index} for unknown indices."
242
+ )
243
+ else:
244
+ self.re_embed = n_e
245
+
246
+ self.sane_index_shape = sane_index_shape
247
+
248
+ def remap_to_used(self, inds):
249
+ ishape = inds.shape
250
+ assert len(ishape) > 1
251
+ inds = inds.reshape(ishape[0], -1)
252
+ used = self.used.to(inds)
253
+ match = (inds[:, :, None] == used[None, None, ...]).long()
254
+ new = match.argmax(-1)
255
+ unknown = match.sum(2) < 1
256
+ if self.unknown_index == "random":
257
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
258
+ else:
259
+ new[unknown] = self.unknown_index
260
+ return new.reshape(ishape)
261
+
262
+ def unmap_to_all(self, inds):
263
+ ishape = inds.shape
264
+ assert len(ishape) > 1
265
+ inds = inds.reshape(ishape[0], -1)
266
+ used = self.used.to(inds)
267
+ if self.re_embed > self.used.shape[0]: # extra token
268
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
269
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
270
+ return back.reshape(ishape)
271
+
272
+ def forward(self, z):
273
+ # reshape z -> (batch, height, width, channel) and flatten
274
+ z = z.permute(0, 2, 3, 1).contiguous()
275
+ z_flattened = z.view(-1, self.e_dim)
276
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
277
+
278
+ d = (
279
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
280
+ + torch.sum(self.embedding.weight**2, dim=1)
281
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
282
+ )
283
+
284
+ min_encoding_indices = torch.argmin(d, dim=1)
285
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
286
+ perplexity = None
287
+ min_encodings = None
288
+
289
+ # compute loss for embedding
290
+ if not self.legacy:
291
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
292
+ else:
293
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
294
+
295
+ # preserve gradients
296
+ z_q = z + (z_q - z).detach()
297
+
298
+ # reshape back to match original input shape
299
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
300
+
301
+ if self.remap is not None:
302
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
303
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
304
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
305
+
306
+ if self.sane_index_shape:
307
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
308
+
309
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
310
+
311
+ def get_codebook_entry(self, indices, shape):
312
+ # shape specifying (batch, height, width, channel)
313
+ if self.remap is not None:
314
+ indices = indices.reshape(shape[0], -1) # add batch axis
315
+ indices = self.unmap_to_all(indices)
316
+ indices = indices.reshape(-1) # flatten again
317
+
318
+ # get quantized latent vectors
319
+ z_q = self.embedding(indices)
320
+
321
+ if shape is not None:
322
+ z_q = z_q.view(shape)
323
+ # reshape back to match original input shape
324
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
325
+
326
+ return z_q
327
+
328
+
329
+ class DiagonalGaussianDistribution(object):
330
+ def __init__(self, parameters, deterministic=False):
331
+ self.parameters = parameters
332
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
333
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
334
+ self.deterministic = deterministic
335
+ self.std = torch.exp(0.5 * self.logvar)
336
+ self.var = torch.exp(self.logvar)
337
+ if self.deterministic:
338
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
339
+
340
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
341
+ device = self.parameters.device
342
+ sample_device = "cpu" if device.type == "mps" else device
343
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
344
+ x = self.mean + self.std * sample
345
+ return x
346
+
347
+ def kl(self, other=None):
348
+ if self.deterministic:
349
+ return torch.Tensor([0.0])
350
+ else:
351
+ if other is None:
352
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
353
+ else:
354
+ return 0.5 * torch.sum(
355
+ torch.pow(self.mean - other.mean, 2) / other.var
356
+ + self.var / other.var
357
+ - 1.0
358
+ - self.logvar
359
+ + other.logvar,
360
+ dim=[1, 2, 3],
361
+ )
362
+
363
+ def nll(self, sample, dims=[1, 2, 3]):
364
+ if self.deterministic:
365
+ return torch.Tensor([0.0])
366
+ logtwopi = np.log(2.0 * np.pi)
367
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
368
+
369
+ def mode(self):
370
+ return self.mean
371
+
372
+
373
+ class VQModel(ModelMixin, ConfigMixin):
374
+ r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
375
+ Kavukcuoglu.
376
+
377
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
378
+ implements for all the model (such as downloading or saving, etc.)
379
+
380
+ Parameters:
381
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
382
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
383
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
384
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
385
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
386
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
387
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
388
+ obj:`(64,)`): Tuple of block output channels.
389
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
390
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
391
+ sample_size (`int`, *optional*, defaults to `32`): TODO
392
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
393
+ """
394
+
395
+ @register_to_config
396
+ def __init__(
397
+ self,
398
+ in_channels: int = 3,
399
+ out_channels: int = 3,
400
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
401
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
402
+ block_out_channels: Tuple[int] = (64,),
403
+ layers_per_block: int = 1,
404
+ act_fn: str = "silu",
405
+ latent_channels: int = 3,
406
+ sample_size: int = 32,
407
+ num_vq_embeddings: int = 256,
408
+ ):
409
+ super().__init__()
410
+
411
+ # pass init params to Encoder
412
+ self.encoder = Encoder(
413
+ in_channels=in_channels,
414
+ out_channels=latent_channels,
415
+ down_block_types=down_block_types,
416
+ block_out_channels=block_out_channels,
417
+ layers_per_block=layers_per_block,
418
+ act_fn=act_fn,
419
+ double_z=False,
420
+ )
421
+
422
+ self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
423
+ self.quantize = VectorQuantizer(
424
+ num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
425
+ )
426
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
427
+
428
+ # pass init params to Decoder
429
+ self.decoder = Decoder(
430
+ in_channels=latent_channels,
431
+ out_channels=out_channels,
432
+ up_block_types=up_block_types,
433
+ block_out_channels=block_out_channels,
434
+ layers_per_block=layers_per_block,
435
+ act_fn=act_fn,
436
+ )
437
+
438
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
439
+ h = self.encoder(x)
440
+ h = self.quant_conv(h)
441
+
442
+ if not return_dict:
443
+ return (h,)
444
+
445
+ return VQEncoderOutput(latents=h)
446
+
447
+ def decode(
448
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
449
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
450
+ # also go through quantization layer
451
+ if not force_not_quantize:
452
+ quant, emb_loss, info = self.quantize(h)
453
+ else:
454
+ quant = h
455
+ quant = self.post_quant_conv(quant)
456
+ dec = self.decoder(quant)
457
+
458
+ return dec
459
+
460
+ # if not return_dict:
461
+ # return (dec,)
462
+ #
463
+ # return DecoderOutput(sample=dec)
464
+
465
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
466
+ r"""
467
+ Args:
468
+ sample (`torch.FloatTensor`): Input sample.
469
+ return_dict (`bool`, *optional*, defaults to `True`):
470
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
471
+ """
472
+ x = sample
473
+ h = self.encode(x).latents
474
+ dec = self.decode(h).sample
475
+
476
+ if not return_dict:
477
+ return (dec,)
478
+
479
+ return DecoderOutput(sample=dec)
480
+
481
+
482
+ class AutoencoderKL(ModelMixin, ConfigMixin):
483
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
484
+ and Max Welling.
485
+
486
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
487
+ implements for all the model (such as downloading or saving, etc.)
488
+
489
+ Parameters:
490
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
491
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
492
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
493
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
494
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
495
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
496
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
497
+ obj:`(64,)`): Tuple of block output channels.
498
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
499
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
500
+ sample_size (`int`, *optional*, defaults to `32`): TODO
501
+ """
502
+
503
+ @register_to_config
504
+ def __init__(
505
+ self,
506
+ in_channels: int = 3,
507
+ out_channels: int = 3,
508
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
509
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
510
+ block_out_channels: Tuple[int] = (64,),
511
+ layers_per_block: int = 1,
512
+ act_fn: str = "silu",
513
+ latent_channels: int = 4,
514
+ sample_size: int = 32,
515
+ ):
516
+ super().__init__()
517
+
518
+ # pass init params to Encoder
519
+ self.encoder = Encoder(
520
+ in_channels=in_channels,
521
+ out_channels=latent_channels,
522
+ down_block_types=down_block_types,
523
+ block_out_channels=block_out_channels,
524
+ layers_per_block=layers_per_block,
525
+ act_fn=act_fn,
526
+ double_z=True,
527
+ )
528
+
529
+ # pass init params to Decoder
530
+ self.decoder = Decoder(
531
+ in_channels=latent_channels,
532
+ out_channels=out_channels,
533
+ up_block_types=up_block_types,
534
+ block_out_channels=block_out_channels,
535
+ layers_per_block=layers_per_block,
536
+ act_fn=act_fn,
537
+ )
538
+
539
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
540
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
541
+
542
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
543
+ h = self.encoder(x)
544
+ moments = self.quant_conv(h)
545
+ posterior = DiagonalGaussianDistribution(moments)
546
+
547
+ if not return_dict:
548
+ return (posterior,)
549
+
550
+ return AutoencoderKLOutput(latent_dist=posterior)
551
+
552
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
553
+ z = self.post_quant_conv(z)
554
+ dec = self.decoder(z)
555
+
556
+ return dec
557
+ #
558
+ # if not return_dict:
559
+ # return (dec,)
560
+ #
561
+ # return DecoderOutput(sample=dec)
562
+
563
+ def forward(
564
+ self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
565
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
566
+ r"""
567
+ Args:
568
+ sample (`torch.FloatTensor`): Input sample.
569
+ sample_posterior (`bool`, *optional*, defaults to `False`):
570
+ Whether to sample from the posterior.
571
+ return_dict (`bool`, *optional*, defaults to `True`):
572
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
573
+ """
574
+ x = sample
575
+ posterior = self.encode(x).latent_dist
576
+ if sample_posterior:
577
+ z = posterior.sample()
578
+ else:
579
+ z = posterior.mode()
580
+ dec = self.decode(z).sample
581
+
582
+ if not return_dict:
583
+ return (dec,)
584
+
585
+ return DecoderOutput(sample=dec)
diffusers/onnx_utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import os
19
+ import shutil
20
+ from pathlib import Path
21
+ from typing import Optional, Union
22
+
23
+ import numpy as np
24
+
25
+ from huggingface_hub import hf_hub_download
26
+
27
+ from .utils import is_onnx_available, logging
28
+
29
+
30
+ if is_onnx_available():
31
+ import onnxruntime as ort
32
+
33
+
34
+ ONNX_WEIGHTS_NAME = "model.onnx"
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class OnnxRuntimeModel:
41
+ base_model_prefix = "onnx_model"
42
+
43
+ def __init__(self, model=None, **kwargs):
44
+ logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
45
+ self.model = model
46
+ self.model_save_dir = kwargs.get("model_save_dir", None)
47
+ self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
48
+
49
+ def __call__(self, **kwargs):
50
+ inputs = {k: np.array(v) for k, v in kwargs.items()}
51
+ return self.model.run(None, inputs)
52
+
53
+ @staticmethod
54
+ def load_model(path: Union[str, Path], provider=None):
55
+ """
56
+ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
57
+
58
+ Arguments:
59
+ path (`str` or `Path`):
60
+ Directory from which to load
61
+ provider(`str`, *optional*):
62
+ Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
63
+ """
64
+ if provider is None:
65
+ logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
66
+ provider = "CPUExecutionProvider"
67
+
68
+ return ort.InferenceSession(path, providers=[provider])
69
+
70
+ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
71
+ """
72
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
73
+ [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
74
+ latest_model_name.
75
+
76
+ Arguments:
77
+ save_directory (`str` or `Path`):
78
+ Directory where to save the model file.
79
+ file_name(`str`, *optional*):
80
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
81
+ model with a different name.
82
+ """
83
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
84
+
85
+ src_path = self.model_save_dir.joinpath(self.latest_model_name)
86
+ dst_path = Path(save_directory).joinpath(model_file_name)
87
+ if not src_path.samefile(dst_path):
88
+ shutil.copyfile(src_path, dst_path)
89
+
90
+ def save_pretrained(
91
+ self,
92
+ save_directory: Union[str, os.PathLike],
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
97
+ method.:
98
+
99
+ Arguments:
100
+ save_directory (`str` or `os.PathLike`):
101
+ Directory to which to save. Will be created if it doesn't exist.
102
+ """
103
+ if os.path.isfile(save_directory):
104
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
105
+ return
106
+
107
+ os.makedirs(save_directory, exist_ok=True)
108
+
109
+ # saving model weights/files
110
+ self._save_pretrained(save_directory, **kwargs)
111
+
112
+ @classmethod
113
+ def _from_pretrained(
114
+ cls,
115
+ model_id: Union[str, Path],
116
+ use_auth_token: Optional[Union[bool, str, None]] = None,
117
+ revision: Optional[Union[str, None]] = None,
118
+ force_download: bool = False,
119
+ cache_dir: Optional[str] = None,
120
+ file_name: Optional[str] = None,
121
+ provider: Optional[str] = None,
122
+ **kwargs,
123
+ ):
124
+ """
125
+ Load a model from a directory or the HF Hub.
126
+
127
+ Arguments:
128
+ model_id (`str` or `Path`):
129
+ Directory from which to load
130
+ use_auth_token (`str` or `bool`):
131
+ Is needed to load models from a private or gated repository
132
+ revision (`str`):
133
+ Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
134
+ cache_dir (`Union[str, Path]`, *optional*):
135
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
136
+ standard cache should not be used.
137
+ force_download (`bool`, *optional*, defaults to `False`):
138
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
139
+ cached versions if they exist.
140
+ file_name(`str`):
141
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
142
+ different model files from the same repository or directory.
143
+ provider(`str`):
144
+ The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
145
+ kwargs (`Dict`, *optional*):
146
+ kwargs will be passed to the model during initialization
147
+ """
148
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
149
+ # load model from local directory
150
+ if os.path.isdir(model_id):
151
+ model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
152
+ kwargs["model_save_dir"] = Path(model_id)
153
+ # load model from hub
154
+ else:
155
+ # download model
156
+ model_cache_path = hf_hub_download(
157
+ repo_id=model_id,
158
+ filename=model_file_name,
159
+ use_auth_token=use_auth_token,
160
+ revision=revision,
161
+ cache_dir=cache_dir,
162
+ force_download=force_download,
163
+ )
164
+ kwargs["model_save_dir"] = Path(model_cache_path).parent
165
+ kwargs["latest_model_name"] = Path(model_cache_path).name
166
+ model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
167
+ return cls(model=model, **kwargs)
168
+
169
+ @classmethod
170
+ def from_pretrained(
171
+ cls,
172
+ model_id: Union[str, Path],
173
+ force_download: bool = True,
174
+ use_auth_token: Optional[str] = None,
175
+ cache_dir: Optional[str] = None,
176
+ **model_kwargs,
177
+ ):
178
+ revision = None
179
+ if len(str(model_id).split("@")) == 2:
180
+ model_id, revision = model_id.split("@")
181
+
182
+ return cls._from_pretrained(
183
+ model_id=model_id,
184
+ revision=revision,
185
+ cache_dir=cache_dir,
186
+ force_download=force_download,
187
+ use_auth_token=use_auth_token,
188
+ **model_kwargs,
189
+ )
diffusers/optimization.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch optimization for diffusion models."""
16
+
17
+ import math
18
+ from enum import Enum
19
+ from typing import Optional, Union
20
+
21
+ from torch.optim import Optimizer
22
+ from torch.optim.lr_scheduler import LambdaLR
23
+
24
+ from .utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SchedulerType(Enum):
31
+ LINEAR = "linear"
32
+ COSINE = "cosine"
33
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
34
+ POLYNOMIAL = "polynomial"
35
+ CONSTANT = "constant"
36
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
37
+
38
+
39
+ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
40
+ """
41
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
42
+
43
+ Args:
44
+ optimizer ([`~torch.optim.Optimizer`]):
45
+ The optimizer for which to schedule the learning rate.
46
+ last_epoch (`int`, *optional*, defaults to -1):
47
+ The index of the last epoch when resuming training.
48
+
49
+ Return:
50
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
51
+ """
52
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
53
+
54
+
55
+ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
56
+ """
57
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
58
+ increases linearly between 0 and the initial lr set in the optimizer.
59
+
60
+ Args:
61
+ optimizer ([`~torch.optim.Optimizer`]):
62
+ The optimizer for which to schedule the learning rate.
63
+ num_warmup_steps (`int`):
64
+ The number of steps for the warmup phase.
65
+ last_epoch (`int`, *optional*, defaults to -1):
66
+ The index of the last epoch when resuming training.
67
+
68
+ Return:
69
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
70
+ """
71
+
72
+ def lr_lambda(current_step: int):
73
+ if current_step < num_warmup_steps:
74
+ return float(current_step) / float(max(1.0, num_warmup_steps))
75
+ return 1.0
76
+
77
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
78
+
79
+
80
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
81
+ """
82
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
83
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
84
+
85
+ Args:
86
+ optimizer ([`~torch.optim.Optimizer`]):
87
+ The optimizer for which to schedule the learning rate.
88
+ num_warmup_steps (`int`):
89
+ The number of steps for the warmup phase.
90
+ num_training_steps (`int`):
91
+ The total number of training steps.
92
+ last_epoch (`int`, *optional*, defaults to -1):
93
+ The index of the last epoch when resuming training.
94
+
95
+ Return:
96
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97
+ """
98
+
99
+ def lr_lambda(current_step: int):
100
+ if current_step < num_warmup_steps:
101
+ return float(current_step) / float(max(1, num_warmup_steps))
102
+ return max(
103
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
104
+ )
105
+
106
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
107
+
108
+
109
+ def get_cosine_schedule_with_warmup(
110
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
111
+ ):
112
+ """
113
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
114
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
115
+ initial lr set in the optimizer.
116
+
117
+ Args:
118
+ optimizer ([`~torch.optim.Optimizer`]):
119
+ The optimizer for which to schedule the learning rate.
120
+ num_warmup_steps (`int`):
121
+ The number of steps for the warmup phase.
122
+ num_training_steps (`int`):
123
+ The total number of training steps.
124
+ num_cycles (`float`, *optional*, defaults to 0.5):
125
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
126
+ following a half-cosine).
127
+ last_epoch (`int`, *optional*, defaults to -1):
128
+ The index of the last epoch when resuming training.
129
+
130
+ Return:
131
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
132
+ """
133
+
134
+ def lr_lambda(current_step):
135
+ if current_step < num_warmup_steps:
136
+ return float(current_step) / float(max(1, num_warmup_steps))
137
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
138
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
139
+
140
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
141
+
142
+
143
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
144
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
145
+ ):
146
+ """
147
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
148
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
149
+ linearly between 0 and the initial lr set in the optimizer.
150
+
151
+ Args:
152
+ optimizer ([`~torch.optim.Optimizer`]):
153
+ The optimizer for which to schedule the learning rate.
154
+ num_warmup_steps (`int`):
155
+ The number of steps for the warmup phase.
156
+ num_training_steps (`int`):
157
+ The total number of training steps.
158
+ num_cycles (`int`, *optional*, defaults to 1):
159
+ The number of hard restarts to use.
160
+ last_epoch (`int`, *optional*, defaults to -1):
161
+ The index of the last epoch when resuming training.
162
+
163
+ Return:
164
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
165
+ """
166
+
167
+ def lr_lambda(current_step):
168
+ if current_step < num_warmup_steps:
169
+ return float(current_step) / float(max(1, num_warmup_steps))
170
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
171
+ if progress >= 1.0:
172
+ return 0.0
173
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
174
+
175
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
176
+
177
+
178
+ def get_polynomial_decay_schedule_with_warmup(
179
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
180
+ ):
181
+ """
182
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
183
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
184
+ initial lr set in the optimizer.
185
+
186
+ Args:
187
+ optimizer ([`~torch.optim.Optimizer`]):
188
+ The optimizer for which to schedule the learning rate.
189
+ num_warmup_steps (`int`):
190
+ The number of steps for the warmup phase.
191
+ num_training_steps (`int`):
192
+ The total number of training steps.
193
+ lr_end (`float`, *optional*, defaults to 1e-7):
194
+ The end LR.
195
+ power (`float`, *optional*, defaults to 1.0):
196
+ Power factor.
197
+ last_epoch (`int`, *optional*, defaults to -1):
198
+ The index of the last epoch when resuming training.
199
+
200
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
201
+ implementation at
202
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
203
+
204
+ Return:
205
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
206
+
207
+ """
208
+
209
+ lr_init = optimizer.defaults["lr"]
210
+ if not (lr_init > lr_end):
211
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
212
+
213
+ def lr_lambda(current_step: int):
214
+ if current_step < num_warmup_steps:
215
+ return float(current_step) / float(max(1, num_warmup_steps))
216
+ elif current_step > num_training_steps:
217
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
218
+ else:
219
+ lr_range = lr_init - lr_end
220
+ decay_steps = num_training_steps - num_warmup_steps
221
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
222
+ decay = lr_range * pct_remaining**power + lr_end
223
+ return decay / lr_init # as LambdaLR multiplies by lr_init
224
+
225
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
226
+
227
+
228
+ TYPE_TO_SCHEDULER_FUNCTION = {
229
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
230
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
231
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
232
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
233
+ SchedulerType.CONSTANT: get_constant_schedule,
234
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
235
+ }
236
+
237
+
238
+ def get_scheduler(
239
+ name: Union[str, SchedulerType],
240
+ optimizer: Optimizer,
241
+ num_warmup_steps: Optional[int] = None,
242
+ num_training_steps: Optional[int] = None,
243
+ ):
244
+ """
245
+ Unified API to get any scheduler from its name.
246
+
247
+ Args:
248
+ name (`str` or `SchedulerType`):
249
+ The name of the scheduler to use.
250
+ optimizer (`torch.optim.Optimizer`):
251
+ The optimizer that will be used during training.
252
+ num_warmup_steps (`int`, *optional*):
253
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
254
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
255
+ num_training_steps (`int``, *optional*):
256
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
257
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
258
+ """
259
+ name = SchedulerType(name)
260
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
261
+ if name == SchedulerType.CONSTANT:
262
+ return schedule_func(optimizer)
263
+
264
+ # All other schedulers require `num_warmup_steps`
265
+ if num_warmup_steps is None:
266
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
267
+
268
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
269
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
270
+
271
+ # All other schedulers require `num_training_steps`
272
+ if num_training_steps is None:
273
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
274
+
275
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
diffusers/pipeline_utils.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import importlib
18
+ import inspect
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ import diffusers
27
+ import PIL
28
+ from huggingface_hub import snapshot_download
29
+ from PIL import Image
30
+ from tqdm.auto import tqdm
31
+
32
+ from .configuration_utils import ConfigMixin
33
+ from .utils import DIFFUSERS_CACHE, BaseOutput, logging
34
+
35
+
36
+ INDEX_FILE = "diffusion_pytorch_model.bin"
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ LOADABLE_CLASSES = {
43
+ "diffusers": {
44
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
45
+ "SchedulerMixin": ["save_config", "from_config"],
46
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
47
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
48
+ },
49
+ "transformers": {
50
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
51
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
52
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
53
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
54
+ },
55
+ }
56
+
57
+ ALL_IMPORTABLE_CLASSES = {}
58
+ for library in LOADABLE_CLASSES:
59
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
60
+
61
+
62
+ @dataclass
63
+ class ImagePipelineOutput(BaseOutput):
64
+ """
65
+ Output class for image pipelines.
66
+
67
+ Args:
68
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
69
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
70
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
71
+ """
72
+
73
+ images: Union[List[PIL.Image.Image], np.ndarray]
74
+
75
+
76
+ class DiffusionPipeline(ConfigMixin):
77
+ r"""
78
+ Base class for all models.
79
+
80
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
81
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
82
+
83
+ - move all PyTorch modules to the device of your choice
84
+ - enabling/disabling the progress bar for the denoising iteration
85
+
86
+ Class attributes:
87
+
88
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
89
+ compenents of the diffusion pipeline.
90
+ """
91
+ config_name = "model_index.json"
92
+
93
+ def register_modules(self, **kwargs):
94
+ # import it here to avoid circular import
95
+ from diffusers import pipelines
96
+
97
+ for name, module in kwargs.items():
98
+ # retrive library
99
+ library = module.__module__.split(".")[0]
100
+
101
+ # check if the module is a pipeline module
102
+ pipeline_dir = module.__module__.split(".")[-2]
103
+ path = module.__module__.split(".")
104
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
105
+
106
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
107
+ # Or if it's a pipeline module, then the module is inside the pipeline
108
+ # folder so we set the library to module name.
109
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
110
+ library = pipeline_dir
111
+
112
+ # retrive class_name
113
+ class_name = module.__class__.__name__
114
+
115
+ register_dict = {name: (library, class_name)}
116
+
117
+ # save model index config
118
+ self.register_to_config(**register_dict)
119
+
120
+ # set models
121
+ setattr(self, name, module)
122
+
123
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
124
+ """
125
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
126
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
127
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
128
+
129
+ Arguments:
130
+ save_directory (`str` or `os.PathLike`):
131
+ Directory to which to save. Will be created if it doesn't exist.
132
+ """
133
+ self.save_config(save_directory)
134
+
135
+ model_index_dict = dict(self.config)
136
+ model_index_dict.pop("_class_name")
137
+ model_index_dict.pop("_diffusers_version")
138
+ model_index_dict.pop("_module", None)
139
+
140
+ for pipeline_component_name in model_index_dict.keys():
141
+ sub_model = getattr(self, pipeline_component_name)
142
+ model_cls = sub_model.__class__
143
+
144
+ save_method_name = None
145
+ # search for the model's base class in LOADABLE_CLASSES
146
+ for library_name, library_classes in LOADABLE_CLASSES.items():
147
+ library = importlib.import_module(library_name)
148
+ for base_class, save_load_methods in library_classes.items():
149
+ class_candidate = getattr(library, base_class)
150
+ if issubclass(model_cls, class_candidate):
151
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
152
+ save_method_name = save_load_methods[0]
153
+ break
154
+ if save_method_name is not None:
155
+ break
156
+
157
+ save_method = getattr(sub_model, save_method_name)
158
+ save_method(os.path.join(save_directory, pipeline_component_name))
159
+
160
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
161
+ if torch_device is None:
162
+ return self
163
+
164
+ module_names, _ = self.extract_init_dict(dict(self.config))
165
+ for name in module_names.keys():
166
+ module = getattr(self, name)
167
+ if isinstance(module, torch.nn.Module):
168
+ module.to(torch_device)
169
+ return self
170
+
171
+ @property
172
+ def device(self) -> torch.device:
173
+ r"""
174
+ Returns:
175
+ `torch.device`: The torch device on which the pipeline is located.
176
+ """
177
+ module_names, _ = self.extract_init_dict(dict(self.config))
178
+ for name in module_names.keys():
179
+ module = getattr(self, name)
180
+ if isinstance(module, torch.nn.Module):
181
+ return module.device
182
+ return torch.device("cpu")
183
+
184
+ @classmethod
185
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
186
+ r"""
187
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
188
+
189
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
190
+
191
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
192
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
193
+ task.
194
+
195
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
196
+ weights are discarded.
197
+
198
+ Parameters:
199
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
200
+ Can be either:
201
+
202
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
203
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
204
+ `CompVis/ldm-text2im-large-256`.
205
+ - A path to a *directory* containing pipeline weights saved using
206
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
207
+ torch_dtype (`str` or `torch.dtype`, *optional*):
208
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
209
+ will be automatically derived from the model's weights.
210
+ force_download (`bool`, *optional*, defaults to `False`):
211
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
212
+ cached versions if they exist.
213
+ resume_download (`bool`, *optional*, defaults to `False`):
214
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
215
+ file exists.
216
+ proxies (`Dict[str, str]`, *optional*):
217
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
218
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
219
+ output_loading_info(`bool`, *optional*, defaults to `False`):
220
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
221
+ local_files_only(`bool`, *optional*, defaults to `False`):
222
+ Whether or not to only look at local files (i.e., do not try to download the model).
223
+ use_auth_token (`str` or *bool*, *optional*):
224
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
225
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
226
+ revision (`str`, *optional*, defaults to `"main"`):
227
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
228
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
229
+ identifier allowed by git.
230
+ mirror (`str`, *optional*):
231
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
232
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
233
+ Please refer to the mirror site for more information. specify the folder name here.
234
+
235
+ kwargs (remaining dictionary of keyword arguments, *optional*):
236
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
237
+ speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__`
238
+ method. See example below for more information.
239
+
240
+ <Tip>
241
+
242
+ Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
243
+ `"CompVis/stable-diffusion-v1-4"`
244
+
245
+ </Tip>
246
+
247
+ <Tip>
248
+
249
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
250
+ this method in a firewalled environment.
251
+
252
+ </Tip>
253
+
254
+ Examples:
255
+
256
+ ```py
257
+ >>> from diffusers import DiffusionPipeline
258
+
259
+ >>> # Download pipeline from huggingface.co and cache.
260
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
261
+
262
+ >>> # Download pipeline that requires an authorization token
263
+ >>> # For more information on access tokens, please refer to this section
264
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
265
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
266
+
267
+ >>> # Download pipeline, but overwrite scheduler
268
+ >>> from diffusers import LMSDiscreteScheduler
269
+
270
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
271
+ >>> pipeline = DiffusionPipeline.from_pretrained(
272
+ ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
273
+ ... )
274
+ ```
275
+ """
276
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
277
+ resume_download = kwargs.pop("resume_download", False)
278
+ proxies = kwargs.pop("proxies", None)
279
+ local_files_only = kwargs.pop("local_files_only", False)
280
+ use_auth_token = kwargs.pop("use_auth_token", None)
281
+ revision = kwargs.pop("revision", None)
282
+ torch_dtype = kwargs.pop("torch_dtype", None)
283
+ provider = kwargs.pop("provider", None)
284
+
285
+ # 1. Download the checkpoints and configs
286
+ # use snapshot download here to get it working from from_pretrained
287
+ if not os.path.isdir(pretrained_model_name_or_path):
288
+ cached_folder = snapshot_download(
289
+ pretrained_model_name_or_path,
290
+ cache_dir=cache_dir,
291
+ resume_download=resume_download,
292
+ proxies=proxies,
293
+ local_files_only=local_files_only,
294
+ use_auth_token=use_auth_token,
295
+ revision=revision,
296
+ )
297
+ else:
298
+ cached_folder = pretrained_model_name_or_path
299
+
300
+ config_dict = cls.get_config_dict(cached_folder)
301
+
302
+ # 2. Load the pipeline class, if using custom module then load it from the hub
303
+ # if we load from explicit class, let's use it
304
+ if cls != DiffusionPipeline:
305
+ pipeline_class = cls
306
+ else:
307
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
308
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
309
+
310
+ # some modules can be passed directly to the init
311
+ # in this case they are already instantiated in `kwargs`
312
+ # extract them here
313
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
314
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
315
+
316
+ init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
317
+
318
+ init_kwargs = {}
319
+
320
+ # import it here to avoid circular import
321
+ from diffusers import pipelines
322
+
323
+ # 3. Load each module in the pipeline
324
+ for name, (library_name, class_name) in init_dict.items():
325
+ is_pipeline_module = hasattr(pipelines, library_name)
326
+ loaded_sub_model = None
327
+
328
+ # if the model is in a pipeline module, then we load it from the pipeline
329
+ if name in passed_class_obj:
330
+ # 1. check that passed_class_obj has correct parent class
331
+ if not is_pipeline_module:
332
+ library = importlib.import_module(library_name)
333
+ class_obj = getattr(library, class_name)
334
+ importable_classes = LOADABLE_CLASSES[library_name]
335
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
336
+
337
+ expected_class_obj = None
338
+ for class_name, class_candidate in class_candidates.items():
339
+ if issubclass(class_obj, class_candidate):
340
+ expected_class_obj = class_candidate
341
+
342
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
343
+ raise ValueError(
344
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
345
+ f" {expected_class_obj}"
346
+ )
347
+ else:
348
+ logger.warn(
349
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
350
+ " has the correct type"
351
+ )
352
+
353
+ # set passed class object
354
+ loaded_sub_model = passed_class_obj[name]
355
+ elif is_pipeline_module:
356
+ pipeline_module = getattr(pipelines, library_name)
357
+ class_obj = getattr(pipeline_module, class_name)
358
+ importable_classes = ALL_IMPORTABLE_CLASSES
359
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
360
+ else:
361
+ # else we just import it from the library.
362
+ library = importlib.import_module(library_name)
363
+ class_obj = getattr(library, class_name)
364
+ importable_classes = LOADABLE_CLASSES[library_name]
365
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
366
+
367
+ if loaded_sub_model is None:
368
+ load_method_name = None
369
+ for class_name, class_candidate in class_candidates.items():
370
+ if issubclass(class_obj, class_candidate):
371
+ load_method_name = importable_classes[class_name][1]
372
+
373
+ load_method = getattr(class_obj, load_method_name)
374
+
375
+ loading_kwargs = {}
376
+ if issubclass(class_obj, torch.nn.Module):
377
+ loading_kwargs["torch_dtype"] = torch_dtype
378
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
379
+ loading_kwargs["provider"] = provider
380
+
381
+ # check if the module is in a subdirectory
382
+ if os.path.isdir(os.path.join(cached_folder, name)):
383
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
384
+ else:
385
+ # else load from the root directory
386
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
387
+
388
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
389
+
390
+ # 4. Instantiate the pipeline
391
+ model = pipeline_class(**init_kwargs)
392
+ return model
393
+
394
+ @staticmethod
395
+ def numpy_to_pil(images):
396
+ """
397
+ Convert a numpy image or a batch of images to a PIL image.
398
+ """
399
+ if images.ndim == 3:
400
+ images = images[None, ...]
401
+ images = (images * 255).round().astype("uint8")
402
+ pil_images = [Image.fromarray(image) for image in images]
403
+
404
+ return pil_images
405
+
406
+ def progress_bar(self, iterable):
407
+ if not hasattr(self, "_progress_bar_config"):
408
+ self._progress_bar_config = {}
409
+ elif not isinstance(self._progress_bar_config, dict):
410
+ raise ValueError(
411
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
412
+ )
413
+
414
+ return tqdm(iterable, **self._progress_bar_config)
415
+
416
+ def set_progress_bar_config(self, **kwargs):
417
+ self._progress_bar_config = kwargs
diffusers/pipelines/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import is_onnx_available, is_transformers_available
2
+ from .ddim import DDIMPipeline
3
+ from .ddpm import DDPMPipeline
4
+ from .latent_diffusion_uncond import LDMPipeline
5
+ from .pndm import PNDMPipeline
6
+ from .score_sde_ve import ScoreSdeVePipeline
7
+ from .stochastic_karras_ve import KarrasVePipeline
8
+
9
+
10
+ if is_transformers_available():
11
+ from .latent_diffusion import LDMTextToImagePipeline
12
+ from .stable_diffusion import (
13
+ StableDiffusionImg2ImgPipeline,
14
+ StableDiffusionInpaintPipeline,
15
+ StableDiffusionPipeline,
16
+ )
17
+
18
+ if is_transformers_available() and is_onnx_available():
19
+ from .stable_diffusion import StableDiffusionOnnxPipeline
diffusers/pipelines/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (829 Bytes). View file
 
diffusers/pipelines/ddim/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # flake8: noqa
2
+ from .pipeline_ddim import DDIMPipeline