ysharma HF staff commited on
Commit
cfa3979
1 Parent(s): 7d429aa

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ doc/cn_example.jpg filter=lfs diff=lfs merge=lfs -text
37
+ doc/md_example.jpg filter=lfs diff=lfs merge=lfs -text
38
+ example_image/train.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright [yyyy] [name of copyright owner]
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,37 @@
1
- ---
2
- title: Style-aligned Sdxl
3
- emoji: 🐨
4
- colorFrom: indigo
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 4.8.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Style Aligned Image Generation via Shared Attention
2
+
3
+
4
+ ### [Project Page](https://style-aligned-gen.github.io)   [Paper](https://style-aligned-gen.github.io/data/StyleAligned.pdf)
5
+
6
+
7
+ ## Setup
8
+
9
+ This code was tested with Python 3.11, [Pytorch 2.1](https://pytorch.org/) and [Diffusers 0.16](https://github.com/huggingface/diffusers).
10
+
11
+ ## Examples
12
+ - See [**style_aligned_sdxl**][style_aligned] notebook for generating style aligned images using [SDXL](https://huggingface.co/docs/diffusers/using-diffusers/sdxl).
13
+
14
+ ![alt text](doc/sa_example.jpg)
15
+
16
+
17
+ - See [**style_aligned_w_controlnet**][controlnet] notebook for generating style aligned and depth conditioned images using SDXL with [ControlNet-Depth](https://arxiv.org/abs/2302.05543).
18
+
19
+ ![alt text](doc/cn_example.jpg)
20
+
21
+
22
+ - [**style_aligned_w_multidiffusion**][multidiffusion] can be used for generating style aligned panoramas using [SD V2](https://huggingface.co/stabilityai/stable-diffusion-2) with [MultiDiffusion](https://multidiffusion.github.io/).
23
+
24
+ ![alt text](doc/md_example.jpg)
25
+
26
+ ## TODOs
27
+ - [ ] Adding demo.
28
+ - [ ] StyleAligned from an input image.
29
+ - [ ] Multi-style with MultiDiffusion.
30
+ - [ ] StyleAligned with DreamBooth
31
+
32
+ ## Disclaimer
33
+ This is not an officially supported Google product.
34
+
35
+ [style_aligned]: style_aligned_sdxl.ipynb
36
+ [controlnet]: style_aligned_w_controlnet.ipynb
37
+ [multidiffusion]: style_aligned_w_multidiffusion.ipynb
contributing.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ We'd love to accept your patches and contributions to this project. There are
4
+ just a few small guidelines you need to follow.
5
+
6
+ ## Contributor License Agreement
7
+
8
+ Contributions to this project must be accompanied by a Contributor License
9
+ Agreement. You (or your employer) retain the copyright to your contribution;
10
+ this simply gives us permission to use and redistribute your contributions as
11
+ part of the project. Head over to <https://cla.developers.google.com/> to see
12
+ your current agreements on file or to sign a new one.
13
+
14
+ You generally only need to submit a CLA once, so if you've already submitted one
15
+ (even if it was for a different project), you probably don't need to do it
16
+ again.
17
+
18
+ ## Code Reviews
19
+
20
+ All submissions, including submissions by project members, require review. We
21
+ use GitHub pull requests for this purpose. Consult
22
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23
+ information on using pull requests.
24
+
25
+ ## Community Guidelines
26
+
27
+ This project follows [Google's Open Source Community
28
+ Guidelines](https://opensource.google/conduct/).
doc/cn_example.jpg ADDED

Git LFS Details

  • SHA256: 76f94e53ddca8389ba142bcb644127a4d8b8b2b13cc8b21f3959434989968dea
  • Pointer size: 132 Bytes
  • Size of remote file: 1.08 MB
doc/md_example.jpg ADDED

Git LFS Details

  • SHA256: 2457056b024e2a1cacac71fe0332e688efb46dd535ace9df51fc326dcf10131a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.62 MB
doc/sa_example.jpg ADDED
example_image/train.png ADDED

Git LFS Details

  • SHA256: 4f6c557bfb56274d3f99f07145c42e3a2380e66849c3b8654efad202c0d09a68
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
pipeline_calls.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+ from typing import Any
18
+ import torch
19
+ import numpy as np
20
+ from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
21
+ from diffusers.image_processor import PipelineImageInput
22
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
23
+ from transformers import DPTImageProcessor, DPTForDepthEstimation
24
+ from diffusers import StableDiffusionPanoramaPipeline
25
+ from PIL import Image
26
+ import copy
27
+
28
+ T = torch.Tensor
29
+ TN = T | None
30
+
31
+
32
+ def get_depth_map(image: Image, feature_processor: DPTImageProcessor, depth_estimator: DPTForDepthEstimation) -> Image:
33
+ image = feature_processor(images=image, return_tensors="pt").pixel_values.to("cuda")
34
+ with torch.no_grad(), torch.autocast("cuda"):
35
+ depth_map = depth_estimator(image).predicted_depth
36
+
37
+ depth_map = torch.nn.functional.interpolate(
38
+ depth_map.unsqueeze(1),
39
+ size=(1024, 1024),
40
+ mode="bicubic",
41
+ align_corners=False,
42
+ )
43
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
44
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
45
+ depth_map = (depth_map - depth_min) / (depth_max - depth_min)
46
+ image = torch.cat([depth_map] * 3, dim=1)
47
+
48
+ image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
49
+ image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
50
+ return image
51
+
52
+
53
+ def concat_zero_control(control_reisduel: T) -> T:
54
+ b = control_reisduel.shape[0] // 2
55
+ zerso_reisduel = torch.zeros_like(control_reisduel[0:1])
56
+ return torch.cat((zerso_reisduel, control_reisduel[:b], zerso_reisduel, control_reisduel[b::]))
57
+
58
+
59
+ @torch.no_grad()
60
+ def controlnet_call(
61
+ pipeline: StableDiffusionXLControlNetPipeline,
62
+ prompt: str | list[str] = None,
63
+ prompt_2: str | list[str] | None = None,
64
+ image: PipelineImageInput = None,
65
+ height: int | None = None,
66
+ width: int | None = None,
67
+ num_inference_steps: int = 50,
68
+ guidance_scale: float = 5.0,
69
+ negative_prompt: str | list[str] | None = None,
70
+ negative_prompt_2: str | list[str] | None = None,
71
+ num_images_per_prompt: int = 1,
72
+ eta: float = 0.0,
73
+ generator: torch.Generator | None = None,
74
+ latents: TN = None,
75
+ prompt_embeds: TN = None,
76
+ negative_prompt_embeds: TN = None,
77
+ pooled_prompt_embeds: TN = None,
78
+ negative_pooled_prompt_embeds: TN = None,
79
+ cross_attention_kwargs: dict[str, Any] | None = None,
80
+ controlnet_conditioning_scale: float | list[float] = 1.0,
81
+ control_guidance_start: float | list[float] = 0.0,
82
+ control_guidance_end: float | list[float] = 1.0,
83
+ original_size: tuple[int, int] = None,
84
+ crops_coords_top_left: tuple[int, int] = (0, 0),
85
+ target_size: tuple[int, int] | None = None,
86
+ negative_original_size: tuple[int, int] | None = None,
87
+ negative_crops_coords_top_left: tuple[int, int] = (0, 0),
88
+ negative_target_size:tuple[int, int] | None = None,
89
+ clip_skip: int | None = None,
90
+ ) -> list[Image]:
91
+ controlnet = pipeline.controlnet._orig_mod if is_compiled_module(pipeline.controlnet) else pipeline.controlnet
92
+
93
+ # align format for control guidance
94
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
95
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
96
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
97
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
98
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
99
+ mult = 1
100
+ control_guidance_start, control_guidance_end = (
101
+ mult * [control_guidance_start],
102
+ mult * [control_guidance_end],
103
+ )
104
+
105
+ # 1. Check inputs. Raise error if not correct
106
+ pipeline.check_inputs(
107
+ prompt,
108
+ prompt_2,
109
+ image,
110
+ 1,
111
+ negative_prompt,
112
+ negative_prompt_2,
113
+ prompt_embeds,
114
+ negative_prompt_embeds,
115
+ pooled_prompt_embeds,
116
+ negative_pooled_prompt_embeds,
117
+ controlnet_conditioning_scale,
118
+ control_guidance_start,
119
+ control_guidance_end,
120
+ )
121
+
122
+ pipeline._guidance_scale = guidance_scale
123
+
124
+ # 2. Define call parameters
125
+ if prompt is not None and isinstance(prompt, str):
126
+ batch_size = 1
127
+ elif prompt is not None and isinstance(prompt, list):
128
+ batch_size = len(prompt)
129
+ else:
130
+ batch_size = prompt_embeds.shape[0]
131
+
132
+ device = pipeline._execution_device
133
+
134
+ # 3. Encode input prompt
135
+ text_encoder_lora_scale = (
136
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
137
+ )
138
+ (
139
+ prompt_embeds,
140
+ negative_prompt_embeds,
141
+ pooled_prompt_embeds,
142
+ negative_pooled_prompt_embeds,
143
+ ) = pipeline.encode_prompt(
144
+ prompt,
145
+ prompt_2,
146
+ device,
147
+ 1,
148
+ True,
149
+ negative_prompt,
150
+ negative_prompt_2,
151
+ prompt_embeds=prompt_embeds,
152
+ negative_prompt_embeds=negative_prompt_embeds,
153
+ pooled_prompt_embeds=pooled_prompt_embeds,
154
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
155
+ lora_scale=text_encoder_lora_scale,
156
+ clip_skip=clip_skip,
157
+ )
158
+
159
+ # 4. Prepare image
160
+ if isinstance(controlnet, ControlNetModel):
161
+ image = pipeline.prepare_image(
162
+ image=image,
163
+ width=width,
164
+ height=height,
165
+ batch_size=1,
166
+ num_images_per_prompt=1,
167
+ device=device,
168
+ dtype=controlnet.dtype,
169
+ do_classifier_free_guidance=True,
170
+ guess_mode=False,
171
+ )
172
+ height, width = image.shape[-2:]
173
+ image = torch.stack([image[0]] * num_images_per_prompt + [image[1]] * num_images_per_prompt)
174
+ else:
175
+ assert False
176
+ # 5. Prepare timesteps
177
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
178
+ timesteps = pipeline.scheduler.timesteps
179
+
180
+ # 6. Prepare latent variables
181
+ num_channels_latents = pipeline.unet.config.in_channels
182
+ latents = pipeline.prepare_latents(
183
+ 1 + num_images_per_prompt,
184
+ num_channels_latents,
185
+ height,
186
+ width,
187
+ prompt_embeds.dtype,
188
+ device,
189
+ generator,
190
+ latents,
191
+ )
192
+
193
+ # 6.5 Optionally get Guidance Scale Embedding
194
+ timestep_cond = None
195
+
196
+ # 7. Prepare extra step kwargs.
197
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
198
+
199
+ # 7.1 Create tensor stating which controlnets to keep
200
+ controlnet_keep = []
201
+ for i in range(len(timesteps)):
202
+ keeps = [
203
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
204
+ for s, e in zip(control_guidance_start, control_guidance_end)
205
+ ]
206
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
207
+
208
+ # 7.2 Prepare added time ids & embeddings
209
+ if isinstance(image, list):
210
+ original_size = original_size or image[0].shape[-2:]
211
+ else:
212
+ original_size = original_size or image.shape[-2:]
213
+ target_size = target_size or (height, width)
214
+
215
+ add_text_embeds = pooled_prompt_embeds
216
+ if pipeline.text_encoder_2 is None:
217
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
218
+ else:
219
+ text_encoder_projection_dim = pipeline.text_encoder_2.config.projection_dim
220
+
221
+ add_time_ids = pipeline._get_add_time_ids(
222
+ original_size,
223
+ crops_coords_top_left,
224
+ target_size,
225
+ dtype=prompt_embeds.dtype,
226
+ text_encoder_projection_dim=text_encoder_projection_dim,
227
+ )
228
+
229
+ if negative_original_size is not None and negative_target_size is not None:
230
+ negative_add_time_ids = pipeline._get_add_time_ids(
231
+ negative_original_size,
232
+ negative_crops_coords_top_left,
233
+ negative_target_size,
234
+ dtype=prompt_embeds.dtype,
235
+ text_encoder_projection_dim=text_encoder_projection_dim,
236
+ )
237
+ else:
238
+ negative_add_time_ids = add_time_ids
239
+
240
+ prompt_embeds = torch.stack([prompt_embeds[0]] + [prompt_embeds[1]] * num_images_per_prompt)
241
+ negative_prompt_embeds = torch.stack([negative_prompt_embeds[0]] + [negative_prompt_embeds[1]] * num_images_per_prompt)
242
+ negative_pooled_prompt_embeds = torch.stack([negative_pooled_prompt_embeds[0]] + [negative_pooled_prompt_embeds[1]] * num_images_per_prompt)
243
+ add_text_embeds = torch.stack([add_text_embeds[0]] + [add_text_embeds[1]] * num_images_per_prompt)
244
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
245
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
246
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
247
+
248
+ prompt_embeds = prompt_embeds.to(device)
249
+ add_text_embeds = add_text_embeds.to(device)
250
+ add_time_ids = add_time_ids.to(device).repeat(1 + num_images_per_prompt, 1)
251
+ batch_size = num_images_per_prompt + 1
252
+ # 8. Denoising loop
253
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
254
+ is_unet_compiled = is_compiled_module(pipeline.unet)
255
+ is_controlnet_compiled = is_compiled_module(pipeline.controlnet)
256
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
257
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
258
+ controlnet_prompt_embeds = torch.cat((prompt_embeds[1:batch_size], prompt_embeds[1:batch_size]))
259
+ controlnet_added_cond_kwargs = {key: torch.cat((item[1:batch_size,], item[1:batch_size])) for key, item in added_cond_kwargs.items()}
260
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
261
+ for i, t in enumerate(timesteps):
262
+ # Relevant thread:
263
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
264
+ if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
265
+ torch._inductor.cudagraph_mark_step_begin()
266
+ # expand the latents if we are doing classifier free guidance
267
+ latent_model_input = torch.cat([latents] * 2)
268
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
269
+
270
+ # controlnet(s) inference
271
+ control_model_input = torch.cat((latent_model_input[1:batch_size,], latent_model_input[batch_size+1:]))
272
+
273
+ if isinstance(controlnet_keep[i], list):
274
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
275
+ else:
276
+ controlnet_cond_scale = controlnet_conditioning_scale
277
+ if isinstance(controlnet_cond_scale, list):
278
+ controlnet_cond_scale = controlnet_cond_scale[0]
279
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
280
+ if cond_scale > 0:
281
+ down_block_res_samples, mid_block_res_sample = pipeline.controlnet(
282
+ control_model_input,
283
+ t,
284
+ encoder_hidden_states=controlnet_prompt_embeds,
285
+ controlnet_cond=image,
286
+ conditioning_scale=cond_scale,
287
+ guess_mode=False,
288
+ added_cond_kwargs=controlnet_added_cond_kwargs,
289
+ return_dict=False,
290
+ )
291
+
292
+ mid_block_res_sample = concat_zero_control(mid_block_res_sample)
293
+ down_block_res_samples = [concat_zero_control(down_block_res_sample) for down_block_res_sample in down_block_res_samples]
294
+ else:
295
+ mid_block_res_sample = down_block_res_samples = None
296
+ # predict the noise residual
297
+ noise_pred = pipeline.unet(
298
+ latent_model_input,
299
+ t,
300
+ encoder_hidden_states=prompt_embeds,
301
+ timestep_cond=timestep_cond,
302
+ cross_attention_kwargs=cross_attention_kwargs,
303
+ down_block_additional_residuals=down_block_res_samples,
304
+ mid_block_additional_residual=mid_block_res_sample,
305
+ added_cond_kwargs=added_cond_kwargs,
306
+ return_dict=False,
307
+ )[0]
308
+
309
+ # perform guidance
310
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
311
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
312
+
313
+ # compute the previous noisy sample x_t -> x_t-1
314
+ latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
315
+
316
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
317
+ progress_bar.update()
318
+
319
+ # manually for max memory savings
320
+ if pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast:
321
+ pipeline.upcast_vae()
322
+ latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype)
323
+
324
+ # make sure the VAE is in float32 mode, as it overflows in float16
325
+ needs_upcasting = pipeline.vae.dtype == torch.float16 and pipeline.vae.config.force_upcast
326
+
327
+ if needs_upcasting:
328
+ pipeline.upcast_vae()
329
+ latents = latents.to(next(iter(pipeline.vae.post_quant_conv.parameters())).dtype)
330
+
331
+ image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]
332
+
333
+ # cast back to fp16 if needed
334
+ if needs_upcasting:
335
+ pipeline.vae.to(dtype=torch.float16)
336
+
337
+ if pipeline.watermark is not None:
338
+ image = pipeline.watermark.apply_watermark(image)
339
+
340
+ image = pipeline.image_processor.postprocess(image, output_type='pil')
341
+
342
+ # Offload all models
343
+ pipeline.maybe_free_model_hooks()
344
+ return image
345
+
346
+
347
+ @torch.no_grad()
348
+ def panorama_call(
349
+ pipeline: StableDiffusionPanoramaPipeline,
350
+ prompt: list[str],
351
+ height: int | None = 512,
352
+ width: int | None = 2048,
353
+ num_inference_steps: int = 50,
354
+ guidance_scale: float = 7.5,
355
+ view_batch_size: int = 1,
356
+ negative_prompt: str | list[str] | None = None,
357
+ num_images_per_prompt: int | None = 1,
358
+ eta: float = 0.0,
359
+ generator: torch.Generator | None = None,
360
+ reference_latent: TN = None,
361
+ latents: TN = None,
362
+ prompt_embeds: TN = None,
363
+ negative_prompt_embeds: TN = None,
364
+ cross_attention_kwargs: dict[str, Any] | None = None,
365
+ circular_padding: bool = False,
366
+ clip_skip: int | None = None,
367
+ stride=8
368
+ ) -> list[Image]:
369
+ # 0. Default height and width to unet
370
+ height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
371
+ width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
372
+
373
+ # 1. Check inputs. Raise error if not correct
374
+ pipeline.check_inputs(
375
+ prompt, height, width, 1, negative_prompt, prompt_embeds, negative_prompt_embeds
376
+ )
377
+
378
+ device = pipeline._execution_device
379
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
380
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
381
+ # corresponds to doing no classifier free guidance.
382
+ do_classifier_free_guidance = guidance_scale > 1.0
383
+
384
+ # 3. Encode input prompt
385
+ text_encoder_lora_scale = (
386
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
387
+ )
388
+ prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
389
+ prompt,
390
+ device,
391
+ num_images_per_prompt,
392
+ do_classifier_free_guidance,
393
+ negative_prompt,
394
+ prompt_embeds=prompt_embeds,
395
+ negative_prompt_embeds=negative_prompt_embeds,
396
+ lora_scale=text_encoder_lora_scale,
397
+ clip_skip=clip_skip,
398
+ )
399
+ # For classifier free guidance, we need to do two forward passes.
400
+ # Here we concatenate the unconditional and text embeddings into a single batch
401
+ # to avoid doing two forward passes
402
+
403
+ # 4. Prepare timesteps
404
+ pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
405
+ timesteps = pipeline.scheduler.timesteps
406
+
407
+ # 5. Prepare latent variables
408
+ num_channels_latents = pipeline.unet.config.in_channels
409
+ latents = pipeline.prepare_latents(
410
+ 1,
411
+ num_channels_latents,
412
+ height,
413
+ width,
414
+ prompt_embeds.dtype,
415
+ device,
416
+ generator,
417
+ latents,
418
+ )
419
+ if reference_latent is None:
420
+ reference_latent = torch.randn(1, 4, pipeline.unet.config.sample_size, pipeline.unet.config.sample_size,
421
+ generator=generator)
422
+ reference_latent = reference_latent.to(device=device, dtype=pipeline.unet.dtype)
423
+ # 6. Define panorama grid and initialize views for synthesis.
424
+ # prepare batch grid
425
+ views = pipeline.get_views(height, width, circular_padding=circular_padding, stride=stride)
426
+ views_batch = [views[i: i + view_batch_size] for i in range(0, len(views), view_batch_size)]
427
+ views_scheduler_status = [copy.deepcopy(pipeline.scheduler.__dict__)] * len(views_batch)
428
+ count = torch.zeros_like(latents)
429
+ value = torch.zeros_like(latents)
430
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
431
+ extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
432
+
433
+ # 8. Denoising loop
434
+ # Each denoising step also includes refinement of the latents with respect to the
435
+ # views.
436
+ num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
437
+
438
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds[:1],
439
+ *[negative_prompt_embeds[1:]] * view_batch_size]
440
+ )
441
+ prompt_embeds = torch.cat([prompt_embeds[:1],
442
+ *[prompt_embeds[1:]] * view_batch_size]
443
+ )
444
+
445
+ with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
446
+ for i, t in enumerate(timesteps):
447
+ count.zero_()
448
+ value.zero_()
449
+
450
+ # generate views
451
+ # Here, we iterate through different spatial crops of the latents and denoise them. These
452
+ # denoised (latent) crops are then averaged to produce the final latent
453
+ # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the
454
+ # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113
455
+ # Batch views denoise
456
+ for j, batch_view in enumerate(views_batch):
457
+ vb_size = len(batch_view)
458
+ # get the latents corresponding to the current view coordinates
459
+ if circular_padding:
460
+ latents_for_view = []
461
+ for h_start, h_end, w_start, w_end in batch_view:
462
+ if w_end > latents.shape[3]:
463
+ # Add circular horizontal padding
464
+ latent_view = torch.cat(
465
+ (
466
+ latents[:, :, h_start:h_end, w_start:],
467
+ latents[:, :, h_start:h_end, : w_end - latents.shape[3]],
468
+ ),
469
+ dim=-1,
470
+ )
471
+ else:
472
+ latent_view = latents[:, :, h_start:h_end, w_start:w_end]
473
+ latents_for_view.append(latent_view)
474
+ latents_for_view = torch.cat(latents_for_view)
475
+ else:
476
+ latents_for_view = torch.cat(
477
+ [
478
+ latents[:, :, h_start:h_end, w_start:w_end]
479
+ for h_start, h_end, w_start, w_end in batch_view
480
+ ]
481
+ )
482
+ # rematch block's scheduler status
483
+ pipeline.scheduler.__dict__.update(views_scheduler_status[j])
484
+
485
+ # expand the latents if we are doing classifier free guidance
486
+ latent_reference_plus_view = torch.cat((reference_latent, latents_for_view))
487
+ latent_model_input = latent_reference_plus_view.repeat(2, 1, 1, 1)
488
+ prompt_embeds_input = torch.cat([negative_prompt_embeds[: 1 + vb_size],
489
+ prompt_embeds[: 1 + vb_size]]
490
+ )
491
+ latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
492
+ # predict the noise residual
493
+ # return
494
+ noise_pred = pipeline.unet(
495
+ latent_model_input,
496
+ t,
497
+ encoder_hidden_states=prompt_embeds_input,
498
+ cross_attention_kwargs=cross_attention_kwargs,
499
+ ).sample
500
+
501
+ # perform guidance
502
+
503
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
504
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
505
+ # compute the previous noisy sample x_t -> x_t-1
506
+ latent_reference_plus_view = pipeline.scheduler.step(
507
+ noise_pred, t, latent_reference_plus_view, **extra_step_kwargs
508
+ ).prev_sample
509
+ if j == len(views_batch) - 1:
510
+ reference_latent = latent_reference_plus_view[:1]
511
+ latents_denoised_batch = latent_reference_plus_view[1:]
512
+ # save views scheduler status after sample
513
+ views_scheduler_status[j] = copy.deepcopy(pipeline.scheduler.__dict__)
514
+
515
+ # extract value from batch
516
+ for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
517
+ latents_denoised_batch.chunk(vb_size), batch_view
518
+ ):
519
+ if circular_padding and w_end > latents.shape[3]:
520
+ # Case for circular padding
521
+ value[:, :, h_start:h_end, w_start:] += latents_view_denoised[
522
+ :, :, h_start:h_end, : latents.shape[3] - w_start
523
+ ]
524
+ value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[
525
+ :, :, h_start:h_end,
526
+ latents.shape[3] - w_start:
527
+ ]
528
+ count[:, :, h_start:h_end, w_start:] += 1
529
+ count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1
530
+ else:
531
+ value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
532
+ count[:, :, h_start:h_end, w_start:w_end] += 1
533
+
534
+ # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
535
+ latents = torch.where(count > 0, value / count, value)
536
+
537
+ # call the callback, if provided
538
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
539
+ progress_bar.update()
540
+
541
+ if circular_padding:
542
+ image = pipeline.decode_latents_with_padding(latents)
543
+ else:
544
+ image = pipeline.vae.decode(latents / pipeline.vae.config.scaling_factor, return_dict=False)[0]
545
+ reference_image = pipeline.vae.decode(reference_latent / pipeline.vae.config.scaling_factor, return_dict=False)[0]
546
+ # image, has_nsfw_concept = pipeline.run_safety_checker(image, device, prompt_embeds.dtype)
547
+ # reference_image, _ = pipeline.run_safety_checker(reference_image, device, prompt_embeds.dtype)
548
+
549
+ image = pipeline.image_processor.postprocess(image, output_type='pil', do_denormalize=[True])
550
+ reference_image = pipeline.image_processor.postprocess(reference_image, output_type='pil', do_denormalize=[True])
551
+ pipeline.maybe_free_model_hooks()
552
+ return reference_image + image
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- diffusers
2
- accelerate
3
  mediapy
 
