Spaces:
wssb
/
Runtime error

wssb bramw commited on
Commit
d2a06b2
0 Parent(s):

Duplicate from Salesforce/EDICT

Browse files

Co-authored-by: Bram Wallace <bramw@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +34 -0
  2. .gitignore +1 -0
  3. README.md +14 -0
  4. app.py +146 -0
  5. app_fully_disabled.py +285 -0
  6. edict_functions.py +997 -0
  7. local_app.py +33 -0
  8. my_diffusers/__init__.py +60 -0
  9. my_diffusers/__pycache__/__init__.cpython-38.pyc +0 -0
  10. my_diffusers/__pycache__/configuration_utils.cpython-38.pyc +0 -0
  11. my_diffusers/__pycache__/modeling_utils.cpython-38.pyc +0 -0
  12. my_diffusers/__pycache__/onnx_utils.cpython-38.pyc +0 -0
  13. my_diffusers/__pycache__/optimization.cpython-38.pyc +0 -0
  14. my_diffusers/__pycache__/pipeline_utils.cpython-38.pyc +0 -0
  15. my_diffusers/__pycache__/training_utils.cpython-38.pyc +0 -0
  16. my_diffusers/commands/__init__.py +27 -0
  17. my_diffusers/commands/diffusers_cli.py +41 -0
  18. my_diffusers/commands/env.py +70 -0
  19. my_diffusers/configuration_utils.py +403 -0
  20. my_diffusers/dependency_versions_check.py +47 -0
  21. my_diffusers/dependency_versions_table.py +26 -0
  22. my_diffusers/dynamic_modules_utils.py +335 -0
  23. my_diffusers/hub_utils.py +197 -0
  24. my_diffusers/modeling_utils.py +542 -0
  25. my_diffusers/models/__init__.py +17 -0
  26. my_diffusers/models/__pycache__/__init__.cpython-38.pyc +0 -0
  27. my_diffusers/models/__pycache__/attention.cpython-38.pyc +0 -0
  28. my_diffusers/models/__pycache__/embeddings.cpython-38.pyc +0 -0
  29. my_diffusers/models/__pycache__/resnet.cpython-38.pyc +0 -0
  30. my_diffusers/models/__pycache__/unet_2d.cpython-38.pyc +0 -0
  31. my_diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc +0 -0
  32. my_diffusers/models/__pycache__/unet_blocks.cpython-38.pyc +0 -0
  33. my_diffusers/models/__pycache__/vae.cpython-38.pyc +0 -0
  34. my_diffusers/models/attention.py +333 -0
  35. my_diffusers/models/embeddings.py +116 -0
  36. my_diffusers/models/resnet.py +483 -0
  37. my_diffusers/models/unet_2d.py +246 -0
  38. my_diffusers/models/unet_2d_condition.py +273 -0
  39. my_diffusers/models/unet_blocks.py +1481 -0
  40. my_diffusers/models/vae.py +581 -0
  41. my_diffusers/onnx_utils.py +189 -0
  42. my_diffusers/optimization.py +275 -0
  43. my_diffusers/pipeline_utils.py +417 -0
  44. my_diffusers/pipelines/__init__.py +19 -0
  45. my_diffusers/pipelines/__pycache__/__init__.cpython-38.pyc +0 -0
  46. my_diffusers/pipelines/ddim/__init__.py +2 -0
  47. my_diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc +0 -0
  48. my_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc +0 -0
  49. my_diffusers/pipelines/ddim/pipeline_ddim.py +117 -0
  50. my_diffusers/pipelines/ddpm/__init__.py +2 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ hf_auth
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: EDICT
3
+ emoji: ⚡
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.18.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: bsd-3-clause
11
+ duplicated_from: Salesforce/EDICT
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ # from edict_functions import EDICT_editing
4
+ from PIL import Image
5
+ from utils import Endpoint, get_token
6
+ from io import BytesIO
7
+ import requests
8
+
9
+
10
+ endpoint = Endpoint()
11
+
12
+ def local_edict(x, source_text, edit_text,
13
+ edit_strength, guidance_scale,
14
+ steps=50, mix_weight=0.93, ):
15
+ x = Image.fromarray(x)
16
+ return_im = EDICT_editing(x,
17
+ source_text,
18
+ edit_text,
19
+ steps=steps,
20
+ mix_weight=mix_weight,
21
+ init_image_strength=edit_strength,
22
+ guidance_scale=guidance_scale
23
+ )[0]
24
+ return np.array(return_im)
25
+
26
+ def encode_image(image):
27
+ buffered = BytesIO()
28
+ image.save(buffered, format="JPEG", quality=95)
29
+ buffered.seek(0)
30
+
31
+ return buffered
32
+
33
+
34
+
35
+ def decode_image(img_obj):
36
+ img = Image.open(img_obj).convert("RGB")
37
+ return img
38
+
39
+ def edict(x, source_text, edit_text,
40
+ edit_strength, guidance_scale,
41
+ steps=50, mix_weight=0.93, ):
42
+
43
+ url = endpoint.url
44
+ url = url + "/api/edit"
45
+ headers = {### Misc.
46
+
47
+ "User-Agent": "EDICT HuggingFace Space",
48
+ "Auth-Token": get_token(),
49
+ }
50
+
51
+ data = {
52
+ "source_text": source_text,
53
+ "edit_text": edit_text,
54
+ "edit_strength": edit_strength,
55
+ "guidance_scale": guidance_scale,
56
+ }
57
+
58
+ image = encode_image(Image.fromarray(x))
59
+ files = {"image": image}
60
+
61
+ response = requests.post(url, data=data, files=files, headers=headers)
62
+
63
+ if response.status_code == 200:
64
+ return np.array(decode_image(BytesIO(response.content)))
65
+ else:
66
+ return "Error: " + response.text
67
+ # x = decode_image(response)
68
+ # return np.array(x)
69
+
70
+ examples = [
71
+ ['square_ims/american_gothic.jpg', 'A painting of two people frowning', 'A painting of two people smiling', 0.5, 3],
72
+ ['square_ims/colloseum.jpg', 'An old ruined building', 'A new modern office building', 0.8, 3],
73
+ ]
74
+
75
+
76
+ examples.append(['square_ims/scream.jpg', 'A painting of someone screaming', 'A painting of an alien', 0.5, 3])
77
+ examples.append(['square_ims/yosemite.jpg', 'Granite forest valley', 'Granite desert valley', 0.8, 3])
78
+ examples.append(['square_ims/einstein.jpg', 'Mouth open', 'Mouth closed', 0.8, 3])
79
+ examples.append(['square_ims/einstein.jpg', 'A man', 'A man in K.I.S.S. facepaint', 0.8, 3])
80
+ """
81
+ examples.extend([
82
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Chinese New Year cupcake', 0.8, 3],
83
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Union Jack cupcake', 0.8, 3],
84
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Nigerian flag cupcake', 0.8, 3],
85
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A Santa Claus cupcake', 0.8, 3],
86
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'An Easter cupcake', 0.8, 3],
87
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A hedgehog cupcake', 0.8, 3],
88
+ ['square_ims/imagenet_cake_2.jpg', 'A cupcake', 'A rose cupcake', 0.8, 3],
89
+ ])
90
+ """
91
+
92
+ for dog_i in [1, 2]:
93
+ for breed in ['Golden Retriever', 'Chihuahua', 'Dalmatian']:
94
+ examples.append([f'square_ims/imagenet_dog_{dog_i}.jpg', 'A dog', f'A {breed}', 0.8, 3])
95
+
96
+
97
+ description = """
98
+ **We have disabled image uploading from March 22. 2023.**
99
+
100
+ **Please try examples provided below.**
101
+
102
+ A gradio demo for [EDICT](https://arxiv.org/abs/2211.12446) (CVPR23)
103
+ """
104
+ # description = gr.Markdown(description)
105
+
106
+ article = """
107
+
108
+ ### Prompting Style
109
+
110
+ As with many text-to-image methods, the prompting style of EDICT can make a big difference. When in doubt, experiment! Some guidance:
111
+ * Parallel *Original Description* and *Edit Description* construction as much as possible. Inserting/editing single words often is enough to affect a change while maintaining a lot of the original structure
112
+ * Words that will affect the entire setting (e.g. "A photo of " vs. "A painting of") can make a big difference. Playing around with them can help a lot
113
+
114
+ ### Parameters
115
+ Both `edit_strength` and `guidance_scale` have similar properties qualitatively: the higher the value the more the image will change. We suggest
116
+ * Increasing/decreasing `edit_strength` first, particularly to alter/preserve more of the original structure/content
117
+ * Then changing `guidance_scale` to make the change in the edited region more or less pronounced.
118
+
119
+ Usually we find changing `edit_strength` to be enough, but feel free to play around (and report any interesting results)!
120
+
121
+ ### Misc.
122
+
123
+ Having difficulty coming up with a caption? Try [BLIP](https://huggingface.co/spaces/Salesforce/BLIP2) to automatically generate one!
124
+
125
+ As with most StableDiffusion approaches, faces/text are often problematic to render, especially if they're small. Having these in the foreground will help keep them cleaner.
126
+
127
+ A returned black image means that the [Safety Checker](https://huggingface.co/CompVis/stable-diffusion-safety-checker) triggered on the photo. This happens in odd cases sometimes (it often rejects
128
+ the huggingface logo or variations), but we need to keep it in for obvious reasons.
129
+ """
130
+ # article = gr.Markdown(description)
131
+
132
+ iface = gr.Interface(fn=edict, inputs=[gr.Image(interactive=False),
133
+ gr.Textbox(label="Original Description", interactive=False),
134
+ gr.Textbox(label="Edit Description", interactive=False),
135
+ # 50, # gr.Slider(5, 50, value=20, step=1),
136
+ # 0.93, # gr.Slider(0.5, 1, value=0.7, step=0.05),
137
+ gr.Slider(0.0, 1, value=0.8, step=0.05),
138
+ gr.Slider(0, 10, value=3, step=0.5),
139
+ ],
140
+ examples = examples,
141
+ outputs="image",
142
+ description=description,
143
+ article=article,
144
+ cache_examples=True
145
+ )
146
+ iface.launch()
app_fully_disabled.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import string
4
+ import gradio as gr
5
+ import requests
6
+ from utils import Endpoint, get_token
7
+
8
+
9
+ def encode_image(image):
10
+ buffered = BytesIO()
11
+ image.save(buffered, format="JPEG")
12
+ buffered.seek(0)
13
+
14
+ return buffered
15
+
16
+
17
+ def query_chat_api(
18
+ image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
19
+ ):
20
+
21
+ url = endpoint.url
22
+ url = url + "/api/generate"
23
+
24
+ headers = {
25
+ "User-Agent": "BLIP-2 HuggingFace Space",
26
+ "Auth-Token": get_token(),
27
+ }
28
+
29
+ data = {
30
+ "prompt": prompt,
31
+ "use_nucleus_sampling": decoding_method == "Nucleus sampling",
32
+ "temperature": temperature,
33
+ "length_penalty": len_penalty,
34
+ "repetition_penalty": repetition_penalty,
35
+ }
36
+
37
+ image = encode_image(image)
38
+ files = {"image": image}
39
+
40
+ response = requests.post(url, data=data, files=files, headers=headers)
41
+
42
+ if response.status_code == 200:
43
+ return response.json()
44
+ else:
45
+ return "Error: " + response.text
46
+
47
+
48
+ def query_caption_api(
49
+ image, decoding_method, temperature, len_penalty, repetition_penalty
50
+ ):
51
+
52
+ url = endpoint.url
53
+ url = url + "/api/caption"
54
+
55
+ headers = {
56
+ "User-Agent": "BLIP-2 HuggingFace Space",
57
+ "Auth-Token": get_token(),
58
+ }
59
+
60
+ data = {
61
+ "use_nucleus_sampling": decoding_method == "Nucleus sampling",
62
+ "temperature": temperature,
63
+ "length_penalty": len_penalty,
64
+ "repetition_penalty": repetition_penalty,
65
+ }
66
+
67
+ image = encode_image(image)
68
+ files = {"image": image}
69
+
70
+ response = requests.post(url, data=data, files=files, headers=headers)
71
+
72
+ if response.status_code == 200:
73
+ return response.json()
74
+ else:
75
+ return "Error: " + response.text
76
+
77
+
78
+ def postprocess_output(output):
79
+ # if last character is not a punctuation, add a full stop
80
+ if not output[0][-1] in string.punctuation:
81
+ output[0] += "."
82
+
83
+ return output
84
+
85
+
86
+ def inference_chat(
87
+ image,
88
+ text_input,
89
+ decoding_method,
90
+ temperature,
91
+ length_penalty,
92
+ repetition_penalty,
93
+ history=[],
94
+ ):
95
+ text_input = text_input
96
+ history.append(text_input)
97
+
98
+ prompt = " ".join(history)
99
+
100
+ output = query_chat_api(
101
+ image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
102
+ )
103
+ output = postprocess_output(output)
104
+ history += output
105
+
106
+ chat = [
107
+ (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
108
+ ] # convert to tuples of list
109
+
110
+ return {chatbot: chat, state: history}
111
+
112
+
113
+ def inference_caption(
114
+ image,
115
+ decoding_method,
116
+ temperature,
117
+ length_penalty,
118
+ repetition_penalty,
119
+ ):
120
+ output = query_caption_api(
121
+ image, decoding_method, temperature, length_penalty, repetition_penalty
122
+ )
123
+
124
+ return output[0]
125
+
126
+
127
+ title = """<h1 align="center">BLIP-2</h1>"""
128
+ description = """Gradio demo for BLIP-2, image-to-text generation from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them.
129
+ <br> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected."""
130
+ article = """<strong>Paper</strong>: <a href='https://arxiv.org/abs/2301.12597' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>
131
+ <br> <strong>Code</strong>: BLIP2 is now integrated into GitHub repo: <a href='https://github.com/salesforce/LAVIS' target='_blank'>LAVIS: a One-stop Library for Language and Vision</a>
132
+ <br> <strong>🤗 `transformers` integration</strong>: You can now use `transformers` to use our BLIP-2 models! Check out the <a href='https://huggingface.co/docs/transformers/main/en/model_doc/blip-2' target='_blank'> official docs </a>
133
+ <p> <strong>Project Page</strong>: <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'> BLIP2 on LAVIS</a>
134
+ <br> <strong>Description</strong>: Captioning results from <strong>BLIP2_OPT_6.7B</strong>. Chat results from <strong>BLIP2_FlanT5xxl</strong>.
135
+
136
+ <p><strong>For safety and ethical considerations, we have disabled image uploading from March 21. 2023. </strong>
137
+ <p><strong>Please try examples provided below.</strong>
138
+ """
139
+
140
+ endpoint = Endpoint()
141
+
142
+ examples = [
143
+ ["house.png", "How could someone get out of the house?"],
144
+ ["flower.jpg", "Question: What is this flower and where is it's origin? Answer:"],
145
+ ["pizza.jpg", "What are steps to cook it?"],
146
+ ["sunset.jpg", "Here is a romantic message going along the photo:"],
147
+ ["forbidden_city.webp", "In what dynasties was this place built?"],
148
+ ]
149
+
150
+ with gr.Blocks(
151
+ css="""
152
+ .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
153
+ #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
154
+ """
155
+ ) as iface:
156
+ state = gr.State([])
157
+
158
+ gr.Markdown(title)
159
+ gr.Markdown(description)
160
+ gr.Markdown(article)
161
+
162
+ with gr.Row():
163
+ with gr.Column(scale=1):
164
+ image_input = gr.Image(type="pil", interactive=False)
165
+
166
+ # with gr.Row():
167
+ sampling = gr.Radio(
168
+ choices=["Beam search", "Nucleus sampling"],
169
+ value="Beam search",
170
+ label="Text Decoding Method",
171
+ interactive=True,
172
+ )
173
+
174
+ temperature = gr.Slider(
175
+ minimum=0.5,
176
+ maximum=1.0,
177
+ value=1.0,
178
+ step=0.1,
179
+ interactive=True,
180
+ label="Temperature (used with nucleus sampling)",
181
+ )
182
+
183
+ len_penalty = gr.Slider(
184
+ minimum=-1.0,
185
+ maximum=2.0,
186
+ value=1.0,
187
+ step=0.2,
188
+ interactive=True,
189
+ label="Length Penalty (set to larger for longer sequence, used with beam search)",
190
+ )
191
+
192
+ rep_penalty = gr.Slider(
193
+ minimum=1.0,
194
+ maximum=5.0,
195
+ value=1.5,
196
+ step=0.5,
197
+ interactive=True,
198
+ label="Repeat Penalty (larger value prevents repetition)",
199
+ )
200
+
201
+ with gr.Column(scale=1.8):
202
+
203
+ with gr.Column():
204
+ caption_output = gr.Textbox(lines=1, label="Caption Output")
205
+ caption_button = gr.Button(
206
+ value="Caption it!", interactive=True, variant="primary"
207
+ )
208
+ caption_button.click(
209
+ inference_caption,
210
+ [
211
+ image_input,
212
+ sampling,
213
+ temperature,
214
+ len_penalty,
215
+ rep_penalty,
216
+ ],
217
+ [caption_output],
218
+ )
219
+
220
+ gr.Markdown("""Trying prompting your input for chat; e.g. example prompt for QA, \"Question: {} Answer:\" Use proper punctuation (e.g., question mark).""")
221
+ with gr.Row():
222
+ with gr.Column(
223
+ scale=1.5,
224
+ ):
225
+ chatbot = gr.Chatbot(
226
+ label="Chat Output (from FlanT5)",
227
+ )
228
+
229
+ # with gr.Row():
230
+ with gr.Column(scale=1):
231
+ chat_input = gr.Textbox(lines=1, label="Chat Input")
232
+ chat_input.submit(
233
+ inference_chat,
234
+ [
235
+ image_input,
236
+ chat_input,
237
+ sampling,
238
+ temperature,
239
+ len_penalty,
240
+ rep_penalty,
241
+ state,
242
+ ],
243
+ [chatbot, state],
244
+ )
245
+
246
+ with gr.Row():
247
+ clear_button = gr.Button(value="Clear", interactive=True)
248
+ clear_button.click(
249
+ lambda: ("", [], []),
250
+ [],
251
+ [chat_input, chatbot, state],
252
+ queue=False,
253
+ )
254
+
255
+ submit_button = gr.Button(
256
+ value="Submit", interactive=True, variant="primary"
257
+ )
258
+ submit_button.click(
259
+ inference_chat,
260
+ [
261
+ image_input,
262
+ chat_input,
263
+ sampling,
264
+ temperature,
265
+ len_penalty,
266
+ rep_penalty,
267
+ state,
268
+ ],
269
+ [chatbot, state],
270
+ )
271
+
272
+ image_input.change(
273
+ lambda: ("", "", []),
274
+ [],
275
+ [chatbot, caption_output, state],
276
+ queue=False,
277
+ )
278
+
279
+ examples = gr.Examples(
280
+ examples=examples,
281
+ inputs=[image_input, chat_input],
282
+ )
283
+
284
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
285
+ iface.launch(enable_queue=True)
edict_functions.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer
3
+ from omegaconf import OmegaConf
4
+ import math
5
+ import imageio
6
+ from PIL import Image
7
+ import torchvision
8
+ import torch.nn.functional as F
9
+ import torch
10
+ import numpy as np
11
+ from PIL import Image
12
+ import time
13
+ import datetime
14
+ import torch
15
+ import sys
16
+ import os
17
+ from torchvision import datasets
18
+ import pickle
19
+
20
+
21
+
22
+ # StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
23
+ use_half_prec = True
24
+ if use_half_prec:
25
+ from my_half_diffusers import AutoencoderKL, UNet2DConditionModel
26
+ from my_half_diffusers.schedulers.scheduling_utils import SchedulerOutput
27
+ from my_half_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler
28
+ else:
29
+ from my_diffusers import AutoencoderKL, UNet2DConditionModel
30
+ from my_diffusers.schedulers.scheduling_utils import SchedulerOutput
31
+ from my_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler
32
+ torch_dtype = torch.float16 if use_half_prec else torch.float64
33
+ np_dtype = np.float16 if use_half_prec else np.float64
34
+
35
+
36
+
37
+ import random
38
+ from tqdm.auto import tqdm
39
+ from torch import autocast
40
+ from difflib import SequenceMatcher
41
+
42
+ # Build our CLIP model
43
+ model_path_clip = "openai/clip-vit-large-patch14"
44
+ clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip)
45
+ clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch_dtype)
46
+ clip = clip_model.text_model
47
+
48
+
49
+ # Getting our HF Auth token
50
+ auth_token = os.environ.get('auth_token')
51
+ if auth_token is None:
52
+ with open('hf_auth', 'r') as f:
53
+ auth_token = f.readlines()[0].strip()
54
+ model_path_diffusion = "CompVis/stable-diffusion-v1-4"
55
+ # Build our SD model
56
+ unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype)
57
+ vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype)
58
+
59
+ # Push to devices w/ double precision
60
+ device = 'cuda'
61
+ if use_half_prec:
62
+ unet.to(device)
63
+ vae.to(device)
64
+ clip.to(device)
65
+ else:
66
+ unet.double().to(device)
67
+ vae.double().to(device)
68
+ clip.double().to(device)
69
+ print("Loaded all models")
70
+
71
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
72
+ from transformers import AutoFeatureExtractor
73
+ # load safety model
74
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
75
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
76
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
77
+ def load_replacement(x):
78
+ try:
79
+ hwc = x.shape
80
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
81
+ y = (np.array(y)/255.0).astype(x.dtype)
82
+ assert y.shape == x.shape
83
+ return y
84
+ except Exception:
85
+ return x
86
+ def check_safety(x_image):
87
+ safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
88
+ x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
89
+ assert x_checked_image.shape[0] == len(has_nsfw_concept)
90
+ for i in range(len(has_nsfw_concept)):
91
+ if has_nsfw_concept[i]:
92
+ # x_checked_image[i] = load_replacement(x_checked_image[i])
93
+ x_checked_image[i] *= 0 # load_replacement(x_checked_image[i])
94
+ return x_checked_image, has_nsfw_concept
95
+
96
+
97
+ def EDICT_editing(im_path,
98
+ base_prompt,
99
+ edit_prompt,
100
+ use_p2p=False,
101
+ steps=50,
102
+ mix_weight=0.93,
103
+ init_image_strength=0.8,
104
+ guidance_scale=3,
105
+ run_baseline=False,
106
+ width=512, height=512):
107
+ """
108
+ Main call of our research, performs editing with either EDICT or DDIM
109
+
110
+ Args:
111
+ im_path: path to image to run on
112
+ base_prompt: conditional prompt to deterministically noise with
113
+ edit_prompt: desired text conditoining
114
+ steps: ddim steps
115
+ mix_weight: Weight of mixing layers.
116
+ Higher means more consistent generations but divergence in inversion
117
+ Lower means opposite
118
+ This is fairly tuned and can get good results
119
+ init_image_strength: Editing strength. Higher = more dramatic edit.
120
+ Typically [0.6, 0.9] is good range.
121
+ Definitely tunable per-image/maybe best results are at a different value
122
+ guidance_scale: classifier-free guidance scale
123
+ 3 I've found is the best for both our method and basic DDIM inversion
124
+ Higher can result in more distorted results
125
+ run_baseline:
126
+ VERY IMPORTANT
127
+ True is EDICT, False is DDIM
128
+ Output:
129
+ PAIR of Images (tuple)
130
+ If run_baseline=True then [0] will be edit and [1] will be original
131
+ If run_baseline=False then they will be two nearly identical edited versions
132
+ """
133
+ # Resize/center crop to 512x512 (Can do higher res. if desired)
134
+ if isinstance(im_path, str):
135
+ orig_im = load_im_into_format_from_path(im_path)
136
+ elif Image.isImageType(im_path):
137
+ width, height = im_path.size
138
+
139
+
140
+ # add max dim for sake of memory
141
+ max_dim = max(width, height)
142
+ if max_dim > 1024:
143
+ factor = 1024 / max_dim
144
+ width *= factor
145
+ height *= factor
146
+ width = int(width)
147
+ height = int(height)
148
+ im_path = im_path.resize((width, height))
149
+
150
+ min_dim = min(width, height)
151
+ if min_dim < 512:
152
+ factor = 512 / min_dim
153
+ width *= factor
154
+ height *= factor
155
+ width = int(width)
156
+ height = int(height)
157
+ im_path = im_path.resize((width, height))
158
+
159
+ width = width - (width%64)
160
+ height = height - (height%64)
161
+
162
+ orig_im = im_path # general_crop(im_path, width, height)
163
+ else:
164
+ orig_im = im_path
165
+
166
+ # compute latent pair (second one will be original latent if run_baseline=True)
167
+ latents = coupled_stablediffusion(base_prompt,
168
+ reverse=True,
169
+ init_image=orig_im,
170
+ init_image_strength=init_image_strength,
171
+ steps=steps,
172
+ mix_weight=mix_weight,
173
+ guidance_scale=guidance_scale,
174
+ run_baseline=run_baseline,
175
+ width=width, height=height)
176
+ # Denoise intermediate state with new conditioning
177
+ gen = coupled_stablediffusion(edit_prompt if (not use_p2p) else base_prompt,
178
+ None if (not use_p2p) else edit_prompt,
179
+ fixed_starting_latent=latents,
180
+ init_image_strength=init_image_strength,
181
+ steps=steps,
182
+ mix_weight=mix_weight,
183
+ guidance_scale=guidance_scale,
184
+ run_baseline=run_baseline,
185
+ width=width, height=height)
186
+
187
+ return gen
188
+
189
+
190
+ def img2img_editing(im_path,
191
+ edit_prompt,
192
+ steps=50,
193
+ init_image_strength=0.7,
194
+ guidance_scale=3):
195
+ """
196
+ Basic SDEdit/img2img, given an image add some noise and denoise with prompt
197
+ """
198
+ orig_im = load_im_into_format_from_path(im_path)
199
+
200
+ return baseline_stablediffusion(edit_prompt,
201
+ init_image_strength=init_image_strength,
202
+ steps=steps,
203
+ init_image=orig_im,
204
+ guidance_scale=guidance_scale)
205
+
206
+
207
+ def center_crop(im):
208
+ width, height = im.size # Get dimensions
209
+ min_dim = min(width, height)
210
+ left = (width - min_dim)/2
211
+ top = (height - min_dim)/2
212
+ right = (width + min_dim)/2
213
+ bottom = (height + min_dim)/2
214
+
215
+ # Crop the center of the image
216
+ im = im.crop((left, top, right, bottom))
217
+ return im
218
+
219
+
220
+
221
+ def general_crop(im, target_w, target_h):
222
+ width, height = im.size # Get dimensions
223
+ min_dim = min(width, height)
224
+ left = target_w / 2 # (width - min_dim)/2
225
+ top = target_h / 2 # (height - min_dim)/2
226
+ right = width - (target_w / 2) # (width + min_dim)/2
227
+ bottom = height - (target_h / 2) # (height + min_dim)/2
228
+
229
+ # Crop the center of the image
230
+ im = im.crop((left, top, right, bottom))
231
+ return im
232
+
233
+
234
+
235
+ def load_im_into_format_from_path(im_path):
236
+ return center_crop(Image.open(im_path)).resize((512,512))
237
+
238
+
239
+ #### P2P STUFF ####
240
+ def init_attention_weights(weight_tuples):
241
+ tokens_length = clip_tokenizer.model_max_length
242
+ weights = torch.ones(tokens_length)
243
+
244
+ for i, w in weight_tuples:
245
+ if i < tokens_length and i >= 0:
246
+ weights[i] = w
247
+
248
+
249
+ for name, module in unet.named_modules():
250
+ module_name = type(module).__name__
251
+ if module_name == "CrossAttention" and "attn2" in name:
252
+ module.last_attn_slice_weights = weights.to(device)
253
+ if module_name == "CrossAttention" and "attn1" in name:
254
+ module.last_attn_slice_weights = None
255
+
256
+
257
+ def init_attention_edit(tokens, tokens_edit):
258
+ tokens_length = clip_tokenizer.model_max_length
259
+ mask = torch.zeros(tokens_length)
260
+ indices_target = torch.arange(tokens_length, dtype=torch.long)
261
+ indices = torch.zeros(tokens_length, dtype=torch.long)
262
+
263
+ tokens = tokens.input_ids.numpy()[0]
264
+ tokens_edit = tokens_edit.input_ids.numpy()[0]
265
+
266
+ for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
267
+ if b0 < tokens_length:
268
+ if name == "equal" or (name == "replace" and a1-a0 == b1-b0):
269
+ mask[b0:b1] = 1
270
+ indices[b0:b1] = indices_target[a0:a1]
271
+
272
+ for name, module in unet.named_modules():
273
+ module_name = type(module).__name__
274
+ if module_name == "CrossAttention" and "attn2" in name:
275
+ module.last_attn_slice_mask = mask.to(device)
276
+ module.last_attn_slice_indices = indices.to(device)
277
+ if module_name == "CrossAttention" and "attn1" in name:
278
+ module.last_attn_slice_mask = None
279
+ module.last_attn_slice_indices = None
280
+
281
+
282
+ def init_attention_func():
283
+ def new_attention(self, query, key, value, sequence_length, dim):
284
+ batch_size_attention = query.shape[0]
285
+ hidden_states = torch.zeros(
286
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
287
+ )
288
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
289
+ for i in range(hidden_states.shape[0] // slice_size):
290
+ start_idx = i * slice_size
291
+ end_idx = (i + 1) * slice_size
292
+ attn_slice = (
293
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
294
+ )
295
+ attn_slice = attn_slice.softmax(dim=-1)
296
+
297
+ if self.use_last_attn_slice:
298
+ if self.last_attn_slice_mask is not None:
299
+ new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
300
+ attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
301
+ else:
302
+ attn_slice = self.last_attn_slice
303
+
304
+ self.use_last_attn_slice = False
305
+
306
+ if self.save_last_attn_slice:
307
+ self.last_attn_slice = attn_slice
308
+ self.save_last_attn_slice = False
309
+
310
+ if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
311
+ attn_slice = attn_slice * self.last_attn_slice_weights
312
+ self.use_last_attn_weights = False
313
+
314
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
315
+
316
+ hidden_states[start_idx:end_idx] = attn_slice
317
+
318
+ # reshape hidden_states
319
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
320
+ return hidden_states
321
+
322
+ for name, module in unet.named_modules():
323
+ module_name = type(module).__name__
324
+ if module_name == "CrossAttention":
325
+ module.last_attn_slice = None
326
+ module.use_last_attn_slice = False
327
+ module.use_last_attn_weights = False
328
+ module.save_last_attn_slice = False
329
+ module._attention = new_attention.__get__(module, type(module))
330
+
331
+ def use_last_tokens_attention(use=True):
332
+ for name, module in unet.named_modules():
333
+ module_name = type(module).__name__
334
+ if module_name == "CrossAttention" and "attn2" in name:
335
+ module.use_last_attn_slice = use
336
+
337
+ def use_last_tokens_attention_weights(use=True):
338
+ for name, module in unet.named_modules():
339
+ module_name = type(module).__name__
340
+ if module_name == "CrossAttention" and "attn2" in name:
341
+ module.use_last_attn_weights = use
342
+
343
+ def use_last_self_attention(use=True):
344
+ for name, module in unet.named_modules():
345
+ module_name = type(module).__name__
346
+ if module_name == "CrossAttention" and "attn1" in name:
347
+ module.use_last_attn_slice = use
348
+
349
+ def save_last_tokens_attention(save=True):
350
+ for name, module in unet.named_modules():
351
+ module_name = type(module).__name__
352
+ if module_name == "CrossAttention" and "attn2" in name:
353
+ module.save_last_attn_slice = save
354
+
355
+ def save_last_self_attention(save=True):
356
+ for name, module in unet.named_modules():
357
+ module_name = type(module).__name__
358
+ if module_name == "CrossAttention" and "attn1" in name:
359
+ module.save_last_attn_slice = save
360
+ ####################################
361
+
362
+
363
+ ##### BASELINE ALGORITHM, ONLY USED NOW FOR SDEDIT ####3
364
+
365
+ @torch.no_grad()
366
+ def baseline_stablediffusion(prompt="",
367
+ prompt_edit=None,
368
+ null_prompt='',
369
+ prompt_edit_token_weights=[],
370
+ prompt_edit_tokens_start=0.0,
371
+ prompt_edit_tokens_end=1.0,
372
+ prompt_edit_spatial_start=0.0,
373
+ prompt_edit_spatial_end=1.0,
374
+ clip_start=0.0,
375
+ clip_end=1.0,
376
+ guidance_scale=7,
377
+ steps=50,
378
+ seed=1,
379
+ width=512, height=512,
380
+ init_image=None, init_image_strength=0.5,
381
+ fixed_starting_latent = None,
382
+ prev_image= None,
383
+ grid=None,
384
+ clip_guidance=None,
385
+ clip_guidance_scale=1,
386
+ num_cutouts=4,
387
+ cut_power=1,
388
+ scheduler_str='lms',
389
+ return_latent=False,
390
+ one_pass=False,
391
+ normalize_noise_pred=False):
392
+ width = width - width % 64
393
+ height = height - height % 64
394
+
395
+ #If seed is None, randomly select seed from 0 to 2^32-1
396
+ if seed is None: seed = random.randrange(2**32 - 1)
397
+ generator = torch.cuda.manual_seed(seed)
398
+
399
+ #Set inference timesteps to scheduler
400
+ scheduler_dict = {'ddim':DDIMScheduler,
401
+ 'lms':LMSDiscreteScheduler,
402
+ 'pndm':PNDMScheduler,
403
+ 'ddpm':DDPMScheduler}
404
+ scheduler_call = scheduler_dict[scheduler_str]
405
+ if scheduler_str == 'ddim':
406
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
407
+ beta_schedule="scaled_linear",
408
+ clip_sample=False, set_alpha_to_one=False)
409
+ else:
410
+ scheduler = scheduler_call(beta_schedule="scaled_linear",
411
+ num_train_timesteps=1000)
412
+
413
+ scheduler.set_timesteps(steps)
414
+ if prev_image is not None:
415
+ prev_scheduler = LMSDiscreteScheduler(beta_start=0.00085,
416
+ beta_end=0.012,
417
+ beta_schedule="scaled_linear",
418
+ num_train_timesteps=1000)
419
+ prev_scheduler.set_timesteps(steps)
420
+
421
+ #Preprocess image if it exists (img2img)
422
+ if init_image is not None:
423
+ init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS)
424
+ init_image = np.array(init_image).astype(np_dtype) / 255.0 * 2.0 - 1.0
425
+ init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))
426
+
427
+ #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
428
+ if init_image.shape[1] > 3:
429
+ init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])
430
+
431
+ #Move image to GPU
432
+ init_image = init_image.to(device)
433
+
434
+ #Encode image
435
+ with autocast(device):
436
+ init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215
437
+
438
+ t_start = steps - int(steps * init_image_strength)
439
+
440
+ else:
441
+ init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)
442
+ t_start = 0
443
+
444
+ #Generate random normal noise
445
+ if fixed_starting_latent is None:
446
+ noise = torch.randn(init_latent.shape, generator=generator, device=device, dtype=unet.dtype)
447
+ if scheduler_str == 'ddim':
448
+ if init_image is not None:
449
+ raise notImplementedError
450
+ latent = scheduler.add_noise(init_latent, noise,
451
+ 1000 - int(1000 * init_image_strength)).to(device)
452
+ else:
453
+ latent = noise
454
+ else:
455
+ latent = scheduler.add_noise(init_latent, noise,
456
+ t_start).to(device)
457
+ else:
458
+ latent = fixed_starting_latent
459
+ t_start = steps - int(steps * init_image_strength)
460
+
461
+ if prev_image is not None:
462
+ #Resize and prev_image for numpy b h w c -> torch b c h w
463
+ prev_image = prev_image.resize((width, height), resample=Image.Resampling.LANCZOS)
464
+ prev_image = np.array(prev_image).astype(np_dtype) / 255.0 * 2.0 - 1.0
465
+ prev_image = torch.from_numpy(prev_image[np.newaxis, ...].transpose(0, 3, 1, 2))
466
+
467
+ #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
468
+ if prev_image.shape[1] > 3:
469
+ prev_image = prev_image[:, :3] * prev_image[:, 3:] + (1 - prev_image[:, 3:])
470
+
471
+ #Move image to GPU
472
+ prev_image = prev_image.to(device)
473
+
474
+ #Encode image
475
+ with autocast(device):
476
+ prev_init_latent = vae.encode(prev_image).latent_dist.sample(generator=generator) * 0.18215
477
+
478
+ t_start = steps - int(steps * init_image_strength)
479
+
480
+ prev_latent = prev_scheduler.add_noise(prev_init_latent, noise, t_start).to(device)
481
+ else:
482
+ prev_latent = None
483
+
484
+
485
+ #Process clip
486
+ with autocast(device):
487
+ tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
488
+ embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state
489
+
490
+ tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
491
+ embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state
492
+
493
+ #Process prompt editing
494
+ assert not ((prompt_edit is not None) and (prev_image is not None))
495
+ if prompt_edit is not None:
496
+ tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
497
+ embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state
498
+ init_attention_edit(tokens_conditional, tokens_conditional_edit)
499
+ elif prev_image is not None:
500
+ init_attention_edit(tokens_conditional, tokens_conditional)
501
+
502
+
503
+ init_attention_func()
504
+ init_attention_weights(prompt_edit_token_weights)
505
+
506
+ timesteps = scheduler.timesteps[t_start:]
507
+ # print(timesteps)
508
+
509
+ assert isinstance(guidance_scale, int)
510
+ num_cycles = 1 # guidance_scale + 1
511
+
512
+ last_noise_preds = None
513
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
514
+ t_index = t_start + i
515
+
516
+ latent_model_input = latent
517
+ if scheduler_str=='lms':
518
+ sigma = scheduler.sigmas[t_index] # last is first and first is last
519
+ latent_model_input = (latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype)
520
+ else:
521
+ assert scheduler_str in ['ddim', 'pndm', 'ddpm']
522
+
523
+ #Predict the unconditional noise residual
524
+
525
+ if len(t.shape) == 0:
526
+ t = t[None].to(unet.device)
527
+ noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional,
528
+ ).sample
529
+
530
+ if prev_latent is not None:
531
+ prev_latent_model_input = prev_latent
532
+ prev_latent_model_input = (prev_latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype)
533
+ prev_noise_pred_uncond = unet(prev_latent_model_input, t,
534
+ encoder_hidden_states=embedding_unconditional,
535
+ ).sample
536
+ # noise_pred_uncond = unet(latent_model_input, t,
537
+ # encoder_hidden_states=embedding_unconditional)['sample']
538
+
539
+ #Prepare the Cross-Attention layers
540
+ if prompt_edit is not None or prev_latent is not None:
541
+ save_last_tokens_attention()
542
+ save_last_self_attention()
543
+ else:
544
+ #Use weights on non-edited prompt when edit is None
545
+ use_last_tokens_attention_weights()
546
+
547
+ #Predict the conditional noise residual and save the cross-attention layer activations
548
+ if prev_latent is not None:
549
+ raise NotImplementedError # I totally lost track of what this is
550
+ prev_noise_pred_cond = unet(prev_latent_model_input, t, encoder_hidden_states=embedding_conditional,
551
+ ).sample
552
+ else:
553
+ noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional,
554
+ ).sample
555
+
556
+ #Edit the Cross-Attention layer activations
557
+ t_scale = t / scheduler.num_train_timesteps
558
+ if prompt_edit is not None or prev_latent is not None:
559
+ if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:
560
+ use_last_tokens_attention()
561
+ if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:
562
+ use_last_self_attention()
563
+
564
+ #Use weights on edited prompt
565
+ use_last_tokens_attention_weights()
566
+
567
+ #Predict the edited conditional noise residual using the cross-attention masks
568
+ if prompt_edit is not None:
569
+ noise_pred_cond = unet(latent_model_input, t,
570
+ encoder_hidden_states=embedding_conditional_edit).sample
571
+
572
+ #Perform guidance
573
+ # if i%(num_cycles)==0: # cycle_i+1==num_cycles:
574
+ """
575
+ if cycle_i+1==num_cycles:
576
+ noise_pred = noise_pred_uncond
577
+ else:
578
+ noise_pred = noise_pred_cond - noise_pred_uncond
579
+
580
+ """
581
+ if last_noise_preds is not None:
582
+ # print( (last_noise_preds[0]*noise_pred_uncond).sum(), (last_noise_preds[1]*noise_pred_cond).sum())
583
+ # print(F.cosine_similarity(last_noise_preds[0].flatten(), noise_pred_uncond.flatten(), dim=0),
584
+ # F.cosine_similarity(last_noise_preds[1].flatten(), noise_pred_cond.flatten(), dim=0))
585
+ last_grad= last_noise_preds[1] - last_noise_preds[0]
586
+ new_grad = noise_pred_cond - noise_pred_uncond
587
+ # print( F.cosine_similarity(last_grad.flatten(), new_grad.flatten(), dim=0))
588
+ last_noise_preds = (noise_pred_uncond, noise_pred_cond)
589
+
590
+ use_cond_guidance = True
591
+ if use_cond_guidance:
592
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
593
+ else:
594
+ noise_pred = noise_pred_uncond
595
+ if clip_guidance is not None and t_scale >= clip_start and t_scale <= clip_end:
596
+ noise_pred, latent = new_cond_fn(latent, t, t_index,
597
+ embedding_conditional, noise_pred,clip_guidance,
598
+ clip_guidance_scale,
599
+ num_cutouts,
600
+ scheduler, unet,use_cutouts=True,
601
+ cut_power=cut_power)
602
+ if normalize_noise_pred:
603
+ noise_pred = noise_pred * noise_pred_uncond.norm() / noise_pred.norm()
604
+ if scheduler_str == 'ddim':
605
+ latent = forward_step(scheduler, noise_pred,
606
+ t,
607
+ latent).prev_sample
608
+ else:
609
+ latent = scheduler.step(noise_pred,
610
+ t_index,
611
+ latent).prev_sample
612
+
613
+ if prev_latent is not None:
614
+ prev_noise_pred = prev_noise_pred_uncond + guidance_scale * (prev_noise_pred_cond - prev_noise_pred_uncond)
615
+ prev_latent = prev_scheduler.step(prev_noise_pred, t_index, prev_latent).prev_sample
616
+ if one_pass: break
617
+
618
+ #scale and decode the image latents with vae
619
+ if return_latent: return latent
620
+ latent = latent / 0.18215
621
+ image = vae.decode(latent.to(vae.dtype)).sample
622
+
623
+ image = (image / 2 + 0.5).clamp(0, 1)
624
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
625
+
626
+ image, _ = check_safety(image)
627
+
628
+ image = (image[0] * 255).round().astype("uint8")
629
+ return Image.fromarray(image)
630
+ ####################################
631
+
632
+ #### HELPER FUNCTIONS FOR OUR METHOD #####
633
+
634
+ def get_alpha_and_beta(t, scheduler):
635
+ # want to run this for both current and previous timnestep
636
+ if t.dtype==torch.long:
637
+ alpha = scheduler.alphas_cumprod[t]
638
+ return alpha, 1-alpha
639
+
640
+ if t<0:
641
+ return scheduler.final_alpha_cumprod, 1 - scheduler.final_alpha_cumprod
642
+
643
+
644
+ low = t.floor().long()
645
+ high = t.ceil().long()
646
+ rem = t - low
647
+
648
+ low_alpha = scheduler.alphas_cumprod[low]
649
+ high_alpha = scheduler.alphas_cumprod[high]
650
+ interpolated_alpha = low_alpha * rem + high_alpha * (1-rem)
651
+ interpolated_beta = 1 - interpolated_alpha
652
+ return interpolated_alpha, interpolated_beta
653
+
654
+
655
+ # A DDIM forward step function
656
+ def forward_step(
657
+ self,
658
+ model_output,
659
+ timestep: int,
660
+ sample,
661
+ eta: float = 0.0,
662
+ use_clipped_model_output: bool = False,
663
+ generator=None,
664
+ return_dict: bool = True,
665
+ use_double=False,
666
+ ) :
667
+ if self.num_inference_steps is None:
668
+ raise ValueError(
669
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
670
+ )
671
+
672
+ prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps
673
+
674
+ if timestep > self.timesteps.max():
675
+ raise NotImplementedError("Need to double check what the overflow is")
676
+
677
+ alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self)
678
+ alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self)
679
+
680
+
681
+ alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5)
682
+ first_term = (1./alpha_quotient) * sample
683
+ second_term = (1./alpha_quotient) * (beta_prod_t ** 0.5) * model_output
684
+ third_term = ((1 - alpha_prod_t_prev)**0.5) * model_output
685
+ return first_term - second_term + third_term
686
+
687
+ # A DDIM reverse step function, the inverse of above
688
+ def reverse_step(
689
+ self,
690
+ model_output,
691
+ timestep: int,
692
+ sample,
693
+ eta: float = 0.0,
694
+ use_clipped_model_output: bool = False,
695
+ generator=None,
696
+ return_dict: bool = True,
697
+ use_double=False,
698
+ ) :
699
+ if self.num_inference_steps is None:
700
+ raise ValueError(
701
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
702
+ )
703
+
704
+ prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps
705
+
706
+ if timestep > self.timesteps.max():
707
+ raise NotImplementedError
708
+ else:
709
+ alpha_prod_t = self.alphas_cumprod[timestep]
710
+
711
+ alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self)
712
+ alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self)
713
+
714
+ alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5)
715
+
716
+ first_term = alpha_quotient * sample
717
+ second_term = ((beta_prod_t)**0.5) * model_output
718
+ third_term = alpha_quotient * ((1 - alpha_prod_t_prev)**0.5) * model_output
719
+ return first_term + second_term - third_term
720
+
721
+
722
+
723
+
724
+ @torch.no_grad()
725
+ def latent_to_image(latent):
726
+ image = vae.decode(latent.to(vae.dtype)/0.18215).sample
727
+ image = prep_image_for_return(image)
728
+ return image
729
+
730
+ def prep_image_for_return(image):
731
+ image = (image / 2 + 0.5).clamp(0, 1)
732
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
733
+ image = (image[0] * 255).round().astype("uint8")
734
+ image = Image.fromarray(image)
735
+ return image
736
+
737
+ #############################
738
+
739
+ ##### MAIN EDICT FUNCTION #######
740
+ # Use EDICT_editing to perform calls
741
+
742
+ @torch.no_grad()
743
+ def coupled_stablediffusion(prompt="",
744
+ prompt_edit=None,
745
+ null_prompt='',
746
+ prompt_edit_token_weights=[],
747
+ prompt_edit_tokens_start=0.0,
748
+ prompt_edit_tokens_end=1.0,
749
+ prompt_edit_spatial_start=0.0,
750
+ prompt_edit_spatial_end=1.0,
751
+ guidance_scale=7.0, steps=50,
752
+ seed=1, width=512, height=512,
753
+ init_image=None, init_image_strength=1.0,
754
+ run_baseline=False,
755
+ use_lms=False,
756
+ leapfrog_steps=True,
757
+ reverse=False,
758
+ return_latents=False,
759
+ fixed_starting_latent=None,
760
+ beta_schedule='scaled_linear',
761
+ mix_weight=0.93):
762
+ #If seed is None, randomly select seed from 0 to 2^32-1
763
+ if seed is None: seed = random.randrange(2**32 - 1)
764
+ generator = torch.cuda.manual_seed(seed)
765
+
766
+ def image_to_latent(im):
767
+ if isinstance(im, torch.Tensor):
768
+ # assume it's the latent
769
+ # used to avoid clipping new generation before inversion
770
+ init_latent = im.to(device)
771
+ else:
772
+ #Resize and transpose for numpy b h w c -> torch b c h w
773
+ im = im.resize((width, height), resample=Image.Resampling.LANCZOS)
774
+ im = np.array(im).astype(np_dtype) / 255.0 * 2.0 - 1.0
775
+ # check if black and white
776
+ if len(im.shape) < 3:
777
+ im = np.stack([im for _ in range(3)], axis=2) # putting at end b/c channels
778
+
779
+ im = torch.from_numpy(im[np.newaxis, ...].transpose(0, 3, 1, 2))
780
+
781
+ #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
782
+ if im.shape[1] > 3:
783
+ im = im[:, :3] * im[:, 3:] + (1 - im[:, 3:])
784
+
785
+ #Move image to GPU
786
+ im = im.to(device)
787
+ #Encode image
788
+ if use_half_prec:
789
+ init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215
790
+ else:
791
+ with autocast(device):
792
+ init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215
793
+ return init_latent
794
+ assert not use_lms, "Can't invert LMS the same as DDIM"
795
+ if run_baseline: leapfrog_steps=False
796
+ #Change size to multiple of 64 to prevent size mismatches inside model
797
+ width = width - width % 64
798
+ height = height - height % 64
799
+
800
+
801
+ #Preprocess image if it exists (img2img)
802
+ if init_image is not None:
803
+ assert reverse # want to be performing deterministic noising
804
+ # can take either pair (output of generative process) or single image
805
+ if isinstance(init_image, list):
806
+ if isinstance(init_image[0], torch.Tensor):
807
+ init_latent = [t.clone() for t in init_image]
808
+ else:
809
+ init_latent = [image_to_latent(im) for im in init_image]
810
+ else:
811
+ init_latent = image_to_latent(init_image)
812
+ # this is t_start for forward, t_end for reverse
813
+ t_limit = steps - int(steps * init_image_strength)
814
+ else:
815
+ assert not reverse, 'Need image to reverse from'
816
+ init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)
817
+ t_limit = 0
818
+
819
+ if reverse:
820
+ latent = init_latent
821
+ else:
822
+ #Generate random normal noise
823
+ noise = torch.randn(init_latent.shape,
824
+ generator=generator,
825
+ device=device,
826
+ dtype=torch_dtype)
827
+ if fixed_starting_latent is None:
828
+ latent = noise
829
+ else:
830
+ if isinstance(fixed_starting_latent, list):
831
+ latent = [l.clone() for l in fixed_starting_latent]
832
+ else:
833
+ latent = fixed_starting_latent.clone()
834
+ t_limit = steps - int(steps * init_image_strength)
835
+ if isinstance(latent, list): # initializing from pair of images
836
+ latent_pair = latent
837
+ else: # initializing from noise
838
+ latent_pair = [latent.clone(), latent.clone()]
839
+
840
+
841
+ if steps==0:
842
+ if init_image is not None:
843
+ return image_to_latent(init_image)
844
+ else:
845
+ image = vae.decode(latent.to(vae.dtype) / 0.18215).sample
846
+ return prep_image_for_return(image)
847
+
848
+ #Set inference timesteps to scheduler
849
+ schedulers = []
850
+ for i in range(2):
851
+ # num_raw_timesteps = max(1000, steps)
852
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012,
853
+ beta_schedule=beta_schedule,
854
+ num_train_timesteps=1000,
855
+ clip_sample=False,
856
+ set_alpha_to_one=False)
857
+ scheduler.set_timesteps(steps)
858
+ schedulers.append(scheduler)
859
+
860
+ with autocast(device):
861
+ # CLIP Text Embeddings
862
+ tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length",
863
+ max_length=clip_tokenizer.model_max_length,
864
+ truncation=True, return_tensors="pt",
865
+ return_overflowing_tokens=True)
866
+ embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state
867
+
868
+ tokens_conditional = clip_tokenizer(prompt, padding="max_length",
869
+ max_length=clip_tokenizer.model_max_length,
870
+ truncation=True, return_tensors="pt",
871
+ return_overflowing_tokens=True)
872
+ embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state
873
+
874
+ #Process prompt editing (if running Prompt-to-Prompt)
875
+ if prompt_edit is not None:
876
+ tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length",
877
+ max_length=clip_tokenizer.model_max_length,
878
+ truncation=True, return_tensors="pt",
879
+ return_overflowing_tokens=True)
880
+ embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state
881
+
882
+ init_attention_edit(tokens_conditional, tokens_conditional_edit)
883
+
884
+ init_attention_func()
885
+ init_attention_weights(prompt_edit_token_weights)
886
+
887
+ timesteps = schedulers[0].timesteps[t_limit:]
888
+ if reverse: timesteps = timesteps.flip(0)
889
+
890
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
891
+ t_scale = t / schedulers[0].num_train_timesteps
892
+
893
+ if (reverse) and (not run_baseline):
894
+ # Reverse mixing layer
895
+ new_latents = [l.clone() for l in latent_pair]
896
+ new_latents[1] = (new_latents[1].clone() - (1-mix_weight)*new_latents[0].clone()) / mix_weight
897
+ new_latents[0] = (new_latents[0].clone() - (1-mix_weight)*new_latents[1].clone()) / mix_weight
898
+ latent_pair = new_latents
899
+
900
+ # alternate EDICT steps
901
+ for latent_i in range(2):
902
+ if run_baseline and latent_i==1: continue # just have one sequence for baseline
903
+ # this modifies latent_pair[i] while using
904
+ # latent_pair[(i+1)%2]
905
+ if reverse and (not run_baseline):
906
+ if leapfrog_steps:
907
+ # what i would be from going other way
908
+ orig_i = len(timesteps) - (i+1)
909
+ offset = (orig_i+1) % 2
910
+ latent_i = (latent_i + offset) % 2
911
+ else:
912
+ # Do 1 then 0
913
+ latent_i = (latent_i+1)%2
914
+ else:
915
+ if leapfrog_steps:
916
+ offset = i%2
917
+ latent_i = (latent_i + offset) % 2
918
+
919
+ latent_j = ((latent_i+1) % 2) if not run_baseline else latent_i
920
+
921
+ latent_model_input = latent_pair[latent_j]
922
+ latent_base = latent_pair[latent_i]
923
+
924
+ #Predict the unconditional noise residual
925
+ noise_pred_uncond = unet(latent_model_input, t,
926
+ encoder_hidden_states=embedding_unconditional).sample
927
+
928
+ #Prepare the Cross-Attention layers
929
+ if prompt_edit is not None:
930
+ save_last_tokens_attention()
931
+ save_last_self_attention()
932
+ else:
933
+ #Use weights on non-edited prompt when edit is None
934
+ use_last_tokens_attention_weights()
935
+
936
+ #Predict the conditional noise residual and save the cross-attention layer activations
937
+ noise_pred_cond = unet(latent_model_input, t,
938
+ encoder_hidden_states=embedding_conditional).sample
939
+
940
+ #Edit the Cross-Attention layer activations
941
+ if prompt_edit is not None:
942
+ t_scale = t / schedulers[0].num_train_timesteps
943
+ if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:
944
+ use_last_tokens_attention()
945
+ if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end:
946
+ use_last_self_attention()
947
+
948
+ #Use weights on edited prompt
949
+ use_last_tokens_attention_weights()
950
+
951
+ #Predict the edited conditional noise residual using the cross-attention masks
952
+ noise_pred_cond = unet(latent_model_input,
953
+ t,
954
+ encoder_hidden_states=embedding_conditional_edit).sample
955
+
956
+ #Perform guidance
957
+ grad = (noise_pred_cond - noise_pred_uncond)
958
+ noise_pred = noise_pred_uncond + guidance_scale * grad
959
+
960
+
961
+ step_call = reverse_step if reverse else forward_step
962
+ new_latent = step_call(schedulers[latent_i],
963
+ noise_pred,
964
+ t,
965
+ latent_base)# .prev_sample
966
+ new_latent = new_latent.to(latent_base.dtype)
967
+
968
+ latent_pair[latent_i] = new_latent
969
+
970
+ if (not reverse) and (not run_baseline):
971
+ # Mixing layer (contraction) during generative process
972
+ new_latents = [l.clone() for l in latent_pair]
973
+ new_latents[0] = (mix_weight*new_latents[0] + (1-mix_weight)*new_latents[1]).clone()
974
+ new_latents[1] = ((1-mix_weight)*new_latents[0] + (mix_weight)*new_latents[1]).clone()
975
+ latent_pair = new_latents
976
+
977
+ #scale and decode the image latents with vae, can return latents instead of images
978
+ if reverse or return_latents:
979
+ results = [latent_pair]
980
+ return results if len(results)>1 else results[0]
981
+
982
+ # decode latents to iamges
983
+ images = []
984
+ for latent_i in range(2):
985
+ latent = latent_pair[latent_i] / 0.18215
986
+ image = vae.decode(latent.to(vae.dtype)).sample
987
+ images.append(image)
988
+
989
+ # Return images
990
+ return_arr = []
991
+ for image in images:
992
+ image = prep_image_for_return(image)
993
+ return_arr.append(image)
994
+ results = [return_arr]
995
+ return results if len(results)>1 else results[0]
996
+
997
+
local_app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from edict_functions import EDICT_editing
4
+ from PIL import Image
5
+
6
+ def greet(name):
7
+ return "Hello " + name + "!!"
8
+
9
+
10
+ def edict(x, source_text, edit_text,
11
+ edit_strength, guidance_scale,
12
+ steps=50, mix_weight=0.93, ):
13
+ x = Image.fromarray(x)
14
+ return_im = EDICT_editing(x,
15
+ source_text,
16
+ edit_text,
17
+ steps=steps,
18
+ mix_weight=mix_weight,
19
+ init_image_strength=edit_strength,
20
+ guidance_scale=guidance_scale
21
+ )[0]
22
+ return np.array(return_im)
23
+
24
+ iface = gr.Interface(fn=edict, inputs=["image",
25
+ gr.Textbox(label="Original Description"),
26
+ gr.Textbox(label="Edit Description"),
27
+ # 50, # gr.Slider(5, 50, value=20, step=1),
28
+ # 0.93, # gr.Slider(0.5, 1, value=0.7, step=0.05),
29
+ gr.Slider(0.0, 1, value=0.8, step=0.05),
30
+ gr.Slider(0, 10, value=3, step=0.5),
31
+ ],
32
+ outputs="image")
33
+ iface.launch()
my_diffusers/__init__.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .utils import (
2
+ is_inflect_available,
3
+ is_onnx_available,
4
+ is_scipy_available,
5
+ is_transformers_available,
6
+ is_unidecode_available,
7
+ )
8
+
9
+
10
+ __version__ = "0.3.0"
11
+
12
+ from .configuration_utils import ConfigMixin
13
+ from .modeling_utils import ModelMixin
14
+ from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
15
+ from .onnx_utils import OnnxRuntimeModel
16
+ from .optimization import (
17
+ get_constant_schedule,
18
+ get_constant_schedule_with_warmup,
19
+ get_cosine_schedule_with_warmup,
20
+ get_cosine_with_hard_restarts_schedule_with_warmup,
21
+ get_linear_schedule_with_warmup,
22
+ get_polynomial_decay_schedule_with_warmup,
23
+ get_scheduler,
24
+ )
25
+ from .pipeline_utils import DiffusionPipeline
26
+ from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
27
+ from .schedulers import (
28
+ DDIMScheduler,
29
+ DDPMScheduler,
30
+ KarrasVeScheduler,
31
+ PNDMScheduler,
32
+ SchedulerMixin,
33
+ ScoreSdeVeScheduler,
34
+ )
35
+ from .utils import logging
36
+
37
+
38
+ if is_scipy_available():
39
+ from .schedulers import LMSDiscreteScheduler
40
+ else:
41
+ from .utils.dummy_scipy_objects import * # noqa F403
42
+
43
+ from .training_utils import EMAModel
44
+
45
+
46
+ if is_transformers_available():
47
+ from .pipelines import (
48
+ LDMTextToImagePipeline,
49
+ StableDiffusionImg2ImgPipeline,
50
+ StableDiffusionInpaintPipeline,
51
+ StableDiffusionPipeline,
52
+ )
53
+ else:
54
+ from .utils.dummy_transformers_objects import * # noqa F403
55
+
56
+
57
+ if is_transformers_available() and is_onnx_available():
58
+ from .pipelines import StableDiffusionOnnxPipeline
59
+ else:
60
+ from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
my_diffusers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.86 kB). View file
 
