Borys Tymchenko commited on
Commit
ae2e28c
1 Parent(s): 1545cf6

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Created by https://www.toptal.com/developers/gitignore/api/linux,macos,python
3
+ # Edit at https://www.toptal.com/developers/gitignore?templates=linux,macos,python
4
+
5
+ ### Linux ###
6
+ *~
7
+
8
+ # temporary files which can be created if a process still has a handle open of a deleted file
9
+ .fuse_hidden*
10
+
11
+ # KDE directory preferences
12
+ .directory
13
+
14
+ # Linux trash folder which might appear on any partition or disk
15
+ .Trash-*
16
+
17
+ # .nfs files are created when an open file is removed but is still being accessed
18
+ .nfs*
19
+
20
+ ### macOS ###
21
+ # General
22
+ .DS_Store
23
+ .AppleDouble
24
+ .LSOverride
25
+
26
+ # Icon must end with two \r
27
+ Icon
28
+
29
+
30
+ # Thumbnails
31
+ ._*
32
+
33
+ # Files that might appear in the root of a volume
34
+ .DocumentRevisions-V100
35
+ .fseventsd
36
+ .Spotlight-V100
37
+ .TemporaryItems
38
+ .Trashes
39
+ .VolumeIcon.icns
40
+ .com.apple.timemachine.donotpresent
41
+
42
+ # Directories potentially created on remote AFP share
43
+ .AppleDB
44
+ .AppleDesktop
45
+ Network Trash Folder
46
+ Temporary Items
47
+ .apdisk
48
+
49
+ ### Python ###
50
+ # Byte-compiled / optimized / DLL files
51
+ __pycache__/
52
+ *.py[cod]
53
+ *$py.class
54
+
55
+ # C extensions
56
+ *.so
57
+
58
+ # Distribution / packaging
59
+ .Python
60
+ */build/
61
+ develop-eggs/
62
+ dist/
63
+ downloads/
64
+ eggs/
65
+ .eggs/
66
+ lib/
67
+ lib64/
68
+ parts/
69
+ sdist/
70
+ var/
71
+ wheels/
72
+ share/python-wheels/
73
+ *.egg-info/
74
+ .installed.cfg
75
+ *.egg
76
+ MANIFEST
77
+
78
+ # PyInstaller
79
+ # Usually these files are written by a python script from a template
80
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
81
+ *.manifest
82
+ *.spec
83
+
84
+ # Installer logs
85
+ pip-log.txt
86
+ pip-delete-this-directory.txt
87
+
88
+ # Unit test / coverage reports
89
+ htmlcov/
90
+ .tox/
91
+ .nox/
92
+ .coverage
93
+ .coverage.*
94
+ .cache
95
+ nosetests.xml
96
+ coverage.xml
97
+ *.cover
98
+ *.py,cover
99
+ .hypothesis/
100
+ .pytest_cache/
101
+ cover/
102
+
103
+ # Translations
104
+ *.mo
105
+ *.pot
106
+
107
+ # Django stuff:
108
+ *.log
109
+ local_settings.py
110
+ db.sqlite3
111
+ db.sqlite3-journal
112
+
113
+ # Flask stuff:
114
+ instance/
115
+ .webassets-cache
116
+
117
+ # Scrapy stuff:
118
+ .scrapy
119
+
120
+ # Sphinx documentation
121
+ documentation/_build/
122
+ documentation/build/
123
+
124
+ # PyBuilder
125
+ .pybuilder/
126
+ target/
127
+
128
+ # Jupyter Notebook
129
+ .ipynb_checkpoints
130
+
131
+ # IPython
132
+ profile_default/
133
+ ipython_config.py
134
+
135
+ # pyenv
136
+ # For a library or package, you might want to ignore these files since the code is
137
+ # intended to run in multiple environments; otherwise, check them in:
138
+ # .python-version
139
+
140
+ # pipenv
141
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
142
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
143
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
144
+ # install all needed dependencies.
145
+ #Pipfile.lock
146
+
147
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
148
+ __pypackages__/
149
+
150
+ # Celery stuff
151
+ celerybeat-schedule
152
+ celerybeat.pid
153
+
154
+ # SageMath parsed files
155
+ *.sage.py
156
+
157
+ # Environments
158
+ .env
159
+ .venv
160
+ env/
161
+ venv/
162
+ ENV/
163
+ env.bak/
164
+ venv.bak/
165
+
166
+ # Spyder project settings
167
+ .spyderproject
168
+ .spyproject
169
+
170
+ # Rope project settings
171
+ .ropeproject
172
+
173
+ # mkdocs documentation
174
+ /site
175
+
176
+ # mypy
177
+ .mypy_cache/
178
+ .dmypy.json
179
+ dmypy.json
180
+
181
+ # Pyre type checker
182
+ .pyre/
183
+
184
+ # pytype static type analyzer
185
+ .pytype/
186
+
187
+ # Cython debug symbols
188
+ cython_debug/
189
+
190
+ .idea
191
+ docs/.doctrees
192
+ # End of https://www.toptal.com/developers/gitignore/api/linux,macos,python
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
flexible_unet/config.json ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlexibleUNet2DConditionModel",
3
+ "_diffusers_version": "0.23.0",
4
+ "_name_or_path": "/home/borys.tymchenko/qcomdiffusion/checkpoint-286000-2050048000/pipeline/unet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "configurations": {
21
+ "add_downsample": [
22
+ true,
23
+ true,
24
+ false
25
+ ],
26
+ "add_upsample": [
27
+ true,
28
+ true,
29
+ false
30
+ ],
31
+ "add_upsample_mid_block": null,
32
+ "cross_attention_dim": 768,
33
+ "down_blocks_in_channels": [
34
+ 320,
35
+ 320,
36
+ 640
37
+ ],
38
+ "down_blocks_num_attentions": [
39
+ 0,
40
+ 1,
41
+ 3
42
+ ],
43
+ "down_blocks_num_resnets": [
44
+ 2,
45
+ 2,
46
+ 1
47
+ ],
48
+ "down_blocks_out_channels": [
49
+ 320,
50
+ 640,
51
+ 1280
52
+ ],
53
+ "mid_num_attentions": 0,
54
+ "mid_num_resnets": 0,
55
+ "mix_block_in_forward": true,
56
+ "num_attention_heads": 8,
57
+ "prev_output_channels": [
58
+ 1280,
59
+ 1280,
60
+ 640
61
+ ],
62
+ "resnet_act_fn": "silu",
63
+ "resnet_eps": 1e-05,
64
+ "sample_size": 64,
65
+ "temb_dim": 1280,
66
+ "up_blocks_num_attentions": [
67
+ 5,
68
+ 3,
69
+ 0
70
+ ],
71
+ "up_blocks_num_resnets": [
72
+ 2,
73
+ 3,
74
+ 3
75
+ ]
76
+ },
77
+ "conv_in_kernel": 3,
78
+ "conv_out_kernel": 3,
79
+ "cross_attention_dim": 768,
80
+ "cross_attention_norm": null,
81
+ "down_block_types": [
82
+ "CrossAttnDownBlock2D",
83
+ "CrossAttnDownBlock2D",
84
+ "CrossAttnDownBlock2D",
85
+ "DownBlock2D"
86
+ ],
87
+ "downsample_padding": 1,
88
+ "dropout": 0.0,
89
+ "dual_cross_attention": false,
90
+ "encoder_hid_dim": null,
91
+ "encoder_hid_dim_type": null,
92
+ "flip_sin_to_cos": true,
93
+ "freq_shift": 0,
94
+ "in_channels": 4,
95
+ "layers_per_block": 2,
96
+ "mid_block_only_cross_attention": null,
97
+ "mid_block_scale_factor": 1,
98
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
99
+ "norm_eps": 1e-05,
100
+ "norm_num_groups": 32,
101
+ "num_attention_heads": null,
102
+ "num_class_embeds": null,
103
+ "only_cross_attention": false,
104
+ "out_channels": 4,
105
+ "projection_class_embeddings_input_dim": null,
106
+ "resnet_out_scale_factor": 1.0,
107
+ "resnet_skip_time_act": false,
108
+ "resnet_time_scale_shift": "default",
109
+ "reverse_transformer_layers_per_block": null,
110
+ "sample_size": 64,
111
+ "time_cond_proj_dim": null,
112
+ "time_embedding_act_fn": null,
113
+ "time_embedding_dim": null,
114
+ "time_embedding_type": "positional",
115
+ "timestep_post_act": null,
116
+ "transformer_layers_per_block": 1,
117
+ "up_block_types": [
118
+ "UpBlock2D",
119
+ "CrossAttnUpBlock2D",
120
+ "CrossAttnUpBlock2D",
121
+ "CrossAttnUpBlock2D"
122
+ ],
123
+ "upcast_attention": false,
124
+ "use_linear_projection": false
125
+ }
flexible_unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:337322d55ebf3ad224f25121b3ab439e3406f5517bdb61b252d1d2aaea06024d
3
+ size 2101170216
model_index.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DeciDiffusionPipeline",
3
+ "_diffusers_version": "0.21.4",
4
+ "_name_or_path": "Deci/DeciDiffusion-v2-0",
5
+ "feature_extractor": [
6
+ "transformers",
7
+ "CLIPImageProcessor"
8
+ ],
9
+ "requires_safety_checker": true,
10
+ "safety_checker": [
11
+ "stable_diffusion",
12
+ "StableDiffusionSafetyChecker"
13
+ ],
14
+ "scheduler": [
15
+ "diffusers",
16
+ "DDIMScheduler"
17
+ ],
18
+ "text_encoder": [
19
+ "transformers",
20
+ "CLIPTextModel"
21
+ ],
22
+ "tokenizer": [
23
+ "transformers",
24
+ "CLIPTokenizer"
25
+ ],
26
+ "unet": [
27
+ "diffusers",
28
+ "UNet2DConditionModel"
29
+ ],
30
+ "vae": [
31
+ "diffusers",
32
+ "AutoencoderKL"
33
+ ]
34
+ }
pipeline.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from functools import partial
3
+ from typing import Any, Dict, Tuple, Callable
4
+ from typing import Union, Optional, List
5
+
6
+ import numpy as np
7
+ import torch
8
+ from diffusers import DPMSolverMultistepScheduler
9
+ from diffusers import StableDiffusionPipeline, AutoencoderKL
10
+ from diffusers import Transformer2DModel, ModelMixin, ConfigMixin
11
+ from diffusers import UNet2DConditionModel
12
+ from diffusers.configuration_utils import register_to_config
13
+ from diffusers.models.attention import BasicTransformerBlock
14
+ from diffusers.models.resnet import ResnetBlock2D, Downsample2D, Upsample2D
15
+ from diffusers.models.transformer_2d import Transformer2DModelOutput
16
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput
17
+ from diffusers.schedulers import KarrasDiffusionSchedulers
18
+ from diffusers.utils import replace_example_docstring
19
+ from torch import nn
20
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
21
+
22
+
23
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
24
+ """
25
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
26
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
27
+ """
28
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
29
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
30
+ # rescale the results from guidance (fixes overexposure)
31
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
32
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
33
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
34
+ return noise_cfg
35
+
36
+
37
+ def custom_sort_order(obj):
38
+ """
39
+ Key function for sorting order of execution in forward methods
40
+ """
41
+ return {ResnetBlock2D: 0, Transformer2DModel: 1, FlexibleTransformer2DModel: 1}.get(obj.__class__)
42
+
43
+
44
+ def squeeze_to_len_n_starting_from_index_i(n, i, timestep_spacing):
45
+ """
46
+ :param timestep_spacing: the timestep_spacing array we want to squeeze
47
+ :param n: the size of the squeezed array
48
+ :param i: the index we start squeezing from
49
+ :return: squeezed timestep_spacing
50
+ Example:
51
+ timesteps = np.array([967, 907, 846, 786, 725, 665, 604, 544, 484, 423, 363, 302, 242, 181, 121, 60]) (len=16)
52
+ n = 10, i = 6
53
+ Expected:
54
+ [967, 907, 846, 786, 725, 665, 4k, 3k, 2k, k], and if we define 665=5k => k = 133
55
+ """
56
+ assert i < n
57
+ squeezed = np.flip(np.arange(n)) + 1 # [n, n-1, ..., 2, 1]
58
+ squeezed[:i] = timestep_spacing[:i]
59
+ k = squeezed[i - 1] // (n - i + 1)
60
+ squeezed[i:] *= k
61
+
62
+ return squeezed
63
+
64
+
65
+ PREDEFINED_TIMESTEP_SQUEEZERS = {
66
+ # Tested with DPM 16-steps (reduced 16 -> 10 or 11 steps)
67
+ "10,6": partial(squeeze_to_len_n_starting_from_index_i, 10, 6),
68
+ "11,7": partial(squeeze_to_len_n_starting_from_index_i, 11, 7),
69
+ }
70
+
71
+ FlexibleUnetConfigurations = {
72
+ # General parameters for all blocks
73
+ "sample_size": 64,
74
+ "temb_dim": 320 * 4,
75
+ "resnet_eps": 1e-5,
76
+ "resnet_act_fn": "silu",
77
+ "num_attention_heads": 8,
78
+ "cross_attention_dim": 768,
79
+ # Controls modules execute order in unet's forward
80
+ "mix_block_in_forward": True,
81
+ # Down blocks parameters
82
+ "down_blocks_in_channels": [320, 320, 640],
83
+ "down_blocks_out_channels": [320, 640, 1280],
84
+ "down_blocks_num_attentions": [0, 1, 3],
85
+ "down_blocks_num_resnets": [2, 2, 1],
86
+ "add_downsample": [True, True, False],
87
+ # Middle block parameters
88
+ "add_upsample_mid_block": None,
89
+ "mid_num_resnets": 0,
90
+ "mid_num_attentions": 0,
91
+ # Up block parameters
92
+ "prev_output_channels": [1280, 1280, 640],
93
+ "up_blocks_num_attentions": [5, 3, 0],
94
+ "up_blocks_num_resnets": [2, 3, 3],
95
+ "add_upsample": [True, True, False],
96
+ }
97
+
98
+
99
+ class SqueezedDPMSolverMultistepScheduler(DPMSolverMultistepScheduler):
100
+ """
101
+ This is a copy-paste from Diffuser's `DPMSolverMultistepScheduler`, with minor differences:
102
+ * Defaults are modified to accommodate DeciDiffusion
103
+ * It supports a squeezer to squeeze the number of inference steps to a smaller number
104
+ //!\\ IMPORTANT: the actual number of inference steps is deduced by the squeezer, and not the pipeline!
105
+ """
106
+
107
+ @register_to_config
108
+ def __init__(
109
+ self,
110
+ num_train_timesteps: int = 1000,
111
+ beta_start: float = 0.0001,
112
+ beta_end: float = 0.02,
113
+ beta_schedule: str = "squaredcos_cap_v2", # NOTE THIS DEFAULT VALUE
114
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
115
+ solver_order: int = 2,
116
+ prediction_type: str = "v_prediction", # NOTE THIS DEFAULT VALUE
117
+ thresholding: bool = False,
118
+ dynamic_thresholding_ratio: float = 0.995,
119
+ sample_max_value: float = 1.0,
120
+ algorithm_type: str = "dpmsolver++",
121
+ solver_type: str = "heun", # NOTE THIS DEFAULT VALUE
122
+ lower_order_final: bool = True,
123
+ use_karras_sigmas: Optional[bool] = False,
124
+ lambda_min_clipped: float = -3.0, # NOTE THIS DEFAULT VALUE
125
+ variance_type: Optional[str] = None,
126
+ timestep_spacing: str = "linspace",
127
+ steps_offset: int = 1,
128
+ squeeze_mode: Optional[str] = None, # NOTE THIS ADDITION. Supports keys from `PREDEFINED_TIMESTEP_SQUEEZERS` defined above
129
+ ):
130
+ self._squeezer = PREDEFINED_TIMESTEP_SQUEEZERS.get(squeeze_mode)
131
+
132
+ if use_karras_sigmas:
133
+ raise NotImplementedError("Squeezing isn't tested with `use_karras_sigmas`. Please provide `use_karras_sigmas=False`")
134
+
135
+ super().__init__(
136
+ num_train_timesteps=num_train_timesteps,
137
+ beta_start=beta_start,
138
+ beta_end=beta_end,
139
+ beta_schedule=beta_schedule,
140
+ trained_betas=trained_betas,
141
+ solver_order=solver_order,
142
+ prediction_type=prediction_type,
143
+ thresholding=thresholding,
144
+ dynamic_thresholding_ratio=dynamic_thresholding_ratio,
145
+ sample_max_value=sample_max_value,
146
+ algorithm_type=algorithm_type,
147
+ solver_type=solver_type,
148
+ lower_order_final=lower_order_final,
149
+ use_karras_sigmas=False,
150
+ lambda_min_clipped=lambda_min_clipped,
151
+ variance_type=variance_type,
152
+ timestep_spacing=timestep_spacing,
153
+ steps_offset=steps_offset,
154
+ )
155
+
156
+ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
157
+ """
158
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
159
+
160
+ Args:
161
+ num_inference_steps (`int`):
162
+ The number of diffusion steps used when generating samples with a pre-trained model.
163
+ device (`str` or `torch.device`, *optional*):
164
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
165
+ """
166
+ super().set_timesteps(num_inference_steps=num_inference_steps, device=device)
167
+ if self._squeezer is not None:
168
+ timesteps = self._squeezer(self.timesteps.cpu())
169
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
170
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
171
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
172
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
173
+ self.sigmas = torch.from_numpy(sigmas)
174
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
175
+ self.num_inference_steps = len(timesteps)
176
+
177
+
178
+ class FlexibleIdentityBlock(nn.Module):
179
+ def forward(
180
+ self,
181
+ hidden_states: torch.FloatTensor,
182
+ temb: Optional[torch.FloatTensor] = None,
183
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
184
+ attention_mask: Optional[torch.FloatTensor] = None,
185
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
186
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
187
+ ):
188
+ return hidden_states
189
+
190
+
191
+ class FlexibleUNet2DConditionModel(UNet2DConditionModel, ModelMixin):
192
+ configurations = FlexibleUnetConfigurations
193
+
194
+ @register_to_config
195
+ def __init__(self):
196
+ super().__init__(
197
+ sample_size=self.configurations.get("sample_size", FlexibleUnetConfigurations["sample_size"]),
198
+ cross_attention_dim=self.configurations.get("cross_attention_dim", FlexibleUnetConfigurations["cross_attention_dim"]),
199
+ )
200
+
201
+ num_attention_heads = self.configurations.get("num_attention_heads")
202
+ cross_attention_dim = self.configurations.get("cross_attention_dim")
203
+ mix_block_in_forward = self.configurations.get("mix_block_in_forward")
204
+ resnet_act_fn = self.configurations.get("resnet_act_fn")
205
+ resnet_eps = self.configurations.get("resnet_eps")
206
+ temb_dim = self.configurations.get("temb_dim")
207
+
208
+ ###############
209
+ # Down blocks #
210
+ ###############
211
+ down_blocks_num_attentions = self.configurations.get("down_blocks_num_attentions")
212
+ down_blocks_out_channels = self.configurations.get("down_blocks_out_channels")
213
+ down_blocks_in_channels = self.configurations.get("down_blocks_in_channels")
214
+ down_blocks_num_resnets = self.configurations.get("down_blocks_num_resnets")
215
+ add_downsample = self.configurations.get("add_downsample")
216
+
217
+ self.down_blocks = nn.ModuleList()
218
+
219
+ for i, (in_c, out_c, n_res, n_att, add_down) in enumerate(
220
+ zip(down_blocks_in_channels, down_blocks_out_channels, down_blocks_num_resnets, down_blocks_num_attentions, add_downsample)
221
+ ):
222
+ last_block = i == len(down_blocks_in_channels) - 1
223
+ self.down_blocks.append(
224
+ FlexibleCrossAttnDownBlock2D(
225
+ in_channels=in_c,
226
+ out_channels=out_c,
227
+ temb_channels=temb_dim,
228
+ num_resnets=n_res,
229
+ num_attentions=n_att,
230
+ resnet_eps=resnet_eps,
231
+ resnet_act_fn=resnet_act_fn,
232
+ num_attention_heads=num_attention_heads,
233
+ cross_attention_dim=cross_attention_dim,
234
+ add_downsample=add_down,
235
+ last_block=last_block,
236
+ mix_block_in_forward=mix_block_in_forward,
237
+ )
238
+ )
239
+
240
+ ###############
241
+ # Mid blocks #
242
+ ###############
243
+
244
+ mid_block_add_upsample = self.configurations.get("add_upsample_mid_block")
245
+ mid_num_attentions = self.configurations.get("mid_num_attentions")
246
+ mid_num_resnets = self.configurations.get("mid_num_resnets")
247
+
248
+ if mid_num_resnets == mid_num_attentions == 0:
249
+ self.mid_block = FlexibleIdentityBlock()
250
+ else:
251
+ self.mid_block = FlexibleUNetMidBlock2DCrossAttn(
252
+ in_channels=down_blocks_out_channels[-1],
253
+ temb_channels=temb_dim,
254
+ resnet_act_fn=resnet_act_fn,
255
+ resnet_eps=resnet_eps,
256
+ cross_attention_dim=cross_attention_dim,
257
+ num_attention_heads=num_attention_heads,
258
+ num_resnets=mid_num_resnets,
259
+ num_attentions=mid_num_attentions,
260
+ mix_block_in_forward=mix_block_in_forward,
261
+ add_upsample=mid_block_add_upsample,
262
+ )
263
+
264
+ ###############
265
+ # Up blocks #
266
+ ###############
267
+
268
+ up_blocks_num_attentions = self.configurations.get("up_blocks_num_attentions")
269
+ up_blocks_num_resnets = self.configurations.get("up_blocks_num_resnets")
270
+ prev_output_channels = self.configurations.get("prev_output_channels")
271
+ up_upsample = self.configurations.get("add_upsample")
272
+
273
+ self.up_blocks = nn.ModuleList()
274
+ for in_c, out_c, prev_out, n_res, n_att, add_up in zip(
275
+ reversed(down_blocks_in_channels),
276
+ reversed(down_blocks_out_channels),
277
+ prev_output_channels,
278
+ up_blocks_num_resnets,
279
+ up_blocks_num_attentions,
280
+ up_upsample,
281
+ ):
282
+ self.up_blocks.append(
283
+ FlexibleCrossAttnUpBlock2D(
284
+ in_channels=in_c,
285
+ out_channels=out_c,
286
+ prev_output_channel=prev_out,
287
+ temb_channels=temb_dim,
288
+ num_resnets=n_res,
289
+ num_attentions=n_att,
290
+ resnet_eps=resnet_eps,
291
+ resnet_act_fn=resnet_act_fn,
292
+ num_attention_heads=num_attention_heads,
293
+ cross_attention_dim=cross_attention_dim,
294
+ add_upsample=add_up,
295
+ mix_block_in_forward=mix_block_in_forward,
296
+ )
297
+ )
298
+
299
+
300
+ class FlexibleCrossAttnDownBlock2D(nn.Module):
301
+ def __init__(
302
+ self,
303
+ in_channels: int,
304
+ out_channels: int,
305
+ temb_channels: int,
306
+ dropout: float = 0.0,
307
+ num_resnets: int = 1,
308
+ num_attentions: int = 1,
309
+ transformer_layers_per_block: int = 1,
310
+ resnet_eps: float = 1e-6,
311
+ resnet_time_scale_shift: str = "default",
312
+ resnet_act_fn: str = "swish",
313
+ resnet_groups: int = 32,
314
+ resnet_pre_norm: bool = True,
315
+ num_attention_heads: int = 1,
316
+ cross_attention_dim: int = 1280,
317
+ output_scale_factor: float = 1.0,
318
+ downsample_padding: int = 1,
319
+ add_downsample: bool = True,
320
+ use_linear_projection: bool = False,
321
+ only_cross_attention: bool = False,
322
+ upcast_attention: bool = False,
323
+ last_block: bool = False,
324
+ mix_block_in_forward: bool = True,
325
+ ):
326
+ super().__init__()
327
+
328
+ self.last_block = last_block
329
+ self.mix_block_in_forward = mix_block_in_forward
330
+ self.has_cross_attention = True
331
+ self.num_attention_heads = num_attention_heads
332
+
333
+ modules = []
334
+
335
+ add_resnets = [True] * num_resnets
336
+ add_cross_attentions = [True] * num_attentions
337
+ for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
338
+ in_channels = in_channels if i == 0 else out_channels
339
+ if add_resnet:
340
+ modules.append(
341
+ ResnetBlock2D(
342
+ in_channels=in_channels,
343
+ out_channels=out_channels,
344
+ temb_channels=temb_channels,
345
+ eps=resnet_eps,
346
+ groups=resnet_groups,
347
+ dropout=dropout,
348
+ time_embedding_norm=resnet_time_scale_shift,
349
+ non_linearity=resnet_act_fn,
350
+ output_scale_factor=output_scale_factor,
351
+ pre_norm=resnet_pre_norm,
352
+ )
353
+ )
354
+ if add_cross_attention:
355
+ modules.append(
356
+ FlexibleTransformer2DModel(
357
+ num_attention_heads=num_attention_heads,
358
+ attention_head_dim=out_channels // num_attention_heads,
359
+ in_channels=out_channels,
360
+ num_layers=transformer_layers_per_block,
361
+ cross_attention_dim=cross_attention_dim,
362
+ norm_num_groups=resnet_groups,
363
+ use_linear_projection=use_linear_projection,
364
+ only_cross_attention=only_cross_attention,
365
+ upcast_attention=upcast_attention,
366
+ )
367
+ )
368
+
369
+ if not mix_block_in_forward:
370
+ modules = sorted(modules, key=custom_sort_order)
371
+
372
+ self.modules_list = nn.ModuleList(modules)
373
+
374
+ if add_downsample:
375
+ self.downsamplers = nn.ModuleList([Downsample2D(out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op")])
376
+ else:
377
+ self.downsamplers = None
378
+
379
+ self.gradient_checkpointing = False
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states: torch.FloatTensor,
384
+ temb: Optional[torch.FloatTensor] = None,
385
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
386
+ attention_mask: Optional[torch.FloatTensor] = None,
387
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
388
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
389
+ ):
390
+ output_states = ()
391
+
392
+ for module in self.modules_list:
393
+ if isinstance(module, ResnetBlock2D):
394
+ hidden_states = module(hidden_states, temb)
395
+ elif isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
396
+ hidden_states = module(
397
+ hidden_states,
398
+ encoder_hidden_states=encoder_hidden_states,
399
+ cross_attention_kwargs=cross_attention_kwargs,
400
+ attention_mask=attention_mask,
401
+ encoder_attention_mask=encoder_attention_mask,
402
+ return_dict=False,
403
+ )[0]
404
+ else:
405
+ raise ValueError(f"Got an unexpected module in modules list! {type(module)}")
406
+ if isinstance(module, ResnetBlock2D):
407
+ output_states = output_states + (hidden_states,)
408
+
409
+ if self.downsamplers is not None:
410
+ for downsampler in self.downsamplers:
411
+ hidden_states = downsampler(hidden_states)
412
+
413
+ if not self.last_block:
414
+ output_states = output_states + (hidden_states,)
415
+
416
+ return hidden_states, output_states
417
+
418
+
419
+ class FlexibleCrossAttnUpBlock2D(nn.Module):
420
+ def __init__(
421
+ self,
422
+ in_channels: int,
423
+ out_channels: int,
424
+ prev_output_channel: int,
425
+ temb_channels: int,
426
+ dropout: float = 0.0,
427
+ num_resnets: int = 1,
428
+ num_attentions: int = 1,
429
+ transformer_layers_per_block: int = 1,
430
+ resnet_eps: float = 1e-6,
431
+ resnet_time_scale_shift: str = "default",
432
+ resnet_act_fn: str = "swish",
433
+ resnet_groups: int = 32,
434
+ resnet_pre_norm: bool = True,
435
+ num_attention_heads: int = 1,
436
+ cross_attention_dim: int = 1280,
437
+ output_scale_factor: float = 1.0,
438
+ add_upsample: bool = True,
439
+ use_linear_projection: bool = False,
440
+ only_cross_attention: bool = False,
441
+ upcast_attention: bool = False,
442
+ mix_block_in_forward: bool = True,
443
+ ):
444
+ super().__init__()
445
+ modules = []
446
+
447
+ # WARNING: This parameter is filled with number of resnets and used within StableDiffusionPipeline
448
+ self.resnets = []
449
+
450
+ self.has_cross_attention = True
451
+ self.num_attention_heads = num_attention_heads
452
+
453
+ add_resnets = [True] * num_resnets
454
+ add_cross_attentions = [True] * num_attentions
455
+ for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
456
+ res_skip_channels = in_channels if (i == len(add_resnets) - 1) else out_channels
457
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
458
+
459
+ if add_resnet:
460
+ self.resnets += [True]
461
+ modules.append(
462
+ ResnetBlock2D(
463
+ in_channels=resnet_in_channels + res_skip_channels,
464
+ out_channels=out_channels,
465
+ temb_channels=temb_channels,
466
+ eps=resnet_eps,
467
+ groups=resnet_groups,
468
+ dropout=dropout,
469
+ time_embedding_norm=resnet_time_scale_shift,
470
+ non_linearity=resnet_act_fn,
471
+ output_scale_factor=output_scale_factor,
472
+ pre_norm=resnet_pre_norm,
473
+ )
474
+ )
475
+ if add_cross_attention:
476
+ modules.append(
477
+ FlexibleTransformer2DModel(
478
+ num_attention_heads,
479
+ out_channels // num_attention_heads,
480
+ in_channels=out_channels,
481
+ num_layers=transformer_layers_per_block,
482
+ cross_attention_dim=cross_attention_dim,
483
+ norm_num_groups=resnet_groups,
484
+ use_linear_projection=use_linear_projection,
485
+ only_cross_attention=only_cross_attention,
486
+ upcast_attention=upcast_attention,
487
+ )
488
+ )
489
+
490
+ if not mix_block_in_forward:
491
+ modules = sorted(modules, key=custom_sort_order)
492
+
493
+ self.modules_list = nn.ModuleList(modules)
494
+
495
+ self.upsamplers = None
496
+ if add_upsample:
497
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
498
+
499
+ self.gradient_checkpointing = False
500
+
501
+ def forward(
502
+ self,
503
+ hidden_states: torch.FloatTensor,
504
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
505
+ temb: Optional[torch.FloatTensor] = None,
506
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
507
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
508
+ upsample_size: Optional[int] = None,
509
+ attention_mask: Optional[torch.FloatTensor] = None,
510
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
511
+ ):
512
+
513
+ for module in self.modules_list:
514
+ if isinstance(module, ResnetBlock2D):
515
+ res_hidden_states = res_hidden_states_tuple[-1]
516
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
517
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
518
+ hidden_states = module(hidden_states, temb)
519
+ if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
520
+ hidden_states = module(
521
+ hidden_states,
522
+ encoder_hidden_states=encoder_hidden_states,
523
+ cross_attention_kwargs=cross_attention_kwargs,
524
+ attention_mask=attention_mask,
525
+ encoder_attention_mask=encoder_attention_mask,
526
+ return_dict=False,
527
+ )[0]
528
+
529
+ if self.upsamplers is not None:
530
+ for upsampler in self.upsamplers:
531
+ hidden_states = upsampler(hidden_states, upsample_size)
532
+
533
+ return hidden_states
534
+
535
+
536
+ class FlexibleUNetMidBlock2DCrossAttn(nn.Module):
537
+ def __init__(
538
+ self,
539
+ in_channels: int,
540
+ temb_channels: int,
541
+ dropout: float = 0.0,
542
+ num_resnets: int = 1,
543
+ num_attentions: int = 1,
544
+ transformer_layers_per_block: int = 1,
545
+ resnet_eps: float = 1e-6,
546
+ resnet_time_scale_shift: str = "default",
547
+ resnet_act_fn: str = "swish",
548
+ resnet_groups: int = 32,
549
+ resnet_pre_norm: bool = True,
550
+ num_attention_heads: int = 1,
551
+ output_scale_factor: float = 1.0,
552
+ cross_attention_dim: int = 1280,
553
+ use_linear_projection: bool = False,
554
+ upcast_attention: bool = False,
555
+ mix_block_in_forward: bool = True,
556
+ add_upsample: bool = True,
557
+ ):
558
+ super().__init__()
559
+
560
+ self.has_cross_attention = True
561
+ self.num_attention_heads = num_attention_heads
562
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
563
+ # There is always at least one resnet
564
+ modules = [
565
+ ResnetBlock2D(
566
+ in_channels=in_channels,
567
+ out_channels=in_channels,
568
+ temb_channels=temb_channels,
569
+ eps=resnet_eps,
570
+ groups=resnet_groups,
571
+ dropout=dropout,
572
+ time_embedding_norm=resnet_time_scale_shift,
573
+ non_linearity=resnet_act_fn,
574
+ output_scale_factor=output_scale_factor,
575
+ pre_norm=resnet_pre_norm,
576
+ )
577
+ ]
578
+
579
+ add_resnets = [True] * num_resnets
580
+ add_cross_attentions = [True] * num_attentions
581
+ for i, (add_resnet, add_cross_attention) in enumerate(itertools.zip_longest(add_resnets, add_cross_attentions, fillvalue=False)):
582
+ if add_cross_attention:
583
+ modules.append(
584
+ FlexibleTransformer2DModel(
585
+ num_attention_heads,
586
+ in_channels // num_attention_heads,
587
+ in_channels=in_channels,
588
+ num_layers=transformer_layers_per_block,
589
+ cross_attention_dim=cross_attention_dim,
590
+ norm_num_groups=resnet_groups,
591
+ use_linear_projection=use_linear_projection,
592
+ upcast_attention=upcast_attention,
593
+ )
594
+ )
595
+
596
+ if add_resnet:
597
+ modules.append(
598
+ ResnetBlock2D(
599
+ in_channels=in_channels,
600
+ out_channels=in_channels,
601
+ temb_channels=temb_channels,
602
+ eps=resnet_eps,
603
+ groups=resnet_groups,
604
+ dropout=dropout,
605
+ time_embedding_norm=resnet_time_scale_shift,
606
+ non_linearity=resnet_act_fn,
607
+ output_scale_factor=output_scale_factor,
608
+ pre_norm=resnet_pre_norm,
609
+ )
610
+ )
611
+ if not mix_block_in_forward:
612
+ modules = sorted(modules, key=custom_sort_order)
613
+
614
+ self.modules_list = nn.ModuleList(modules)
615
+
616
+ self.upsamplers = nn.ModuleList([nn.Identity()])
617
+ if add_upsample:
618
+ self.upsamplers = nn.ModuleList([Upsample2D(in_channels, use_conv=True, out_channels=in_channels)])
619
+
620
+ def forward(
621
+ self,
622
+ hidden_states: torch.FloatTensor,
623
+ temb: Optional[torch.FloatTensor] = None,
624
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
625
+ attention_mask: Optional[torch.FloatTensor] = None,
626
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
627
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
628
+ ) -> torch.FloatTensor:
629
+ hidden_states = self.modules_list[0](hidden_states, temb)
630
+
631
+ for module in self.modules_list:
632
+ if isinstance(module, (FlexibleTransformer2DModel, Transformer2DModel)):
633
+ hidden_states = module(
634
+ hidden_states,
635
+ encoder_hidden_states=encoder_hidden_states,
636
+ cross_attention_kwargs=cross_attention_kwargs,
637
+ attention_mask=attention_mask,
638
+ encoder_attention_mask=encoder_attention_mask,
639
+ return_dict=False,
640
+ )[0]
641
+ elif isinstance(module, ResnetBlock2D):
642
+ hidden_states = module(hidden_states, temb)
643
+
644
+ for upsampler in self.upsamplers:
645
+ hidden_states = upsampler(hidden_states)
646
+
647
+ return hidden_states
648
+
649
+
650
+ class FlexibleTransformer2DModel(ModelMixin, ConfigMixin):
651
+ @register_to_config
652
+ def __init__(
653
+ self,
654
+ num_attention_heads: int = 16,
655
+ attention_head_dim: int = 88,
656
+ in_channels: Optional[int] = None,
657
+ out_channels: Optional[int] = None,
658
+ num_layers: int = 1,
659
+ dropout: float = 0.0,
660
+ norm_num_groups: int = 32,
661
+ cross_attention_dim: Optional[int] = None,
662
+ attention_bias: bool = False,
663
+ activation_fn: str = "geglu",
664
+ num_embeds_ada_norm: Optional[int] = None,
665
+ only_cross_attention: bool = False,
666
+ use_linear_projection: bool = False,
667
+ upcast_attention: bool = False,
668
+ norm_type: str = "layer_norm",
669
+ norm_elementwise_affine: bool = True,
670
+ ):
671
+ super().__init__()
672
+ self.num_attention_heads = num_attention_heads
673
+ self.attention_head_dim = attention_head_dim
674
+ self.in_channels = in_channels
675
+ inner_dim = num_attention_heads * attention_head_dim
676
+
677
+ # Define input layers
678
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
679
+ self.use_linear_projection = use_linear_projection
680
+ if self.use_linear_projection:
681
+ self.proj_in = nn.Linear(in_channels, inner_dim)
682
+ else:
683
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
684
+
685
+ # Define transformers blocks
686
+ self.transformer_blocks = nn.ModuleList(
687
+ [
688
+ BasicTransformerBlock(
689
+ inner_dim,
690
+ num_attention_heads,
691
+ attention_head_dim,
692
+ dropout=dropout,
693
+ cross_attention_dim=cross_attention_dim,
694
+ activation_fn=activation_fn,
695
+ num_embeds_ada_norm=num_embeds_ada_norm,
696
+ attention_bias=attention_bias,
697
+ only_cross_attention=only_cross_attention,
698
+ upcast_attention=upcast_attention,
699
+ norm_type=norm_type,
700
+ norm_elementwise_affine=norm_elementwise_affine,
701
+ )
702
+ for _ in range(num_layers)
703
+ ]
704
+ )
705
+
706
+ # Define output layers
707
+ self.out_channels = in_channels if out_channels is None else out_channels
708
+ if self.use_linear_projection:
709
+ self.proj_out = nn.Linear(inner_dim, in_channels)
710
+ else:
711
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
712
+
713
+ def forward(
714
+ self,
715
+ hidden_states: torch.Tensor,
716
+ encoder_hidden_states: Optional[torch.Tensor] = None,
717
+ timestep: Optional[torch.LongTensor] = None,
718
+ class_labels: Optional[torch.LongTensor] = None,
719
+ cross_attention_kwargs: Dict[str, Any] = None,
720
+ attention_mask: Optional[torch.Tensor] = None,
721
+ encoder_attention_mask: Optional[torch.Tensor] = None,
722
+ return_dict: bool = False,
723
+ ):
724
+ # 1. Input
725
+ batch, _, height, width = hidden_states.shape
726
+ residual = hidden_states
727
+
728
+ hidden_states = self.norm(hidden_states)
729
+ if not self.use_linear_projection:
730
+ hidden_states = self.proj_in(hidden_states)
731
+ inner_dim = hidden_states.shape[1]
732
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
733
+ else:
734
+ inner_dim = hidden_states.shape[1]
735
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
736
+ hidden_states = self.proj_in(hidden_states)
737
+
738
+ # 2. Blocks
739
+ for block in self.transformer_blocks:
740
+ hidden_states = block(
741
+ hidden_states,
742
+ attention_mask=attention_mask,
743
+ encoder_hidden_states=encoder_hidden_states,
744
+ encoder_attention_mask=encoder_attention_mask,
745
+ timestep=timestep,
746
+ cross_attention_kwargs=cross_attention_kwargs,
747
+ class_labels=class_labels,
748
+ )
749
+
750
+ # 3. Output
751
+ if not self.use_linear_projection:
752
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
753
+ hidden_states = self.proj_out(hidden_states)
754
+ else:
755
+ hidden_states = self.proj_out(hidden_states)
756
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
757
+
758
+ output = hidden_states + residual
759
+ if return_dict:
760
+ return (output,)
761
+ return Transformer2DModelOutput(sample=output)
762
+
763
+
764
+ class DeciDiffusionPipeline(StableDiffusionPipeline):
765
+ deci_default_squeeze_mode = "10,6"
766
+ deci_default_number_of_iterations = 16
767
+ deci_default_guidance_rescale = 0.7
768
+
769
+ def __init__(
770
+ self,
771
+ vae: AutoencoderKL,
772
+ text_encoder: CLIPTextModel,
773
+ tokenizer: CLIPTokenizer,
774
+ unet: UNet2DConditionModel,
775
+ scheduler: KarrasDiffusionSchedulers,
776
+ safety_checker: StableDiffusionSafetyChecker,
777
+ feature_extractor: CLIPImageProcessor,
778
+ requires_safety_checker: bool = True,
779
+ ):
780
+ # Replace UNet with Deci`s unet
781
+ del unet
782
+ unet = FlexibleUNet2DConditionModel()
783
+
784
+ # Replace with custom scheduler
785
+ del scheduler
786
+ scheduler = SqueezedDPMSolverMultistepScheduler(squeeze_mode=self.deci_default_squeeze_mode)
787
+
788
+ super().__init__(
789
+ vae=vae,
790
+ text_encoder=text_encoder,
791
+ tokenizer=tokenizer,
792
+ unet=unet,
793
+ scheduler=scheduler,
794
+ safety_checker=safety_checker,
795
+ feature_extractor=feature_extractor,
796
+ requires_safety_checker=requires_safety_checker,
797
+ )
798
+
799
+ self.register_modules(
800
+ vae=vae,
801
+ text_encoder=text_encoder,
802
+ tokenizer=tokenizer,
803
+ unet=unet,
804
+ scheduler=scheduler,
805
+ safety_checker=safety_checker,
806
+ feature_extractor=feature_extractor,
807
+ )
808
+
809
+ @torch.no_grad()
810
+ def __call__(
811
+ self,
812
+ prompt: Union[str, List[str]] = None,
813
+ height: Optional[int] = None,
814
+ width: Optional[int] = None,
815
+ num_inference_steps: int = 16,
816
+ guidance_scale: float = 7.5,
817
+ negative_prompt: Optional[Union[str, List[str]]] = None,
818
+ num_images_per_prompt: Optional[int] = 1,
819
+ eta: float = 0.0,
820
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
821
+ latents: Optional[torch.FloatTensor] = None,
822
+ prompt_embeds: Optional[torch.FloatTensor] = None,
823
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
824
+ output_type: Optional[str] = "pil",
825
+ return_dict: bool = True,
826
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
827
+ callback_steps: int = 1,
828
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
829
+ guidance_rescale: float = 0.7,
830
+ ):
831
+ r"""
832
+ The call function to the pipeline for generation.
833
+
834
+ Args:
835
+ prompt (`str` or `List[str]`, *optional*):
836
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
837
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
838
+ The height in pixels of the generated image.
839
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
840
+ The width in pixels of the generated image.
841
+ num_inference_steps (`int`, *optional*, defaults to 50):
842
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
843
+ expense of slower inference.
844
+ guidance_scale (`float`, *optional*, defaults to 7.5):
845
+ A higher guidance scale value encourages the model to generate images closely linked to the text
846
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
847
+ negative_prompt (`str` or `List[str]`, *optional*):
848
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
849
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
850
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
851
+ The number of images to generate per prompt.
852
+ eta (`float`, *optional*, defaults to 0.0):
853
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
854
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
855
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
856
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
857
+ generation deterministic.
858
+ latents (`torch.FloatTensor`, *optional*):
859
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
860
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
861
+ tensor is generated by sampling using the supplied random `generator`.
862
+ prompt_embeds (`torch.FloatTensor`, *optional*):
863
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
864
+ provided, text embeddings are generated from the `prompt` input argument.
865
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
866
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
867
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
868
+ output_type (`str`, *optional*, defaults to `"pil"`):
869
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
870
+ return_dict (`bool`, *optional*, defaults to `True`):
871
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
872
+ plain tuple.
873
+ callback (`Callable`, *optional*):
874
+ A function that calls every `callback_steps` steps during inference. The function is called with the
875
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
876
+ callback_steps (`int`, *optional*, defaults to 1):
877
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
878
+ every step.
879
+ cross_attention_kwargs (`dict`, *optional*):
880
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
881
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
882
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
883
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
884
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
885
+ using zero terminal SNR.
886
+
887
+ Examples:
888
+
889
+ Returns:
890
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
891
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
892
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
893
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
894
+ "not-safe-for-work" (nsfw) content.
895
+ """
896
+ # 0. Default height and width to unet
897
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
898
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
899
+
900
+ # 1. Check inputs. Raise error if not correct
901
+ self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
902
+
903
+ # 2. Define call parameters
904
+ if prompt is not None and isinstance(prompt, str):
905
+ batch_size = 1
906
+ elif prompt is not None and isinstance(prompt, list):
907
+ batch_size = len(prompt)
908
+ else:
909
+ batch_size = prompt_embeds.shape[0]
910
+
911
+ device = self._execution_device
912
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
913
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
914
+ # corresponds to doing no classifier free guidance.
915
+ do_classifier_free_guidance = guidance_scale > 1.0
916
+
917
+ # 3. Encode input prompt
918
+ text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
919
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
920
+ prompt,
921
+ device,
922
+ num_images_per_prompt,
923
+ do_classifier_free_guidance,
924
+ negative_prompt,
925
+ prompt_embeds=prompt_embeds,
926
+ negative_prompt_embeds=negative_prompt_embeds,
927
+ lora_scale=text_encoder_lora_scale,
928
+ )
929
+ # For classifier free guidance, we need to do two forward passes.
930
+ # Here we concatenate the unconditional and text embeddings into a single batch
931
+ # to avoid doing two forward passes
932
+ if do_classifier_free_guidance:
933
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
934
+
935
+ # 4. Prepare timesteps
936
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
937
+ timesteps = self.scheduler.timesteps
938
+
939
+ # 5. Prepare latent variables
940
+ num_channels_latents = self.unet.config.in_channels
941
+ latents = self.prepare_latents(
942
+ batch_size * num_images_per_prompt,
943
+ num_channels_latents,
944
+ height,
945
+ width,
946
+ prompt_embeds.dtype,
947
+ device,
948
+ generator,
949
+ latents,
950
+ )
951
+
952
+ # 6. Prepare extra step kwargs.
953
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
954
+
955
+ # 7. Denoising loop
956
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
957
+ with self.progress_bar(total=len(timesteps)) as progress_bar:
958
+ for i, t in enumerate(timesteps):
959
+ # expand the latents if we are doing classifier free guidance
960
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
961
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
962
+
963
+ # predict the noise residual
964
+ noise_pred = self.unet(
965
+ latent_model_input,
966
+ t,
967
+ encoder_hidden_states=prompt_embeds,
968
+ cross_attention_kwargs=cross_attention_kwargs,
969
+ return_dict=False,
970
+ )[0]
971
+
972
+ # perform guidance
973
+ if do_classifier_free_guidance:
974
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
975
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
976
+
977
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
978
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
979
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
980
+
981
+ # compute the previous noisy sample x_t -> x_t-1
982
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
983
+
984
+ # call the callback, if provided
985
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
986
+ progress_bar.update()
987
+ if callback is not None and i % callback_steps == 0:
988
+ callback(i, t, latents)
989
+
990
+ if not output_type == "latent":
991
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
992
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
993
+ else:
994
+ image = latents
995
+ has_nsfw_concept = None
996
+
997
+ if has_nsfw_concept is None:
998
+ do_denormalize = [True] * image.shape[0]
999
+ else:
1000
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1001
+
1002
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1003
+
1004
+ # Offload all models
1005
+ self.maybe_free_model_hooks()
1006
+
1007
+ if not return_dict:
1008
+ return (image, has_nsfw_concept)
1009
+
1010
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
safety_checker/config.json ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9",
3
+ "_name_or_path": "/home/borys.tymchenko/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/safety_checker",
4
+ "architectures": [
5
+ "StableDiffusionSafetyChecker"
6
+ ],
7
+ "initializer_factor": 1.0,
8
+ "logit_scale_init_value": 2.6592,
9
+ "model_type": "clip",
10
+ "projection_dim": 768,
11
+ "text_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "begin_suppress_tokens": null,
18
+ "bos_token_id": 0,
19
+ "chunk_size_feed_forward": 0,
20
+ "cross_attention_hidden_size": null,
21
+ "decoder_start_token_id": null,
22
+ "diversity_penalty": 0.0,
23
+ "do_sample": false,
24
+ "dropout": 0.0,
25
+ "early_stopping": false,
26
+ "encoder_no_repeat_ngram_size": 0,
27
+ "eos_token_id": 2,
28
+ "exponential_decay_length_penalty": null,
29
+ "finetuning_task": null,
30
+ "forced_bos_token_id": null,
31
+ "forced_eos_token_id": null,
32
+ "hidden_act": "quick_gelu",
33
+ "hidden_size": 768,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "initializer_factor": 1.0,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 3072,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "layer_norm_eps": 1e-05,
48
+ "length_penalty": 1.0,
49
+ "max_length": 20,
50
+ "max_position_embeddings": 77,
51
+ "min_length": 0,
52
+ "model_type": "clip_text_model",
53
+ "no_repeat_ngram_size": 0,
54
+ "num_attention_heads": 12,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_hidden_layers": 12,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": 1,
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "projection_dim": 512,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.30.2",
84
+ "typical_p": 1.0,
85
+ "use_bfloat16": false,
86
+ "vocab_size": 49408
87
+ },
88
+ "torch_dtype": "float32",
89
+ "transformers_version": null,
90
+ "vision_config": {
91
+ "_name_or_path": "",
92
+ "add_cross_attention": false,
93
+ "architectures": null,
94
+ "attention_dropout": 0.0,
95
+ "bad_words_ids": null,
96
+ "begin_suppress_tokens": null,
97
+ "bos_token_id": null,
98
+ "chunk_size_feed_forward": 0,
99
+ "cross_attention_hidden_size": null,
100
+ "decoder_start_token_id": null,
101
+ "diversity_penalty": 0.0,
102
+ "do_sample": false,
103
+ "dropout": 0.0,
104
+ "early_stopping": false,
105
+ "encoder_no_repeat_ngram_size": 0,
106
+ "eos_token_id": null,
107
+ "exponential_decay_length_penalty": null,
108
+ "finetuning_task": null,
109
+ "forced_bos_token_id": null,
110
+ "forced_eos_token_id": null,
111
+ "hidden_act": "quick_gelu",
112
+ "hidden_size": 1024,
113
+ "id2label": {
114
+ "0": "LABEL_0",
115
+ "1": "LABEL_1"
116
+ },
117
+ "image_size": 224,
118
+ "initializer_factor": 1.0,
119
+ "initializer_range": 0.02,
120
+ "intermediate_size": 4096,
121
+ "is_decoder": false,
122
+ "is_encoder_decoder": false,
123
+ "label2id": {
124
+ "LABEL_0": 0,
125
+ "LABEL_1": 1
126
+ },
127
+ "layer_norm_eps": 1e-05,
128
+ "length_penalty": 1.0,
129
+ "max_length": 20,
130
+ "min_length": 0,
131
+ "model_type": "clip_vision_model",
132
+ "no_repeat_ngram_size": 0,
133
+ "num_attention_heads": 16,
134
+ "num_beam_groups": 1,
135
+ "num_beams": 1,
136
+ "num_channels": 3,
137
+ "num_hidden_layers": 24,
138
+ "num_return_sequences": 1,
139
+ "output_attentions": false,
140
+ "output_hidden_states": false,
141
+ "output_scores": false,
142
+ "pad_token_id": null,
143
+ "patch_size": 14,
144
+ "prefix": null,
145
+ "problem_type": null,
146
+ "projection_dim": 512,
147
+ "pruned_heads": {},
148
+ "remove_invalid_values": false,
149
+ "repetition_penalty": 1.0,
150
+ "return_dict": true,
151
+ "return_dict_in_generate": false,
152
+ "sep_token_id": null,
153
+ "suppress_tokens": null,
154
+ "task_specific_params": null,
155
+ "temperature": 1.0,
156
+ "tf_legacy_loss": false,
157
+ "tie_encoder_decoder": false,
158
+ "tie_word_embeddings": true,
159
+ "tokenizer_class": null,
160
+ "top_k": 50,
161
+ "top_p": 1.0,
162
+ "torch_dtype": null,
163
+ "torchscript": false,
164
+ "transformers_version": "4.30.2",
165
+ "typical_p": 1.0,
166
+ "use_bfloat16": false
167
+ }
168
+ }
safety_checker/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11cfe53105625af8c00faac32a430626641cce686454f3c39d837f14397d858b
3
+ size 1215981832
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.23.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "squaredcos_cap_v2",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "v_prediction",
12
+ "rescale_betas_zero_snr": true,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": true,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "trailing",
18
+ "trained_betas": null
19
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/home/borys.tymchenko/qcomdiffusion/checkpoint-286000-2050048000/pipeline/text_encoder",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.30.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22928c6a6a99759e4a19648ba56e044d1df47b650f7879470501b71ec996a3ef
3
+ size 492265880
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "clean_up_tokenization_spaces": true,
12
+ "do_lower_case": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 77,
23
+ "pad_token": "<|endoftext|>",
24
+ "tokenizer_class": "CLIPTokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.23.0",
4
+ "_name_or_path": "/home/borys.tymchenko/.cache/huggingface/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9/unet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 4,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 4,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": 64,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": 1,
60
+ "up_block_types": [
61
+ "UpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D",
64
+ "CrossAttnUpBlock2D"
65
+ ],
66
+ "upcast_attention": false,
67
+ "use_linear_projection": false
68
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d27cd69d4a0aa32105087a619f32a51bc087e133be93fe23da92f3c0bcc07d79
3
+ size 3438167536
vae/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.23.0",
4
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "layers_per_block": 2,
22
+ "norm_num_groups": 32,
23
+ "out_channels": 3,
24
+ "sample_size": 768,
25
+ "scaling_factor": 0.18215,
26
+ "up_block_types": [
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D",
30
+ "UpDecoderBlock2D"
31
+ ]
32
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268