4
  einops
 
1
+ diffusers==0.16.1
2
+ transformers
3
  mediapy
4
+ ipywidgets
5
  einops
sa_handler.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass
19
+ from diffusers import StableDiffusionXLPipeline
20
+ import torch
21
+ import torch.nn as nn
22
+ from torch.nn import functional as nnf
23
+ from diffusers.models import attention_processor
24
+ import einops
25
+
26
+ T = torch.Tensor
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class StyleAlignedArgs:
31
+ share_group_norm: bool = True
32
+ share_layer_norm: bool = True,
33
+ share_attention: bool = True
34
+ adain_queries: bool = True
35
+ adain_keys: bool = True
36
+ adain_values: bool = False
37
+ full_attention_share: bool = False
38
+ keys_scale: float = 1.
39
+ only_self_level: float = 0.
40
+
41
+
42
+ def expand_first(feat: T, scale=1., ) -> T:
43
+ b = feat.shape[0]
44
+ feat_style = torch.stack((feat[0], feat[b // 2])).unsqueeze(1)
45
+ if scale == 1:
46
+ feat_style = feat_style.expand(2, b // 2, *feat.shape[1:])
47
+ else:
48
+ feat_style = feat_style.repeat(1, b // 2, 1, 1, 1)
49
+ feat_style = torch.cat([feat_style[:, :1], scale * feat_style[:, 1:]], dim=1)
50
+ return feat_style.reshape(*feat.shape)
51
+
52
+
53
+ def concat_first(feat: T, dim=2, scale=1.) -> T:
54
+ feat_style = expand_first(feat, scale=scale)
55
+ return torch.cat((feat, feat_style), dim=dim)
56
+
57
+
58
+ def calc_mean_std(feat, eps: float = 1e-5) -> tuple[T, T]:
59
+ feat_std = (feat.var(dim=-2, keepdims=True) + eps).sqrt()
60
+ feat_mean = feat.mean(dim=-2, keepdims=True)
61
+ return feat_mean, feat_std
62
+
63
+
64
+ def adain(feat: T) -> T:
65
+ feat_mean, feat_std = calc_mean_std(feat)
66
+ feat_style_mean = expand_first(feat_mean)
67
+ feat_style_std = expand_first(feat_std)
68
+ feat = (feat - feat_mean) / feat_std
69
+ feat = feat * feat_style_std + feat_style_mean
70
+ return feat
71
+
72
+
73
+ class DefaultAttentionProcessor(nn.Module):
74
+
75
+ def __init__(self):
76
+ super().__init__()
77
+ self.processor = attention_processor.AttnProcessor2_0()
78
+
79
+ def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None,
80
+ attention_mask=None, **kwargs):
81
+ return self.processor(attn, hidden_states, encoder_hidden_states, attention_mask)
82
+
83
+
84
+ class SharedAttentionProcessor(DefaultAttentionProcessor):
85
+
86
+ def shared_call(
87
+ self,
88
+ attn: attention_processor.Attention,
89
+ hidden_states,
90
+ encoder_hidden_states=None,
91
+ attention_mask=None,
92
+ **kwargs
93
+ ):
94
+
95
+ residual = hidden_states
96
+ input_ndim = hidden_states.ndim
97
+ if input_ndim == 4:
98
+ batch_size, channel, height, width = hidden_states.shape
99
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
100
+ batch_size, sequence_length, _ = (
101
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
102
+ )
103
+
104
+ if attention_mask is not None:
105
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
106
+ # scaled_dot_product_attention expects attention_mask shape to be
107
+ # (batch, heads, source_length, target_length)
108
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
109
+
110
+ if attn.group_norm is not None:
111
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
112
+
113
+ query = attn.to_q(hidden_states)
114
+ key = attn.to_k(hidden_states)
115
+ value = attn.to_v(hidden_states)
116
+ inner_dim = key.shape[-1]
117
+ head_dim = inner_dim // attn.heads
118
+
119
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
120
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
121
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
122
+ # if self.step >= self.start_inject:
123
+ if self.adain_queries:
124
+ query = adain(query)
125
+ if self.adain_keys:
126
+ key = adain(key)
127
+ if self.adain_values:
128
+ value = adain(value)
129
+ if self.share_attention:
130
+ key = concat_first(key, -2, scale=self.keys_scale)
131
+ value = concat_first(value, -2)
132
+ hidden_states = nnf.scaled_dot_product_attention(
133
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
134
+ )
135
+ else:
136
+ hidden_states = nnf.scaled_dot_product_attention(
137
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
138
+ )
139
+ # hidden_states = adain(hidden_states)
140
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
141
+ hidden_states = hidden_states.to(query.dtype)
142
+
143
+ # linear proj
144
+ hidden_states = attn.to_out[0](hidden_states)
145
+ # dropout
146
+ hidden_states = attn.to_out[1](hidden_states)
147
+
148
+ if input_ndim == 4:
149
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
150
+
151
+ if attn.residual_connection:
152
+ hidden_states = hidden_states + residual
153
+
154
+ hidden_states = hidden_states / attn.rescale_output_factor
155
+ return hidden_states
156
+
157
+ def __call__(self, attn: attention_processor.Attention, hidden_states, encoder_hidden_states=None,
158
+ attention_mask=None, **kwargs):
159
+ if self.full_attention_share:
160
+ b, n, d = hidden_states.shape
161
+ hidden_states = einops.rearrange(hidden_states, '(k b) n d -> k (b n) d', k=2)
162
+ hidden_states = super().__call__(attn, hidden_states, encoder_hidden_states=encoder_hidden_states,
163
+ attention_mask=attention_mask, **kwargs)
164
+ hidden_states = einops.rearrange(hidden_states, 'k (b n) d -> (k b) n d', n=n)
165
+ else:
166
+ hidden_states = self.shared_call(attn, hidden_states, hidden_states, attention_mask, **kwargs)
167
+
168
+ return hidden_states
169
+
170
+ def __init__(self, style_aligned_args: StyleAlignedArgs):
171
+ super().__init__()
172
+ self.share_attention = style_aligned_args.share_attention
173
+ self.adain_queries = style_aligned_args.adain_queries
174
+ self.adain_keys = style_aligned_args.adain_keys
175
+ self.adain_values = style_aligned_args.adain_values
176
+ self.full_attention_share = style_aligned_args.full_attention_share
177
+ self.keys_scale = style_aligned_args.keys_scale
178
+
179
+
180
+ def _get_switch_vec(total_num_layers, level):
181
+ if level == 0:
182
+ return torch.zeros(total_num_layers, dtype=torch.bool)
183
+ if level == 1:
184
+ return torch.ones(total_num_layers, dtype=torch.bool)
185
+ to_flip = level > .5
186
+ if to_flip:
187
+ level = 1 - level
188
+ num_switch = int(level * total_num_layers)
189
+ vec = torch.arange(total_num_layers)
190
+ vec = vec % (total_num_layers // num_switch)
191
+ vec = vec == 0
192
+ if to_flip:
193
+ vec = ~vec
194
+ return vec
195
+
196
+
197
+ def init_attention_processors(pipeline: StableDiffusionXLPipeline, style_aligned_args: StyleAlignedArgs | None = None):
198
+ attn_procs = {}
199
+ unet = pipeline.unet
200
+ number_of_self, number_of_cross = 0, 0
201
+ num_self_layers = len([name for name in unet.attn_processors.keys() if 'attn1' in name])
202
+ if style_aligned_args is None:
203
+ only_self_vec = _get_switch_vec(num_self_layers, 1)
204
+ else:
205
+ only_self_vec = _get_switch_vec(num_self_layers, style_aligned_args.only_self_level)
206
+ for i, name in enumerate(unet.attn_processors.keys()):
207
+ is_self_attention = 'attn1' in name
208
+ if is_self_attention:
209
+ number_of_self += 1
210
+ if style_aligned_args is None or only_self_vec[i // 2]:
211
+ attn_procs[name] = DefaultAttentionProcessor()
212
+ else:
213
+ attn_procs[name] = SharedAttentionProcessor(style_aligned_args)
214
+
215
+ else:
216
+ number_of_cross += 1
217
+ attn_procs[name] = DefaultAttentionProcessor()
218
+
219
+ unet.set_attn_processor(attn_procs)
220
+
221
+
222
+ def register_shared_norm(pipeline: StableDiffusionXLPipeline,
223
+ share_group_norm: bool = True,
224
+ share_layer_norm: bool = True, ):
225
+ def register_norm_forward(norm_layer: nn.GroupNorm | nn.LayerNorm) -> nn.GroupNorm | nn.LayerNorm:
226
+ if not hasattr(norm_layer, 'orig_forward'):
227
+ setattr(norm_layer, 'orig_forward', norm_layer.forward)
228
+ orig_forward = norm_layer.orig_forward
229
+
230
+ def forward_(hidden_states: T) -> T:
231
+ n = hidden_states.shape[-2]
232
+ hidden_states = concat_first(hidden_states, dim=-2)
233
+ hidden_states = orig_forward(hidden_states)
234
+ return hidden_states[..., :n, :]
235
+
236
+ norm_layer.forward = forward_
237
+ return norm_layer
238
+
239
+ def get_norm_layers(pipeline_, norm_layers_: dict[str, list[nn.GroupNorm | nn.LayerNorm]]):
240
+ if isinstance(pipeline_, nn.LayerNorm) and share_layer_norm:
241
+ norm_layers_['layer'].append(pipeline_)
242
+ if isinstance(pipeline_, nn.GroupNorm) and share_group_norm:
243
+ norm_layers_['group'].append(pipeline_)
244
+ else:
245
+ for layer in pipeline_.children():
246
+ get_norm_layers(layer, norm_layers_)
247
+
248
+ norm_layers = {'group': [], 'layer': []}
249
+ get_norm_layers(pipeline.unet, norm_layers)
250
+ return [register_norm_forward(layer) for layer in norm_layers['group']] + [register_norm_forward(layer) for layer in
251
+ norm_layers['layer']]
252
+
253
+
254
+ class Handler:
255
+
256
+ def register(self, style_aligned_args: StyleAlignedArgs, ):
257
+ self.norm_layers = register_shared_norm(self.pipeline, style_aligned_args.share_group_norm,
258
+ style_aligned_args.share_layer_norm)
259
+ init_attention_processors(self.pipeline, style_aligned_args)
260
+
261
+ def remove(self):
262
+ for layer in self.norm_layers:
263
+ layer.forward = layer.orig_forward
264
+ self.norm_layers = []
265
+ init_attention_processors(self.pipeline, None)
266
+
267
+ def __init__(self, pipeline: StableDiffusionXLPipeline):
268
+ self.pipeline = pipeline
269
+ self.norm_layers = []
style_aligned_sdxl.ipynb ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a885cf5d-c525-4f5b-a8e4-f67d2f699909",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Copyright 2023 Google LLC"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "d891d022-8979-40d4-848f-ecb84c17f12c",
15
+ "metadata": {
16
+ "jp-MarkdownHeadingCollapsed": true
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "# Copyright 2023 Google LLC\n",
21
+ "#\n",
22
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
23
+ "# you may not use this file except in compliance with the License.\n",
24
+ "# You may obtain a copy of the License at\n",
25
+ "#\n",
26
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
27
+ "#\n",
28
+ "# Unless required by applicable law or agreed to in writing, software\n",
29
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
30
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
31
+ "# See the License for the specific language governing permissions and\n",
32
+ "# limitations under the License."
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "id": "540d8642-c203-471c-a66d-0d43aabb0706",
38
+ "metadata": {},
39
+ "source": [
40
+ "# StyleAligned over SDXL"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "23d54ea7-f7ab-4548-9b10-ece87216dc18",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "from diffusers import StableDiffusionXLPipeline, DDIMScheduler\n",
51
+ "import torch\n",
52
+ "import mediapy\n",
53
+ "import sa_handler"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "id": "c2f6f1e6-445f-47bc-b9db-0301caeb7490",
60
+ "metadata": {
61
+ "pycharm": {
62
+ "name": "#%%\n"
63
+ }
64
+ },
65
+ "outputs": [],
66
+ "source": [
67
+ "# init models\n",
68
+ "\n",
69
+ "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False,\n",
70
+ " set_alpha_to_one=False)\n",
71
+ "pipeline = StableDiffusionXLPipeline.from_pretrained(\n",
72
+ " \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True,\n",
73
+ " scheduler=scheduler\n",
74
+ ").to(\"cuda\")\n",
75
+ "\n",
76
+ "handler = sa_handler.Handler(pipeline)\n",
77
+ "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n",
78
+ " share_layer_norm=False,\n",
79
+ " share_attention=True,\n",
80
+ " adain_queries=True,\n",
81
+ " adain_keys=True,\n",
82
+ " adain_values=False,\n",
83
+ " )\n",
84
+ "\n",
85
+ "handler.register(sa_args, )"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "id": "5cca9256-0ce0-45c3-9cba-68c7eff1452f",
92
+ "metadata": {
93
+ "pycharm": {
94
+ "name": "#%%\n"
95
+ }
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "# run StyleAligned\n",
100
+ "\n",
101
+ "sets_of_prompts = [\n",
102
+ " \"a toy train. macro photo. 3d game asset\",\n",
103
+ " \"a toy airplane. macro photo. 3d game asset\",\n",
104
+ " \"a toy bicycle. macro photo. 3d game asset\",\n",
105
+ " \"a toy car. macro photo. 3d game asset\",\n",
106
+ " \"a toy boat. macro photo. 3d game asset\",\n",
107
+ "]\n",
108
+ "images = pipeline(sets_of_prompts,).images\n",
109
+ "mediapy.show_images(images)"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "id": "d819ad6d-0c19-411f-ba97-199909f64805",
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": []
119
+ }
120
+ ],
121
+ "metadata": {
122
+ "kernelspec": {
123
+ "display_name": "Python 3 (ipykernel)",
124
+ "language": "python",
125
+ "name": "python3"
126
+ },
127
+ "language_info": {
128
+ "codemirror_mode": {
129
+ "name": "ipython",
130
+ "version": 3
131
+ },
132
+ "file_extension": ".py",
133
+ "mimetype": "text/x-python",
134
+ "name": "python",
135
+ "nbconvert_exporter": "python",
136
+ "pygments_lexer": "ipython3",
137
+ "version": "3.11.5"
138
+ }
139
+ },
140
+ "nbformat": 4,
141
+ "nbformat_minor": 5
142
+ }
style_aligned_w_controlnet.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "f86ede39-8d9f-4da9-bc12-955f2fddd484",
6
+ "metadata": {
7
+ "pycharm": {
8
+ "name": "#%% md\n"
9
+ }
10
+ },
11
+ "source": [
12
+ "## Copyright 2023 Google LLC"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "3f3cbf47-a52b-48b1-9bd3-3435f92f2174",
19
+ "metadata": {
20
+ "pycharm": {
21
+ "name": "#%%\n"
22
+ }
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "# Copyright 2023 Google LLC\n",
27
+ "#\n",
28
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
29
+ "# you may not use this file except in compliance with the License.\n",
30
+ "# You may obtain a copy of the License at\n",
31
+ "#\n",
32
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
33
+ "#\n",
34
+ "# Unless required by applicable law or agreed to in writing, software\n",
35
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
36
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
37
+ "# See the License for the specific language governing permissions and\n",
38
+ "# limitations under the License."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "id": "22de629b-581f-4335-9e7b-f73221d8dbcb",
44
+ "metadata": {
45
+ "pycharm": {
46
+ "name": "#%% md\n"
47
+ }
48
+ },
49
+ "source": [
50
+ "# ControlNet depth with StyleAligned over SDXL"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "486b7ebb-c483-4bf0-ace8-f8092c2d1f23",
57
+ "metadata": {
58
+ "pycharm": {
59
+ "name": "#%%\n"
60
+ }
61
+ },
62
+ "outputs": [],
63
+ "source": [
64
+ "from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL\n",
65
+ "from diffusers.utils import load_image\n",
66
+ "from transformers import DPTImageProcessor, DPTForDepthEstimation\n",
67
+ "import torch\n",
68
+ "import mediapy\n",
69
+ "import sa_handler\n",
70
+ "import pipeline_calls"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "2a7e85e7-b5cf-45b2-946a-5ba1e4923586",
77
+ "metadata": {
78
+ "pycharm": {
79
+ "name": "#%%\n"
80
+ }
81
+ },
82
+ "outputs": [],
83
+ "source": [
84
+ "# init models\n",
85
+ "\n",
86
+ "depth_estimator = DPTForDepthEstimation.from_pretrained(\"Intel/dpt-hybrid-midas\").to(\"cuda\")\n",
87
+ "feature_processor = DPTImageProcessor.from_pretrained(\"Intel/dpt-hybrid-midas\")\n",
88
+ "\n",
89
+ "controlnet = ControlNetModel.from_pretrained(\n",
90
+ " \"diffusers/controlnet-depth-sdxl-1.0\",\n",
91
+ " variant=\"fp16\",\n",
92
+ " use_safetensors=True,\n",
93
+ " torch_dtype=torch.float16,\n",
94
+ ").to(\"cuda\")\n",
95
+ "vae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16).to(\"cuda\")\n",
96
+ "pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n",
97
+ " \"stabilityai/stable-diffusion-xl-base-1.0\",\n",
98
+ " controlnet=controlnet,\n",
99
+ " vae=vae,\n",
100
+ " variant=\"fp16\",\n",
101
+ " use_safetensors=True,\n",
102
+ " torch_dtype=torch.float16,\n",
103
+ ").to(\"cuda\")\n",
104
+ "pipeline.enable_model_cpu_offload()\n",
105
+ "\n",
106
+ "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n",
107
+ " share_layer_norm=False,\n",
108
+ " share_attention=True,\n",
109
+ " adain_queries=True,\n",
110
+ " adain_keys=True,\n",
111
+ " adain_values=False,\n",
112
+ " )\n",
113
+ "handler = sa_handler.Handler(pipeline)\n",
114
+ "handler.register(sa_args, )"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "94ca26b4-9061-4012-9400-8d97ef212d87",
121
+ "metadata": {
122
+ "pycharm": {
123
+ "name": "#%%\n"
124
+ }
125
+ },
126
+ "outputs": [],
127
+ "source": [
128
+ "# get depth maps\n",
129
+ "\n",
130
+ "image = load_image(\"./example_image/train.png\")\n",
131
+ "depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)\n",
132
+ "depth_image2 = load_image(\"./example_image/sun.png\").resize((1024, 1024))\n",
133
+ "mediapy.show_images([depth_image1, depth_image2])"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "id": "c8f56fe4-559f-49ff-a2d8-460dcfeb56a0",
140
+ "metadata": {
141
+ "pycharm": {
142
+ "name": "#%%\n"
143
+ }
144
+ },
145
+ "outputs": [],
146
+ "source": [
147
+ "# run ControlNet depth with StyleAligned\n",
148
+ "\n",
149
+ "reference_prompt = \"a poster in flat design style\"\n",
150
+ "target_prompts = [\"a train in flat design style\", \"the sun in flat design style\"]\n",
151
+ "controlnet_conditioning_scale = 0.8\n",
152
+ "num_images_per_prompt = 3 # adjust according to VRAM size\n",
153
+ "latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n",
154
+ "for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts):\n",
155
+ " latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n",
156
+ " images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt],\n",
157
+ " image=deph_map,\n",
158
+ " num_inference_steps=50,\n",
159
+ " controlnet_conditioning_scale=controlnet_conditioning_scale,\n",
160
+ " num_images_per_prompt=num_images_per_prompt,\n",
161
+ " latents=latents)\n",
162
+ " \n",
163
+ " mediapy.show_images([images[0], deph_map] + images[1:], titles=[\"reference\", \"depth\"] + [f'result {i}' for i in range(1, len(images))])\n"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": null,
169
+ "id": "437ba4bd-6243-486b-8ba5-3b7cd661d53a",
170
+ "metadata": {
171
+ "pycharm": {
172
+ "name": "#%%\n"
173
+ }
174
+ },
175
+ "outputs": [],
176
+ "source": []
177
+ }
178
+ ],
179
+ "metadata": {
180
+ "kernelspec": {
181
+ "display_name": "Python 3 (ipykernel)",
182
+ "language": "python",
183
+ "name": "python3"
184
+ },
185
+ "language_info": {
186
+ "codemirror_mode": {
187
+ "name": "ipython",
188
+ "version": 3
189
+ },
190
+ "file_extension": ".py",
191
+ "mimetype": "text/x-python",
192
+ "name": "python",
193
+ "nbconvert_exporter": "python",
194
+ "pygments_lexer": "ipython3",
195
+ "version": "3.11.5"
196
+ }
197
+ },
198
+ "nbformat": 4,
199
+ "nbformat_minor": 5
200
+ }
style_aligned_w_multidiffusion.ipynb ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "50fa980f-1bae-40c1-a1f3-f5f89bef60d3",
6
+ "metadata": {
7
+ "pycharm": {
8
+ "name": "#%% md\n"
9
+ }
10
+ },
11
+ "source": [
12
+ "## Copyright 2023 Google LLC"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "id": "5da5f038-057f-4475-a783-95660f98238c",
19
+ "metadata": {
20
+ "pycharm": {
21
+ "name": "#%%\n"
22
+ }
23
+ },
24
+ "outputs": [],
25
+ "source": [
26
+ "# Copyright 2023 Google LLC\n",
27
+ "#\n",
28
+ "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
29
+ "# you may not use this file except in compliance with the License.\n",
30
+ "# You may obtain a copy of the License at\n",
31
+ "#\n",
32
+ "# http://www.apache.org/licenses/LICENSE-2.0\n",
33
+ "#\n",
34
+ "# Unless required by applicable law or agreed to in writing, software\n",
35
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
36
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
37
+ "# See the License for the specific language governing permissions and\n",
38
+ "# limitations under the License."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "markdown",
43
+ "id": "c3a7c069-c441-4204-a905-59cbd9edc13a",
44
+ "metadata": {
45
+ "pycharm": {
46
+ "name": "#%% md\n"
47
+ }
48
+ },
49
+ "source": [
50
+ "# MultiDiffusion with StyleAligned over SD v2"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "14178de7-d4c8-4881-ac1d-ff84bae57c6f",
57
+ "metadata": {
58
+ "pycharm": {
59
+ "name": "#%%\n"
60
+ }
61
+ },
62
+ "outputs": [],
63
+ "source": [
64
+ "import torch\n",
65
+ "from diffusers import StableDiffusionPanoramaPipeline, DDIMScheduler\n",
66
+ "import mediapy\n",
67
+ "import sa_handler\n",
68
+ "import pipeline_calls"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "id": "738cee0e-4d6e-4875-b4df-eadff6e27e7f",
75
+ "metadata": {
76
+ "pycharm": {
77
+ "name": "#%%\n"
78
+ }
79
+ },
80
+ "outputs": [],
81
+ "source": [
82
+ "# init models\n",
83
+ "model_ckpt = \"stabilityai/stable-diffusion-2-base\"\n",
84
+ "scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder=\"scheduler\")\n",
85
+ "pipeline = StableDiffusionPanoramaPipeline.from_pretrained(\n",
86
+ " model_ckpt, scheduler=scheduler, torch_dtype=torch.float16\n",
87
+ ").to(\"cuda\")\n",
88
+ "\n",
89
+ "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,\n",
90
+ " share_layer_norm=True,\n",
91
+ " share_attention=True,\n",
92
+ " adain_queries=True,\n",
93
+ " adain_keys=True,\n",
94
+ " adain_values=False,\n",
95
+ " )\n",
96
+ "handler = sa_handler.Handler(pipeline)\n",
97
+ "handler.register(sa_args)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "id": "ea61e789-2814-4820-8ae7-234c3c6640a0",
104
+ "metadata": {
105
+ "pycharm": {
106
+ "name": "#%%\n"
107
+ }
108
+ },
109
+ "outputs": [],
110
+ "source": [
111
+ "# run MultiDiffusion with StyleAligned\n",
112
+ "\n",
113
+ "reference_prompt = \"a beautiful papercut art design\"\n",
114
+ "target_prompts = [\"mountains in a beautiful papercut art design\", \"giraffes in a beautiful papercut art design\"]\n",
115
+ "view_batch_size = 25 # adjust according to VRAM size\n",
116
+ "reference_latent = torch.randn(1, 4, 64, 64,)\n",
117
+ "for target_prompt in target_prompts:\n",
118
+ " images = pipeline_calls.panorama_call(pipeline, [reference_prompt, target_prompt], reference_latent=reference_latent, view_batch_size=view_batch_size)\n",
119
+ " mediapy.show_images(images, titles=[\"reference\", \"result\"])"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "id": "791a9b28-f0ce-4fd0-9f3c-594281c2ae56",
126
+ "metadata": {
127
+ "pycharm": {
128
+ "name": "#%%\n"
129
+ }
130
+ },
131
+ "outputs": [],
132
+ "source": []
133
+ }
134
+ ],
135
+ "metadata": {
136
+ "kernelspec": {
137
+ "display_name": "Python 3 (ipykernel)",
138
+ "language": "python",
139
+ "name": "python3"
140
+ },
141
+ "language_info": {
142
+ "codemirror_mode": {
143
+ "name": "ipython",
144
+ "version": 3
145
+ },
146
+ "file_extension": ".py",
147
+ "mimetype": "text/x-python",
148
+ "name": "python",
149
+ "nbconvert_exporter": "python",
150
+ "pygments_lexer": "ipython3",
151
+ "version": "3.11.5"
152
+ }
153
+ },
154
+ "nbformat": 4,
155
+ "nbformat_minor": 5
156
+ }