my_diffusers/__pycache__/configuration_utils.cpython-38.pyc ADDED
Binary file (15.5 kB). View file
 
my_diffusers/__pycache__/modeling_utils.cpython-38.pyc ADDED
Binary file (18.9 kB). View file
 
my_diffusers/__pycache__/onnx_utils.cpython-38.pyc ADDED
Binary file (6.28 kB). View file
 
my_diffusers/__pycache__/optimization.cpython-38.pyc ADDED
Binary file (10.2 kB). View file
 
my_diffusers/__pycache__/pipeline_utils.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
my_diffusers/__pycache__/training_utils.cpython-38.pyc ADDED
Binary file (3.63 kB). View file
 
my_diffusers/commands/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from abc import ABC, abstractmethod
16
+ from argparse import ArgumentParser
17
+
18
+
19
+ class BaseDiffusersCLICommand(ABC):
20
+ @staticmethod
21
+ @abstractmethod
22
+ def register_subcommand(parser: ArgumentParser):
23
+ raise NotImplementedError()
24
+
25
+ @abstractmethod
26
+ def run(self):
27
+ raise NotImplementedError()
my_diffusers/commands/diffusers_cli.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from argparse import ArgumentParser
17
+
18
+ from .env import EnvironmentCommand
19
+
20
+
21
+ def main():
22
+ parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
+ commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
+
25
+ # Register commands
26
+ EnvironmentCommand.register_subcommand(commands_parser)
27
+
28
+ # Let's go
29
+ args = parser.parse_args()
30
+
31
+ if not hasattr(args, "func"):
32
+ parser.print_help()
33
+ exit(1)
34
+
35
+ # Run
36
+ service = args.func(args)
37
+ service.run()
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
my_diffusers/commands/env.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import platform
16
+ from argparse import ArgumentParser
17
+
18
+ import huggingface_hub
19
+
20
+ from .. import __version__ as version
21
+ from ..utils import is_torch_available, is_transformers_available
22
+ from . import BaseDiffusersCLICommand
23
+
24
+
25
+ def info_command_factory(_):
26
+ return EnvironmentCommand()
27
+
28
+
29
+ class EnvironmentCommand(BaseDiffusersCLICommand):
30
+ @staticmethod
31
+ def register_subcommand(parser: ArgumentParser):
32
+ download_parser = parser.add_parser("env")
33
+ download_parser.set_defaults(func=info_command_factory)
34
+
35
+ def run(self):
36
+ hub_version = huggingface_hub.__version__
37
+
38
+ pt_version = "not installed"
39
+ pt_cuda_available = "NA"
40
+ if is_torch_available():
41
+ import torch
42
+
43
+ pt_version = torch.__version__
44
+ pt_cuda_available = torch.cuda.is_available()
45
+
46
+ transformers_version = "not installed"
47
+ if is_transformers_available:
48
+ import transformers
49
+
50
+ transformers_version = transformers.__version__
51
+
52
+ info = {
53
+ "`diffusers` version": version,
54
+ "Platform": platform.platform(),
55
+ "Python version": platform.python_version(),
56
+ "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
57
+ "Huggingface_hub version": hub_version,
58
+ "Transformers version": transformers_version,
59
+ "Using GPU in script?": "<fill in>",
60
+ "Using distributed or parallel set-up in script?": "<fill in>",
61
+ }
62
+
63
+ print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
64
+ print(self.format_dict(info))
65
+
66
+ return info
67
+
68
+ @staticmethod
69
+ def format_dict(d):
70
+ return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
my_diffusers/configuration_utils.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ ConfigMixinuration base class and utilities."""
17
+ import functools
18
+ import inspect
19
+ import json
20
+ import os
21
+ import re
22
+ from collections import OrderedDict
23
+ from typing import Any, Dict, Tuple, Union
24
+
25
+ from huggingface_hub import hf_hub_download
26
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
27
+ from requests import HTTPError
28
+
29
+ from . import __version__
30
+ from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+ _re_configuration_file = re.compile(r"config\.(.*)\.json")
36
+
37
+
38
+ class ConfigMixin:
39
+ r"""
40
+ Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
41
+ methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
42
+ - [`~ConfigMixin.from_config`]
43
+ - [`~ConfigMixin.save_config`]
44
+
45
+ Class attributes:
46
+ - **config_name** (`str`) -- A filename under which the config should stored when calling
47
+ [`~ConfigMixin.save_config`] (should be overridden by parent class).
48
+ - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
49
+ overridden by parent class).
50
+ """
51
+ config_name = None
52
+ ignore_for_config = []
53
+
54
+ def register_to_config(self, **kwargs):
55
+ if self.config_name is None:
56
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
57
+ kwargs["_class_name"] = self.__class__.__name__
58
+ kwargs["_diffusers_version"] = __version__
59
+
60
+ for key, value in kwargs.items():
61
+ try:
62
+ setattr(self, key, value)
63
+ except AttributeError as err:
64
+ logger.error(f"Can't set {key} with value {value} for {self}")
65
+ raise err
66
+
67
+ if not hasattr(self, "_internal_dict"):
68
+ internal_dict = kwargs
69
+ else:
70
+ previous_dict = dict(self._internal_dict)
71
+ internal_dict = {**self._internal_dict, **kwargs}
72
+ logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
73
+
74
+ self._internal_dict = FrozenDict(internal_dict)
75
+
76
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
77
+ """
78
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
79
+ [`~ConfigMixin.from_config`] class method.
80
+
81
+ Args:
82
+ save_directory (`str` or `os.PathLike`):
83
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
84
+ """
85
+ if os.path.isfile(save_directory):
86
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
87
+
88
+ os.makedirs(save_directory, exist_ok=True)
89
+
90
+ # If we save using the predefined names, we can load using `from_config`
91
+ output_config_file = os.path.join(save_directory, self.config_name)
92
+
93
+ self.to_json_file(output_config_file)
94
+ logger.info(f"ConfigMixinuration saved in {output_config_file}")
95
+
96
+ @classmethod
97
+ def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
98
+ r"""
99
+ Instantiate a Python class from a pre-defined JSON-file.
100
+
101
+ Parameters:
102
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
103
+ Can be either:
104
+
105
+ - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
106
+ organization name, like `google/ddpm-celebahq-256`.
107
+ - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
108
+ `./my_model_directory/`.
109
+
110
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
111
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
112
+ standard cache should not be used.
113
+ ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
114
+ Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
115
+ as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
116
+ checkpoint with 3 labels).
117
+ force_download (`bool`, *optional*, defaults to `False`):
118
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
119
+ cached versions if they exist.
120
+ resume_download (`bool`, *optional*, defaults to `False`):
121
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
122
+ file exists.
123
+ proxies (`Dict[str, str]`, *optional*):
124
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
125
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
126
+ output_loading_info(`bool`, *optional*, defaults to `False`):
127
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
128
+ local_files_only(`bool`, *optional*, defaults to `False`):
129
+ Whether or not to only look at local files (i.e., do not try to download the model).
130
+ use_auth_token (`str` or *bool*, *optional*):
131
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
132
+ when running `transformers-cli login` (stored in `~/.huggingface`).
133
+ revision (`str`, *optional*, defaults to `"main"`):
134
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
135
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
136
+ identifier allowed by git.
137
+ mirror (`str`, *optional*):
138
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
139
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
140
+ Please refer to the mirror site for more information.
141
+
142
+ <Tip>
143
+
144
+ Passing `use_auth_token=True`` is required when you want to use a private model.
145
+
146
+ </Tip>
147
+
148
+ <Tip>
149
+
150
+ Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
151
+ use this method in a firewalled environment.
152
+
153
+ </Tip>
154
+
155
+ """
156
+ config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
157
+
158
+ init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
159
+
160
+ model = cls(**init_dict)
161
+
162
+ if return_unused_kwargs:
163
+ return model, unused_kwargs
164
+ else:
165
+ return model
166
+
167
+ @classmethod
168
+ def get_config_dict(
169
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
170
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
171
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
172
+ force_download = kwargs.pop("force_download", False)
173
+ resume_download = kwargs.pop("resume_download", False)
174
+ proxies = kwargs.pop("proxies", None)
175
+ use_auth_token = kwargs.pop("use_auth_token", None)
176
+ local_files_only = kwargs.pop("local_files_only", False)
177
+ revision = kwargs.pop("revision", None)
178
+ subfolder = kwargs.pop("subfolder", None)
179
+
180
+ user_agent = {"file_type": "config"}
181
+
182
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
183
+
184
+ if cls.config_name is None:
185
+ raise ValueError(
186
+ "`self.config_name` is not defined. Note that one should not load a config from "
187
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
188
+ )
189
+
190
+ if os.path.isfile(pretrained_model_name_or_path):
191
+ config_file = pretrained_model_name_or_path
192
+ elif os.path.isdir(pretrained_model_name_or_path):
193
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
194
+ # Load from a PyTorch checkpoint
195
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
196
+ elif subfolder is not None and os.path.isfile(
197
+ os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
198
+ ):
199
+ config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
200
+ else:
201
+ raise EnvironmentError(
202
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
203
+ )
204
+ else:
205
+ try:
206
+ # Load from URL or cache if already cached
207
+ config_file = hf_hub_download(
208
+ pretrained_model_name_or_path,
209
+ filename=cls.config_name,
210
+ cache_dir=cache_dir,
211
+ force_download=force_download,
212
+ proxies=proxies,
213
+ resume_download=resume_download,
214
+ local_files_only=local_files_only,
215
+ use_auth_token=use_auth_token,
216
+ user_agent=user_agent,
217
+ subfolder=subfolder,
218
+ revision=revision,
219
+ )
220
+
221
+ except RepositoryNotFoundError:
222
+ raise EnvironmentError(
223
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
224
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
225
+ " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
226
+ " login` and pass `use_auth_token=True`."
227
+ )
228
+ except RevisionNotFoundError:
229
+ raise EnvironmentError(
230
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
231
+ " this model name. Check the model page at"
232
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
233
+ )
234
+ except EntryNotFoundError:
235
+ raise EnvironmentError(
236
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
237
+ )
238
+ except HTTPError as err:
239
+ raise EnvironmentError(
240
+ "There was a specific connection error when trying to load"
241
+ f" {pretrained_model_name_or_path}:\n{err}"
242
+ )
243
+ except ValueError:
244
+ raise EnvironmentError(
245
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
246
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
247
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
248
+ " run the library in offline mode at"
249
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
250
+ )
251
+ except EnvironmentError:
252
+ raise EnvironmentError(
253
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
254
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
255
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
256
+ f"containing a {cls.config_name} file"
257
+ )
258
+
259
+ try:
260
+ # Load config dict
261
+ config_dict = cls._dict_from_json_file(config_file)
262
+ except (json.JSONDecodeError, UnicodeDecodeError):
263
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
264
+
265
+ return config_dict
266
+
267
+ @classmethod
268
+ def extract_init_dict(cls, config_dict, **kwargs):
269
+ expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
270
+ expected_keys.remove("self")
271
+ # remove general kwargs if present in dict
272
+ if "kwargs" in expected_keys:
273
+ expected_keys.remove("kwargs")
274
+ # remove keys to be ignored
275
+ if len(cls.ignore_for_config) > 0:
276
+ expected_keys = expected_keys - set(cls.ignore_for_config)
277
+ init_dict = {}
278
+ for key in expected_keys:
279
+ if key in kwargs:
280
+ # overwrite key
281
+ init_dict[key] = kwargs.pop(key)
282
+ elif key in config_dict:
283
+ # use value from config dict
284
+ init_dict[key] = config_dict.pop(key)
285
+
286
+ unused_kwargs = config_dict.update(kwargs)
287
+
288
+ passed_keys = set(init_dict.keys())
289
+ if len(expected_keys - passed_keys) > 0:
290
+ logger.warning(
291
+ f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
292
+ )
293
+
294
+ return init_dict, unused_kwargs
295
+
296
+ @classmethod
297
+ def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
298
+ with open(json_file, "r", encoding="utf-8") as reader:
299
+ text = reader.read()
300
+ return json.loads(text)
301
+
302
+ def __repr__(self):
303
+ return f"{self.__class__.__name__} {self.to_json_string()}"
304
+
305
+ @property
306
+ def config(self) -> Dict[str, Any]:
307
+ return self._internal_dict
308
+
309
+ def to_json_string(self) -> str:
310
+ """
311
+ Serializes this instance to a JSON string.
312
+
313
+ Returns:
314
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
315
+ """
316
+ config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
317
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
318
+
319
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
320
+ """
321
+ Save this instance to a JSON file.
322
+
323
+ Args:
324
+ json_file_path (`str` or `os.PathLike`):
325
+ Path to the JSON file in which this configuration instance's parameters will be saved.
326
+ """
327
+ with open(json_file_path, "w", encoding="utf-8") as writer:
328
+ writer.write(self.to_json_string())
329
+
330
+
331
+ class FrozenDict(OrderedDict):
332
+ def __init__(self, *args, **kwargs):
333
+ super().__init__(*args, **kwargs)
334
+
335
+ for key, value in self.items():
336
+ setattr(self, key, value)
337
+
338
+ self.__frozen = True
339
+
340
+ def __delitem__(self, *args, **kwargs):
341
+ raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
342
+
343
+ def setdefault(self, *args, **kwargs):
344
+ raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
345
+
346
+ def pop(self, *args, **kwargs):
347
+ raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
348
+
349
+ def update(self, *args, **kwargs):
350
+ raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
351
+
352
+ def __setattr__(self, name, value):
353
+ if hasattr(self, "__frozen") and self.__frozen:
354
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
355
+ super().__setattr__(name, value)
356
+
357
+ def __setitem__(self, name, value):
358
+ if hasattr(self, "__frozen") and self.__frozen:
359
+ raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
360
+ super().__setitem__(name, value)
361
+
362
+
363
+ def register_to_config(init):
364
+ r"""
365
+ Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
366
+ automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
367
+ shouldn't be registered in the config, use the `ignore_for_config` class variable
368
+
369
+ Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
370
+ """
371
+
372
+ @functools.wraps(init)
373
+ def inner_init(self, *args, **kwargs):
374
+ # Ignore private kwargs in the init.
375
+ init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
376
+ init(self, *args, **init_kwargs)
377
+ if not isinstance(self, ConfigMixin):
378
+ raise RuntimeError(
379
+ f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
380
+ "not inherit from `ConfigMixin`."
381
+ )
382
+
383
+ ignore = getattr(self, "ignore_for_config", [])
384
+ # Get positional arguments aligned with kwargs
385
+ new_kwargs = {}
386
+ signature = inspect.signature(init)
387
+ parameters = {
388
+ name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
389
+ }
390
+ for arg, name in zip(args, parameters.keys()):
391
+ new_kwargs[name] = arg
392
+
393
+ # Then add all kwargs
394
+ new_kwargs.update(
395
+ {
396
+ k: init_kwargs.get(k, default)
397
+ for k, default in parameters.items()
398
+ if k not in ignore and k not in new_kwargs
399
+ }
400
+ )
401
+ getattr(self, "register_to_config")(**new_kwargs)
402
+
403
+ return inner_init
my_diffusers/dependency_versions_check.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import sys
15
+
16
+ from .dependency_versions_table import deps
17
+ from .utils.versions import require_version, require_version_core
18
+
19
+
20
+ # define which module versions we always want to check at run time
21
+ # (usually the ones defined in `install_requires` in setup.py)
22
+ #
23
+ # order specific notes:
24
+ # - tqdm must be checked before tokenizers
25
+
26
+ pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
+ if sys.version_info < (3, 7):
28
+ pkgs_to_check_at_runtime.append("dataclasses")
29
+ if sys.version_info < (3, 8):
30
+ pkgs_to_check_at_runtime.append("importlib_metadata")
31
+
32
+ for pkg in pkgs_to_check_at_runtime:
33
+ if pkg in deps:
34
+ if pkg == "tokenizers":
35
+ # must be loaded here, or else tqdm check may fail
36
+ from .utils import is_tokenizers_available
37
+
38
+ if not is_tokenizers_available():
39
+ continue # not required, check version only if installed
40
+
41
+ require_version_core(deps[pkg])
42
+ else:
43
+ raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
+
45
+
46
+ def dep_version_check(pkg, hint=None):
47
+ require_version(deps[pkg], hint)
my_diffusers/dependency_versions_table.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
+ # 1. modify the `_deps` dict in setup.py
3
+ # 2. run `make deps_table_update``
4
+ deps = {
5
+ "Pillow": "Pillow",
6
+ "accelerate": "accelerate>=0.11.0",
7
+ "black": "black==22.3",
8
+ "datasets": "datasets",
9
+ "filelock": "filelock",
10
+ "flake8": "flake8>=3.8.3",
11
+ "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
+ "huggingface-hub": "huggingface-hub>=0.8.1",
13
+ "importlib_metadata": "importlib_metadata",
14
+ "isort": "isort>=5.5.4",
15
+ "modelcards": "modelcards==0.1.4",
16
+ "numpy": "numpy",
17
+ "pytest": "pytest",
18
+ "pytest-timeout": "pytest-timeout",
19
+ "pytest-xdist": "pytest-xdist",
20
+ "scipy": "scipy",
21
+ "regex": "regex!=2019.12.17",
22
+ "requests": "requests",
23
+ "tensorboard": "tensorboard",
24
+ "torch": "torch>=1.4",
25
+ "transformers": "transformers>=4.21.0",
26
+ }
my_diffusers/dynamic_modules_utils.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Utilities to dynamically load objects from the Hub."""
16
+
17
+ import importlib
18
+ import os
19
+ import re
20
+ import shutil
21
+ import sys
22
+ from pathlib import Path
23
+ from typing import Dict, Optional, Union
24
+
25
+ from huggingface_hub import cached_download
26
+
27
+ from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ def init_hf_modules():
34
+ """
35
+ Creates the cache directory for modules with an init, and adds it to the Python path.
36
+ """
37
+ # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
38
+ if HF_MODULES_CACHE in sys.path:
39
+ return
40
+
41
+ sys.path.append(HF_MODULES_CACHE)
42
+ os.makedirs(HF_MODULES_CACHE, exist_ok=True)
43
+ init_path = Path(HF_MODULES_CACHE) / "__init__.py"
44
+ if not init_path.exists():
45
+ init_path.touch()
46
+
47
+
48
+ def create_dynamic_module(name: Union[str, os.PathLike]):
49
+ """
50
+ Creates a dynamic module in the cache directory for modules.
51
+ """
52
+ init_hf_modules()
53
+ dynamic_module_path = Path(HF_MODULES_CACHE) / name
54
+ # If the parent module does not exist yet, recursively create it.
55
+ if not dynamic_module_path.parent.exists():
56
+ create_dynamic_module(dynamic_module_path.parent)
57
+ os.makedirs(dynamic_module_path, exist_ok=True)
58
+ init_path = dynamic_module_path / "__init__.py"
59
+ if not init_path.exists():
60
+ init_path.touch()
61
+
62
+
63
+ def get_relative_imports(module_file):
64
+ """
65
+ Get the list of modules that are relatively imported in a module file.
66
+
67
+ Args:
68
+ module_file (`str` or `os.PathLike`): The module file to inspect.
69
+ """
70
+ with open(module_file, "r", encoding="utf-8") as f:
71
+ content = f.read()
72
+
73
+ # Imports of the form `import .xxx`
74
+ relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
75
+ # Imports of the form `from .xxx import yyy`
76
+ relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
77
+ # Unique-ify
78
+ return list(set(relative_imports))
79
+
80
+
81
+ def get_relative_import_files(module_file):
82
+ """
83
+ Get the list of all files that are needed for a given module. Note that this function recurses through the relative
84
+ imports (if a imports b and b imports c, it will return module files for b and c).
85
+
86
+ Args:
87
+ module_file (`str` or `os.PathLike`): The module file to inspect.
88
+ """
89
+ no_change = False
90
+ files_to_check = [module_file]
91
+ all_relative_imports = []
92
+
93
+ # Let's recurse through all relative imports
94
+ while not no_change:
95
+ new_imports = []
96
+ for f in files_to_check:
97
+ new_imports.extend(get_relative_imports(f))
98
+
99
+ module_path = Path(module_file).parent
100
+ new_import_files = [str(module_path / m) for m in new_imports]
101
+ new_import_files = [f for f in new_import_files if f not in all_relative_imports]
102
+ files_to_check = [f"{f}.py" for f in new_import_files]
103
+
104
+ no_change = len(new_import_files) == 0
105
+ all_relative_imports.extend(files_to_check)
106
+
107
+ return all_relative_imports
108
+
109
+
110
+ def check_imports(filename):
111
+ """
112
+ Check if the current Python environment contains all the libraries that are imported in a file.
113
+ """
114
+ with open(filename, "r", encoding="utf-8") as f:
115
+ content = f.read()
116
+
117
+ # Imports of the form `import xxx`
118
+ imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
119
+ # Imports of the form `from xxx import yyy`
120
+ imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
121
+ # Only keep the top-level module
122
+ imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
123
+
124
+ # Unique-ify and test we got them all
125
+ imports = list(set(imports))
126
+ missing_packages = []
127
+ for imp in imports:
128
+ try:
129
+ importlib.import_module(imp)
130
+ except ImportError:
131
+ missing_packages.append(imp)
132
+
133
+ if len(missing_packages) > 0:
134
+ raise ImportError(
135
+ "This modeling file requires the following packages that were not found in your environment: "
136
+ f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
137
+ )
138
+
139
+ return get_relative_imports(filename)
140
+
141
+
142
+ def get_class_in_module(class_name, module_path):
143
+ """
144
+ Import a module on the cache directory for modules and extract a class from it.
145
+ """
146
+ module_path = module_path.replace(os.path.sep, ".")
147
+ module = importlib.import_module(module_path)
148
+ return getattr(module, class_name)
149
+
150
+
151
+ def get_cached_module_file(
152
+ pretrained_model_name_or_path: Union[str, os.PathLike],
153
+ module_file: str,
154
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
155
+ force_download: bool = False,
156
+ resume_download: bool = False,
157
+ proxies: Optional[Dict[str, str]] = None,
158
+ use_auth_token: Optional[Union[bool, str]] = None,
159
+ revision: Optional[str] = None,
160
+ local_files_only: bool = False,
161
+ ):
162
+ """
163
+ Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
164
+ Transformers module.
165
+
166
+ Args:
167
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
168
+ This can be either:
169
+
170
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
171
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
172
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
173
+ - a path to a *directory* containing a configuration file saved using the
174
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
175
+
176
+ module_file (`str`):
177
+ The name of the module file containing the class to look for.
178
+ cache_dir (`str` or `os.PathLike`, *optional*):
179
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
180
+ cache should not be used.
181
+ force_download (`bool`, *optional*, defaults to `False`):
182
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
183
+ exist.
184
+ resume_download (`bool`, *optional*, defaults to `False`):
185
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
186
+ proxies (`Dict[str, str]`, *optional*):
187
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
188
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
189
+ use_auth_token (`str` or *bool*, *optional*):
190
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
191
+ when running `transformers-cli login` (stored in `~/.huggingface`).
192
+ revision (`str`, *optional*, defaults to `"main"`):
193
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
194
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
195
+ identifier allowed by git.
196
+ local_files_only (`bool`, *optional*, defaults to `False`):
197
+ If `True`, will only try to load the tokenizer configuration from local files.
198
+
199
+ <Tip>
200
+
201
+ Passing `use_auth_token=True` is required when you want to use a private model.
202
+
203
+ </Tip>
204
+
205
+ Returns:
206
+ `str`: The path to the module inside the cache.
207
+ """
208
+ # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
209
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
210
+ module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
211
+ submodule = "local"
212
+
213
+ if os.path.isfile(module_file_or_url):
214
+ resolved_module_file = module_file_or_url
215
+ else:
216
+ try:
217
+ # Load from URL or cache if already cached
218
+ resolved_module_file = cached_download(
219
+ module_file_or_url,
220
+ cache_dir=cache_dir,
221
+ force_download=force_download,
222
+ proxies=proxies,
223
+ resume_download=resume_download,
224
+ local_files_only=local_files_only,
225
+ use_auth_token=use_auth_token,
226
+ )
227
+
228
+ except EnvironmentError:
229
+ logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
230
+ raise
231
+
232
+ # Check we have all the requirements in our environment
233
+ modules_needed = check_imports(resolved_module_file)
234
+
235
+ # Now we move the module inside our cached dynamic modules.
236
+ full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
237
+ create_dynamic_module(full_submodule)
238
+ submodule_path = Path(HF_MODULES_CACHE) / full_submodule
239
+ # We always copy local files (we could hash the file to see if there was a change, and give them the name of
240
+ # that hash, to only copy when there is a modification but it seems overkill for now).
241
+ # The only reason we do the copy is to avoid putting too many folders in sys.path.
242
+ shutil.copy(resolved_module_file, submodule_path / module_file)
243
+ for module_needed in modules_needed:
244
+ module_needed = f"{module_needed}.py"
245
+ shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
246
+ return os.path.join(full_submodule, module_file)
247
+
248
+
249
+ def get_class_from_dynamic_module(
250
+ pretrained_model_name_or_path: Union[str, os.PathLike],
251
+ module_file: str,
252
+ class_name: str,
253
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
254
+ force_download: bool = False,
255
+ resume_download: bool = False,
256
+ proxies: Optional[Dict[str, str]] = None,
257
+ use_auth_token: Optional[Union[bool, str]] = None,
258
+ revision: Optional[str] = None,
259
+ local_files_only: bool = False,
260
+ **kwargs,
261
+ ):
262
+ """
263
+ Extracts a class from a module file, present in the local folder or repository of a model.
264
+
265
+ <Tip warning={true}>
266
+
267
+ Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
268
+ therefore only be called on trusted repos.
269
+
270
+ </Tip>
271
+
272
+ Args:
273
+ pretrained_model_name_or_path (`str` or `os.PathLike`):
274
+ This can be either:
275
+
276
+ - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
277
+ huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
278
+ under a user or organization name, like `dbmdz/bert-base-german-cased`.
279
+ - a path to a *directory* containing a configuration file saved using the
280
+ [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
281
+
282
+ module_file (`str`):
283
+ The name of the module file containing the class to look for.
284
+ class_name (`str`):
285
+ The name of the class to import in the module.
286
+ cache_dir (`str` or `os.PathLike`, *optional*):
287
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
288
+ cache should not be used.
289
+ force_download (`bool`, *optional*, defaults to `False`):
290
+ Whether or not to force to (re-)download the configuration files and override the cached versions if they
291
+ exist.
292
+ resume_download (`bool`, *optional*, defaults to `False`):
293
+ Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
294
+ proxies (`Dict[str, str]`, *optional*):
295
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
296
+ 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
297
+ use_auth_token (`str` or `bool`, *optional*):
298
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
299
+ when running `transformers-cli login` (stored in `~/.huggingface`).
300
+ revision (`str`, *optional*, defaults to `"main"`):
301
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
302
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
303
+ identifier allowed by git.
304
+ local_files_only (`bool`, *optional*, defaults to `False`):
305
+ If `True`, will only try to load the tokenizer configuration from local files.
306
+
307
+ <Tip>
308
+
309
+ Passing `use_auth_token=True` is required when you want to use a private model.
310
+
311
+ </Tip>
312
+
313
+ Returns:
314
+ `type`: The class, dynamically imported from the module.
315
+
316
+ Examples:
317
+
318
+ ```python
319
+ # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
320
+ # module.
321
+ cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
322
+ ```"""
323
+ # And lastly we get the class inside our newly created module
324
+ final_module = get_cached_module_file(
325
+ pretrained_model_name_or_path,
326
+ module_file,
327
+ cache_dir=cache_dir,
328
+ force_download=force_download,
329
+ resume_download=resume_download,
330
+ proxies=proxies,
331
+ use_auth_token=use_auth_token,
332
+ revision=revision,
333
+ local_files_only=local_files_only,
334
+ )
335
+ return get_class_in_module(class_name, final_module.replace(".py", ""))
my_diffusers/hub_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import os
18
+ import shutil
19
+ from pathlib import Path
20
+ from typing import Optional
21
+
22
+ from huggingface_hub import HfFolder, Repository, whoami
23
+
24
+ from .pipeline_utils import DiffusionPipeline
25
+ from .utils import is_modelcards_available, logging
26
+
27
+
28
+ if is_modelcards_available():
29
+ from modelcards import CardData, ModelCard
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
36
+
37
+
38
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
39
+ if token is None:
40
+ token = HfFolder.get_token()
41
+ if organization is None:
42
+ username = whoami(token)["name"]
43
+ return f"{username}/{model_id}"
44
+ else:
45
+ return f"{organization}/{model_id}"
46
+
47
+
48
+ def init_git_repo(args, at_init: bool = False):
49
+ """
50
+ Args:
51
+ Initializes a git repo in `args.hub_model_id`.
52
+ at_init (`bool`, *optional*, defaults to `False`):
53
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
54
+ and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
55
+ """
56
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
57
+ return
58
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
59
+ use_auth_token = True if hub_token is None else hub_token
60
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
61
+ repo_name = Path(args.output_dir).absolute().name
62
+ else:
63
+ repo_name = args.hub_model_id
64
+ if "/" not in repo_name:
65
+ repo_name = get_full_repo_name(repo_name, token=hub_token)
66
+
67
+ try:
68
+ repo = Repository(
69
+ args.output_dir,
70
+ clone_from=repo_name,
71
+ use_auth_token=use_auth_token,
72
+ private=args.hub_private_repo,
73
+ )
74
+ except EnvironmentError:
75
+ if args.overwrite_output_dir and at_init:
76
+ # Try again after wiping output_dir
77
+ shutil.rmtree(args.output_dir)
78
+ repo = Repository(
79
+ args.output_dir,
80
+ clone_from=repo_name,
81
+ use_auth_token=use_auth_token,
82
+ )
83
+ else:
84
+ raise
85
+
86
+ repo.git_pull()
87
+
88
+ # By default, ignore the checkpoint folders
89
+ if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
90
+ with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
91
+ writer.writelines(["checkpoint-*/"])
92
+
93
+ return repo
94
+
95
+
96
+ def push_to_hub(
97
+ args,
98
+ pipeline: DiffusionPipeline,
99
+ repo: Repository,
100
+ commit_message: Optional[str] = "End of training",
101
+ blocking: bool = True,
102
+ **kwargs,
103
+ ) -> str:
104
+ """
105
+ Parameters:
106
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
107
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
108
+ Message to commit while pushing.
109
+ blocking (`bool`, *optional*, defaults to `True`):
110
+ Whether the function should return only when the `git push` has finished.
111
+ kwargs:
112
+ Additional keyword arguments passed along to [`create_model_card`].
113
+ Returns:
114
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
115
+ commit and an object to track the progress of the commit if `blocking=True`
116
+ """
117
+
118
+ if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
119
+ model_name = Path(args.output_dir).name
120
+ else:
121
+ model_name = args.hub_model_id.split("/")[-1]
122
+
123
+ output_dir = args.output_dir
124
+ os.makedirs(output_dir, exist_ok=True)
125
+ logger.info(f"Saving pipeline checkpoint to {output_dir}")
126
+ pipeline.save_pretrained(output_dir)
127
+
128
+ # Only push from one node.
129
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
130
+ return
131
+
132
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
133
+ if (
134
+ blocking
135
+ and len(repo.command_queue) > 0
136
+ and repo.command_queue[-1] is not None
137
+ and not repo.command_queue[-1].is_done
138
+ ):
139
+ repo.command_queue[-1]._process.kill()
140
+
141
+ git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
142
+ # push separately the model card to be independent from the rest of the model
143
+ create_model_card(args, model_name=model_name)
144
+ try:
145
+ repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
146
+ except EnvironmentError as exc:
147
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
148
+
149
+ return git_head_commit_url
150
+
151
+
152
+ def create_model_card(args, model_name):
153
+ if not is_modelcards_available:
154
+ raise ValueError(
155
+ "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
156
+ " install the package with `pip install modelcards`."
157
+ )
158
+
159
+ if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
160
+ return
161
+
162
+ hub_token = args.hub_token if hasattr(args, "hub_token") else None
163
+ repo_name = get_full_repo_name(model_name, token=hub_token)
164
+
165
+ model_card = ModelCard.from_template(
166
+ card_data=CardData( # Card metadata object that will be converted to YAML block
167
+ language="en",
168
+ license="apache-2.0",
169
+ library_name="diffusers",
170
+ tags=[],
171
+ datasets=args.dataset_name,
172
+ metrics=[],
173
+ ),
174
+ template_path=MODEL_CARD_TEMPLATE_PATH,
175
+ model_name=model_name,
176
+ repo_name=repo_name,
177
+ dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
178
+ learning_rate=args.learning_rate,
179
+ train_batch_size=args.train_batch_size,
180
+ eval_batch_size=args.eval_batch_size,
181
+ gradient_accumulation_steps=args.gradient_accumulation_steps
182
+ if hasattr(args, "gradient_accumulation_steps")
183
+ else None,
184
+ adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
185
+ adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
186
+ adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
187
+ adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
188
+ lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
189
+ lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
190
+ ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
191
+ ema_power=args.ema_power if hasattr(args, "ema_power") else None,
192
+ ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
193
+ mixed_precision=args.mixed_precision,
194
+ )
195
+
196
+ card_path = os.path.join(args.output_dir, "README.md")
197
+ model_card.save(card_path)
my_diffusers/modeling_utils.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ from typing import Callable, List, Optional, Tuple, Union
19
+
20
+ import torch
21
+ from torch import Tensor, device
22
+
23
+ from huggingface_hub import hf_hub_download
24
+ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
25
+ from requests import HTTPError
26
+
27
+ from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
28
+
29
+
30
+ WEIGHTS_NAME = "diffusion_pytorch_model.bin"
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ def get_parameter_device(parameter: torch.nn.Module):
37
+ try:
38
+ return next(parameter.parameters()).device
39
+ except StopIteration:
40
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
41
+
42
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
43
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
44
+ return tuples
45
+
46
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
47
+ first_tuple = next(gen)
48
+ return first_tuple[1].device
49
+
50
+
51
+ def get_parameter_dtype(parameter: torch.nn.Module):
52
+ try:
53
+ return next(parameter.parameters()).dtype
54
+ except StopIteration:
55
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
56
+
57
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
58
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
59
+ return tuples
60
+
61
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
62
+ first_tuple = next(gen)
63
+ return first_tuple[1].dtype
64
+
65
+
66
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
67
+ """
68
+ Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
69
+ """
70
+ try:
71
+ return torch.load(checkpoint_file, map_location="cpu")
72
+ except Exception as e:
73
+ try:
74
+ with open(checkpoint_file) as f:
75
+ if f.read().startswith("version"):
76
+ raise OSError(
77
+ "You seem to have cloned a repository without having git-lfs installed. Please install "
78
+ "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
79
+ "you cloned."
80
+ )
81
+ else:
82
+ raise ValueError(
83
+ f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
84
+ "model. Make sure you have saved the model properly."
85
+ ) from e
86
+ except (UnicodeDecodeError, ValueError):
87
+ raise OSError(
88
+ f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
89
+ f"at '{checkpoint_file}'. "
90
+ "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
91
+ )
92
+
93
+
94
+ def _load_state_dict_into_model(model_to_load, state_dict):
95
+ # Convert old format to new format if needed from a PyTorch state_dict
96
+ # copy state_dict so _load_from_state_dict can modify it
97
+ state_dict = state_dict.copy()
98
+ error_msgs = []
99
+
100
+ # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
101
+ # so we need to apply the function recursively.
102
+ def load(module: torch.nn.Module, prefix=""):
103
+ args = (state_dict, prefix, {}, True, [], [], error_msgs)
104
+ module._load_from_state_dict(*args)
105
+
106
+ for name, child in module._modules.items():
107
+ if child is not None:
108
+ load(child, prefix + name + ".")
109
+
110
+ load(model_to_load)
111
+
112
+ return error_msgs
113
+
114
+
115
+ class ModelMixin(torch.nn.Module):
116
+ r"""
117
+ Base class for all models.
118
+
119
+ [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
120
+ and saving models.
121
+
122
+ - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
123
+ [`~modeling_utils.ModelMixin.save_pretrained`].
124
+ """
125
+ config_name = CONFIG_NAME
126
+ _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
127
+
128
+ def __init__(self):
129
+ super().__init__()
130
+
131
+ def save_pretrained(
132
+ self,
133
+ save_directory: Union[str, os.PathLike],
134
+ is_main_process: bool = True,
135
+ save_function: Callable = torch.save,
136
+ ):
137
+ """
138
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
139
+ `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
140
+
141
+ Arguments:
142
+ save_directory (`str` or `os.PathLike`):
143
+ Directory to which to save. Will be created if it doesn't exist.
144
+ is_main_process (`bool`, *optional*, defaults to `True`):
145
+ Whether the process calling this is the main process or not. Useful when in distributed training like
146
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
147
+ the main process to avoid race conditions.
148
+ save_function (`Callable`):
149
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
150
+ need to replace `torch.save` by another method.
151
+ """
152
+ if os.path.isfile(save_directory):
153
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
154
+ return
155
+
156
+ os.makedirs(save_directory, exist_ok=True)
157
+
158
+ model_to_save = self
159
+
160
+ # Attach architecture to the config
161
+ # Save the config
162
+ if is_main_process:
163
+ model_to_save.save_config(save_directory)
164
+
165
+ # Save the model
166
+ state_dict = model_to_save.state_dict()
167
+
168
+ # Clean the folder from a previous save
169
+ for filename in os.listdir(save_directory):
170
+ full_filename = os.path.join(save_directory, filename)
171
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
172
+ # in distributed settings to avoid race conditions.
173
+ if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
174
+ os.remove(full_filename)
175
+
176
+ # Save the model
177
+ save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
178
+
179
+ logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
180
+
181
+ @classmethod
182
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
183
+ r"""
184
+ Instantiate a pretrained pytorch model from a pre-trained model configuration.
185
+
186
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
187
+ the model, you should first set it back in training mode with `model.train()`.
188
+
189
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
190
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
191
+ task.
192
+
193
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
194
+ weights are discarded.
195
+
196
+ Parameters:
197
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
198
+ Can be either:
199
+
200
+ - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
201
+ Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
202
+ - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
203
+ `./my_model_directory/`.
204
+
205
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
206
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
207
+ standard cache should not be used.
208
+ torch_dtype (`str` or `torch.dtype`, *optional*):
209
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
210
+ will be automatically derived from the model's weights.
211
+ force_download (`bool`, *optional*, defaults to `False`):
212
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
213
+ cached versions if they exist.
214
+ resume_download (`bool`, *optional*, defaults to `False`):
215
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
216
+ file exists.
217
+ proxies (`Dict[str, str]`, *optional*):
218
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
219
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
220
+ output_loading_info(`bool`, *optional*, defaults to `False`):
221
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
222
+ local_files_only(`bool`, *optional*, defaults to `False`):
223
+ Whether or not to only look at local files (i.e., do not try to download the model).
224
+ use_auth_token (`str` or *bool*, *optional*):
225
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
226
+ when running `diffusers-cli login` (stored in `~/.huggingface`).
227
+ revision (`str`, *optional*, defaults to `"main"`):
228
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
229
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
230
+ identifier allowed by git.
231
+ mirror (`str`, *optional*):
232
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
233
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
234
+ Please refer to the mirror site for more information.
235
+
236
+ <Tip>
237
+
238
+ Passing `use_auth_token=True`` is required when you want to use a private model.
239
+
240
+ </Tip>
241
+
242
+ <Tip>
243
+
244
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
245
+ this method in a firewalled environment.
246
+
247
+ </Tip>
248
+
249
+ """
250
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
251
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
252
+ force_download = kwargs.pop("force_download", False)
253
+ resume_download = kwargs.pop("resume_download", False)
254
+ proxies = kwargs.pop("proxies", None)
255
+ output_loading_info = kwargs.pop("output_loading_info", False)
256
+ local_files_only = kwargs.pop("local_files_only", False)
257
+ use_auth_token = kwargs.pop("use_auth_token", None)
258
+ revision = kwargs.pop("revision", None)
259
+ from_auto_class = kwargs.pop("_from_auto", False)
260
+ torch_dtype = kwargs.pop("torch_dtype", None)
261
+ subfolder = kwargs.pop("subfolder", None)
262
+
263
+ user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
264
+
265
+ # Load config if we don't provide a configuration
266
+ config_path = pretrained_model_name_or_path
267
+ model, unused_kwargs = cls.from_config(
268
+ config_path,
269
+ cache_dir=cache_dir,
270
+ return_unused_kwargs=True,
271
+ force_download=force_download,
272
+ resume_download=resume_download,
273
+ proxies=proxies,
274
+ local_files_only=local_files_only,
275
+ use_auth_token=use_auth_token,
276
+ revision=revision,
277
+ subfolder=subfolder,
278
+ **kwargs,
279
+ )
280
+
281
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
282
+ raise ValueError(
283
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
284
+ )
285
+ elif torch_dtype is not None:
286
+ model = model.to(torch_dtype)
287
+
288
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
289
+ # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
290
+ # Load model
291
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
292
+ if os.path.isdir(pretrained_model_name_or_path):
293
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
294
+ # Load from a PyTorch checkpoint
295
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
296
+ elif subfolder is not None and os.path.isfile(
297
+ os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
298
+ ):
299
+ model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
300
+ else:
301
+ raise EnvironmentError(
302
+ f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
303
+ )
304
+ else:
305
+ try:
306
+ # Load from URL or cache if already cached
307
+ model_file = hf_hub_download(
308
+ pretrained_model_name_or_path,
309
+ filename=WEIGHTS_NAME,
310
+ cache_dir=cache_dir,
311
+ force_download=force_download,
312
+ proxies=proxies,
313
+ resume_download=resume_download,
314
+ local_files_only=local_files_only,
315
+ use_auth_token=use_auth_token,
316
+ user_agent=user_agent,
317
+ subfolder=subfolder,
318
+ revision=revision,
319
+ )
320
+
321
+ except RepositoryNotFoundError:
322
+ raise EnvironmentError(
323
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
324
+ "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
325
+ "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
326
+ "login` and pass `use_auth_token=True`."
327
+ )
328
+ except RevisionNotFoundError:
329
+ raise EnvironmentError(
330
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
331
+ "this model name. Check the model page at "
332
+ f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
333
+ )
334
+ except EntryNotFoundError:
335
+ raise EnvironmentError(
336
+ f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
337
+ )
338
+ except HTTPError as err:
339
+ raise EnvironmentError(
340
+ "There was a specific connection error when trying to load"
341
+ f" {pretrained_model_name_or_path}:\n{err}"
342
+ )
343
+ except ValueError:
344
+ raise EnvironmentError(
345
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
346
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
347
+ f" directory containing a file named {WEIGHTS_NAME} or"
348
+ " \nCheckout your internet connection or see how to run the library in"
349
+ " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
350
+ )
351
+ except EnvironmentError:
352
+ raise EnvironmentError(
353
+ f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
354
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
355
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
356
+ f"containing a file named {WEIGHTS_NAME}"
357
+ )
358
+
359
+ # restore default dtype
360
+ state_dict = load_state_dict(model_file)
361
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
362
+ model,
363
+ state_dict,
364
+ model_file,
365
+ pretrained_model_name_or_path,
366
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
367
+ )
368
+
369
+ # Set model in evaluation mode to deactivate DropOut modules by default
370
+ model.eval()
371
+
372
+ if output_loading_info:
373
+ loading_info = {
374
+ "missing_keys": missing_keys,
375
+ "unexpected_keys": unexpected_keys,
376
+ "mismatched_keys": mismatched_keys,
377
+ "error_msgs": error_msgs,
378
+ }
379
+ return model, loading_info
380
+
381
+ return model
382
+
383
+ @classmethod
384
+ def _load_pretrained_model(
385
+ cls,
386
+ model,
387
+ state_dict,
388
+ resolved_archive_file,
389
+ pretrained_model_name_or_path,
390
+ ignore_mismatched_sizes=False,
391
+ ):
392
+ # Retrieve missing & unexpected_keys
393
+ model_state_dict = model.state_dict()
394
+ loaded_keys = [k for k in state_dict.keys()]
395
+
396
+ expected_keys = list(model_state_dict.keys())
397
+
398
+ original_loaded_keys = loaded_keys
399
+
400
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
401
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
402
+
403
+ # Make sure we are able to load base models as well as derived models (with heads)
404
+ model_to_load = model
405
+
406
+ def _find_mismatched_keys(
407
+ state_dict,
408
+ model_state_dict,
409
+ loaded_keys,
410
+ ignore_mismatched_sizes,
411
+ ):
412
+ mismatched_keys = []
413
+ if ignore_mismatched_sizes:
414
+ for checkpoint_key in loaded_keys:
415
+ model_key = checkpoint_key
416
+
417
+ if (
418
+ model_key in model_state_dict
419
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
420
+ ):
421
+ mismatched_keys.append(
422
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
423
+ )
424
+ del state_dict[checkpoint_key]
425
+ return mismatched_keys
426
+
427
+ if state_dict is not None:
428
+ # Whole checkpoint
429
+ mismatched_keys = _find_mismatched_keys(
430
+ state_dict,
431
+ model_state_dict,
432
+ original_loaded_keys,
433
+ ignore_mismatched_sizes,
434
+ )
435
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
436
+
437
+ if len(error_msgs) > 0:
438
+ error_msg = "\n\t".join(error_msgs)
439
+ if "size mismatch" in error_msg:
440
+ error_msg += (
441
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
442
+ )
443
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
444
+
445
+ if len(unexpected_keys) > 0:
446
+ logger.warning(
447
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
448
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
449
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
450
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
451
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
452
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
453
+ " identical (initializing a BertForSequenceClassification model from a"
454
+ " BertForSequenceClassification model)."
455
+ )
456
+ else:
457
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
458
+ if len(missing_keys) > 0:
459
+ logger.warning(
460
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
461
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
462
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
463
+ )
464
+ elif len(mismatched_keys) == 0:
465
+ logger.info(
466
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
467
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
468
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
469
+ " without further training."
470
+ )
471
+ if len(mismatched_keys) > 0:
472
+ mismatched_warning = "\n".join(
473
+ [
474
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
475
+ for key, shape1, shape2 in mismatched_keys
476
+ ]
477
+ )
478
+ logger.warning(
479
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
480
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
481
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
482
+ " able to use it for predictions and inference."
483
+ )
484
+
485
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
486
+
487
+ @property
488
+ def device(self) -> device:
489
+ """
490
+ `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
491
+ device).
492
+ """
493
+ return get_parameter_device(self)
494
+
495
+ @property
496
+ def dtype(self) -> torch.dtype:
497
+ """
498
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
499
+ """
500
+ return get_parameter_dtype(self)
501
+
502
+ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
503
+ """
504
+ Get number of (optionally, trainable or non-embeddings) parameters in the module.
505
+
506
+ Args:
507
+ only_trainable (`bool`, *optional*, defaults to `False`):
508
+ Whether or not to return only the number of trainable parameters
509
+
510
+ exclude_embeddings (`bool`, *optional*, defaults to `False`):
511
+ Whether or not to return only the number of non-embeddings parameters
512
+
513
+ Returns:
514
+ `int`: The number of parameters.
515
+ """
516
+
517
+ if exclude_embeddings:
518
+ embedding_param_names = [
519
+ f"{name}.weight"
520
+ for name, module_type in self.named_modules()
521
+ if isinstance(module_type, torch.nn.Embedding)
522
+ ]
523
+ non_embedding_parameters = [
524
+ parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
525
+ ]
526
+ return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
527
+ else:
528
+ return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
529
+
530
+
531
+ def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
532
+ """
533
+ Recursively unwraps a model from potential containers (as used in distributed training).
534
+
535
+ Args:
536
+ model (`torch.nn.Module`): The model to unwrap.
537
+ """
538
+ # since there could be multiple levels of wrapping, unwrap recursively
539
+ if hasattr(model, "module"):
540
+ return unwrap_model(model.module)
541
+ else:
542
+ return model
my_diffusers/models/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .unet_2d import UNet2DModel
16
+ from .unet_2d_condition import UNet2DConditionModel
17
+ from .vae import AutoencoderKL, VQModel
my_diffusers/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (322 Bytes). View file
 
my_diffusers/models/__pycache__/attention.cpython-38.pyc ADDED
Binary file (12.2 kB). View file
 
my_diffusers/models/__pycache__/embeddings.cpython-38.pyc ADDED
Binary file (3.71 kB). View file
 
my_diffusers/models/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (14.6 kB). View file
 
my_diffusers/models/__pycache__/unet_2d.cpython-38.pyc ADDED
Binary file (7.84 kB). View file
 
my_diffusers/models/__pycache__/unet_2d_condition.cpython-38.pyc ADDED
Binary file (8.68 kB). View file
 
my_diffusers/models/__pycache__/unet_blocks.cpython-38.pyc ADDED
Binary file (23 kB). View file
 
my_diffusers/models/__pycache__/vae.cpython-38.pyc ADDED
Binary file (16.5 kB). View file
 
my_diffusers/models/attention.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+
9
+ class AttentionBlock(nn.Module):
10
+ """
11
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
12
+ to the N-d case.
13
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
14
+ Uses three q, k, v linear layers to compute attention.
15
+
16
+ Parameters:
17
+ channels (:obj:`int`): The number of channels in the input and output.
18
+ num_head_channels (:obj:`int`, *optional*):
19
+ The number of channels in each head. If None, then `num_heads` = 1.
20
+ num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
21
+ rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
22
+ eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ channels: int,
28
+ num_head_channels: Optional[int] = None,
29
+ num_groups: int = 32,
30
+ rescale_output_factor = 1.0,
31
+ eps = 1e-5,
32
+ ):
33
+ super().__init__()
34
+ self.channels = channels
35
+
36
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
37
+ self.num_head_size = num_head_channels
38
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
39
+
40
+ # define q,k,v as linear layers
41
+ self.query = nn.Linear(channels, channels)
42
+ self.key = nn.Linear(channels, channels)
43
+ self.value = nn.Linear(channels, channels)
44
+
45
+ self.rescale_output_factor = rescale_output_factor
46
+ self.proj_attn = nn.Linear(channels, channels, 1)
47
+
48
+ def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
49
+ new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
50
+ # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
51
+ new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
52
+ return new_projection
53
+
54
+ def forward(self, hidden_states):
55
+ residual = hidden_states
56
+ batch, channel, height, width = hidden_states.shape
57
+
58
+ # norm
59
+ hidden_states = self.group_norm(hidden_states)
60
+
61
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
62
+
63
+ # proj to q, k, v
64
+ query_proj = self.query(hidden_states)
65
+ key_proj = self.key(hidden_states)
66
+ value_proj = self.value(hidden_states)
67
+
68
+ # transpose
69
+ query_states = self.transpose_for_scores(query_proj)
70
+ key_states = self.transpose_for_scores(key_proj)
71
+ value_states = self.transpose_for_scores(value_proj)
72
+
73
+ # get scores
74
+ scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
75
+
76
+ attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
77
+ attention_probs = torch.softmax(attention_scores.double(), dim=-1).type(attention_scores.dtype)
78
+
79
+ # compute attention output
80
+ hidden_states = torch.matmul(attention_probs, value_states)
81
+
82
+ hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
83
+ new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
84
+ hidden_states = hidden_states.view(new_hidden_states_shape)
85
+
86
+ # compute next hidden_states
87
+ hidden_states = self.proj_attn(hidden_states)
88
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
89
+
90
+ # res connect and rescale
91
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
92
+ return hidden_states
93
+
94
+
95
+ class SpatialTransformer(nn.Module):
96
+ """
97
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
98
+ standard transformer action. Finally, reshape to image.
99
+
100
+ Parameters:
101
+ in_channels (:obj:`int`): The number of channels in the input and output.
102
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
103
+ d_head (:obj:`int`): The number of channels in each head.
104
+ depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
105
+ dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
106
+ context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ in_channels: int,
112
+ n_heads: int,
113
+ d_head: int,
114
+ depth: int = 1,
115
+ dropout = 0.0,
116
+ context_dim: Optional[int] = None,
117
+ ):
118
+ super().__init__()
119
+ self.n_heads = n_heads
120
+ self.d_head = d_head
121
+ self.in_channels = in_channels
122
+ inner_dim = n_heads * d_head
123
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
124
+
125
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
126
+
127
+ self.transformer_blocks = nn.ModuleList(
128
+ [
129
+ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
130
+ for d in range(depth)
131
+ ]
132
+ )
133
+
134
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
135
+
136
+ def _set_attention_slice(self, slice_size):
137
+ for block in self.transformer_blocks:
138
+ block._set_attention_slice(slice_size)
139
+
140
+ def forward(self, x, context=None):
141
+ # note: if no context is given, cross-attention defaults to self-attention
142
+ b, c, h, w = x.shape
143
+ x_in = x
144
+ x = self.norm(x)
145
+ x = self.proj_in(x)
146
+ x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
147
+ for block in self.transformer_blocks:
148
+ x = block(x, context=context)
149
+ x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
150
+ x = self.proj_out(x)
151
+ return x + x_in
152
+
153
+
154
+ class BasicTransformerBlock(nn.Module):
155
+ r"""
156
+ A basic Transformer block.
157
+
158
+ Parameters:
159
+ dim (:obj:`int`): The number of channels in the input and output.
160
+ n_heads (:obj:`int`): The number of heads to use for multi-head attention.
161
+ d_head (:obj:`int`): The number of channels in each head.
162
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
163
+ context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
164
+ gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
165
+ checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ dim: int,
171
+ n_heads: int,
172
+ d_head: int,
173
+ dropout=0.0,
174
+ context_dim: Optional[int] = None,
175
+ gated_ff: bool = True,
176
+ checkpoint: bool = True,
177
+ ):
178
+ super().__init__()
179
+ self.attn1 = CrossAttention(
180
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
181
+ ) # is a self-attention
182
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
183
+ self.attn2 = CrossAttention(
184
+ query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
185
+ ) # is self-attn if context is none
186
+ self.norm1 = nn.LayerNorm(dim)
187
+ self.norm2 = nn.LayerNorm(dim)
188
+ self.norm3 = nn.LayerNorm(dim)
189
+ self.checkpoint = checkpoint
190
+
191
+ def _set_attention_slice(self, slice_size):
192
+ self.attn1._slice_size = slice_size
193
+ self.attn2._slice_size = slice_size
194
+
195
+ def forward(self, x, context=None):
196
+ x = x.contiguous() if x.device.type == "mps" else x
197
+ x = self.attn1(self.norm1(x)) + x
198
+ x = self.attn2(self.norm2(x), context=context) + x
199
+ x = self.ff(self.norm3(x)) + x
200
+ return x
201
+
202
+
203
+ class CrossAttention(nn.Module):
204
+ r"""
205
+ A cross attention layer.
206
+
207
+ Parameters:
208
+ query_dim (:obj:`int`): The number of channels in the query.
209
+ context_dim (:obj:`int`, *optional*):
210
+ The number of channels in the context. If not given, defaults to `query_dim`.
211
+ heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
212
+ dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
213
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
214
+ """
215
+
216
+ def __init__(
217
+ self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
218
+ ):
219
+ super().__init__()
220
+ inner_dim = dim_head * heads
221
+ context_dim = context_dim if context_dim is not None else query_dim
222
+
223
+ self.scale = dim_head**-0.5
224
+ self.heads = heads
225
+ # for slice_size > 0 the attention score computation
226
+ # is split across the batch axis to save memory
227
+ # You can set slice_size with `set_attention_slice`
228
+ self._slice_size = None
229
+
230
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
231
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
232
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
233
+
234
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
235
+
236
+ def reshape_heads_to_batch_dim(self, tensor):
237
+ batch_size, seq_len, dim = tensor.shape
238
+ head_size = self.heads
239
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
240
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
241
+ return tensor
242
+
243
+ def reshape_batch_dim_to_heads(self, tensor):
244
+ batch_size, seq_len, dim = tensor.shape
245
+ head_size = self.heads
246
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
247
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
248
+ return tensor
249
+
250
+ def forward(self, x, context=None, mask=None):
251
+ batch_size, sequence_length, dim = x.shape
252
+
253
+ q = self.to_q(x)
254
+ context = context if context is not None else x
255
+ k = self.to_k(context)
256
+ v = self.to_v(context)
257
+
258
+ q = self.reshape_heads_to_batch_dim(q)
259
+ k = self.reshape_heads_to_batch_dim(k)
260
+ v = self.reshape_heads_to_batch_dim(v)
261
+
262
+ # TODO(PVP) - mask is currently never used. Remember to re-implement when used
263
+
264
+ # attention, what we cannot get enough of
265
+ hidden_states = self._attention(q, k, v, sequence_length, dim)
266
+
267
+ return self.to_out(hidden_states)
268
+
269
+ def _attention(self, query, key, value, sequence_length, dim):
270
+ batch_size_attention = query.shape[0]
271
+ hidden_states = torch.zeros(
272
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
273
+ )
274
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
275
+ for i in range(hidden_states.shape[0] // slice_size):
276
+ start_idx = i * slice_size
277
+ end_idx = (i + 1) * slice_size
278
+ attn_slice = (
279
+ torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
280
+ )
281
+ attn_slice = attn_slice.softmax(dim=-1)
282
+ attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
283
+
284
+ hidden_states[start_idx:end_idx] = attn_slice
285
+
286
+ # reshape hidden_states
287
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
288
+ return hidden_states
289
+
290
+
291
+ class FeedForward(nn.Module):
292
+ r"""
293
+ A feed-forward layer.
294
+
295
+ Parameters:
296
+ dim (:obj:`int`): The number of channels in the input.
297
+ dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
298
+ mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
299
+ glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
300
+ dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
301
+ """
302
+
303
+ def __init__(
304
+ self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout = 0.0
305
+ ):
306
+ super().__init__()
307
+ inner_dim = int(dim * mult)
308
+ dim_out = dim_out if dim_out is not None else dim
309
+ project_in = GEGLU(dim, inner_dim)
310
+
311
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
312
+
313
+ def forward(self, x):
314
+ return self.net(x)
315
+
316
+
317
+ # feedforward
318
+ class GEGLU(nn.Module):
319
+ r"""
320
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
321
+
322
+ Parameters:
323
+ dim_in (:obj:`int`): The number of channels in the input.
324
+ dim_out (:obj:`int`): The number of channels in the output.
325
+ """
326
+
327
+ def __init__(self, dim_in: int, dim_out: int):
328
+ super().__init__()
329
+ self.proj = nn.Linear(dim_in, dim_out * 2)
330
+
331
+ def forward(self, x):
332
+ x, gate = self.proj(x).chunk(2, dim=-1)
333
+ return x * F.gelu(gate)
my_diffusers/models/embeddings.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+
16
+ import numpy as np
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ def get_timestep_embedding(
22
+ timesteps: torch.Tensor,
23
+ embedding_dim: int,
24
+ flip_sin_to_cos: bool = False,
25
+ downscale_freq_shift: float = 1,
26
+ scale: float = 1,
27
+ max_period: int = 10000,
28
+ ):
29
+ # print(timesteps)
30
+ """
31
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
32
+
33
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
34
+ These may be fractional.
35
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
36
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
37
+ """
38
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
39
+
40
+ half_dim = embedding_dim // 2
41
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float64)
42
+ exponent = exponent / (half_dim - downscale_freq_shift)
43
+
44
+ emb = torch.exp(exponent).to(device=timesteps.device)
45
+ emb = timesteps[:, None].double() * emb[None, :]
46
+
47
+ # scale embeddings
48
+ emb = scale * emb
49
+
50
+ # concat sine and cosine embeddings
51
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
52
+
53
+ # flip sine and cosine embeddings
54
+ if flip_sin_to_cos:
55
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
56
+
57
+ # zero pad
58
+ if embedding_dim % 2 == 1:
59
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
60
+ return emb
61
+
62
+
63
+ class TimestepEmbedding(nn.Module):
64
+ def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
65
+ super().__init__()
66
+
67
+ self.linear_1 = nn.Linear(channel, time_embed_dim)
68
+ self.act = None
69
+ if act_fn == "silu":
70
+ self.act = nn.SiLU()
71
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
72
+
73
+ def forward(self, sample):
74
+ sample = self.linear_1(sample)
75
+
76
+ if self.act is not None:
77
+ sample = self.act(sample)
78
+
79
+ sample = self.linear_2(sample)
80
+ return sample
81
+
82
+
83
+ class Timesteps(nn.Module):
84
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
85
+ super().__init__()
86
+ self.num_channels = num_channels
87
+ self.flip_sin_to_cos = flip_sin_to_cos
88
+ self.downscale_freq_shift = downscale_freq_shift
89
+
90
+ def forward(self, timesteps):
91
+ t_emb = get_timestep_embedding(
92
+ timesteps,
93
+ self.num_channels,
94
+ flip_sin_to_cos=self.flip_sin_to_cos,
95
+ downscale_freq_shift=self.downscale_freq_shift,
96
+ )
97
+ return t_emb
98
+
99
+
100
+ class GaussianFourierProjection(nn.Module):
101
+ """Gaussian Fourier embeddings for noise levels."""
102
+
103
+ def __init__(self, embedding_size: int = 256, scale: float = 1.0):
104
+ super().__init__()
105
+ self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
106
+
107
+ # to delete later
108
+ self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
109
+
110
+ self.weight = self.W
111
+
112
+ def forward(self, x):
113
+ x = torch.log(x)
114
+ x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
115
+ out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
116
+ return out
my_diffusers/models/resnet.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class Upsample2D(nn.Module):
10
+ """
11
+ An upsampling layer with an optional convolution.
12
+
13
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
14
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
15
+ upsampling occurs in the inner-two dimensions.
16
+ """
17
+
18
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.out_channels = out_channels or channels
22
+ self.use_conv = use_conv
23
+ self.use_conv_transpose = use_conv_transpose
24
+ self.name = name
25
+
26
+ conv = None
27
+ if use_conv_transpose:
28
+ conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
29
+ elif use_conv:
30
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
31
+
32
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
33
+ if name == "conv":
34
+ self.conv = conv
35
+ else:
36
+ self.Conv2d_0 = conv
37
+
38
+ def forward(self, x):
39
+ assert x.shape[1] == self.channels
40
+ if self.use_conv_transpose:
41
+ return self.conv(x)
42
+
43
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
44
+
45
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
46
+ if self.use_conv:
47
+ if self.name == "conv":
48
+ x = self.conv(x)
49
+ else:
50
+ x = self.Conv2d_0(x)
51
+
52
+ return x
53
+
54
+
55
+ class Downsample2D(nn.Module):
56
+ """
57
+ A downsampling layer with an optional convolution.
58
+
59
+ :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
60
+ applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
61
+ downsampling occurs in the inner-two dimensions.
62
+ """
63
+
64
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
65
+ super().__init__()
66
+ self.channels = channels
67
+ self.out_channels = out_channels or channels
68
+ self.use_conv = use_conv
69
+ self.padding = padding
70
+ stride = 2
71
+ self.name = name
72
+
73
+ if use_conv:
74
+ conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
75
+ else:
76
+ assert self.channels == self.out_channels
77
+ conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
78
+
79
+ # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
80
+ if name == "conv":
81
+ self.Conv2d_0 = conv
82
+ self.conv = conv
83
+ elif name == "Conv2d_0":
84
+ self.conv = conv
85
+ else:
86
+ self.conv = conv
87
+
88
+ def forward(self, x):
89
+ assert x.shape[1] == self.channels
90
+ if self.use_conv and self.padding == 0:
91
+ pad = (0, 1, 0, 1)
92
+ x = F.pad(x, pad, mode="constant", value=0)
93
+
94
+ assert x.shape[1] == self.channels
95
+ x = self.conv(x)
96
+
97
+ return x
98
+
99
+
100
+ class FirUpsample2D(nn.Module):
101
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
102
+ super().__init__()
103
+ out_channels = out_channels if out_channels else channels
104
+ if use_conv:
105
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
106
+ self.use_conv = use_conv
107
+ self.fir_kernel = fir_kernel
108
+ self.out_channels = out_channels
109
+
110
+ def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
111
+ """Fused `upsample_2d()` followed by `Conv2d()`.
112
+
113
+ Args:
114
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
115
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
116
+ order.
117
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
118
+ C]`.
119
+ weight: Weight tensor of the shape `[filterH, filterW, inChannels,
120
+ outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
121
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
122
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
123
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
124
+
125
+ Returns:
126
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
127
+ `x`.
128
+ """
129
+
130
+ assert isinstance(factor, int) and factor >= 1
131
+
132
+ # Setup filter kernel.
133
+ if kernel is None:
134
+ kernel = [1] * factor
135
+
136
+ # setup kernel
137
+ kernel = np.asarray(kernel, dtype=np.float64)
138
+ if kernel.ndim == 1:
139
+ kernel = np.outer(kernel, kernel)
140
+ kernel /= np.sum(kernel)
141
+
142
+ kernel = kernel * (gain * (factor**2))
143
+
144
+ if self.use_conv:
145
+ convH = weight.shape[2]
146
+ convW = weight.shape[3]
147
+ inC = weight.shape[1]
148
+
149
+ p = (kernel.shape[0] - factor) - (convW - 1)
150
+
151
+ stride = (factor, factor)
152
+ # Determine data dimensions.
153
+ stride = [1, 1, factor, factor]
154
+ output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
155
+ output_padding = (
156
+ output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
157
+ output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
158
+ )
159
+ assert output_padding[0] >= 0 and output_padding[1] >= 0
160
+ inC = weight.shape[1]
161
+ num_groups = x.shape[1] // inC
162
+
163
+ # Transpose weights.
164
+ weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
165
+ weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
166
+ weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
167
+
168
+ x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
169
+
170
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
171
+ else:
172
+ p = kernel.shape[0] - factor
173
+ x = upfirdn2d_native(
174
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
175
+ )
176
+
177
+ return x
178
+
179
+ def forward(self, x):
180
+ if self.use_conv:
181
+ height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
182
+ height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
183
+ else:
184
+ height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
185
+
186
+ return height
187
+
188
+
189
+ class FirDownsample2D(nn.Module):
190
+ def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
191
+ super().__init__()
192
+ out_channels = out_channels if out_channels else channels
193
+ if use_conv:
194
+ self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
195
+ self.fir_kernel = fir_kernel
196
+ self.use_conv = use_conv
197
+ self.out_channels = out_channels
198
+
199
+ def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
200
+ """Fused `Conv2d()` followed by `downsample_2d()`.
201
+
202
+ Args:
203
+ Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
204
+ efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
205
+ order.
206
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
207
+ filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
208
+ numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
209
+ factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
210
+ Scaling factor for signal magnitude (default: 1.0).
211
+
212
+ Returns:
213
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
214
+ datatype as `x`.
215
+ """
216
+
217
+ assert isinstance(factor, int) and factor >= 1
218
+ if kernel is None:
219
+ kernel = [1] * factor
220
+
221
+ # setup kernel
222
+ kernel = np.asarray(kernel, dtype=np.float64)
223
+ if kernel.ndim == 1:
224
+ kernel = np.outer(kernel, kernel)
225
+ kernel /= np.sum(kernel)
226
+
227
+ kernel = kernel * gain
228
+
229
+ if self.use_conv:
230
+ _, _, convH, convW = weight.shape
231
+ p = (kernel.shape[0] - factor) + (convW - 1)
232
+ s = [factor, factor]
233
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
234
+ x = F.conv2d(x, weight, stride=s, padding=0)
235
+ else:
236
+ p = kernel.shape[0] - factor
237
+ x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
238
+
239
+ return x
240
+
241
+ def forward(self, x):
242
+ if self.use_conv:
243
+ x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
244
+ x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
245
+ else:
246
+ x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
247
+
248
+ return x
249
+
250
+
251
+ class ResnetBlock2D(nn.Module):
252
+ def __init__(
253
+ self,
254
+ *,
255
+ in_channels,
256
+ out_channels=None,
257
+ conv_shortcut=False,
258
+ dropout=0.0,
259
+ temb_channels=512,
260
+ groups=32,
261
+ groups_out=None,
262
+ pre_norm=True,
263
+ eps=1e-6,
264
+ non_linearity="swish",
265
+ time_embedding_norm="default",
266
+ kernel=None,
267
+ output_scale_factor=1.0,
268
+ use_nin_shortcut=None,
269
+ up=False,
270
+ down=False,
271
+ ):
272
+ super().__init__()
273
+ self.pre_norm = pre_norm
274
+ self.pre_norm = True
275
+ self.in_channels = in_channels
276
+ out_channels = in_channels if out_channels is None else out_channels
277
+ self.out_channels = out_channels
278
+ self.use_conv_shortcut = conv_shortcut
279
+ self.time_embedding_norm = time_embedding_norm
280
+ self.up = up
281
+ self.down = down
282
+ self.output_scale_factor = output_scale_factor
283
+
284
+ if groups_out is None:
285
+ groups_out = groups
286
+
287
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
288
+
289
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
290
+
291
+ if temb_channels is not None:
292
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
293
+ else:
294
+ self.time_emb_proj = None
295
+
296
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
297
+ self.dropout = torch.nn.Dropout(dropout)
298
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
299
+
300
+ if non_linearity == "swish":
301
+ self.nonlinearity = lambda x: F.silu(x)
302
+ elif non_linearity == "mish":
303
+ self.nonlinearity = Mish()
304
+ elif non_linearity == "silu":
305
+ self.nonlinearity = nn.SiLU()
306
+
307
+ self.upsample = self.downsample = None
308
+ if self.up:
309
+ if kernel == "fir":
310
+ fir_kernel = (1, 3, 3, 1)
311
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
312
+ elif kernel == "sde_vp":
313
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
314
+ else:
315
+ self.upsample = Upsample2D(in_channels, use_conv=False)
316
+ elif self.down:
317
+ if kernel == "fir":
318
+ fir_kernel = (1, 3, 3, 1)
319
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
320
+ elif kernel == "sde_vp":
321
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
322
+ else:
323
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
324
+
325
+ self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
326
+
327
+ self.conv_shortcut = None
328
+ if self.use_nin_shortcut:
329
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
330
+
331
+ def forward(self, x, temb):
332
+ hidden_states = x
333
+
334
+ # make sure hidden states is in float32
335
+ # when running in half-precision
336
+ hidden_states = self.norm1(hidden_states.double()).type(hidden_states.dtype)
337
+ hidden_states = self.nonlinearity(hidden_states)
338
+
339
+ if self.upsample is not None:
340
+ x = self.upsample(x)
341
+ hidden_states = self.upsample(hidden_states)
342
+ elif self.downsample is not None:
343
+ x = self.downsample(x)
344
+ hidden_states = self.downsample(hidden_states)
345
+
346
+ hidden_states = self.conv1(hidden_states)
347
+
348
+ if temb is not None:
349
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
350
+ hidden_states = hidden_states + temb
351
+
352
+ # make sure hidden states is in float32
353
+ # when running in half-precision
354
+ hidden_states = self.norm2(hidden_states.double()).type(hidden_states.dtype)
355
+ hidden_states = self.nonlinearity(hidden_states)
356
+
357
+ hidden_states = self.dropout(hidden_states)
358
+ hidden_states = self.conv2(hidden_states)
359
+
360
+ if self.conv_shortcut is not None:
361
+ x = self.conv_shortcut(x)
362
+
363
+ out = (x + hidden_states) / self.output_scale_factor
364
+
365
+ return out
366
+
367
+
368
+ class Mish(torch.nn.Module):
369
+ def forward(self, x):
370
+ return x * torch.tanh(torch.nn.functional.softplus(x))
371
+
372
+
373
+ def upsample_2d(x, kernel=None, factor=2, gain=1):
374
+ r"""Upsample2D a batch of 2D images with the given filter.
375
+
376
+ Args:
377
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
378
+ filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
379
+ `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
380
+ multiple of the upsampling factor.
381
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
382
+ C]`.
383
+ k: FIR filter of the shape `[firH, firW]` or `[firN]`
384
+ (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
385
+ factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
386
+
387
+ Returns:
388
+ Tensor of the shape `[N, C, H * factor, W * factor]`
389
+ """
390
+ assert isinstance(factor, int) and factor >= 1
391
+ if kernel is None:
392
+ kernel = [1] * factor
393
+
394
+ kernel = np.asarray(kernel, dtype=np.float64)
395
+ if kernel.ndim == 1:
396
+ kernel = np.outer(kernel, kernel)
397
+ kernel /= np.sum(kernel)
398
+
399
+ kernel = kernel * (gain * (factor**2))
400
+ p = kernel.shape[0] - factor
401
+ return upfirdn2d_native(
402
+ x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
403
+ )
404
+
405
+
406
+ def downsample_2d(x, kernel=None, factor=2, gain=1):
407
+ r"""Downsample2D a batch of 2D images with the given filter.
408
+
409
+ Args:
410
+ Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
411
+ given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
412
+ specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
413
+ shape is a multiple of the downsampling factor.
414
+ x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
415
+ C]`.
416
+ kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
417
+ (separable). The default is `[1] * factor`, which corresponds to average pooling.
418
+ factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
419
+
420
+ Returns:
421
+ Tensor of the shape `[N, C, H // factor, W // factor]`
422
+ """
423
+
424
+ assert isinstance(factor, int) and factor >= 1
425
+ if kernel is None:
426
+ kernel = [1] * factor
427
+
428
+ kernel = np.asarray(kernel, dtype=np.float64)
429
+ if kernel.ndim == 1:
430
+ kernel = np.outer(kernel, kernel)
431
+ kernel /= np.sum(kernel)
432
+
433
+ kernel = kernel * gain
434
+ p = kernel.shape[0] - factor
435
+ return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
436
+
437
+
438
+ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
439
+ up_x = up_y = up
440
+ down_x = down_y = down
441
+ pad_x0 = pad_y0 = pad[0]
442
+ pad_x1 = pad_y1 = pad[1]
443
+
444
+ _, channel, in_h, in_w = input.shape
445
+ input = input.reshape(-1, in_h, in_w, 1)
446
+
447
+ _, in_h, in_w, minor = input.shape
448
+ kernel_h, kernel_w = kernel.shape
449
+
450
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
451
+
452
+ # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
453
+ if input.device.type == "mps":
454
+ out = out.to("cpu")
455
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
456
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
457
+
458
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
459
+ out = out.to(input.device) # Move back to mps if necessary
460
+ out = out[
461
+ :,
462
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
463
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
464
+ :,
465
+ ]
466
+
467
+ out = out.permute(0, 3, 1, 2)
468
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
469
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
470
+ out = F.conv2d(out, w)
471
+ out = out.reshape(
472
+ -1,
473
+ minor,
474
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
475
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
476
+ )
477
+ out = out.permute(0, 2, 3, 1)
478
+ out = out[:, ::down_y, ::down_x, :]
479
+
480
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
481
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
482
+
483
+ return out.view(-1, channel, out_h, out_w)
my_diffusers/models/unet_2d.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..configuration_utils import ConfigMixin, register_to_config
8
+ from ..modeling_utils import ModelMixin
9
+ from ..utils import BaseOutput
10
+ from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
11
+ from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
12
+
13
+
14
+ @dataclass
15
+ class UNet2DOutput(BaseOutput):
16
+ """
17
+ Args:
18
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
19
+ Hidden states output. Output of last layer of model.
20
+ """
21
+
22
+ sample: torch.DoubleTensor
23
+
24
+
25
+ class UNet2DModel(ModelMixin, ConfigMixin):
26
+ r"""
27
+ UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
28
+
29
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
30
+ implements for all the model (such as downloading or saving, etc.)
31
+
32
+ Parameters:
33
+ sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
34
+ Input sample size.
35
+ in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
36
+ out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
37
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
38
+ time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
39
+ freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
40
+ flip_sin_to_cos (`bool`, *optional*, defaults to :
41
+ obj:`False`): Whether to flip sin to cos for fourier time embedding.
42
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
43
+ obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
44
+ types.
45
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
46
+ obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
47
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
48
+ obj:`(224, 448, 672, 896)`): Tuple of block output channels.
49
+ layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
50
+ mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
51
+ downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
52
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
53
+ attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
54
+ norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
55
+ norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
56
+ """
57
+
58
+ @register_to_config
59
+ def __init__(
60
+ self,
61
+ sample_size: Optional[int] = None,
62
+ in_channels: int = 3,
63
+ out_channels: int = 3,
64
+ center_input_sample: bool = False,
65
+ time_embedding_type: str = "positional",
66
+ freq_shift: int = 0,
67
+ flip_sin_to_cos: bool = True,
68
+ down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
69
+ up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
70
+ block_out_channels: Tuple[int] = (224, 448, 672, 896),
71
+ layers_per_block: int = 2,
72
+ mid_block_scale_factor = 1,
73
+ downsample_padding: int = 1,
74
+ act_fn: str = "silu",
75
+ attention_head_dim: int = 8,
76
+ norm_num_groups: int = 32,
77
+ norm_eps = 1e-5,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.sample_size = sample_size
82
+ time_embed_dim = block_out_channels[0] * 4
83
+
84
+ # input
85
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86
+
87
+ # time
88
+ if time_embedding_type == "fourier":
89
+ self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
90
+ timestep_input_dim = 2 * block_out_channels[0]
91
+ elif time_embedding_type == "positional":
92
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
+ timestep_input_dim = block_out_channels[0]
94
+
95
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
+
97
+ self.down_blocks = nn.ModuleList([])
98
+ self.mid_block = None
99
+ self.up_blocks = nn.ModuleList([])
100
+
101
+ # down
102
+ output_channel = block_out_channels[0]
103
+ for i, down_block_type in enumerate(down_block_types):
104
+ input_channel = output_channel
105
+ output_channel = block_out_channels[i]
106
+ is_final_block = i == len(block_out_channels) - 1
107
+
108
+ down_block = get_down_block(
109
+ down_block_type,
110
+ num_layers=layers_per_block,
111
+ in_channels=input_channel,
112
+ out_channels=output_channel,
113
+ temb_channels=time_embed_dim,
114
+ add_downsample=not is_final_block,
115
+ resnet_eps=norm_eps,
116
+ resnet_act_fn=act_fn,
117
+ attn_num_head_channels=attention_head_dim,
118
+ downsample_padding=downsample_padding,
119
+ )
120
+ self.down_blocks.append(down_block)
121
+
122
+ # mid
123
+ self.mid_block = UNetMidBlock2D(
124
+ in_channels=block_out_channels[-1],
125
+ temb_channels=time_embed_dim,
126
+ resnet_eps=norm_eps,
127
+ resnet_act_fn=act_fn,
128
+ output_scale_factor=mid_block_scale_factor,
129
+ resnet_time_scale_shift="default",
130
+ attn_num_head_channels=attention_head_dim,
131
+ resnet_groups=norm_num_groups,
132
+ )
133
+
134
+ # up
135
+ reversed_block_out_channels = list(reversed(block_out_channels))
136
+ output_channel = reversed_block_out_channels[0]
137
+ for i, up_block_type in enumerate(up_block_types):
138
+ prev_output_channel = output_channel
139
+ output_channel = reversed_block_out_channels[i]
140
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
141
+
142
+ is_final_block = i == len(block_out_channels) - 1
143
+
144
+ up_block = get_up_block(
145
+ up_block_type,
146
+ num_layers=layers_per_block + 1,
147
+ in_channels=input_channel,
148
+ out_channels=output_channel,
149
+ prev_output_channel=prev_output_channel,
150
+ temb_channels=time_embed_dim,
151
+ add_upsample=not is_final_block,
152
+ resnet_eps=norm_eps,
153
+ resnet_act_fn=act_fn,
154
+ attn_num_head_channels=attention_head_dim,
155
+ )
156
+ self.up_blocks.append(up_block)
157
+ prev_output_channel = output_channel
158
+
159
+ # out
160
+ num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
161
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
162
+ self.conv_act = nn.SiLU()
163
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
164
+
165
+ def forward(
166
+ self,
167
+ sample: torch.DoubleTensor,
168
+ timestep: Union[torch.Tensor, float, int],
169
+ return_dict: bool = True,
170
+ ) -> Union[UNet2DOutput, Tuple]:
171
+ """r
172
+ Args:
173
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
174
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
175
+ return_dict (`bool`, *optional*, defaults to `True`):
176
+ Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
177
+
178
+ Returns:
179
+ [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
180
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
181
+ """
182
+ # 0. center input if necessary
183
+ if self.config.center_input_sample:
184
+ sample = 2 * sample - 1.0
185
+
186
+ # 1. time
187
+ timesteps = timestep
188
+ if not torch.is_tensor(timesteps):
189
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
190
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
191
+ timesteps = timesteps[None].to(sample.device)
192
+
193
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
194
+ timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
195
+
196
+ t_emb = self.time_proj(timesteps)
197
+ emb = self.time_embedding(t_emb)
198
+
199
+ # 2. pre-process
200
+ skip_sample = sample
201
+ sample = self.conv_in(sample)
202
+
203
+ # 3. down
204
+ down_block_res_samples = (sample,)
205
+ for downsample_block in self.down_blocks:
206
+ if hasattr(downsample_block, "skip_conv"):
207
+ sample, res_samples, skip_sample = downsample_block(
208
+ hidden_states=sample, temb=emb, skip_sample=skip_sample
209
+ )
210
+ else:
211
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
212
+
213
+ down_block_res_samples += res_samples
214
+
215
+ # 4. mid
216
+ sample = self.mid_block(sample, emb)
217
+
218
+ # 5. up
219
+ skip_sample = None
220
+ for upsample_block in self.up_blocks:
221
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
222
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
223
+
224
+ if hasattr(upsample_block, "skip_conv"):
225
+ sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
226
+ else:
227
+ sample = upsample_block(sample, res_samples, emb)
228
+
229
+ # 6. post-process
230
+ # make sure hidden states is in float32
231
+ # when running in half-precision
232
+ sample = self.conv_norm_out(sample.double()).type(sample.dtype)
233
+ sample = self.conv_act(sample)
234
+ sample = self.conv_out(sample)
235
+
236
+ if skip_sample is not None:
237
+ sample += skip_sample
238
+
239
+ if self.config.time_embedding_type == "fourier":
240
+ timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
241
+ sample = sample / timesteps
242
+
243
+ if not return_dict:
244
+ return (sample,)
245
+
246
+ return UNet2DOutput(sample=sample)
my_diffusers/models/unet_2d_condition.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..configuration_utils import ConfigMixin, register_to_config
8
+ from ..modeling_utils import ModelMixin
9
+ from ..utils import BaseOutput
10
+ from .embeddings import TimestepEmbedding, Timesteps
11
+ from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
12
+
13
+
14
+ @dataclass
15
+ class UNet2DConditionOutput(BaseOutput):
16
+ """
17
+ Args:
18
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
19
+ Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
20
+ """
21
+
22
+ sample: torch.FloatTensor
23
+
24
+
25
+ class UNet2DConditionModel(ModelMixin, ConfigMixin):
26
+ r"""
27
+ UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
28
+ and returns sample shaped output.
29
+
30
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
31
+ implements for all the model (such as downloading or saving, etc.)
32
+
33
+ Parameters:
34
+ sample_size (`int`, *optional*): The size of the input sample.
35
+ in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
36
+ out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
37
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
38
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
39
+ Whether to flip the sin to cos in the time embedding.
40
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
41
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
42
+ The tuple of downsample blocks to use.
43
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
44
+ The tuple of upsample blocks to use.
45
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
46
+ The tuple of output channels for each block.
47
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
48
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
49
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
50
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
51
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
52
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
53
+ cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
54
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
55
+ """
56
+
57
+ @register_to_config
58
+ def __init__(
59
+ self,
60
+ sample_size: Optional[int] = None,
61
+ in_channels: int = 4,
62
+ out_channels: int = 4,
63
+ center_input_sample: bool = False,
64
+ flip_sin_to_cos: bool = True,
65
+ freq_shift: int = 0,
66
+ down_block_types: Tuple[str] = (
67
+ "CrossAttnDownBlock2D",
68
+ "CrossAttnDownBlock2D",
69
+ "CrossAttnDownBlock2D",
70
+ "DownBlock2D",
71
+ ),
72
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
73
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
74
+ layers_per_block: int = 2,
75
+ downsample_padding: int = 1,
76
+ mid_block_scale_factor: float = 1,
77
+ act_fn: str = "silu",
78
+ norm_num_groups: int = 32,
79
+ norm_eps: float = 1e-5,
80
+ cross_attention_dim: int = 1280,
81
+ attention_head_dim: int = 8,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.sample_size = sample_size
86
+ time_embed_dim = block_out_channels[0] * 4
87
+
88
+ # input
89
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
90
+
91
+ # time
92
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
+ timestep_input_dim = block_out_channels[0]
94
+
95
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
+
97
+ self.down_blocks = nn.ModuleList([])
98
+ self.mid_block = None
99
+ self.up_blocks = nn.ModuleList([])
100
+
101
+ # down
102
+ output_channel = block_out_channels[0]
103
+ for i, down_block_type in enumerate(down_block_types):
104
+ input_channel = output_channel
105
+ output_channel = block_out_channels[i]
106
+ is_final_block = i == len(block_out_channels) - 1
107
+
108
+ down_block = get_down_block(
109
+ down_block_type,
110
+ num_layers=layers_per_block,
111
+ in_channels=input_channel,
112
+ out_channels=output_channel,
113
+ temb_channels=time_embed_dim,
114
+ add_downsample=not is_final_block,
115
+ resnet_eps=norm_eps,
116
+ resnet_act_fn=act_fn,
117
+ cross_attention_dim=cross_attention_dim,
118
+ attn_num_head_channels=attention_head_dim,
119
+ downsample_padding=downsample_padding,
120
+ )
121
+ self.down_blocks.append(down_block)
122
+
123
+ # mid
124
+ self.mid_block = UNetMidBlock2DCrossAttn(
125
+ in_channels=block_out_channels[-1],
126
+ temb_channels=time_embed_dim,
127
+ resnet_eps=norm_eps,
128
+ resnet_act_fn=act_fn,
129
+ output_scale_factor=mid_block_scale_factor,
130
+ resnet_time_scale_shift="default",
131
+ cross_attention_dim=cross_attention_dim,
132
+ attn_num_head_channels=attention_head_dim,
133
+ resnet_groups=norm_num_groups,
134
+ )
135
+
136
+ # up
137
+ reversed_block_out_channels = list(reversed(block_out_channels))
138
+ output_channel = reversed_block_out_channels[0]
139
+ for i, up_block_type in enumerate(up_block_types):
140
+ prev_output_channel = output_channel
141
+ output_channel = reversed_block_out_channels[i]
142
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
143
+
144
+ is_final_block = i == len(block_out_channels) - 1
145
+
146
+ up_block = get_up_block(
147
+ up_block_type,
148
+ num_layers=layers_per_block + 1,
149
+ in_channels=input_channel,
150
+ out_channels=output_channel,
151
+ prev_output_channel=prev_output_channel,
152
+ temb_channels=time_embed_dim,
153
+ add_upsample=not is_final_block,
154
+ resnet_eps=norm_eps,
155
+ resnet_act_fn=act_fn,
156
+ cross_attention_dim=cross_attention_dim,
157
+ attn_num_head_channels=attention_head_dim,
158
+ )
159
+ self.up_blocks.append(up_block)
160
+ prev_output_channel = output_channel
161
+
162
+ # out
163
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
164
+ self.conv_act = nn.SiLU()
165
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
166
+
167
+ def set_attention_slice(self, slice_size):
168
+ if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
169
+ raise ValueError(
170
+ f"Make sure slice_size {slice_size} is a divisor of "
171
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
172
+ )
173
+ if slice_size is not None and slice_size > self.config.attention_head_dim:
174
+ raise ValueError(
175
+ f"Chunk_size {slice_size} has to be smaller or equal to "
176
+ f"the number of heads used in cross_attention {self.config.attention_head_dim}"
177
+ )
178
+
179
+ for block in self.down_blocks:
180
+ if hasattr(block, "attentions") and block.attentions is not None:
181
+ block.set_attention_slice(slice_size)
182
+
183
+ self.mid_block.set_attention_slice(slice_size)
184
+
185
+ for block in self.up_blocks:
186
+ if hasattr(block, "attentions") and block.attentions is not None:
187
+ block.set_attention_slice(slice_size)
188
+
189
+ def forward(
190
+ self,
191
+ sample: torch.FloatTensor,
192
+ timestep: Union[torch.Tensor, float, int],
193
+ encoder_hidden_states: torch.Tensor,
194
+ return_dict: bool = True,
195
+ ) -> Union[UNet2DConditionOutput, Tuple]:
196
+ """r
197
+ Args:
198
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
199
+ timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
200
+ encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
201
+ return_dict (`bool`, *optional*, defaults to `True`):
202
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
203
+
204
+ Returns:
205
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
206
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
207
+ returning a tuple, the first element is the sample tensor.
208
+ """
209
+ # 0. center input if necessary
210
+ if self.config.center_input_sample:
211
+ sample = 2 * sample - 1.0
212
+
213
+ # 1. time
214
+ timesteps = timestep
215
+ if not torch.is_tensor(timesteps):
216
+ timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
217
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
218
+ timesteps = timesteps.to(dtype=torch.float64)
219
+ timesteps = timesteps[None].to(device=sample.device)
220
+
221
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
222
+ timesteps = timesteps.expand(sample.shape[0])
223
+
224
+ t_emb = self.time_proj(timesteps)
225
+ # print(t_emb.dtype)
226
+ t_emb = t_emb.to(sample.dtype).to(sample.device)
227
+ emb = self.time_embedding(t_emb)
228
+
229
+ # 2. pre-process
230
+ sample = self.conv_in(sample)
231
+
232
+ # 3. down
233
+ down_block_res_samples = (sample,)
234
+ for downsample_block in self.down_blocks:
235
+ if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
236
+ # print(sample.dtype, emb.dtype, encoder_hidden_states.dtype)
237
+ sample, res_samples = downsample_block(
238
+ hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
239
+ )
240
+ else:
241
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
242
+
243
+ down_block_res_samples += res_samples
244
+
245
+ # 4. mid
246
+ sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
247
+
248
+ # 5. up
249
+ for upsample_block in self.up_blocks:
250
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
251
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
252
+
253
+ if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
254
+ sample = upsample_block(
255
+ hidden_states=sample,
256
+ temb=emb,
257
+ res_hidden_states_tuple=res_samples,
258
+ encoder_hidden_states=encoder_hidden_states,
259
+ )
260
+ else:
261
+ sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
262
+
263
+ # 6. post-process
264
+ # make sure hidden states is in float32
265
+ # when running in half-precision
266
+ sample = self.conv_norm_out(sample.double()).type(sample.dtype)
267
+ sample = self.conv_act(sample)
268
+ sample = self.conv_out(sample)
269
+
270
+ if not return_dict:
271
+ return (sample,)
272
+
273
+ return UNet2DConditionOutput(sample=sample)
my_diffusers/models/unet_blocks.py ADDED
@@ -0,0 +1,1481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+
14
+ import numpy as np
15
+
16
+ # limitations under the License.
17
+ import torch
18
+ from torch import nn
19
+
20
+ from .attention import AttentionBlock, SpatialTransformer
21
+ from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
22
+
23
+
24
+ def get_down_block(
25
+ down_block_type,
26
+ num_layers,
27
+ in_channels,
28
+ out_channels,
29
+ temb_channels,
30
+ add_downsample,
31
+ resnet_eps,
32
+ resnet_act_fn,
33
+ attn_num_head_channels,
34
+ cross_attention_dim=None,
35
+ downsample_padding=None,
36
+ ):
37
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
38
+ if down_block_type == "DownBlock2D":
39
+ return DownBlock2D(
40
+ num_layers=num_layers,
41
+ in_channels=in_channels,
42
+ out_channels=out_channels,
43
+ temb_channels=temb_channels,
44
+ add_downsample=add_downsample,
45
+ resnet_eps=resnet_eps,
46
+ resnet_act_fn=resnet_act_fn,
47
+ downsample_padding=downsample_padding,
48
+ )
49
+ elif down_block_type == "AttnDownBlock2D":
50
+ return AttnDownBlock2D(
51
+ num_layers=num_layers,
52
+ in_channels=in_channels,
53
+ out_channels=out_channels,
54
+ temb_channels=temb_channels,
55
+ add_downsample=add_downsample,
56
+ resnet_eps=resnet_eps,
57
+ resnet_act_fn=resnet_act_fn,
58
+ downsample_padding=downsample_padding,
59
+ attn_num_head_channels=attn_num_head_channels,
60
+ )
61
+ elif down_block_type == "CrossAttnDownBlock2D":
62
+ if cross_attention_dim is None:
63
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
64
+ return CrossAttnDownBlock2D(
65
+ num_layers=num_layers,
66
+ in_channels=in_channels,
67
+ out_channels=out_channels,
68
+ temb_channels=temb_channels,
69
+ add_downsample=add_downsample,
70
+ resnet_eps=resnet_eps,
71
+ resnet_act_fn=resnet_act_fn,
72
+ downsample_padding=downsample_padding,
73
+ cross_attention_dim=cross_attention_dim,
74
+ attn_num_head_channels=attn_num_head_channels,
75
+ )
76
+ elif down_block_type == "SkipDownBlock2D":
77
+ return SkipDownBlock2D(
78
+ num_layers=num_layers,
79
+ in_channels=in_channels,
80
+ out_channels=out_channels,
81
+ temb_channels=temb_channels,
82
+ add_downsample=add_downsample,
83
+ resnet_eps=resnet_eps,
84
+ resnet_act_fn=resnet_act_fn,
85
+ downsample_padding=downsample_padding,
86
+ )
87
+ elif down_block_type == "AttnSkipDownBlock2D":
88
+ return AttnSkipDownBlock2D(
89
+ num_layers=num_layers,
90
+ in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ temb_channels=temb_channels,
93
+ add_downsample=add_downsample,
94
+ resnet_eps=resnet_eps,
95
+ resnet_act_fn=resnet_act_fn,
96
+ downsample_padding=downsample_padding,
97
+ attn_num_head_channels=attn_num_head_channels,
98
+ )
99
+ elif down_block_type == "DownEncoderBlock2D":
100
+ return DownEncoderBlock2D(
101
+ num_layers=num_layers,
102
+ in_channels=in_channels,
103
+ out_channels=out_channels,
104
+ add_downsample=add_downsample,
105
+ resnet_eps=resnet_eps,
106
+ resnet_act_fn=resnet_act_fn,
107
+ downsample_padding=downsample_padding,
108
+ )
109
+
110
+
111
+ def get_up_block(
112
+ up_block_type,
113
+ num_layers,
114
+ in_channels,
115
+ out_channels,
116
+ prev_output_channel,
117
+ temb_channels,
118
+ add_upsample,
119
+ resnet_eps,
120
+ resnet_act_fn,
121
+ attn_num_head_channels,
122
+ cross_attention_dim=None,
123
+ ):
124
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
125
+ if up_block_type == "UpBlock2D":
126
+ return UpBlock2D(
127
+ num_layers=num_layers,
128
+ in_channels=in_channels,
129
+ out_channels=out_channels,
130
+ prev_output_channel=prev_output_channel,
131
+ temb_channels=temb_channels,
132
+ add_upsample=add_upsample,
133
+ resnet_eps=resnet_eps,
134
+ resnet_act_fn=resnet_act_fn,
135
+ )
136
+ elif up_block_type == "CrossAttnUpBlock2D":
137
+ if cross_attention_dim is None:
138
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
139
+ return CrossAttnUpBlock2D(
140
+ num_layers=num_layers,
141
+ in_channels=in_channels,
142
+ out_channels=out_channels,
143
+ prev_output_channel=prev_output_channel,
144
+ temb_channels=temb_channels,
145
+ add_upsample=add_upsample,
146
+ resnet_eps=resnet_eps,
147
+ resnet_act_fn=resnet_act_fn,
148
+ cross_attention_dim=cross_attention_dim,
149
+ attn_num_head_channels=attn_num_head_channels,
150
+ )
151
+ elif up_block_type == "AttnUpBlock2D":
152
+ return AttnUpBlock2D(
153
+ num_layers=num_layers,
154
+ in_channels=in_channels,
155
+ out_channels=out_channels,
156
+ prev_output_channel=prev_output_channel,
157
+ temb_channels=temb_channels,
158
+ add_upsample=add_upsample,
159
+ resnet_eps=resnet_eps,
160
+ resnet_act_fn=resnet_act_fn,
161
+ attn_num_head_channels=attn_num_head_channels,
162
+ )
163
+ elif up_block_type == "SkipUpBlock2D":
164
+ return SkipUpBlock2D(
165
+ num_layers=num_layers,
166
+ in_channels=in_channels,
167
+ out_channels=out_channels,
168
+ prev_output_channel=prev_output_channel,
169
+ temb_channels=temb_channels,
170
+ add_upsample=add_upsample,
171
+ resnet_eps=resnet_eps,
172
+ resnet_act_fn=resnet_act_fn,
173
+ )
174
+ elif up_block_type == "AttnSkipUpBlock2D":
175
+ return AttnSkipUpBlock2D(
176
+ num_layers=num_layers,
177
+ in_channels=in_channels,
178
+ out_channels=out_channels,
179
+ prev_output_channel=prev_output_channel,
180
+ temb_channels=temb_channels,
181
+ add_upsample=add_upsample,
182
+ resnet_eps=resnet_eps,
183
+ resnet_act_fn=resnet_act_fn,
184
+ attn_num_head_channels=attn_num_head_channels,
185
+ )
186
+ elif up_block_type == "UpDecoderBlock2D":
187
+ return UpDecoderBlock2D(
188
+ num_layers=num_layers,
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ add_upsample=add_upsample,
192
+ resnet_eps=resnet_eps,
193
+ resnet_act_fn=resnet_act_fn,
194
+ )
195
+ raise ValueError(f"{up_block_type} does not exist.")
196
+
197
+
198
+ class UNetMidBlock2D(nn.Module):
199
+ def __init__(
200
+ self,
201
+ in_channels: int,
202
+ temb_channels: int,
203
+ dropout: float = 0.0,
204
+ num_layers: int = 1,
205
+ resnet_eps: float = 1e-6,
206
+ resnet_time_scale_shift: str = "default",
207
+ resnet_act_fn: str = "swish",
208
+ resnet_groups: int = 32,
209
+ resnet_pre_norm: bool = True,
210
+ attn_num_head_channels=1,
211
+ attention_type="default",
212
+ output_scale_factor=1.0,
213
+ **kwargs,
214
+ ):
215
+ super().__init__()
216
+
217
+ self.attention_type = attention_type
218
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
219
+
220
+ # there is always at least one resnet
221
+ resnets = [
222
+ ResnetBlock2D(
223
+ in_channels=in_channels,
224
+ out_channels=in_channels,
225
+ temb_channels=temb_channels,
226
+ eps=resnet_eps,
227
+ groups=resnet_groups,
228
+ dropout=dropout,
229
+ time_embedding_norm=resnet_time_scale_shift,
230
+ non_linearity=resnet_act_fn,
231
+ output_scale_factor=output_scale_factor,
232
+ pre_norm=resnet_pre_norm,
233
+ )
234
+ ]
235
+ attentions = []
236
+
237
+ for _ in range(num_layers):
238
+ attentions.append(
239
+ AttentionBlock(
240
+ in_channels,
241
+ num_head_channels=attn_num_head_channels,
242
+ rescale_output_factor=output_scale_factor,
243
+ eps=resnet_eps,
244
+ num_groups=resnet_groups,
245
+ )
246
+ )
247
+ resnets.append(
248
+ ResnetBlock2D(
249
+ in_channels=in_channels,
250
+ out_channels=in_channels,
251
+ temb_channels=temb_channels,
252
+ eps=resnet_eps,
253
+ groups=resnet_groups,
254
+ dropout=dropout,
255
+ time_embedding_norm=resnet_time_scale_shift,
256
+ non_linearity=resnet_act_fn,
257
+ output_scale_factor=output_scale_factor,
258
+ pre_norm=resnet_pre_norm,
259
+ )
260
+ )
261
+
262
+ self.attentions = nn.ModuleList(attentions)
263
+ self.resnets = nn.ModuleList(resnets)
264
+
265
+ def forward(self, hidden_states, temb=None, encoder_states=None):
266
+ hidden_states = self.resnets[0](hidden_states, temb)
267
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
268
+ if self.attention_type == "default":
269
+ hidden_states = attn(hidden_states)
270
+ else:
271
+ hidden_states = attn(hidden_states, encoder_states)
272
+ hidden_states = resnet(hidden_states, temb)
273
+
274
+ return hidden_states
275
+
276
+
277
+ class UNetMidBlock2DCrossAttn(nn.Module):
278
+ def __init__(
279
+ self,
280
+ in_channels: int,
281
+ temb_channels: int,
282
+ dropout: float = 0.0,
283
+ num_layers: int = 1,
284
+ resnet_eps: float = 1e-6,
285
+ resnet_time_scale_shift: str = "default",
286
+ resnet_act_fn: str = "swish",
287
+ resnet_groups: int = 32,
288
+ resnet_pre_norm: bool = True,
289
+ attn_num_head_channels=1,
290
+ attention_type="default",
291
+ output_scale_factor=1.0,
292
+ cross_attention_dim=1280,
293
+ **kwargs,
294
+ ):
295
+ super().__init__()
296
+
297
+ self.attention_type = attention_type
298
+ self.attn_num_head_channels = attn_num_head_channels
299
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
300
+
301
+ # there is always at least one resnet
302
+ resnets = [
303
+ ResnetBlock2D(
304
+ in_channels=in_channels,
305
+ out_channels=in_channels,
306
+ temb_channels=temb_channels,
307
+ eps=resnet_eps,
308
+ groups=resnet_groups,
309
+ dropout=dropout,
310
+ time_embedding_norm=resnet_time_scale_shift,
311
+ non_linearity=resnet_act_fn,
312
+ output_scale_factor=output_scale_factor,
313
+ pre_norm=resnet_pre_norm,
314
+ )
315
+ ]
316
+ attentions = []
317
+
318
+ for _ in range(num_layers):
319
+ attentions.append(
320
+ SpatialTransformer(
321
+ in_channels,
322
+ attn_num_head_channels,
323
+ in_channels // attn_num_head_channels,
324
+ depth=1,
325
+ context_dim=cross_attention_dim,
326
+ )
327
+ )
328
+ resnets.append(
329
+ ResnetBlock2D(
330
+ in_channels=in_channels,
331
+ out_channels=in_channels,
332
+ temb_channels=temb_channels,
333
+ eps=resnet_eps,
334
+ groups=resnet_groups,
335
+ dropout=dropout,
336
+ time_embedding_norm=resnet_time_scale_shift,
337
+ non_linearity=resnet_act_fn,
338
+ output_scale_factor=output_scale_factor,
339
+ pre_norm=resnet_pre_norm,
340
+ )
341
+ )
342
+
343
+ self.attentions = nn.ModuleList(attentions)
344
+ self.resnets = nn.ModuleList(resnets)
345
+
346
+ def set_attention_slice(self, slice_size):
347
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
348
+ raise ValueError(
349
+ f"Make sure slice_size {slice_size} is a divisor of "
350
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
351
+ )
352
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
353
+ raise ValueError(
354
+ f"Chunk_size {slice_size} has to be smaller or equal to "
355
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
356
+ )
357
+
358
+ for attn in self.attentions:
359
+ attn._set_attention_slice(slice_size)
360
+
361
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
362
+ hidden_states = self.resnets[0](hidden_states, temb)
363
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
364
+ hidden_states = attn(hidden_states, encoder_hidden_states)
365
+ hidden_states = resnet(hidden_states, temb)
366
+
367
+ return hidden_states
368
+
369
+
370
+ class AttnDownBlock2D(nn.Module):
371
+ def __init__(
372
+ self,
373
+ in_channels: int,
374
+ out_channels: int,
375
+ temb_channels: int,
376
+ dropout: float = 0.0,
377
+ num_layers: int = 1,
378
+ resnet_eps: float = 1e-6,
379
+ resnet_time_scale_shift: str = "default",
380
+ resnet_act_fn: str = "swish",
381
+ resnet_groups: int = 32,
382
+ resnet_pre_norm: bool = True,
383
+ attn_num_head_channels=1,
384
+ attention_type="default",
385
+ output_scale_factor=1.0,
386
+ downsample_padding=1,
387
+ add_downsample=True,
388
+ ):
389
+ super().__init__()
390
+ resnets = []
391
+ attentions = []
392
+
393
+ self.attention_type = attention_type
394
+
395
+ for i in range(num_layers):
396
+ in_channels = in_channels if i == 0 else out_channels
397
+ resnets.append(
398
+ ResnetBlock2D(
399
+ in_channels=in_channels,
400
+ out_channels=out_channels,
401
+ temb_channels=temb_channels,
402
+ eps=resnet_eps,
403
+ groups=resnet_groups,
404
+ dropout=dropout,
405
+ time_embedding_norm=resnet_time_scale_shift,
406
+ non_linearity=resnet_act_fn,
407
+ output_scale_factor=output_scale_factor,
408
+ pre_norm=resnet_pre_norm,
409
+ )
410
+ )
411
+ attentions.append(
412
+ AttentionBlock(
413
+ out_channels,
414
+ num_head_channels=attn_num_head_channels,
415
+ rescale_output_factor=output_scale_factor,
416
+ eps=resnet_eps,
417
+ )
418
+ )
419
+
420
+ self.attentions = nn.ModuleList(attentions)
421
+ self.resnets = nn.ModuleList(resnets)
422
+
423
+ if add_downsample:
424
+ self.downsamplers = nn.ModuleList(
425
+ [
426
+ Downsample2D(
427
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
428
+ )
429
+ ]
430
+ )
431
+ else:
432
+ self.downsamplers = None
433
+
434
+ def forward(self, hidden_states, temb=None):
435
+ output_states = ()
436
+
437
+ for resnet, attn in zip(self.resnets, self.attentions):
438
+ hidden_states = resnet(hidden_states, temb)
439
+ hidden_states = attn(hidden_states)
440
+ output_states += (hidden_states,)
441
+
442
+ if self.downsamplers is not None:
443
+ for downsampler in self.downsamplers:
444
+ hidden_states = downsampler(hidden_states)
445
+
446
+ output_states += (hidden_states,)
447
+
448
+ return hidden_states, output_states
449
+
450
+
451
+ class CrossAttnDownBlock2D(nn.Module):
452
+ def __init__(
453
+ self,
454
+ in_channels: int,
455
+ out_channels: int,
456
+ temb_channels: int,
457
+ dropout: float = 0.0,
458
+ num_layers: int = 1,
459
+ resnet_eps: float = 1e-6,
460
+ resnet_time_scale_shift: str = "default",
461
+ resnet_act_fn: str = "swish",
462
+ resnet_groups: int = 32,
463
+ resnet_pre_norm: bool = True,
464
+ attn_num_head_channels=1,
465
+ cross_attention_dim=1280,
466
+ attention_type="default",
467
+ output_scale_factor=1.0,
468
+ downsample_padding=1,
469
+ add_downsample=True,
470
+ ):
471
+ super().__init__()
472
+ resnets = []
473
+ attentions = []
474
+
475
+ self.attention_type = attention_type
476
+ self.attn_num_head_channels = attn_num_head_channels
477
+
478
+ for i in range(num_layers):
479
+ in_channels = in_channels if i == 0 else out_channels
480
+ resnets.append(
481
+ ResnetBlock2D(
482
+ in_channels=in_channels,
483
+ out_channels=out_channels,
484
+ temb_channels=temb_channels,
485
+ eps=resnet_eps,
486
+ groups=resnet_groups,
487
+ dropout=dropout,
488
+ time_embedding_norm=resnet_time_scale_shift,
489
+ non_linearity=resnet_act_fn,
490
+ output_scale_factor=output_scale_factor,
491
+ pre_norm=resnet_pre_norm,
492
+ )
493
+ )
494
+ attentions.append(
495
+ SpatialTransformer(
496
+ out_channels,
497
+ attn_num_head_channels,
498
+ out_channels // attn_num_head_channels,
499
+ depth=1,
500
+ context_dim=cross_attention_dim,
501
+ )
502
+ )
503
+ self.attentions = nn.ModuleList(attentions)
504
+ self.resnets = nn.ModuleList(resnets)
505
+
506
+ if add_downsample:
507
+ self.downsamplers = nn.ModuleList(
508
+ [
509
+ Downsample2D(
510
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
511
+ )
512
+ ]
513
+ )
514
+ else:
515
+ self.downsamplers = None
516
+
517
+ def set_attention_slice(self, slice_size):
518
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
519
+ raise ValueError(
520
+ f"Make sure slice_size {slice_size} is a divisor of "
521
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
522
+ )
523
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
524
+ raise ValueError(
525
+ f"Chunk_size {slice_size} has to be smaller or equal to "
526
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
527
+ )
528
+
529
+ for attn in self.attentions:
530
+ attn._set_attention_slice(slice_size)
531
+
532
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
533
+ output_states = ()
534
+
535
+ for resnet, attn in zip(self.resnets, self.attentions):
536
+ hidden_states = resnet(hidden_states, temb)
537
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
538
+ output_states += (hidden_states,)
539
+
540
+ if self.downsamplers is not None:
541
+ for downsampler in self.downsamplers:
542
+ hidden_states = downsampler(hidden_states)
543
+
544
+ output_states += (hidden_states,)
545
+
546
+ return hidden_states, output_states
547
+
548
+
549
+ class DownBlock2D(nn.Module):
550
+ def __init__(
551
+ self,
552
+ in_channels: int,
553
+ out_channels: int,
554
+ temb_channels: int,
555
+ dropout: float = 0.0,
556
+ num_layers: int = 1,
557
+ resnet_eps: float = 1e-6,
558
+ resnet_time_scale_shift: str = "default",
559
+ resnet_act_fn: str = "swish",
560
+ resnet_groups: int = 32,
561
+ resnet_pre_norm: bool = True,
562
+ output_scale_factor=1.0,
563
+ add_downsample=True,
564
+ downsample_padding=1,
565
+ ):
566
+ super().__init__()
567
+ resnets = []
568
+
569
+ for i in range(num_layers):
570
+ in_channels = in_channels if i == 0 else out_channels
571
+ resnets.append(
572
+ ResnetBlock2D(
573
+ in_channels=in_channels,
574
+ out_channels=out_channels,
575
+ temb_channels=temb_channels,
576
+ eps=resnet_eps,
577
+ groups=resnet_groups,
578
+ dropout=dropout,
579
+ time_embedding_norm=resnet_time_scale_shift,
580
+ non_linearity=resnet_act_fn,
581
+ output_scale_factor=output_scale_factor,
582
+ pre_norm=resnet_pre_norm,
583
+ )
584
+ )
585
+
586
+ self.resnets = nn.ModuleList(resnets)
587
+
588
+ if add_downsample:
589
+ self.downsamplers = nn.ModuleList(
590
+ [
591
+ Downsample2D(
592
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
593
+ )
594
+ ]
595
+ )
596
+ else:
597
+ self.downsamplers = None
598
+
599
+ def forward(self, hidden_states, temb=None):
600
+ output_states = ()
601
+
602
+ for resnet in self.resnets:
603
+ hidden_states = resnet(hidden_states, temb)
604
+ output_states += (hidden_states,)
605
+
606
+ if self.downsamplers is not None:
607
+ for downsampler in self.downsamplers:
608
+ hidden_states = downsampler(hidden_states)
609
+
610
+ output_states += (hidden_states,)
611
+
612
+ return hidden_states, output_states
613
+
614
+
615
+ class DownEncoderBlock2D(nn.Module):
616
+ def __init__(
617
+ self,
618
+ in_channels: int,
619
+ out_channels: int,
620
+ dropout: float = 0.0,
621
+ num_layers: int = 1,
622
+ resnet_eps: float = 1e-6,
623
+ resnet_time_scale_shift: str = "default",
624
+ resnet_act_fn: str = "swish",
625
+ resnet_groups: int = 32,
626
+ resnet_pre_norm: bool = True,
627
+ output_scale_factor=1.0,
628
+ add_downsample=True,
629
+ downsample_padding=1,
630
+ ):
631
+ super().__init__()
632
+ resnets = []
633
+
634
+ for i in range(num_layers):
635
+ in_channels = in_channels if i == 0 else out_channels
636
+ resnets.append(
637
+ ResnetBlock2D(
638
+ in_channels=in_channels,
639
+ out_channels=out_channels,
640
+ temb_channels=None,
641
+ eps=resnet_eps,
642
+ groups=resnet_groups,
643
+ dropout=dropout,
644
+ time_embedding_norm=resnet_time_scale_shift,
645
+ non_linearity=resnet_act_fn,
646
+ output_scale_factor=output_scale_factor,
647
+ pre_norm=resnet_pre_norm,
648
+ )
649
+ )
650
+
651
+ self.resnets = nn.ModuleList(resnets)
652
+
653
+ if add_downsample:
654
+ self.downsamplers = nn.ModuleList(
655
+ [
656
+ Downsample2D(
657
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
658
+ )
659
+ ]
660
+ )
661
+ else:
662
+ self.downsamplers = None
663
+
664
+ def forward(self, hidden_states):
665
+ for resnet in self.resnets:
666
+ hidden_states = resnet(hidden_states, temb=None)
667
+
668
+ if self.downsamplers is not None:
669
+ for downsampler in self.downsamplers:
670
+ hidden_states = downsampler(hidden_states)
671
+
672
+ return hidden_states
673
+
674
+
675
+ class AttnDownEncoderBlock2D(nn.Module):
676
+ def __init__(
677
+ self,
678
+ in_channels: int,
679
+ out_channels: int,
680
+ dropout: float = 0.0,
681
+ num_layers: int = 1,
682
+ resnet_eps: float = 1e-6,
683
+ resnet_time_scale_shift: str = "default",
684
+ resnet_act_fn: str = "swish",
685
+ resnet_groups: int = 32,
686
+ resnet_pre_norm: bool = True,
687
+ attn_num_head_channels=1,
688
+ output_scale_factor=1.0,
689
+ add_downsample=True,
690
+ downsample_padding=1,
691
+ ):
692
+ super().__init__()
693
+ resnets = []
694
+ attentions = []
695
+
696
+ for i in range(num_layers):
697
+ in_channels = in_channels if i == 0 else out_channels
698
+ resnets.append(
699
+ ResnetBlock2D(
700
+ in_channels=in_channels,
701
+ out_channels=out_channels,
702
+ temb_channels=None,
703
+ eps=resnet_eps,
704
+ groups=resnet_groups,
705
+ dropout=dropout,
706
+ time_embedding_norm=resnet_time_scale_shift,
707
+ non_linearity=resnet_act_fn,
708
+ output_scale_factor=output_scale_factor,
709
+ pre_norm=resnet_pre_norm,
710
+ )
711
+ )
712
+ attentions.append(
713
+ AttentionBlock(
714
+ out_channels,
715
+ num_head_channels=attn_num_head_channels,
716
+ rescale_output_factor=output_scale_factor,
717
+ eps=resnet_eps,
718
+ num_groups=resnet_groups,
719
+ )
720
+ )
721
+
722
+ self.attentions = nn.ModuleList(attentions)
723
+ self.resnets = nn.ModuleList(resnets)
724
+
725
+ if add_downsample:
726
+ self.downsamplers = nn.ModuleList(
727
+ [
728
+ Downsample2D(
729
+ in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
730
+ )
731
+ ]
732
+ )
733
+ else:
734
+ self.downsamplers = None
735
+
736
+ def forward(self, hidden_states):
737
+ for resnet, attn in zip(self.resnets, self.attentions):
738
+ hidden_states = resnet(hidden_states, temb=None)
739
+ hidden_states = attn(hidden_states)
740
+
741
+ if self.downsamplers is not None:
742
+ for downsampler in self.downsamplers:
743
+ hidden_states = downsampler(hidden_states)
744
+
745
+ return hidden_states
746
+
747
+
748
+ class AttnSkipDownBlock2D(nn.Module):
749
+ def __init__(
750
+ self,
751
+ in_channels: int,
752
+ out_channels: int,
753
+ temb_channels: int,
754
+ dropout: float = 0.0,
755
+ num_layers: int = 1,
756
+ resnet_eps: float = 1e-6,
757
+ resnet_time_scale_shift: str = "default",
758
+ resnet_act_fn: str = "swish",
759
+ resnet_pre_norm: bool = True,
760
+ attn_num_head_channels=1,
761
+ attention_type="default",
762
+ output_scale_factor=np.sqrt(2.0),
763
+ downsample_padding=1,
764
+ add_downsample=True,
765
+ ):
766
+ super().__init__()
767
+ self.attentions = nn.ModuleList([])
768
+ self.resnets = nn.ModuleList([])
769
+
770
+ self.attention_type = attention_type
771
+
772
+ for i in range(num_layers):
773
+ in_channels = in_channels if i == 0 else out_channels
774
+ self.resnets.append(
775
+ ResnetBlock2D(
776
+ in_channels=in_channels,
777
+ out_channels=out_channels,
778
+ temb_channels=temb_channels,
779
+ eps=resnet_eps,
780
+ groups=min(in_channels // 4, 32),
781
+ groups_out=min(out_channels // 4, 32),
782
+ dropout=dropout,
783
+ time_embedding_norm=resnet_time_scale_shift,
784
+ non_linearity=resnet_act_fn,
785
+ output_scale_factor=output_scale_factor,
786
+ pre_norm=resnet_pre_norm,
787
+ )
788
+ )
789
+ self.attentions.append(
790
+ AttentionBlock(
791
+ out_channels,
792
+ num_head_channels=attn_num_head_channels,
793
+ rescale_output_factor=output_scale_factor,
794
+ eps=resnet_eps,
795
+ )
796
+ )
797
+
798
+ if add_downsample:
799
+ self.resnet_down = ResnetBlock2D(
800
+ in_channels=out_channels,
801
+ out_channels=out_channels,
802
+ temb_channels=temb_channels,
803
+ eps=resnet_eps,
804
+ groups=min(out_channels // 4, 32),
805
+ dropout=dropout,
806
+ time_embedding_norm=resnet_time_scale_shift,
807
+ non_linearity=resnet_act_fn,
808
+ output_scale_factor=output_scale_factor,
809
+ pre_norm=resnet_pre_norm,
810
+ use_nin_shortcut=True,
811
+ down=True,
812
+ kernel="fir",
813
+ )
814
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
815
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
816
+ else:
817
+ self.resnet_down = None
818
+ self.downsamplers = None
819
+ self.skip_conv = None
820
+
821
+ def forward(self, hidden_states, temb=None, skip_sample=None):
822
+ output_states = ()
823
+
824
+ for resnet, attn in zip(self.resnets, self.attentions):
825
+ hidden_states = resnet(hidden_states, temb)
826
+ hidden_states = attn(hidden_states)
827
+ output_states += (hidden_states,)
828
+
829
+ if self.downsamplers is not None:
830
+ hidden_states = self.resnet_down(hidden_states, temb)
831
+ for downsampler in self.downsamplers:
832
+ skip_sample = downsampler(skip_sample)
833
+
834
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
835
+
836
+ output_states += (hidden_states,)
837
+
838
+ return hidden_states, output_states, skip_sample
839
+
840
+
841
+ class SkipDownBlock2D(nn.Module):
842
+ def __init__(
843
+ self,
844
+ in_channels: int,
845
+ out_channels: int,
846
+ temb_channels: int,
847
+ dropout: float = 0.0,
848
+ num_layers: int = 1,
849
+ resnet_eps: float = 1e-6,
850
+ resnet_time_scale_shift: str = "default",
851
+ resnet_act_fn: str = "swish",
852
+ resnet_pre_norm: bool = True,
853
+ output_scale_factor=np.sqrt(2.0),
854
+ add_downsample=True,
855
+ downsample_padding=1,
856
+ ):
857
+ super().__init__()
858
+ self.resnets = nn.ModuleList([])
859
+
860
+ for i in range(num_layers):
861
+ in_channels = in_channels if i == 0 else out_channels
862
+ self.resnets.append(
863
+ ResnetBlock2D(
864
+ in_channels=in_channels,
865
+ out_channels=out_channels,
866
+ temb_channels=temb_channels,
867
+ eps=resnet_eps,
868
+ groups=min(in_channels // 4, 32),
869
+ groups_out=min(out_channels // 4, 32),
870
+ dropout=dropout,
871
+ time_embedding_norm=resnet_time_scale_shift,
872
+ non_linearity=resnet_act_fn,
873
+ output_scale_factor=output_scale_factor,
874
+ pre_norm=resnet_pre_norm,
875
+ )
876
+ )
877
+
878
+ if add_downsample:
879
+ self.resnet_down = ResnetBlock2D(
880
+ in_channels=out_channels,
881
+ out_channels=out_channels,
882
+ temb_channels=temb_channels,
883
+ eps=resnet_eps,
884
+ groups=min(out_channels // 4, 32),
885
+ dropout=dropout,
886
+ time_embedding_norm=resnet_time_scale_shift,
887
+ non_linearity=resnet_act_fn,
888
+ output_scale_factor=output_scale_factor,
889
+ pre_norm=resnet_pre_norm,
890
+ use_nin_shortcut=True,
891
+ down=True,
892
+ kernel="fir",
893
+ )
894
+ self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
895
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
896
+ else:
897
+ self.resnet_down = None
898
+ self.downsamplers = None
899
+ self.skip_conv = None
900
+
901
+ def forward(self, hidden_states, temb=None, skip_sample=None):
902
+ output_states = ()
903
+
904
+ for resnet in self.resnets:
905
+ hidden_states = resnet(hidden_states, temb)
906
+ output_states += (hidden_states,)
907
+
908
+ if self.downsamplers is not None:
909
+ hidden_states = self.resnet_down(hidden_states, temb)
910
+ for downsampler in self.downsamplers:
911
+ skip_sample = downsampler(skip_sample)
912
+
913
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
914
+
915
+ output_states += (hidden_states,)
916
+
917
+ return hidden_states, output_states, skip_sample
918
+
919
+
920
+ class AttnUpBlock2D(nn.Module):
921
+ def __init__(
922
+ self,
923
+ in_channels: int,
924
+ prev_output_channel: int,
925
+ out_channels: int,
926
+ temb_channels: int,
927
+ dropout: float = 0.0,
928
+ num_layers: int = 1,
929
+ resnet_eps: float = 1e-6,
930
+ resnet_time_scale_shift: str = "default",
931
+ resnet_act_fn: str = "swish",
932
+ resnet_groups: int = 32,
933
+ resnet_pre_norm: bool = True,
934
+ attention_type="default",
935
+ attn_num_head_channels=1,
936
+ output_scale_factor=1.0,
937
+ add_upsample=True,
938
+ ):
939
+ super().__init__()
940
+ resnets = []
941
+ attentions = []
942
+
943
+ self.attention_type = attention_type
944
+
945
+ for i in range(num_layers):
946
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
947
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
948
+
949
+ resnets.append(
950
+ ResnetBlock2D(
951
+ in_channels=resnet_in_channels + res_skip_channels,
952
+ out_channels=out_channels,
953
+ temb_channels=temb_channels,
954
+ eps=resnet_eps,
955
+ groups=resnet_groups,
956
+ dropout=dropout,
957
+ time_embedding_norm=resnet_time_scale_shift,
958
+ non_linearity=resnet_act_fn,
959
+ output_scale_factor=output_scale_factor,
960
+ pre_norm=resnet_pre_norm,
961
+ )
962
+ )
963
+ attentions.append(
964
+ AttentionBlock(
965
+ out_channels,
966
+ num_head_channels=attn_num_head_channels,
967
+ rescale_output_factor=output_scale_factor,
968
+ eps=resnet_eps,
969
+ )
970
+ )
971
+
972
+ self.attentions = nn.ModuleList(attentions)
973
+ self.resnets = nn.ModuleList(resnets)
974
+
975
+ if add_upsample:
976
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
977
+ else:
978
+ self.upsamplers = None
979
+
980
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
981
+ for resnet, attn in zip(self.resnets, self.attentions):
982
+
983
+ # pop res hidden states
984
+ res_hidden_states = res_hidden_states_tuple[-1]
985
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
986
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
987
+
988
+ hidden_states = resnet(hidden_states, temb)
989
+ hidden_states = attn(hidden_states)
990
+
991
+ if self.upsamplers is not None:
992
+ for upsampler in self.upsamplers:
993
+ hidden_states = upsampler(hidden_states)
994
+
995
+ return hidden_states
996
+
997
+
998
+ class CrossAttnUpBlock2D(nn.Module):
999
+ def __init__(
1000
+ self,
1001
+ in_channels: int,
1002
+ out_channels: int,
1003
+ prev_output_channel: int,
1004
+ temb_channels: int,
1005
+ dropout: float = 0.0,
1006
+ num_layers: int = 1,
1007
+ resnet_eps: float = 1e-6,
1008
+ resnet_time_scale_shift: str = "default",
1009
+ resnet_act_fn: str = "swish",
1010
+ resnet_groups: int = 32,
1011
+ resnet_pre_norm: bool = True,
1012
+ attn_num_head_channels=1,
1013
+ cross_attention_dim=1280,
1014
+ attention_type="default",
1015
+ output_scale_factor=1.0,
1016
+ downsample_padding=1,
1017
+ add_upsample=True,
1018
+ ):
1019
+ super().__init__()
1020
+ resnets = []
1021
+ attentions = []
1022
+
1023
+ self.attention_type = attention_type
1024
+ self.attn_num_head_channels = attn_num_head_channels
1025
+
1026
+ for i in range(num_layers):
1027
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1028
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1029
+
1030
+ resnets.append(
1031
+ ResnetBlock2D(
1032
+ in_channels=resnet_in_channels + res_skip_channels,
1033
+ out_channels=out_channels,
1034
+ temb_channels=temb_channels,
1035
+ eps=resnet_eps,
1036
+ groups=resnet_groups,
1037
+ dropout=dropout,
1038
+ time_embedding_norm=resnet_time_scale_shift,
1039
+ non_linearity=resnet_act_fn,
1040
+ output_scale_factor=output_scale_factor,
1041
+ pre_norm=resnet_pre_norm,
1042
+ )
1043
+ )
1044
+ attentions.append(
1045
+ SpatialTransformer(
1046
+ out_channels,
1047
+ attn_num_head_channels,
1048
+ out_channels // attn_num_head_channels,
1049
+ depth=1,
1050
+ context_dim=cross_attention_dim,
1051
+ )
1052
+ )
1053
+ self.attentions = nn.ModuleList(attentions)
1054
+ self.resnets = nn.ModuleList(resnets)
1055
+
1056
+ if add_upsample:
1057
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1058
+ else:
1059
+ self.upsamplers = None
1060
+
1061
+ def set_attention_slice(self, slice_size):
1062
+ if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1063
+ raise ValueError(
1064
+ f"Make sure slice_size {slice_size} is a divisor of "
1065
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1066
+ )
1067
+ if slice_size is not None and slice_size > self.attn_num_head_channels:
1068
+ raise ValueError(
1069
+ f"Chunk_size {slice_size} has to be smaller or equal to "
1070
+ f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1071
+ )
1072
+
1073
+ for attn in self.attentions:
1074
+ attn._set_attention_slice(slice_size)
1075
+
1076
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
1077
+ for resnet, attn in zip(self.resnets, self.attentions):
1078
+
1079
+ # pop res hidden states
1080
+ res_hidden_states = res_hidden_states_tuple[-1]
1081
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1082
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1083
+
1084
+ hidden_states = resnet(hidden_states, temb)
1085
+ hidden_states = attn(hidden_states, context=encoder_hidden_states)
1086
+
1087
+ if self.upsamplers is not None:
1088
+ for upsampler in self.upsamplers:
1089
+ hidden_states = upsampler(hidden_states)
1090
+
1091
+ return hidden_states
1092
+
1093
+
1094
+ class UpBlock2D(nn.Module):
1095
+ def __init__(
1096
+ self,
1097
+ in_channels: int,
1098
+ prev_output_channel: int,
1099
+ out_channels: int,
1100
+ temb_channels: int,
1101
+ dropout: float = 0.0,
1102
+ num_layers: int = 1,
1103
+ resnet_eps: float = 1e-6,
1104
+ resnet_time_scale_shift: str = "default",
1105
+ resnet_act_fn: str = "swish",
1106
+ resnet_groups: int = 32,
1107
+ resnet_pre_norm: bool = True,
1108
+ output_scale_factor=1.0,
1109
+ add_upsample=True,
1110
+ ):
1111
+ super().__init__()
1112
+ resnets = []
1113
+
1114
+ for i in range(num_layers):
1115
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1116
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1117
+
1118
+ resnets.append(
1119
+ ResnetBlock2D(
1120
+ in_channels=resnet_in_channels + res_skip_channels,
1121
+ out_channels=out_channels,
1122
+ temb_channels=temb_channels,
1123
+ eps=resnet_eps,
1124
+ groups=resnet_groups,
1125
+ dropout=dropout,
1126
+ time_embedding_norm=resnet_time_scale_shift,
1127
+ non_linearity=resnet_act_fn,
1128
+ output_scale_factor=output_scale_factor,
1129
+ pre_norm=resnet_pre_norm,
1130
+ )
1131
+ )
1132
+
1133
+ self.resnets = nn.ModuleList(resnets)
1134
+
1135
+ if add_upsample:
1136
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1137
+ else:
1138
+ self.upsamplers = None
1139
+
1140
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1141
+ for resnet in self.resnets:
1142
+
1143
+ # pop res hidden states
1144
+ res_hidden_states = res_hidden_states_tuple[-1]
1145
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1146
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1147
+
1148
+ hidden_states = resnet(hidden_states, temb)
1149
+
1150
+ if self.upsamplers is not None:
1151
+ for upsampler in self.upsamplers:
1152
+ hidden_states = upsampler(hidden_states)
1153
+
1154
+ return hidden_states
1155
+
1156
+
1157
+ class UpDecoderBlock2D(nn.Module):
1158
+ def __init__(
1159
+ self,
1160
+ in_channels: int,
1161
+ out_channels: int,
1162
+ dropout: float = 0.0,
1163
+ num_layers: int = 1,
1164
+ resnet_eps: float = 1e-6,
1165
+ resnet_time_scale_shift: str = "default",
1166
+ resnet_act_fn: str = "swish",
1167
+ resnet_groups: int = 32,
1168
+ resnet_pre_norm: bool = True,
1169
+ output_scale_factor=1.0,
1170
+ add_upsample=True,
1171
+ ):
1172
+ super().__init__()
1173
+ resnets = []
1174
+
1175
+ for i in range(num_layers):
1176
+ input_channels = in_channels if i == 0 else out_channels
1177
+
1178
+ resnets.append(
1179
+ ResnetBlock2D(
1180
+ in_channels=input_channels,
1181
+ out_channels=out_channels,
1182
+ temb_channels=None,
1183
+ eps=resnet_eps,
1184
+ groups=resnet_groups,
1185
+ dropout=dropout,
1186
+ time_embedding_norm=resnet_time_scale_shift,
1187
+ non_linearity=resnet_act_fn,
1188
+ output_scale_factor=output_scale_factor,
1189
+ pre_norm=resnet_pre_norm,
1190
+ )
1191
+ )
1192
+
1193
+ self.resnets = nn.ModuleList(resnets)
1194
+
1195
+ if add_upsample:
1196
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1197
+ else:
1198
+ self.upsamplers = None
1199
+
1200
+ def forward(self, hidden_states):
1201
+ for resnet in self.resnets:
1202
+ hidden_states = resnet(hidden_states, temb=None)
1203
+
1204
+ if self.upsamplers is not None:
1205
+ for upsampler in self.upsamplers:
1206
+ hidden_states = upsampler(hidden_states)
1207
+
1208
+ return hidden_states
1209
+
1210
+
1211
+ class AttnUpDecoderBlock2D(nn.Module):
1212
+ def __init__(
1213
+ self,
1214
+ in_channels: int,
1215
+ out_channels: int,
1216
+ dropout: float = 0.0,
1217
+ num_layers: int = 1,
1218
+ resnet_eps: float = 1e-6,
1219
+ resnet_time_scale_shift: str = "default",
1220
+ resnet_act_fn: str = "swish",
1221
+ resnet_groups: int = 32,
1222
+ resnet_pre_norm: bool = True,
1223
+ attn_num_head_channels=1,
1224
+ output_scale_factor=1.0,
1225
+ add_upsample=True,
1226
+ ):
1227
+ super().__init__()
1228
+ resnets = []
1229
+ attentions = []
1230
+
1231
+ for i in range(num_layers):
1232
+ input_channels = in_channels if i == 0 else out_channels
1233
+
1234
+ resnets.append(
1235
+ ResnetBlock2D(
1236
+ in_channels=input_channels,
1237
+ out_channels=out_channels,
1238
+ temb_channels=None,
1239
+ eps=resnet_eps,
1240
+ groups=resnet_groups,
1241
+ dropout=dropout,
1242
+ time_embedding_norm=resnet_time_scale_shift,
1243
+ non_linearity=resnet_act_fn,
1244
+ output_scale_factor=output_scale_factor,
1245
+ pre_norm=resnet_pre_norm,
1246
+ )
1247
+ )
1248
+ attentions.append(
1249
+ AttentionBlock(
1250
+ out_channels,
1251
+ num_head_channels=attn_num_head_channels,
1252
+ rescale_output_factor=output_scale_factor,
1253
+ eps=resnet_eps,
1254
+ num_groups=resnet_groups,
1255
+ )
1256
+ )
1257
+
1258
+ self.attentions = nn.ModuleList(attentions)
1259
+ self.resnets = nn.ModuleList(resnets)
1260
+
1261
+ if add_upsample:
1262
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1263
+ else:
1264
+ self.upsamplers = None
1265
+
1266
+ def forward(self, hidden_states):
1267
+ for resnet, attn in zip(self.resnets, self.attentions):
1268
+ hidden_states = resnet(hidden_states, temb=None)
1269
+ hidden_states = attn(hidden_states)
1270
+
1271
+ if self.upsamplers is not None:
1272
+ for upsampler in self.upsamplers:
1273
+ hidden_states = upsampler(hidden_states)
1274
+
1275
+ return hidden_states
1276
+
1277
+
1278
+ class AttnSkipUpBlock2D(nn.Module):
1279
+ def __init__(
1280
+ self,
1281
+ in_channels: int,
1282
+ prev_output_channel: int,
1283
+ out_channels: int,
1284
+ temb_channels: int,
1285
+ dropout: float = 0.0,
1286
+ num_layers: int = 1,
1287
+ resnet_eps: float = 1e-6,
1288
+ resnet_time_scale_shift: str = "default",
1289
+ resnet_act_fn: str = "swish",
1290
+ resnet_pre_norm: bool = True,
1291
+ attn_num_head_channels=1,
1292
+ attention_type="default",
1293
+ output_scale_factor=np.sqrt(2.0),
1294
+ upsample_padding=1,
1295
+ add_upsample=True,
1296
+ ):
1297
+ super().__init__()
1298
+ self.attentions = nn.ModuleList([])
1299
+ self.resnets = nn.ModuleList([])
1300
+
1301
+ self.attention_type = attention_type
1302
+
1303
+ for i in range(num_layers):
1304
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1305
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1306
+
1307
+ self.resnets.append(
1308
+ ResnetBlock2D(
1309
+ in_channels=resnet_in_channels + res_skip_channels,
1310
+ out_channels=out_channels,
1311
+ temb_channels=temb_channels,
1312
+ eps=resnet_eps,
1313
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1314
+ groups_out=min(out_channels // 4, 32),
1315
+ dropout=dropout,
1316
+ time_embedding_norm=resnet_time_scale_shift,
1317
+ non_linearity=resnet_act_fn,
1318
+ output_scale_factor=output_scale_factor,
1319
+ pre_norm=resnet_pre_norm,
1320
+ )
1321
+ )
1322
+
1323
+ self.attentions.append(
1324
+ AttentionBlock(
1325
+ out_channels,
1326
+ num_head_channels=attn_num_head_channels,
1327
+ rescale_output_factor=output_scale_factor,
1328
+ eps=resnet_eps,
1329
+ )
1330
+ )
1331
+
1332
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1333
+ if add_upsample:
1334
+ self.resnet_up = ResnetBlock2D(
1335
+ in_channels=out_channels,
1336
+ out_channels=out_channels,
1337
+ temb_channels=temb_channels,
1338
+ eps=resnet_eps,
1339
+ groups=min(out_channels // 4, 32),
1340
+ groups_out=min(out_channels // 4, 32),
1341
+ dropout=dropout,
1342
+ time_embedding_norm=resnet_time_scale_shift,
1343
+ non_linearity=resnet_act_fn,
1344
+ output_scale_factor=output_scale_factor,
1345
+ pre_norm=resnet_pre_norm,
1346
+ use_nin_shortcut=True,
1347
+ up=True,
1348
+ kernel="fir",
1349
+ )
1350
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1351
+ self.skip_norm = torch.nn.GroupNorm(
1352
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1353
+ )
1354
+ self.act = nn.SiLU()
1355
+ else:
1356
+ self.resnet_up = None
1357
+ self.skip_conv = None
1358
+ self.skip_norm = None
1359
+ self.act = None
1360
+
1361
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1362
+ for resnet in self.resnets:
1363
+ # pop res hidden states
1364
+ res_hidden_states = res_hidden_states_tuple[-1]
1365
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1366
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1367
+
1368
+ hidden_states = resnet(hidden_states, temb)
1369
+
1370
+ hidden_states = self.attentions[0](hidden_states)
1371
+
1372
+ if skip_sample is not None:
1373
+ skip_sample = self.upsampler(skip_sample)
1374
+ else:
1375
+ skip_sample = 0
1376
+
1377
+ if self.resnet_up is not None:
1378
+ skip_sample_states = self.skip_norm(hidden_states)
1379
+ skip_sample_states = self.act(skip_sample_states)
1380
+ skip_sample_states = self.skip_conv(skip_sample_states)
1381
+
1382
+ skip_sample = skip_sample + skip_sample_states
1383
+
1384
+ hidden_states = self.resnet_up(hidden_states, temb)
1385
+
1386
+ return hidden_states, skip_sample
1387
+
1388
+
1389
+ class SkipUpBlock2D(nn.Module):
1390
+ def __init__(
1391
+ self,
1392
+ in_channels: int,
1393
+ prev_output_channel: int,
1394
+ out_channels: int,
1395
+ temb_channels: int,
1396
+ dropout: float = 0.0,
1397
+ num_layers: int = 1,
1398
+ resnet_eps: float = 1e-6,
1399
+ resnet_time_scale_shift: str = "default",
1400
+ resnet_act_fn: str = "swish",
1401
+ resnet_pre_norm: bool = True,
1402
+ output_scale_factor=np.sqrt(2.0),
1403
+ add_upsample=True,
1404
+ upsample_padding=1,
1405
+ ):
1406
+ super().__init__()
1407
+ self.resnets = nn.ModuleList([])
1408
+
1409
+ for i in range(num_layers):
1410
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1411
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1412
+
1413
+ self.resnets.append(
1414
+ ResnetBlock2D(
1415
+ in_channels=resnet_in_channels + res_skip_channels,
1416
+ out_channels=out_channels,
1417
+ temb_channels=temb_channels,
1418
+ eps=resnet_eps,
1419
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1420
+ groups_out=min(out_channels // 4, 32),
1421
+ dropout=dropout,
1422
+ time_embedding_norm=resnet_time_scale_shift,
1423
+ non_linearity=resnet_act_fn,
1424
+ output_scale_factor=output_scale_factor,
1425
+ pre_norm=resnet_pre_norm,
1426
+ )
1427
+ )
1428
+
1429
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1430
+ if add_upsample:
1431
+ self.resnet_up = ResnetBlock2D(
1432
+ in_channels=out_channels,
1433
+ out_channels=out_channels,
1434
+ temb_channels=temb_channels,
1435
+ eps=resnet_eps,
1436
+ groups=min(out_channels // 4, 32),
1437
+ groups_out=min(out_channels // 4, 32),
1438
+ dropout=dropout,
1439
+ time_embedding_norm=resnet_time_scale_shift,
1440
+ non_linearity=resnet_act_fn,
1441
+ output_scale_factor=output_scale_factor,
1442
+ pre_norm=resnet_pre_norm,
1443
+ use_nin_shortcut=True,
1444
+ up=True,
1445
+ kernel="fir",
1446
+ )
1447
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1448
+ self.skip_norm = torch.nn.GroupNorm(
1449
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1450
+ )
1451
+ self.act = nn.SiLU()
1452
+ else:
1453
+ self.resnet_up = None
1454
+ self.skip_conv = None
1455
+ self.skip_norm = None
1456
+ self.act = None
1457
+
1458
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1459
+ for resnet in self.resnets:
1460
+ # pop res hidden states
1461
+ res_hidden_states = res_hidden_states_tuple[-1]
1462
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1463
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1464
+
1465
+ hidden_states = resnet(hidden_states, temb)
1466
+
1467
+ if skip_sample is not None:
1468
+ skip_sample = self.upsampler(skip_sample)
1469
+ else:
1470
+ skip_sample = 0
1471
+
1472
+ if self.resnet_up is not None:
1473
+ skip_sample_states = self.skip_norm(hidden_states)
1474
+ skip_sample_states = self.act(skip_sample_states)
1475
+ skip_sample_states = self.skip_conv(skip_sample_states)
1476
+
1477
+ skip_sample = skip_sample + skip_sample_states
1478
+
1479
+ hidden_states = self.resnet_up(hidden_states, temb)
1480
+
1481
+ return hidden_states, skip_sample
my_diffusers/models/vae.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ..configuration_utils import ConfigMixin, register_to_config
9
+ from ..modeling_utils import ModelMixin
10
+ from ..utils import BaseOutput
11
+ from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
12
+
13
+
14
+ @dataclass
15
+ class DecoderOutput(BaseOutput):
16
+ """
17
+ Output of decoding method.
18
+
19
+ Args:
20
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
21
+ Decoded output sample of the model. Output of the last layer of the model.
22
+ """
23
+
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ @dataclass
28
+ class VQEncoderOutput(BaseOutput):
29
+ """
30
+ Output of VQModel encoding method.
31
+
32
+ Args:
33
+ latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
+ Encoded output sample of the model. Output of the last layer of the model.
35
+ """
36
+
37
+ latents: torch.FloatTensor
38
+
39
+
40
+ @dataclass
41
+ class AutoencoderKLOutput(BaseOutput):
42
+ """
43
+ Output of AutoencoderKL encoding method.
44
+
45
+ Args:
46
+ latent_dist (`DiagonalGaussianDistribution`):
47
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
48
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
49
+ """
50
+
51
+ latent_dist: "DiagonalGaussianDistribution"
52
+
53
+
54
+ class Encoder(nn.Module):
55
+ def __init__(
56
+ self,
57
+ in_channels=3,
58
+ out_channels=3,
59
+ down_block_types=("DownEncoderBlock2D",),
60
+ block_out_channels=(64,),
61
+ layers_per_block=2,
62
+ act_fn="silu",
63
+ double_z=True,
64
+ ):
65
+ super().__init__()
66
+ self.layers_per_block = layers_per_block
67
+
68
+ self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
69
+
70
+ self.mid_block = None
71
+ self.down_blocks = nn.ModuleList([])
72
+
73
+ # down
74
+ output_channel = block_out_channels[0]
75
+ for i, down_block_type in enumerate(down_block_types):
76
+ input_channel = output_channel
77
+ output_channel = block_out_channels[i]
78
+ is_final_block = i == len(block_out_channels) - 1
79
+
80
+ down_block = get_down_block(
81
+ down_block_type,
82
+ num_layers=self.layers_per_block,
83
+ in_channels=input_channel,
84
+ out_channels=output_channel,
85
+ add_downsample=not is_final_block,
86
+ resnet_eps=1e-6,
87
+ downsample_padding=0,
88
+ resnet_act_fn=act_fn,
89
+ attn_num_head_channels=None,
90
+ temb_channels=None,
91
+ )
92
+ self.down_blocks.append(down_block)
93
+
94
+ # mid
95
+ self.mid_block = UNetMidBlock2D(
96
+ in_channels=block_out_channels[-1],
97
+ resnet_eps=1e-6,
98
+ resnet_act_fn=act_fn,
99
+ output_scale_factor=1,
100
+ resnet_time_scale_shift="default",
101
+ attn_num_head_channels=None,
102
+ resnet_groups=32,
103
+ temb_channels=None,
104
+ )
105
+
106
+ # out
107
+ num_groups_out = 32
108
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
109
+ self.conv_act = nn.SiLU()
110
+
111
+ conv_out_channels = 2 * out_channels if double_z else out_channels
112
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
113
+
114
+ def forward(self, x):
115
+ sample = x
116
+ sample = self.conv_in(sample)
117
+
118
+ # down
119
+ for down_block in self.down_blocks:
120
+ sample = down_block(sample)
121
+
122
+ # middle
123
+ sample = self.mid_block(sample)
124
+
125
+ # post-process
126
+ sample = self.conv_norm_out(sample)
127
+ sample = self.conv_act(sample)
128
+ sample = self.conv_out(sample)
129
+
130
+ return sample
131
+
132
+
133
+ class Decoder(nn.Module):
134
+ def __init__(
135
+ self,
136
+ in_channels=3,
137
+ out_channels=3,
138
+ up_block_types=("UpDecoderBlock2D",),
139
+ block_out_channels=(64,),
140
+ layers_per_block=2,
141
+ act_fn="silu",
142
+ ):
143
+ super().__init__()
144
+ self.layers_per_block = layers_per_block
145
+
146
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
147
+
148
+ self.mid_block = None
149
+ self.up_blocks = nn.ModuleList([])
150
+
151
+ # mid
152
+ self.mid_block = UNetMidBlock2D(
153
+ in_channels=block_out_channels[-1],
154
+ resnet_eps=1e-6,
155
+ resnet_act_fn=act_fn,
156
+ output_scale_factor=1,
157
+ resnet_time_scale_shift="default",
158
+ attn_num_head_channels=None,
159
+ resnet_groups=32,
160
+ temb_channels=None,
161
+ )
162
+
163
+ # up
164
+ reversed_block_out_channels = list(reversed(block_out_channels))
165
+ output_channel = reversed_block_out_channels[0]
166
+ for i, up_block_type in enumerate(up_block_types):
167
+ prev_output_channel = output_channel
168
+ output_channel = reversed_block_out_channels[i]
169
+
170
+ is_final_block = i == len(block_out_channels) - 1
171
+
172
+ up_block = get_up_block(
173
+ up_block_type,
174
+ num_layers=self.layers_per_block + 1,
175
+ in_channels=prev_output_channel,
176
+ out_channels=output_channel,
177
+ prev_output_channel=None,
178
+ add_upsample=not is_final_block,
179
+ resnet_eps=1e-6,
180
+ resnet_act_fn=act_fn,
181
+ attn_num_head_channels=None,
182
+ temb_channels=None,
183
+ )
184
+ self.up_blocks.append(up_block)
185
+ prev_output_channel = output_channel
186
+
187
+ # out
188
+ num_groups_out = 32
189
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
190
+ self.conv_act = nn.SiLU()
191
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
192
+
193
+ def forward(self, z):
194
+ sample = z
195
+ sample = self.conv_in(sample)
196
+
197
+ # middle
198
+ sample = self.mid_block(sample)
199
+
200
+ # up
201
+ for up_block in self.up_blocks:
202
+ sample = up_block(sample)
203
+
204
+ # post-process
205
+ sample = self.conv_norm_out(sample)
206
+ sample = self.conv_act(sample)
207
+ sample = self.conv_out(sample)
208
+
209
+ return sample
210
+
211
+
212
+ class VectorQuantizer(nn.Module):
213
+ """
214
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
215
+ multiplications and allows for post-hoc remapping of indices.
216
+ """
217
+
218
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
219
+ # backwards compatibility we use the buggy version by default, but you can
220
+ # specify legacy=False to fix it.
221
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
222
+ super().__init__()
223
+ self.n_e = n_e
224
+ self.e_dim = e_dim
225
+ self.beta = beta
226
+ self.legacy = legacy
227
+
228
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
229
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
230
+
231
+ self.remap = remap
232
+ if self.remap is not None:
233
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
234
+ self.re_embed = self.used.shape[0]
235
+ self.unknown_index = unknown_index # "random" or "extra" or integer
236
+ if self.unknown_index == "extra":
237
+ self.unknown_index = self.re_embed
238
+ self.re_embed = self.re_embed + 1
239
+ print(
240
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
+ f"Using {self.unknown_index} for unknown indices."
242
+ )
243
+ else:
244
+ self.re_embed = n_e
245
+
246
+ self.sane_index_shape = sane_index_shape
247
+
248
+ def remap_to_used(self, inds):
249
+ ishape = inds.shape
250
+ assert len(ishape) > 1
251
+ inds = inds.reshape(ishape[0], -1)
252
+ used = self.used.to(inds)
253
+ match = (inds[:, :, None] == used[None, None, ...]).long()
254
+ new = match.argmax(-1)
255
+ unknown = match.sum(2) < 1
256
+ if self.unknown_index == "random":
257
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
258
+ else:
259
+ new[unknown] = self.unknown_index
260
+ return new.reshape(ishape)
261
+
262
+ def unmap_to_all(self, inds):
263
+ ishape = inds.shape
264
+ assert len(ishape) > 1
265
+ inds = inds.reshape(ishape[0], -1)
266
+ used = self.used.to(inds)
267
+ if self.re_embed > self.used.shape[0]: # extra token
268
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
269
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
270
+ return back.reshape(ishape)
271
+
272
+ def forward(self, z):
273
+ # reshape z -> (batch, height, width, channel) and flatten
274
+ z = z.permute(0, 2, 3, 1).contiguous()
275
+ z_flattened = z.view(-1, self.e_dim)
276
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
277
+
278
+ d = (
279
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
280
+ + torch.sum(self.embedding.weight**2, dim=1)
281
+ - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
282
+ )
283
+
284
+ min_encoding_indices = torch.argmin(d, dim=1)
285
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
286
+ perplexity = None
287
+ min_encodings = None
288
+
289
+ # compute loss for embedding
290
+ if not self.legacy:
291
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
292
+ else:
293
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
294
+
295
+ # preserve gradients
296
+ z_q = z + (z_q - z).detach()
297
+
298
+ # reshape back to match original input shape
299
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
300
+
301
+ if self.remap is not None:
302
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
303
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
304
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
305
+
306
+ if self.sane_index_shape:
307
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
308
+
309
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
310
+
311
+ def get_codebook_entry(self, indices, shape):
312
+ # shape specifying (batch, height, width, channel)
313
+ if self.remap is not None:
314
+ indices = indices.reshape(shape[0], -1) # add batch axis
315
+ indices = self.unmap_to_all(indices)
316
+ indices = indices.reshape(-1) # flatten again
317
+
318
+ # get quantized latent vectors
319
+ z_q = self.embedding(indices)
320
+
321
+ if shape is not None:
322
+ z_q = z_q.view(shape)
323
+ # reshape back to match original input shape
324
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
325
+
326
+ return z_q
327
+
328
+
329
+ class DiagonalGaussianDistribution(object):
330
+ def __init__(self, parameters, deterministic=False):
331
+ self.parameters = parameters
332
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
333
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
334
+ self.deterministic = deterministic
335
+ self.std = torch.exp(0.5 * self.logvar)
336
+ self.var = torch.exp(self.logvar)
337
+ if self.deterministic:
338
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
339
+
340
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
341
+ device = self.parameters.device
342
+ sample_device = "cpu" if device.type == "mps" else device
343
+ sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
344
+ x = self.mean + self.std * sample
345
+ return x
346
+
347
+ def kl(self, other=None):
348
+ if self.deterministic:
349
+ return torch.Tensor([0.0])
350
+ else:
351
+ if other is None:
352
+ return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
353
+ else:
354
+ return 0.5 * torch.sum(
355
+ torch.pow(self.mean - other.mean, 2) / other.var
356
+ + self.var / other.var
357
+ - 1.0
358
+ - self.logvar
359
+ + other.logvar,
360
+ dim=[1, 2, 3],
361
+ )
362
+
363
+ def nll(self, sample, dims=[1, 2, 3]):
364
+ if self.deterministic:
365
+ return torch.Tensor([0.0])
366
+ logtwopi = np.log(2.0 * np.pi)
367
+ return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
368
+
369
+ def mode(self):
370
+ return self.mean
371
+
372
+
373
+ class VQModel(ModelMixin, ConfigMixin):
374
+ r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
375
+ Kavukcuoglu.
376
+
377
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
378
+ implements for all the model (such as downloading or saving, etc.)
379
+
380
+ Parameters:
381
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
382
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
383
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
384
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
385
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
386
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
387
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
388
+ obj:`(64,)`): Tuple of block output channels.
389
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
390
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
391
+ sample_size (`int`, *optional*, defaults to `32`): TODO
392
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
393
+ """
394
+
395
+ @register_to_config
396
+ def __init__(
397
+ self,
398
+ in_channels: int = 3,
399
+ out_channels: int = 3,
400
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
401
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
402
+ block_out_channels: Tuple[int] = (64,),
403
+ layers_per_block: int = 1,
404
+ act_fn: str = "silu",
405
+ latent_channels: int = 3,
406
+ sample_size: int = 32,
407
+ num_vq_embeddings: int = 256,
408
+ ):
409
+ super().__init__()
410
+
411
+ # pass init params to Encoder
412
+ self.encoder = Encoder(
413
+ in_channels=in_channels,
414
+ out_channels=latent_channels,
415
+ down_block_types=down_block_types,
416
+ block_out_channels=block_out_channels,
417
+ layers_per_block=layers_per_block,
418
+ act_fn=act_fn,
419
+ double_z=False,
420
+ )
421
+
422
+ self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
423
+ self.quantize = VectorQuantizer(
424
+ num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
425
+ )
426
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
427
+
428
+ # pass init params to Decoder
429
+ self.decoder = Decoder(
430
+ in_channels=latent_channels,
431
+ out_channels=out_channels,
432
+ up_block_types=up_block_types,
433
+ block_out_channels=block_out_channels,
434
+ layers_per_block=layers_per_block,
435
+ act_fn=act_fn,
436
+ )
437
+
438
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
439
+ h = self.encoder(x)
440
+ h = self.quant_conv(h)
441
+
442
+ if not return_dict:
443
+ return (h,)
444
+
445
+ return VQEncoderOutput(latents=h)
446
+
447
+ def decode(
448
+ self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
449
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
450
+ # also go through quantization layer
451
+ if not force_not_quantize:
452
+ quant, emb_loss, info = self.quantize(h)
453
+ else:
454
+ quant = h
455
+ quant = self.post_quant_conv(quant)
456
+ dec = self.decoder(quant)
457
+
458
+ if not return_dict:
459
+ return (dec,)
460
+
461
+ return DecoderOutput(sample=dec)
462
+
463
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
464
+ r"""
465
+ Args:
466
+ sample (`torch.FloatTensor`): Input sample.
467
+ return_dict (`bool`, *optional*, defaults to `True`):
468
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
469
+ """
470
+ x = sample
471
+ h = self.encode(x).latents
472
+ dec = self.decode(h).sample
473
+
474
+ if not return_dict:
475
+ return (dec,)
476
+
477
+ return DecoderOutput(sample=dec)
478
+
479
+
480
+ class AutoencoderKL(ModelMixin, ConfigMixin):
481
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
482
+ and Max Welling.
483
+
484
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
485
+ implements for all the model (such as downloading or saving, etc.)
486
+
487
+ Parameters:
488
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
489
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
490
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
491
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
492
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
493
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
494
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
495
+ obj:`(64,)`): Tuple of block output channels.
496
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
497
+ latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
498
+ sample_size (`int`, *optional*, defaults to `32`): TODO
499
+ """
500
+
501
+ @register_to_config
502
+ def __init__(
503
+ self,
504
+ in_channels: int = 3,
505
+ out_channels: int = 3,
506
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
507
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
508
+ block_out_channels: Tuple[int] = (64,),
509
+ layers_per_block: int = 1,
510
+ act_fn: str = "silu",
511
+ latent_channels: int = 4,
512
+ sample_size: int = 32,
513
+ ):
514
+ super().__init__()
515
+
516
+ # pass init params to Encoder
517
+ self.encoder = Encoder(
518
+ in_channels=in_channels,
519
+ out_channels=latent_channels,
520
+ down_block_types=down_block_types,
521
+ block_out_channels=block_out_channels,
522
+ layers_per_block=layers_per_block,
523
+ act_fn=act_fn,
524
+ double_z=True,
525
+ )
526
+
527
+ # pass init params to Decoder
528
+ self.decoder = Decoder(
529
+ in_channels=latent_channels,
530
+ out_channels=out_channels,
531
+ up_block_types=up_block_types,
532
+ block_out_channels=block_out_channels,
533
+ layers_per_block=layers_per_block,
534
+ act_fn=act_fn,
535
+ )
536
+
537
+ self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
538
+ self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
539
+
540
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
541
+ h = self.encoder(x)
542
+ moments = self.quant_conv(h)
543
+ posterior = DiagonalGaussianDistribution(moments)
544
+
545
+ if not return_dict:
546
+ return (posterior,)
547
+
548
+ return AutoencoderKLOutput(latent_dist=posterior)
549
+
550
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
551
+ z = self.post_quant_conv(z)
552
+ dec = self.decoder(z)
553
+
554
+ if not return_dict:
555
+ return (dec,)
556
+
557
+ return DecoderOutput(sample=dec)
558
+
559
+ def forward(
560
+ self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
561
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
562
+ r"""
563
+ Args:
564
+ sample (`torch.FloatTensor`): Input sample.
565
+ sample_posterior (`bool`, *optional*, defaults to `False`):
566
+ Whether to sample from the posterior.
567
+ return_dict (`bool`, *optional*, defaults to `True`):
568
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
569
+ """
570
+ x = sample
571
+ posterior = self.encode(x).latent_dist
572
+ if sample_posterior:
573
+ z = posterior.sample()
574
+ else:
575
+ z = posterior.mode()
576
+ dec = self.decode(z).sample
577
+
578
+ if not return_dict:
579
+ return (dec,)
580
+
581
+ return DecoderOutput(sample=dec)
my_diffusers/onnx_utils.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import os
19
+ import shutil
20
+ from pathlib import Path
21
+ from typing import Optional, Union
22
+
23
+ import numpy as np
24
+
25
+ from huggingface_hub import hf_hub_download
26
+
27
+ from .utils import is_onnx_available, logging
28
+
29
+
30
+ if is_onnx_available():
31
+ import onnxruntime as ort
32
+
33
+
34
+ ONNX_WEIGHTS_NAME = "model.onnx"
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class OnnxRuntimeModel:
41
+ base_model_prefix = "onnx_model"
42
+
43
+ def __init__(self, model=None, **kwargs):
44
+ logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
45
+ self.model = model
46
+ self.model_save_dir = kwargs.get("model_save_dir", None)
47
+ self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
48
+
49
+ def __call__(self, **kwargs):
50
+ inputs = {k: np.array(v) for k, v in kwargs.items()}
51
+ return self.model.run(None, inputs)
52
+
53
+ @staticmethod
54
+ def load_model(path: Union[str, Path], provider=None):
55
+ """
56
+ Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
57
+
58
+ Arguments:
59
+ path (`str` or `Path`):
60
+ Directory from which to load
61
+ provider(`str`, *optional*):
62
+ Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
63
+ """
64
+ if provider is None:
65
+ logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
66
+ provider = "CPUExecutionProvider"
67
+
68
+ return ort.InferenceSession(path, providers=[provider])
69
+
70
+ def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
71
+ """
72
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
73
+ [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
74
+ latest_model_name.
75
+
76
+ Arguments:
77
+ save_directory (`str` or `Path`):
78
+ Directory where to save the model file.
79
+ file_name(`str`, *optional*):
80
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
81
+ model with a different name.
82
+ """
83
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
84
+
85
+ src_path = self.model_save_dir.joinpath(self.latest_model_name)
86
+ dst_path = Path(save_directory).joinpath(model_file_name)
87
+ if not src_path.samefile(dst_path):
88
+ shutil.copyfile(src_path, dst_path)
89
+
90
+ def save_pretrained(
91
+ self,
92
+ save_directory: Union[str, os.PathLike],
93
+ **kwargs,
94
+ ):
95
+ """
96
+ Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
97
+ method.:
98
+
99
+ Arguments:
100
+ save_directory (`str` or `os.PathLike`):
101
+ Directory to which to save. Will be created if it doesn't exist.
102
+ """
103
+ if os.path.isfile(save_directory):
104
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
105
+ return
106
+
107
+ os.makedirs(save_directory, exist_ok=True)
108
+
109
+ # saving model weights/files
110
+ self._save_pretrained(save_directory, **kwargs)
111
+
112
+ @classmethod
113
+ def _from_pretrained(
114
+ cls,
115
+ model_id: Union[str, Path],
116
+ use_auth_token: Optional[Union[bool, str, None]] = None,
117
+ revision: Optional[Union[str, None]] = None,
118
+ force_download: bool = False,
119
+ cache_dir: Optional[str] = None,
120
+ file_name: Optional[str] = None,
121
+ provider: Optional[str] = None,
122
+ **kwargs,
123
+ ):
124
+ """
125
+ Load a model from a directory or the HF Hub.
126
+
127
+ Arguments:
128
+ model_id (`str` or `Path`):
129
+ Directory from which to load
130
+ use_auth_token (`str` or `bool`):
131
+ Is needed to load models from a private or gated repository
132
+ revision (`str`):
133
+ Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
134
+ cache_dir (`Union[str, Path]`, *optional*):
135
+ Path to a directory in which a downloaded pretrained model configuration should be cached if the
136
+ standard cache should not be used.
137
+ force_download (`bool`, *optional*, defaults to `False`):
138
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
139
+ cached versions if they exist.
140
+ file_name(`str`):
141
+ Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
142
+ different model files from the same repository or directory.
143
+ provider(`str`):
144
+ The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
145
+ kwargs (`Dict`, *optional*):
146
+ kwargs will be passed to the model during initialization
147
+ """
148
+ model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
149
+ # load model from local directory
150
+ if os.path.isdir(model_id):
151
+ model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
152
+ kwargs["model_save_dir"] = Path(model_id)
153
+ # load model from hub
154
+ else:
155
+ # download model
156
+ model_cache_path = hf_hub_download(
157
+ repo_id=model_id,
158
+ filename=model_file_name,
159
+ use_auth_token=use_auth_token,
160
+ revision=revision,
161
+ cache_dir=cache_dir,
162
+ force_download=force_download,
163
+ )
164
+ kwargs["model_save_dir"] = Path(model_cache_path).parent
165
+ kwargs["latest_model_name"] = Path(model_cache_path).name
166
+ model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
167
+ return cls(model=model, **kwargs)
168
+
169
+ @classmethod
170
+ def from_pretrained(
171
+ cls,
172
+ model_id: Union[str, Path],
173
+ force_download: bool = True,
174
+ use_auth_token: Optional[str] = None,
175
+ cache_dir: Optional[str] = None,
176
+ **model_kwargs,
177
+ ):
178
+ revision = None
179
+ if len(str(model_id).split("@")) == 2:
180
+ model_id, revision = model_id.split("@")
181
+
182
+ return cls._from_pretrained(
183
+ model_id=model_id,
184
+ revision=revision,
185
+ cache_dir=cache_dir,
186
+ force_download=force_download,
187
+ use_auth_token=use_auth_token,
188
+ **model_kwargs,
189
+ )
my_diffusers/optimization.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch optimization for diffusion models."""
16
+
17
+ import math
18
+ from enum import Enum
19
+ from typing import Optional, Union
20
+
21
+ from torch.optim import Optimizer
22
+ from torch.optim.lr_scheduler import LambdaLR
23
+
24
+ from .utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ class SchedulerType(Enum):
31
+ LINEAR = "linear"
32
+ COSINE = "cosine"
33
+ COSINE_WITH_RESTARTS = "cosine_with_restarts"
34
+ POLYNOMIAL = "polynomial"
35
+ CONSTANT = "constant"
36
+ CONSTANT_WITH_WARMUP = "constant_with_warmup"
37
+
38
+
39
+ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
40
+ """
41
+ Create a schedule with a constant learning rate, using the learning rate set in optimizer.
42
+
43
+ Args:
44
+ optimizer ([`~torch.optim.Optimizer`]):
45
+ The optimizer for which to schedule the learning rate.
46
+ last_epoch (`int`, *optional*, defaults to -1):
47
+ The index of the last epoch when resuming training.
48
+
49
+ Return:
50
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
51
+ """
52
+ return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
53
+
54
+
55
+ def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
56
+ """
57
+ Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
58
+ increases linearly between 0 and the initial lr set in the optimizer.
59
+
60
+ Args:
61
+ optimizer ([`~torch.optim.Optimizer`]):
62
+ The optimizer for which to schedule the learning rate.
63
+ num_warmup_steps (`int`):
64
+ The number of steps for the warmup phase.
65
+ last_epoch (`int`, *optional*, defaults to -1):
66
+ The index of the last epoch when resuming training.
67
+
68
+ Return:
69
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
70
+ """
71
+
72
+ def lr_lambda(current_step: int):
73
+ if current_step < num_warmup_steps:
74
+ return float(current_step) / float(max(1.0, num_warmup_steps))
75
+ return 1.0
76
+
77
+ return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
78
+
79
+
80
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
81
+ """
82
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
83
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
84
+
85
+ Args:
86
+ optimizer ([`~torch.optim.Optimizer`]):
87
+ The optimizer for which to schedule the learning rate.
88
+ num_warmup_steps (`int`):
89
+ The number of steps for the warmup phase.
90
+ num_training_steps (`int`):
91
+ The total number of training steps.
92
+ last_epoch (`int`, *optional*, defaults to -1):
93
+ The index of the last epoch when resuming training.
94
+
95
+ Return:
96
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97
+ """
98
+
99
+ def lr_lambda(current_step: int):
100
+ if current_step < num_warmup_steps:
101
+ return float(current_step) / float(max(1, num_warmup_steps))
102
+ return max(
103
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
104
+ )
105
+
106
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
107
+
108
+
109
+ def get_cosine_schedule_with_warmup(
110
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
111
+ ):
112
+ """
113
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
114
+ initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
115
+ initial lr set in the optimizer.
116
+
117
+ Args:
118
+ optimizer ([`~torch.optim.Optimizer`]):
119
+ The optimizer for which to schedule the learning rate.
120
+ num_warmup_steps (`int`):
121
+ The number of steps for the warmup phase.
122
+ num_training_steps (`int`):
123
+ The total number of training steps.
124
+ num_cycles (`float`, *optional*, defaults to 0.5):
125
+ The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
126
+ following a half-cosine).
127
+ last_epoch (`int`, *optional*, defaults to -1):
128
+ The index of the last epoch when resuming training.
129
+
130
+ Return:
131
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
132
+ """
133
+
134
+ def lr_lambda(current_step):
135
+ if current_step < num_warmup_steps:
136
+ return float(current_step) / float(max(1, num_warmup_steps))
137
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
138
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
139
+
140
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
141
+
142
+
143
+ def get_cosine_with_hard_restarts_schedule_with_warmup(
144
+ optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
145
+ ):
146
+ """
147
+ Create a schedule with a learning rate that decreases following the values of the cosine function between the
148
+ initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
149
+ linearly between 0 and the initial lr set in the optimizer.
150
+
151
+ Args:
152
+ optimizer ([`~torch.optim.Optimizer`]):
153
+ The optimizer for which to schedule the learning rate.
154
+ num_warmup_steps (`int`):
155
+ The number of steps for the warmup phase.
156
+ num_training_steps (`int`):
157
+ The total number of training steps.
158
+ num_cycles (`int`, *optional*, defaults to 1):
159
+ The number of hard restarts to use.
160
+ last_epoch (`int`, *optional*, defaults to -1):
161
+ The index of the last epoch when resuming training.
162
+
163
+ Return:
164
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
165
+ """
166
+
167
+ def lr_lambda(current_step):
168
+ if current_step < num_warmup_steps:
169
+ return float(current_step) / float(max(1, num_warmup_steps))
170
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
171
+ if progress >= 1.0:
172
+ return 0.0
173
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
174
+
175
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
176
+
177
+
178
+ def get_polynomial_decay_schedule_with_warmup(
179
+ optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
180
+ ):
181
+ """
182
+ Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
183
+ optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
184
+ initial lr set in the optimizer.
185
+
186
+ Args:
187
+ optimizer ([`~torch.optim.Optimizer`]):
188
+ The optimizer for which to schedule the learning rate.
189
+ num_warmup_steps (`int`):
190
+ The number of steps for the warmup phase.
191
+ num_training_steps (`int`):
192
+ The total number of training steps.
193
+ lr_end (`float`, *optional*, defaults to 1e-7):
194
+ The end LR.
195
+ power (`float`, *optional*, defaults to 1.0):
196
+ Power factor.
197
+ last_epoch (`int`, *optional*, defaults to -1):
198
+ The index of the last epoch when resuming training.
199
+
200
+ Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
201
+ implementation at
202
+ https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
203
+
204
+ Return:
205
+ `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
206
+
207
+ """
208
+
209
+ lr_init = optimizer.defaults["lr"]
210
+ if not (lr_init > lr_end):
211
+ raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
212
+
213
+ def lr_lambda(current_step: int):
214
+ if current_step < num_warmup_steps:
215
+ return float(current_step) / float(max(1, num_warmup_steps))
216
+ elif current_step > num_training_steps:
217
+ return lr_end / lr_init # as LambdaLR multiplies by lr_init
218
+ else:
219
+ lr_range = lr_init - lr_end
220
+ decay_steps = num_training_steps - num_warmup_steps
221
+ pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
222
+ decay = lr_range * pct_remaining**power + lr_end
223
+ return decay / lr_init # as LambdaLR multiplies by lr_init
224
+
225
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
226
+
227
+
228
+ TYPE_TO_SCHEDULER_FUNCTION = {
229
+ SchedulerType.LINEAR: get_linear_schedule_with_warmup,
230
+ SchedulerType.COSINE: get_cosine_schedule_with_warmup,
231
+ SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
232
+ SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
233
+ SchedulerType.CONSTANT: get_constant_schedule,
234
+ SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
235
+ }
236
+
237
+
238
+ def get_scheduler(
239
+ name: Union[str, SchedulerType],
240
+ optimizer: Optimizer,
241
+ num_warmup_steps: Optional[int] = None,
242
+ num_training_steps: Optional[int] = None,
243
+ ):
244
+ """
245
+ Unified API to get any scheduler from its name.
246
+
247
+ Args:
248
+ name (`str` or `SchedulerType`):
249
+ The name of the scheduler to use.
250
+ optimizer (`torch.optim.Optimizer`):
251
+ The optimizer that will be used during training.
252
+ num_warmup_steps (`int`, *optional*):
253
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
254
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
255
+ num_training_steps (`int``, *optional*):
256
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
257
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
258
+ """
259
+ name = SchedulerType(name)
260
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
261
+ if name == SchedulerType.CONSTANT:
262
+ return schedule_func(optimizer)
263
+
264
+ # All other schedulers require `num_warmup_steps`
265
+ if num_warmup_steps is None:
266
+ raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
267
+
268
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
269
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
270
+
271
+ # All other schedulers require `num_training_steps`
272
+ if num_training_steps is None:
273
+ raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
274
+
275
+ return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
my_diffusers/pipeline_utils.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import importlib
18
+ import inspect
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ import diffusers
27
+ import PIL
28
+ from huggingface_hub import snapshot_download
29
+ from PIL import Image
30
+ from tqdm.auto import tqdm
31
+
32
+ from .configuration_utils import ConfigMixin
33
+ from .utils import DIFFUSERS_CACHE, BaseOutput, logging
34
+
35
+
36
+ INDEX_FILE = "diffusion_pytorch_model.bin"
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+
42
+ LOADABLE_CLASSES = {
43
+ "diffusers": {
44
+ "ModelMixin": ["save_pretrained", "from_pretrained"],
45
+ "SchedulerMixin": ["save_config", "from_config"],
46
+ "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
47
+ "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
48
+ },
49
+ "transformers": {
50
+ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
51
+ "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
52
+ "PreTrainedModel": ["save_pretrained", "from_pretrained"],
53
+ "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
54
+ },
55
+ }
56
+
57
+ ALL_IMPORTABLE_CLASSES = {}
58
+ for library in LOADABLE_CLASSES:
59
+ ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
60
+
61
+
62
+ @dataclass
63
+ class ImagePipelineOutput(BaseOutput):
64
+ """
65
+ Output class for image pipelines.
66
+
67
+ Args:
68
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
69
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
70
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
71
+ """
72
+
73
+ images: Union[List[PIL.Image.Image], np.ndarray]
74
+
75
+
76
+ class DiffusionPipeline(ConfigMixin):
77
+ r"""
78
+ Base class for all models.
79
+
80
+ [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
81
+ and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
82
+
83
+ - move all PyTorch modules to the device of your choice
84
+ - enabling/disabling the progress bar for the denoising iteration
85
+
86
+ Class attributes:
87
+
88
+ - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
89
+ compenents of the diffusion pipeline.
90
+ """
91
+ config_name = "model_index.json"
92
+
93
+ def register_modules(self, **kwargs):
94
+ # import it here to avoid circular import
95
+ from diffusers import pipelines
96
+
97
+ for name, module in kwargs.items():
98
+ # retrive library
99
+ library = module.__module__.split(".")[0]
100
+
101
+ # check if the module is a pipeline module
102
+ pipeline_dir = module.__module__.split(".")[-2]
103
+ path = module.__module__.split(".")
104
+ is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
105
+
106
+ # if library is not in LOADABLE_CLASSES, then it is a custom module.
107
+ # Or if it's a pipeline module, then the module is inside the pipeline
108
+ # folder so we set the library to module name.
109
+ if library not in LOADABLE_CLASSES or is_pipeline_module:
110
+ library = pipeline_dir
111
+
112
+ # retrive class_name
113
+ class_name = module.__class__.__name__
114
+
115
+ register_dict = {name: (library, class_name)}
116
+
117
+ # save model index config
118
+ self.register_to_config(**register_dict)
119
+
120
+ # set models
121
+ setattr(self, name, module)
122
+
123
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
124
+ """
125
+ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
126
+ a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
127
+ method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
128
+
129
+ Arguments:
130
+ save_directory (`str` or `os.PathLike`):
131
+ Directory to which to save. Will be created if it doesn't exist.
132
+ """
133
+ self.save_config(save_directory)
134
+
135
+ model_index_dict = dict(self.config)
136
+ model_index_dict.pop("_class_name")
137
+ model_index_dict.pop("_diffusers_version")
138
+ model_index_dict.pop("_module", None)
139
+
140
+ for pipeline_component_name in model_index_dict.keys():
141
+ sub_model = getattr(self, pipeline_component_name)
142
+ model_cls = sub_model.__class__
143
+
144
+ save_method_name = None
145
+ # search for the model's base class in LOADABLE_CLASSES
146
+ for library_name, library_classes in LOADABLE_CLASSES.items():
147
+ library = importlib.import_module(library_name)
148
+ for base_class, save_load_methods in library_classes.items():
149
+ class_candidate = getattr(library, base_class)
150
+ if issubclass(model_cls, class_candidate):
151
+ # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
152
+ save_method_name = save_load_methods[0]
153
+ break
154
+ if save_method_name is not None:
155
+ break
156
+
157
+ save_method = getattr(sub_model, save_method_name)
158
+ save_method(os.path.join(save_directory, pipeline_component_name))
159
+
160
+ def to(self, torch_device: Optional[Union[str, torch.device]] = None):
161
+ if torch_device is None:
162
+ return self
163
+
164
+ module_names, _ = self.extract_init_dict(dict(self.config))
165
+ for name in module_names.keys():
166
+ module = getattr(self, name)
167
+ if isinstance(module, torch.nn.Module):
168
+ module.to(torch_device)
169
+ return self
170
+
171
+ @property
172
+ def device(self) -> torch.device:
173
+ r"""
174
+ Returns:
175
+ `torch.device`: The torch device on which the pipeline is located.
176
+ """
177
+ module_names, _ = self.extract_init_dict(dict(self.config))
178
+ for name in module_names.keys():
179
+ module = getattr(self, name)
180
+ if isinstance(module, torch.nn.Module):
181
+ return module.device
182
+ return torch.device("cpu")
183
+
184
+ @classmethod
185
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
186
+ r"""
187
+ Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
188
+
189
+ The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
190
+
191
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
192
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
193
+ task.
194
+
195
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
196
+ weights are discarded.
197
+
198
+ Parameters:
199
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
200
+ Can be either:
201
+
202
+ - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
203
+ https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
204
+ `CompVis/ldm-text2im-large-256`.
205
+ - A path to a *directory* containing pipeline weights saved using
206
+ [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
207
+ torch_dtype (`str` or `torch.dtype`, *optional*):
208
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
209
+ will be automatically derived from the model's weights.
210
+ force_download (`bool`, *optional*, defaults to `False`):
211
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
212
+ cached versions if they exist.
213
+ resume_download (`bool`, *optional*, defaults to `False`):
214
+ Whether or not to delete incompletely received files. Will attempt to resume the download if such a
215
+ file exists.
216
+ proxies (`Dict[str, str]`, *optional*):
217
+ A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
218
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
219
+ output_loading_info(`bool`, *optional*, defaults to `False`):
220
+ Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
221
+ local_files_only(`bool`, *optional*, defaults to `False`):
222
+ Whether or not to only look at local files (i.e., do not try to download the model).
223
+ use_auth_token (`str` or *bool*, *optional*):
224
+ The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
225
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
226
+ revision (`str`, *optional*, defaults to `"main"`):
227
+ The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
228
+ git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
229
+ identifier allowed by git.
230
+ mirror (`str`, *optional*):
231
+ Mirror source to accelerate downloads in China. If you are from China and have an accessibility
232
+ problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
233
+ Please refer to the mirror site for more information. specify the folder name here.
234
+
235
+ kwargs (remaining dictionary of keyword arguments, *optional*):
236
+ Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
237
+ speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__`
238
+ method. See example below for more information.
239
+
240
+ <Tip>
241
+
242
+ Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
243
+ `"CompVis/stable-diffusion-v1-4"`
244
+
245
+ </Tip>
246
+
247
+ <Tip>
248
+
249
+ Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
250
+ this method in a firewalled environment.
251
+
252
+ </Tip>
253
+
254
+ Examples:
255
+
256
+ ```py
257
+ >>> from diffusers import DiffusionPipeline
258
+
259
+ >>> # Download pipeline from huggingface.co and cache.
260
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
261
+
262
+ >>> # Download pipeline that requires an authorization token
263
+ >>> # For more information on access tokens, please refer to this section
264
+ >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
265
+ >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
266
+
267
+ >>> # Download pipeline, but overwrite scheduler
268
+ >>> from diffusers import LMSDiscreteScheduler
269
+
270
+ >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
271
+ >>> pipeline = DiffusionPipeline.from_pretrained(
272
+ ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
273
+ ... )
274
+ ```
275
+ """
276
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
277
+ resume_download = kwargs.pop("resume_download", False)
278
+ proxies = kwargs.pop("proxies", None)
279
+ local_files_only = kwargs.pop("local_files_only", False)
280
+ use_auth_token = kwargs.pop("use_auth_token", None)
281
+ revision = kwargs.pop("revision", None)
282
+ torch_dtype = kwargs.pop("torch_dtype", None)
283
+ provider = kwargs.pop("provider", None)
284
+
285
+ # 1. Download the checkpoints and configs
286
+ # use snapshot download here to get it working from from_pretrained
287
+ if not os.path.isdir(pretrained_model_name_or_path):
288
+ cached_folder = snapshot_download(
289
+ pretrained_model_name_or_path,
290
+ cache_dir=cache_dir,
291
+ resume_download=resume_download,
292
+ proxies=proxies,
293
+ local_files_only=local_files_only,
294
+ use_auth_token=use_auth_token,
295
+ revision=revision,
296
+ )
297
+ else:
298
+ cached_folder = pretrained_model_name_or_path
299
+
300
+ config_dict = cls.get_config_dict(cached_folder)
301
+
302
+ # 2. Load the pipeline class, if using custom module then load it from the hub
303
+ # if we load from explicit class, let's use it
304
+ if cls != DiffusionPipeline:
305
+ pipeline_class = cls
306
+ else:
307
+ diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
308
+ pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
309
+
310
+ # some modules can be passed directly to the init
311
+ # in this case they are already instantiated in `kwargs`
312
+ # extract them here
313
+ expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
314
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
315
+
316
+ init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
317
+
318
+ init_kwargs = {}
319
+
320
+ # import it here to avoid circular import
321
+ from diffusers import pipelines
322
+
323
+ # 3. Load each module in the pipeline
324
+ for name, (library_name, class_name) in init_dict.items():
325
+ is_pipeline_module = hasattr(pipelines, library_name)
326
+ loaded_sub_model = None
327
+
328
+ # if the model is in a pipeline module, then we load it from the pipeline
329
+ if name in passed_class_obj:
330
+ # 1. check that passed_class_obj has correct parent class
331
+ if not is_pipeline_module:
332
+ library = importlib.import_module(library_name)
333
+ class_obj = getattr(library, class_name)
334
+ importable_classes = LOADABLE_CLASSES[library_name]
335
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
336
+
337
+ expected_class_obj = None
338
+ for class_name, class_candidate in class_candidates.items():
339
+ if issubclass(class_obj, class_candidate):
340
+ expected_class_obj = class_candidate
341
+
342
+ if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
343
+ raise ValueError(
344
+ f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
345
+ f" {expected_class_obj}"
346
+ )
347
+ else:
348
+ logger.warn(
349
+ f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
350
+ " has the correct type"
351
+ )
352
+
353
+ # set passed class object
354
+ loaded_sub_model = passed_class_obj[name]
355
+ elif is_pipeline_module:
356
+ pipeline_module = getattr(pipelines, library_name)
357
+ class_obj = getattr(pipeline_module, class_name)
358
+ importable_classes = ALL_IMPORTABLE_CLASSES
359
+ class_candidates = {c: class_obj for c in importable_classes.keys()}
360
+ else:
361
+ # else we just import it from the library.
362
+ library = importlib.import_module(library_name)
363
+ class_obj = getattr(library, class_name)
364
+ importable_classes = LOADABLE_CLASSES[library_name]
365
+ class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
366
+
367
+ if loaded_sub_model is None:
368
+ load_method_name = None
369
+ for class_name, class_candidate in class_candidates.items():
370
+ if issubclass(class_obj, class_candidate):
371
+ load_method_name = importable_classes[class_name][1]
372
+
373
+ load_method = getattr(class_obj, load_method_name)
374
+
375
+ loading_kwargs = {}
376
+ if issubclass(class_obj, torch.nn.Module):
377
+ loading_kwargs["torch_dtype"] = torch_dtype
378
+ if issubclass(class_obj, diffusers.OnnxRuntimeModel):
379
+ loading_kwargs["provider"] = provider
380
+
381
+ # check if the module is in a subdirectory
382
+ if os.path.isdir(os.path.join(cached_folder, name)):
383
+ loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
384
+ else:
385
+ # else load from the root directory
386
+ loaded_sub_model = load_method(cached_folder, **loading_kwargs)
387
+
388
+ init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
389
+
390
+ # 4. Instantiate the pipeline
391
+ model = pipeline_class(**init_kwargs)
392
+ return model
393
+
394
+ @staticmethod
395
+ def numpy_to_pil(images):
396
+ """
397
+ Convert a numpy image or a batch of images to a PIL image.
398
+ """
399
+ if images.ndim == 3:
400
+ images = images[None, ...]
401
+ images = (images * 255).round().astype("uint8")
402
+ pil_images = [Image.fromarray(image) for image in images]
403
+
404
+ return pil_images
405
+
406
+ def progress_bar(self, iterable):
407
+ if not hasattr(self, "_progress_bar_config"):
408
+ self._progress_bar_config = {}
409
+ elif not isinstance(self._progress_bar_config, dict):
410
+ raise ValueError(
411
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
412
+ )
413
+
414
+ return tqdm(iterable, **self._progress_bar_config)
415
+
416
+ def set_progress_bar_config(self, **kwargs):
417
+ self._progress_bar_config = kwargs
my_diffusers/pipelines/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..utils import is_onnx_available, is_transformers_available
2
+ from .ddim import DDIMPipeline
3
+ from .ddpm import DDPMPipeline
4
+ from .latent_diffusion_uncond import LDMPipeline
5
+ from .pndm import PNDMPipeline
6
+ from .score_sde_ve import ScoreSdeVePipeline
7
+ from .stochastic_karras_ve import KarrasVePipeline
8
+
9
+
10
+ if is_transformers_available():
11
+ from .latent_diffusion import LDMTextToImagePipeline
12
+ from .stable_diffusion import (
13
+ StableDiffusionImg2ImgPipeline,
14
+ StableDiffusionInpaintPipeline,
15
+ StableDiffusionPipeline,
16
+ )
17
+
18
+ if is_transformers_available() and is_onnx_available():
19
+ from .stable_diffusion import StableDiffusionOnnxPipeline
my_diffusers/pipelines/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (828 Bytes). View file
 
my_diffusers/pipelines/ddim/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # flake8: noqa
2
+ from .pipeline_ddim import DDIMPipeline
my_diffusers/pipelines/ddim/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (216 Bytes). View file
 
my_diffusers/pipelines/ddim/__pycache__/pipeline_ddim.cpython-38.pyc ADDED
Binary file (3.96 kB). View file
 
my_diffusers/pipelines/ddim/pipeline_ddim.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+
14
+ # limitations under the License.
15
+
16
+
17
+ import warnings
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+
22
+ from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
23
+
24
+
25
+ class DDIMPipeline(DiffusionPipeline):
26
+ r"""
27
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
28
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
29
+
30
+ Parameters:
31
+ unet ([`UNet2DModel`]): U-Net architecture to denoise the encoded image.
32
+ scheduler ([`SchedulerMixin`]):
33
+ A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
34
+ [`DDPMScheduler`], or [`DDIMScheduler`].
35
+ """
36
+
37
+ def __init__(self, unet, scheduler):
38
+ super().__init__()
39
+ scheduler = scheduler.set_format("pt")
40
+ self.register_modules(unet=unet, scheduler=scheduler)
41
+
42
+ @torch.no_grad()
43
+ def __call__(
44
+ self,
45
+ batch_size: int = 1,
46
+ generator: Optional[torch.Generator] = None,
47
+ eta: float = 0.0,
48
+ num_inference_steps: int = 50,
49
+ output_type: Optional[str] = "pil",
50
+ return_dict: bool = True,
51
+ **kwargs,
52
+ ) -> Union[ImagePipelineOutput, Tuple]:
53
+ r"""
54
+ Args:
55
+ batch_size (`int`, *optional*, defaults to 1):
56
+ The number of images to generate.
57
+ generator (`torch.Generator`, *optional*):
58
+ A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
59
+ deterministic.
60
+ eta (`float`, *optional*, defaults to 0.0):
61
+ The eta parameter which controls the scale of the variance (0 is DDIM and 1 is one type of DDPM).
62
+ num_inference_steps (`int`, *optional*, defaults to 50):
63
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
64
+ expense of slower inference.
65
+ output_type (`str`, *optional*, defaults to `"pil"`):
66
+ The output format of the generate image. Choose between
67
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
68
+ return_dict (`bool`, *optional*, defaults to `True`):
69
+ Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
70
+
71
+ Returns:
72
+ [`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~pipelines.utils.ImagePipelineOutput`] if
73
+ `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
74
+ generated images.
75
+ """
76
+
77
+ if "torch_device" in kwargs:
78
+ device = kwargs.pop("torch_device")
79
+ warnings.warn(
80
+ "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0."
81
+ " Consider using `pipe.to(torch_device)` instead."
82
+ )
83
+
84
+ # Set device as before (to be removed in 0.3.0)
85
+ if device is None:
86
+ device = "cuda" if torch.cuda.is_available() else "cpu"
87
+ self.to(device)
88
+
89
+ # eta corresponds to η in paper and should be between [0, 1]
90
+
91
+ # Sample gaussian noise to begin loop
92
+ image = torch.randn(
93
+ (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
94
+ generator=generator,
95
+ )
96
+ image = image.to(self.device)
97
+
98
+ # set step values
99
+ self.scheduler.set_timesteps(num_inference_steps)
100
+
101
+ for t in self.progress_bar(self.scheduler.timesteps):
102
+ # 1. predict noise model_output
103
+ model_output = self.unet(image, t).sample
104
+
105
+ # 2. predict previous mean of image x_t-1 and add variance depending on eta
106
+ # do x_t -> x_t-1
107
+ image = self.scheduler.step(model_output, t, image, eta).prev_sample
108
+
109
+ image = (image / 2 + 0.5).clamp(0, 1)
110
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
111
+ if output_type == "pil":
112
+ image = self.numpy_to_pil(image)
113
+
114
+ if not return_dict:
115
+ return (image,)
116
+
117
+ return ImagePipelineOutput(images=image)
my_diffusers/pipelines/ddpm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # flake8: noqa
2
+ from .pipeline_ddpm import DDPMPipeline