kaz-sony commited on
Commit
832c977
1 Parent(s): 7c10426
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/*.jpg filter=lfs diff=lfs merge=lfs -text
37
+ extern/splatting-0.0.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "extern/ZoeDepth"]
2
+ path = extern/ZoeDepth
3
+ url = git@github.com:isl-org/ZoeDepth.git
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sony Research Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
NOTICE ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This repository contains files and parts of codes adapted or modified from thrid-party repositories under other licenses. Below are list of the reporitories. Adapted files are specified in top lines of each file.
2
+
3
+ -----------------------------
4
+ Moore-AnimateAnyone
5
+ Apache License, Version 2.0
6
+ Copyright @2023-2024 Moore Threads Technology Co., Ltd.
7
+ https://github.com/MooreThreads/Moore-AnimateAnyone
8
+
9
+ -----------------------------
10
+ magic-animate
11
+ BSD 3-Clause License
12
+ Copyright (c) Bytedance Inc.
13
+ https://github.com/magic-research/magic-animate
14
+
15
+ -----------------------------
16
+ AnimateDiff
17
+ Apache License, Version 2.0
18
+ https://github.com/guoyww/AnimateDiff
19
+
20
+ -----------------------------
21
+ Diffusers
22
+ Apache License, Version 2.0
23
+ Copyright (c) Hugging Face Inc.
24
+ https://github.com/huggingface/diffusers
25
+
26
+
27
+ ================================================================================
28
+
29
+
30
+ Apache License
31
+ Version 2.0, January 2004
32
+ http://www.apache.org/licenses/
33
+
34
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
35
+
36
+ 1. Definitions.
37
+
38
+ "License" shall mean the terms and conditions for use, reproduction,
39
+ and distribution as defined by Sections 1 through 9 of this document.
40
+
41
+ "Licensor" shall mean the copyright owner or entity authorized by
42
+ the copyright owner that is granting the License.
43
+
44
+ "Legal Entity" shall mean the union of the acting entity and all
45
+ other entities that control, are controlled by, or are under common
46
+ control with that entity. For the purposes of this definition,
47
+ "control" means (i) the power, direct or indirect, to cause the
48
+ direction or management of such entity, whether by contract or
49
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
50
+ outstanding shares, or (iii) beneficial ownership of such entity.
51
+
52
+ "You" (or "Your") shall mean an individual or Legal Entity
53
+ exercising permissions granted by this License.
54
+
55
+ "Source" form shall mean the preferred form for making modifications,
56
+ including but not limited to software source code, documentation
57
+ source, and configuration files.
58
+
59
+ "Object" form shall mean any form resulting from mechanical
60
+ transformation or translation of a Source form, including but
61
+ not limited to compiled object code, generated documentation,
62
+ and conversions to other media types.
63
+
64
+ "Work" shall mean the work of authorship, whether in Source or
65
+ Object form, made available under the License, as indicated by a
66
+ copyright notice that is included in or attached to the work
67
+ (an example is provided in the Appendix below).
68
+
69
+ "Derivative Works" shall mean any work, whether in Source or Object
70
+ form, that is based on (or derived from) the Work and for which the
71
+ editorial revisions, annotations, elaborations, or other modifications
72
+ represent, as a whole, an original work of authorship. For the purposes
73
+ of this License, Derivative Works shall not include works that remain
74
+ separable from, or merely link (or bind by name) to the interfaces of,
75
+ the Work and Derivative Works thereof.
76
+
77
+ "Contribution" shall mean any work of authorship, including
78
+ the original version of the Work and any modifications or additions
79
+ to that Work or Derivative Works thereof, that is intentionally
80
+ submitted to Licensor for inclusion in the Work by the copyright owner
81
+ or by an individual or Legal Entity authorized to submit on behalf of
82
+ the copyright owner. For the purposes of this definition, "submitted"
83
+ means any form of electronic, verbal, or written communication sent
84
+ to the Licensor or its representatives, including but not limited to
85
+ communication on electronic mailing lists, source code control systems,
86
+ and issue tracking systems that are managed by, or on behalf of, the
87
+ Licensor for the purpose of discussing and improving the Work, but
88
+ excluding communication that is conspicuously marked or otherwise
89
+ designated in writing by the copyright owner as "Not a Contribution."
90
+
91
+ "Contributor" shall mean Licensor and any individual or Legal Entity
92
+ on behalf of whom a Contribution has been received by Licensor and
93
+ subsequently incorporated within the Work.
94
+
95
+ 2. Grant of Copyright License. Subject to the terms and conditions of
96
+ this License, each Contributor hereby grants to You a perpetual,
97
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
98
+ copyright license to reproduce, prepare Derivative Works of,
99
+ publicly display, publicly perform, sublicense, and distribute the
100
+ Work and such Derivative Works in Source or Object form.
101
+
102
+ 3. Grant of Patent License. Subject to the terms and conditions of
103
+ this License, each Contributor hereby grants to You a perpetual,
104
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
105
+ (except as stated in this section) patent license to make, have made,
106
+ use, offer to sell, sell, import, and otherwise transfer the Work,
107
+ where such license applies only to those patent claims licensable
108
+ by such Contributor that are necessarily infringed by their
109
+ Contribution(s) alone or by combination of their Contribution(s)
110
+ with the Work to which such Contribution(s) was submitted. If You
111
+ institute patent litigation against any entity (including a
112
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
113
+ or a Contribution incorporated within the Work constitutes direct
114
+ or contributory patent infringement, then any patent licenses
115
+ granted to You under this License for that Work shall terminate
116
+ as of the date such litigation is filed.
117
+
118
+ 4. Redistribution. You may reproduce and distribute copies of the
119
+ Work or Derivative Works thereof in any medium, with or without
120
+ modifications, and in Source or Object form, provided that You
121
+ meet the following conditions:
122
+
123
+ (a) You must give any other recipients of the Work or
124
+ Derivative Works a copy of this License; and
125
+
126
+ (b) You must cause any modified files to carry prominent notices
127
+ stating that You changed the files; and
128
+
129
+ (c) You must retain, in the Source form of any Derivative Works
130
+ that You distribute, all copyright, patent, trademark, and
131
+ attribution notices from the Source form of the Work,
132
+ excluding those notices that do not pertain to any part of
133
+ the Derivative Works; and
134
+
135
+ (d) If the Work includes a "NOTICE" text file as part of its
136
+ distribution, then any Derivative Works that You distribute must
137
+ include a readable copy of the attribution notices contained
138
+ within such NOTICE file, excluding those notices that do not
139
+ pertain to any part of the Derivative Works, in at least one
140
+ of the following places: within a NOTICE text file distributed
141
+ as part of the Derivative Works; within the Source form or
142
+ documentation, if provided along with the Derivative Works; or,
143
+ within a display generated by the Derivative Works, if and
144
+ wherever such third-party notices normally appear. The contents
145
+ of the NOTICE file are for informational purposes only and
146
+ do not modify the License. You may add Your own attribution
147
+ notices within Derivative Works that You distribute, alongside
148
+ or as an addendum to the NOTICE text from the Work, provided
149
+ that such additional attribution notices cannot be construed
150
+ as modifying the License.
151
+
152
+ You may add Your own copyright statement to Your modifications and
153
+ may provide additional or different license terms and conditions
154
+ for use, reproduction, or distribution of Your modifications, or
155
+ for any such Derivative Works as a whole, provided Your use,
156
+ reproduction, and distribution of the Work otherwise complies with
157
+ the conditions stated in this License.
158
+
159
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
160
+ any Contribution intentionally submitted for inclusion in the Work
161
+ by You to the Licensor shall be under the terms and conditions of
162
+ this License, without any additional terms or conditions.
163
+ Notwithstanding the above, nothing herein shall supersede or modify
164
+ the terms of any separate license agreement you may have executed
165
+ with Licensor regarding such Contributions.
166
+
167
+ 6. Trademarks. This License does not grant permission to use the trade
168
+ names, trademarks, service marks, or product names of the Licensor,
169
+ except as required for reasonable and customary use in describing the
170
+ origin of the Work and reproducing the content of the NOTICE file.
171
+
172
+ 7. Disclaimer of Warranty. Unless required by applicable law or
173
+ agreed to in writing, Licensor provides the Work (and each
174
+ Contributor provides its Contributions) on an "AS IS" BASIS,
175
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
176
+ implied, including, without limitation, any warranties or conditions
177
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
178
+ PARTICULAR PURPOSE. You are solely responsible for determining the
179
+ appropriateness of using or redistributing the Work and assume any
180
+ risks associated with Your exercise of permissions under this License.
181
+
182
+ 8. Limitation of Liability. In no event and under no legal theory,
183
+ whether in tort (including negligence), contract, or otherwise,
184
+ unless required by applicable law (such as deliberate and grossly
185
+ negligent acts) or agreed to in writing, shall any Contributor be
186
+ liable to You for damages, including any direct, indirect, special,
187
+ incidental, or consequential damages of any character arising as a
188
+ result of this License or out of the use or inability to use the
189
+ Work (including but not limited to damages for loss of goodwill,
190
+ work stoppage, computer failure or malfunction, or any and all
191
+ other commercial damages or losses), even if such Contributor
192
+ has been advised of the possibility of such damages.
193
+
194
+ 9. Accepting Warranty or Additional Liability. While redistributing
195
+ the Work or Derivative Works thereof, You may choose to offer,
196
+ and charge a fee for, acceptance of support, warranty, indemnity,
197
+ or other liability obligations and/or rights consistent with this
198
+ License. However, in accepting such obligations, You may act only
199
+ on Your own behalf and on Your sole responsibility, not on behalf
200
+ of any other Contributor, and only if You agree to indemnify,
201
+ defend, and hold each Contributor harmless for any liability
202
+ incurred by, or claims asserted against, such Contributor by reason
203
+ of your accepting any such warranty or additional liability.
204
+
205
+ END OF TERMS AND CONDITIONS
206
+
207
+ APPENDIX: How to apply the Apache License to your work.
208
+
209
+ To apply the Apache License to your work, attach the following
210
+ boilerplate notice, with the fields enclosed by brackets "[]"
211
+ replaced with your own identifying information. (Don't include
212
+ the brackets!) The text should be enclosed in the appropriate
213
+ comment syntax for the file format. We also recommend that a
214
+ file or class name and description of purpose be included on the
215
+ same "printed page" as the copyright notice for easier
216
+ identification within third-party archives.
217
+
218
+ Copyright [yyyy] [name of copyright owner]
219
+
220
+ Licensed under the Apache License, Version 2.0 (the "License");
221
+ you may not use this file except in compliance with the License.
222
+ You may obtain a copy of the License at
223
+
224
+ http://www.apache.org/licenses/LICENSE-2.0
225
+
226
+ Unless required by applicable law or agreed to in writing, software
227
+ distributed under the License is distributed on an "AS IS" BASIS,
228
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
229
+ See the License for the specific language governing permissions and
230
+ limitations under the License.
231
+
232
+
233
+ ================================================================================
234
+
235
+
236
+ BSD 3-Clause License
237
+
238
+ Copyright 2023 MagicAnimate Team All rights reserved.
239
+
240
+ Redistribution and use in source and binary forms, with or without
241
+ modification, are permitted provided that the following conditions are met:
242
+
243
+ 1. Redistributions of source code must retain the above copyright notice, this
244
+ list of conditions and the following disclaimer.
245
+
246
+ 2. Redistributions in binary form must reproduce the above copyright notice,
247
+ this list of conditions and the following disclaimer in the documentation
248
+ and/or other materials provided with the distribution.
249
+
250
+ 3. Neither the name of the copyright holder nor the names of its
251
+ contributors may be used to endorse or promote products derived from
252
+ this software without specific prior written permission.
253
+
254
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
255
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
256
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
257
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
258
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
259
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
260
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
261
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
262
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
263
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
- title: Genwarp
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: gray
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
 
 
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: GenWarp
3
+ emoji: 🌃
4
+ colorFrom: purple
5
+ colorTo: blue
6
  sdk: gradio
7
  sdk_version: 4.42.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ preload_from_hub:
12
+ - Sony/genwarp
13
+ - stabilityai/sd-vae-ft-mse diffusion_pytorch_model.safetensors
14
+ - lambdalabs/sd-image-variations-diffusers image_encoder/pytorch_model.bin
15
  ---
 
 
app.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from subprocess import check_call
4
+ import tempfile
5
+
6
+ from os.path import basename, splitext, join
7
+ from io import BytesIO
8
+
9
+ import numpy as np
10
+ from scipy.spatial import KDTree
11
+ from PIL import Image
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torchvision.transforms.functional import to_tensor, to_pil_image
16
+ from einops import rearrange
17
+ import gradio as gr
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ from extern.ZoeDepth.zoedepth.utils.misc import colorize
21
+
22
+ from gradio_model3dgscamera import Model3DGSCamera
23
+
24
+ IMAGE_SIZE = 512
25
+ NEAR, FAR = 0.01, 100
26
+ FOVY = np.deg2rad(55)
27
+
28
+ def download_models():
29
+ models = [
30
+ {
31
+ 'repo': 'stabilityai/sd-vae-ft-mse',
32
+ 'sub': None,
33
+ 'dst': 'checkpoints/sd-vae-ft-mse',
34
+ 'files': ['config.json', 'diffusion_pytorch_model.safetensors'],
35
+ 'token': None
36
+ },
37
+ {
38
+ 'repo': 'lambdalabs/sd-image-variations-diffusers',
39
+ 'sub': 'image_encoder',
40
+ 'dst': 'checkpoints',
41
+ 'files': ['config.json', 'pytorch_model.bin'],
42
+ 'token': None
43
+ },
44
+ {
45
+ 'repo': 'Sony/genwarp',
46
+ 'sub': 'multi1',
47
+ 'dst': 'checkpoints',
48
+ 'files': ['config.json', 'denoising_unet.pth', 'pose_guider.pth', 'reference_unet.pth'],
49
+ 'token': None
50
+ }
51
+ ]
52
+
53
+ for model in models:
54
+ for file in model['files']:
55
+ hf_hub_download(
56
+ repo_id=model['repo'],
57
+ subfolder=model['sub'],
58
+ filename=file,
59
+ local_dir=model['dst'],
60
+ token=model['token']
61
+ )
62
+
63
+ # Crop the image to the shorter side.
64
+ def crop(img: Image) -> Image:
65
+ W, H = img.size
66
+ if W < H:
67
+ left, right = 0, W
68
+ top, bottom = np.ceil((H - W) / 2.), np.floor((H - W) / 2.) + W
69
+ else:
70
+ left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
71
+ top, bottom = 0, H
72
+ return img.crop((left, top, right, bottom))
73
+
74
+ def unproject(depth):
75
+ fovy_deg = 55
76
+ H, W = depth.shape[2:4]
77
+
78
+ mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
79
+
80
+ viewport_mtx = get_viewport_matrix(
81
+ IMAGE_SIZE, IMAGE_SIZE,
82
+ batch_size=1
83
+ ).to(depth)
84
+
85
+ # Projection matrix.
86
+ fovy = torch.ones(1) * FOVY
87
+ proj_mtx = get_projection_matrix(
88
+ fovy=fovy,
89
+ aspect_wh=1.,
90
+ near=NEAR,
91
+ far=FAR
92
+ ).to(depth)
93
+
94
+ view_mtx = camera_lookat(
95
+ torch.tensor([[0., 0., 0.]]),
96
+ torch.tensor([[0., 0., 1.]]),
97
+ torch.tensor([[0., -1., 0.]])
98
+ ).to(depth)
99
+
100
+ scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
101
+
102
+ grid = torch.stack(torch.meshgrid(
103
+ torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
104
+ ).to(depth)[None] # BHW2
105
+
106
+ screen = F.pad(grid, (0, 1), 'constant', 0)
107
+ screen = F.pad(screen, (0, 1), 'constant', 1)
108
+ screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
109
+
110
+ eye = screen_flat @ torch.linalg.inv_ex(
111
+ scr_mtx.float()
112
+ )[0].mT.to(depth)
113
+ eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
114
+ eye[..., 3] = 1
115
+
116
+ points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
117
+ points = points[0, :, :3]
118
+
119
+ # Translate to the origin.
120
+ points[..., 2] -= mean_depth
121
+ camera_pos = (0, 0, -mean_depth)
122
+ view_mtx = camera_lookat(
123
+ torch.tensor([[0., 0., -mean_depth]]),
124
+ torch.tensor([[0., 0., 0.]]),
125
+ torch.tensor([[0., -1., 0.]])
126
+ ).to(depth)
127
+
128
+ return points, camera_pos, view_mtx, proj_mtx
129
+
130
+ def calc_dist2(points: np.ndarray):
131
+ dists, _ = KDTree(points).query(points, k=4)
132
+ mean_dists = (dists[:, 1:] ** 2).mean(1)
133
+ return mean_dists
134
+
135
+ def save_as_splat(
136
+ filepath: str,
137
+ xyz: np.ndarray,
138
+ rgb: np.ndarray
139
+ ):
140
+ # To gaussian splat
141
+ inv_sigmoid = lambda x: np.log(x / (1 - x))
142
+ dist2 = np.clip(calc_dist2(xyz), a_min=0.0000001, a_max=None)
143
+ scales = np.repeat(np.log(np.sqrt(dist2))[..., np.newaxis], 3, axis=1)
144
+ rots = np.zeros((xyz.shape[0], 4))
145
+ rots[:, 0] = 1
146
+ opacities = inv_sigmoid(0.1 * np.ones((xyz.shape[0], 1)))
147
+
148
+ sorted_indices = np.argsort((
149
+ -np.exp(np.sum(scales, axis=-1, keepdims=True))
150
+ / (1 + np.exp(-opacities))
151
+ ).squeeze())
152
+
153
+ buffer = BytesIO()
154
+ for idx in sorted_indices:
155
+ position = xyz[idx]
156
+ scale = np.exp(scales[idx]).astype(np.float32)
157
+ rot = rots[idx].astype(np.float32)
158
+ color = np.concatenate(
159
+ (rgb[idx], 1 / (1 + np.exp(-opacities[idx]))),
160
+ axis=-1
161
+ )
162
+ buffer.write(position.tobytes())
163
+ buffer.write(scale.tobytes())
164
+ buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
165
+ buffer.write(
166
+ ((rot / np.linalg.norm(rot)) * 128 + 128)
167
+ .clip(0, 255)
168
+ .astype(np.uint8)
169
+ .tobytes()
170
+ )
171
+
172
+ with open(filepath, "wb") as f:
173
+ f.write(buffer.getvalue())
174
+
175
+ def view_from_rt(position, rotation):
176
+ t = np.array(position)
177
+ euler = np.array(rotation)
178
+
179
+ cx = np.cos(euler[0])
180
+ sx = np.sin(euler[0])
181
+ cy = np.cos(euler[1])
182
+ sy = np.sin(euler[1])
183
+ cz = np.cos(euler[2])
184
+ sz = np.sin(euler[2])
185
+ R = np.array([
186
+ cy * cz + sy * sx * sz,
187
+ -cy * sz + sy * sx * cz,
188
+ sy * cx,
189
+ cx * sz,
190
+ cx * cz,
191
+ -sx,
192
+ -sy * cz + cy * sx * sz,
193
+ sy * sz + cy * sx * cz,
194
+ cy * cx
195
+ ])
196
+ view_mtx = np.array([
197
+ [R[0], R[1], R[2], 0],
198
+ [R[3], R[4], R[5], 0],
199
+ [R[6], R[7], R[8], 0],
200
+ [
201
+ -t[0] * R[0] - t[1] * R[3] - t[2] * R[6],
202
+ -t[0] * R[1] - t[1] * R[4] - t[2] * R[7],
203
+ -t[0] * R[2] - t[1] * R[5] - t[2] * R[8],
204
+ 1
205
+ ]
206
+ ]).T
207
+
208
+ B = np.array([
209
+ [1, 0, 0, 0],
210
+ [0, -1, 0, 0],
211
+ [0, 0, -1, 0],
212
+ [0, 0, 0, 1]
213
+ ])
214
+ return B @ view_mtx
215
+
216
+
217
+ # Setup.
218
+ download_models()
219
+
220
+ mde = torch.hub.load(
221
+ './extern/ZoeDepth',
222
+ 'ZoeD_N',
223
+ source='local',
224
+ pretrained=True,
225
+ trust_repo=True
226
+ )
227
+
228
+ import spaces
229
+
230
+ check_call([
231
+ sys.executable, '-m', 'pip', 'install',
232
+ 'extern/splatting-0.0.1-py3-none-any.whl'
233
+ ])
234
+
235
+ from genwarp import GenWarp
236
+ from genwarp.ops import (
237
+ camera_lookat, get_projection_matrix, get_viewport_matrix
238
+ )
239
+
240
+ # GenWarp
241
+ genwarp_cfg = dict(
242
+ pretrained_model_path='checkpoints',
243
+ checkpoint_name='multi1',
244
+ half_precision_weights=True
245
+ )
246
+ genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu')
247
+
248
+
249
+ with tempfile.TemporaryDirectory() as tmpdir:
250
+ with gr.Blocks(
251
+ title='GenWarp Demo',
252
+ css='img {display: inline;}'
253
+ ) as demo:
254
+ # Internal states.
255
+ src_image = gr.State()
256
+ src_depth = gr.State()
257
+ proj_mtx = gr.State()
258
+ src_view_mtx = gr.State()
259
+
260
+ # Blocks.
261
+ gr.Markdown(
262
+ """
263
+ # GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping
264
+ [![Project Site](https://img.shields.io/badge/Project-Web-green)](https://genwarp-nvs.github.io/) &nbsp;
265
+ [![Spaces](https://img.shields.io/badge/Spaces-Demo-yellow?logo=huggingface)](https://huggingface.co/spaces/Sony/GenWarp) &nbsp;
266
+ [![Github](https://img.shields.io/badge/Github-Repo-orange?logo=github)](https://github.com/sony/genwarp/) &nbsp;
267
+ [![Models](https://img.shields.io/badge/Models-checkpoints-blue?logo=huggingface)](https://huggingface.co/Sony/genwarp) &nbsp;
268
+ [![arXiv](https://img.shields.io/badge/arXiv-2405.17251-red?logo=arxiv)](https://arxiv.org/abs/2405.17251)
269
+
270
+ ## Introduction
271
+ This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer the [paper](https://arxiv.org/abs/2405.17251).
272
+
273
+ ## How to Use
274
+ 1. Upload a reference image to "Reference Input"
275
+ - You can also select a image from "Examples"
276
+ 2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
277
+ 3. Hit "Generate a novel view" button and check the result
278
+
279
+ """
280
+ )
281
+ file = gr.File(label='Reference Input', file_types=['image'])
282
+ examples = gr.Examples(
283
+ examples=['./assets/pexels-heyho-5998120_19mm.jpg',
284
+ './assets/pexels-itsterrymag-12639296_24mm.jpg'],
285
+ inputs=file
286
+ )
287
+ with gr.Row():
288
+ image_widget = gr.Image(
289
+ label='Reference View', type='filepath',
290
+ interactive=False
291
+ )
292
+ depth_widget = gr.Image(label='Estimated Depth', type='pil')
293
+ viewer = Model3DGSCamera(
294
+ label = 'Unprojected 3DGS',
295
+ width=IMAGE_SIZE,
296
+ height=IMAGE_SIZE,
297
+ camera_width=IMAGE_SIZE,
298
+ camera_height=IMAGE_SIZE,
299
+ camera_fx=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2.,
300
+ camera_fy=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2.,
301
+ camera_near=NEAR,
302
+ camera_far=FAR
303
+ )
304
+ button = gr.Button('Generate a novel view', size='lg', variant='primary')
305
+ with gr.Row():
306
+ warped_widget = gr.Image(
307
+ label='Warped Image', type='pil', interactive=False
308
+ )
309
+ gen_widget = gr.Image(
310
+ label='Generated View', type='pil', interactive=False
311
+ )
312
+
313
+ # Callbacks
314
+ @spaces.GPU
315
+ def cb_mde(image_file: str):
316
+ image = to_tensor(crop(Image.open(
317
+ image_file
318
+ ).convert('RGB')).resize((IMAGE_SIZE, IMAGE_SIZE)))[None].cuda()
319
+ depth = mde.cuda().infer(image)
320
+ depth_image = to_pil_image(colorize(depth[0]))
321
+ return to_pil_image(image[0]), depth_image, image.cpu().detach(), depth.cpu().detach()
322
+
323
+ @spaces.GPU
324
+ def cb_3d(image, depth, image_file):
325
+ xyz, camera_pos, view_mtx, proj_mtx = unproject(depth.cuda())
326
+ rgb = rearrange(image, 'b c h w -> b (h w) c')[0]
327
+ splat_file = join(tmpdir, f'./{splitext(basename(image_file))[0]}.splat')
328
+ save_as_splat(splat_file, xyz.cpu().detach().numpy(), rgb.cpu().detach().numpy())
329
+ return (splat_file, camera_pos, None), view_mtx.cpu().detach(), proj_mtx.cpu().detach()
330
+
331
+ @spaces.GPU
332
+ def cb_generate(viewer, image, depth, src_view_mtx, proj_mtx):
333
+ image = image.cuda()
334
+ depth = depth.cuda()
335
+ src_view_mtx = src_view_mtx.cuda()
336
+ proj_mtx = proj_mtx.cuda()
337
+ src_camera_pos = viewer[1]
338
+ src_camera_rot = viewer[2]
339
+ tar_view_mtx = view_from_rt(src_camera_pos, src_camera_rot)
340
+ tar_view_mtx = torch.from_numpy(tar_view_mtx).to(image)
341
+ rel_view_mtx = (
342
+ tar_view_mtx @ torch.linalg.inv(src_view_mtx.to(image))
343
+ ).to(image)
344
+
345
+ # GenWarp.
346
+ renders = genwarp_nvs.to('cuda')(
347
+ src_image=image.half(),
348
+ src_depth=depth.half(),
349
+ rel_view_mtx=rel_view_mtx.half(),
350
+ src_proj_mtx=proj_mtx.half(),
351
+ tar_proj_mtx=proj_mtx.half()
352
+ )
353
+
354
+ warped = renders['warped']
355
+ synthesized = renders['synthesized']
356
+ warped_pil = to_pil_image(warped[0])
357
+ synthesized_pil = to_pil_image(synthesized[0])
358
+
359
+ return warped_pil, synthesized_pil
360
+
361
+ # Events
362
+ file.change(
363
+ fn=cb_mde,
364
+ inputs=file,
365
+ outputs=[image_widget, depth_widget, src_image, src_depth]
366
+ ).then(
367
+ fn=cb_3d,
368
+ inputs=[src_image, src_depth, image_widget],
369
+ outputs=[viewer, src_view_mtx, proj_mtx])
370
+ button.click(
371
+ fn=cb_generate,
372
+ inputs=[viewer, src_image, src_depth, src_view_mtx, proj_mtx],
373
+ outputs=[warped_widget, gen_widget])
374
+
375
+ if __name__ == '__main__':
376
+ demo.launch()
assets/NOTICE ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ Images are taken from Pexels
2
+ https://www.pexels.com/
3
+
4
+ pexels-itsterrymag-12639296_24mm.jpg
5
+ https://www.pexels.com/ja-jp/photo/12639296/
6
+
7
+ pexels-heyho-5998120_19mm.jpg
8
+ https://www.pexels.com/ja-jp/photo/5998120/
assets/pexels-heyho-5998120_19mm.jpg ADDED

Git LFS Details

  • SHA256: 0a1c15147a7a5ba9e77c9685658720884e7ae14e0663564f8121b70e39201bcf
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB
assets/pexels-itsterrymag-12639296_24mm.jpg ADDED

Git LFS Details

  • SHA256: 74a6d1e3e651f2ebe91f38f2863036a457631f1743270d13d6a7cc35546719ec
  • Pointer size: 131 Bytes
  • Size of remote file: 126 kB
extern/ZoeDepth ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit edb6daf45458569e24f50250ef1ed08c015f17a7
extern/splatting-0.0.1-py3-none-any.whl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26d488928a774f4677a0f6cdd9f2a2a63ee73502d90676f507444cc21ecd069d
3
+ size 5189840
genwarp/GenWarp.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join
2
+ from typing import Union, Optional, List, Dict, Tuple, Any
3
+ from dataclasses import dataclass
4
+ import inspect
5
+
6
+ from omegaconf import OmegaConf, DictConfig
7
+ from jaxtyping import Float
8
+
9
+ import torch
10
+ from torch import Tensor
11
+ import torch.nn.functional as F
12
+ from einops import rearrange, repeat
13
+
14
+ from diffusers import AutoencoderKL, DDIMScheduler
15
+ from diffusers.image_processor import VaeImageProcessor
16
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
17
+
18
+ from .models import (
19
+ PoseGuider,
20
+ UNet2DConditionModel,
21
+ UNet3DConditionModel,
22
+ ReferenceAttentionControl
23
+ )
24
+ from .ops import get_viewport_matrix, forward_warper
25
+
26
+ class GenWarp():
27
+ @dataclass
28
+ class Config():
29
+ pretrained_model_path: str = ''
30
+ checkpoint_name: str = ''
31
+ half_precision_weights: bool = False
32
+ height: int = 512
33
+ width: int = 512
34
+ num_inference_steps: int = 20
35
+ guidance_scale: float = 3.5
36
+
37
+ cfg: Config
38
+
39
+ class Embedder():
40
+ def __init__(self, **kwargs) -> None:
41
+ self.kwargs = kwargs
42
+ self.create_embedding_fn()
43
+
44
+ def create_embedding_fn(self) -> None:
45
+ embed_fns = []
46
+ d = self.kwargs['input_dims']
47
+ out_dim = 0
48
+ if self.kwargs['include_input']:
49
+ embed_fns.append(lambda x : x)
50
+ out_dim += d
51
+
52
+ max_freq = self.kwargs['max_freq_log2']
53
+ N_freqs = self.kwargs['num_freqs']
54
+
55
+ if self.kwargs['log_sampling']:
56
+ freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
57
+ else:
58
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
59
+
60
+ for freq in freq_bands:
61
+ for p_fn in self.kwargs['periodic_fns']:
62
+ embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
63
+ out_dim += d
64
+
65
+ self.embed_fns = embed_fns
66
+ self.out_dim = out_dim
67
+
68
+ def embed(self, inputs) -> Tensor:
69
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
70
+
71
+ def __init__(
72
+ self,
73
+ cfg: Optional[Union[dict, DictConfig]] = None,
74
+ device: Optional[str] = 'cpu'
75
+ ) -> None:
76
+ self.cfg = OmegaConf.structured(self.Config(**cfg))
77
+ self.model_path = join(
78
+ self.cfg.pretrained_model_path, self.cfg.checkpoint_name
79
+ )
80
+ self.device = device
81
+ self.configure()
82
+
83
+ def configure(self) -> None:
84
+ print(f"Loading GenWarp...")
85
+
86
+ # Configurations.
87
+ self.dtype = (
88
+ torch.float16 if self.cfg.half_precision_weights else torch.float32
89
+ )
90
+ self.viewport_mtx: Float[Tensor, 'B 4 4'] = get_viewport_matrix(
91
+ self.cfg.width, self.cfg.height,
92
+ batch_size=1, device=self.device
93
+ ).to(self.dtype)
94
+
95
+ # Load models.
96
+ self.load_models()
97
+
98
+ # Timestep
99
+ self.scheduler.set_timesteps(
100
+ self.cfg.num_inference_steps, device=self.device)
101
+ self.num_train_timesteps = self.scheduler.config.num_train_timesteps
102
+
103
+ print(f"Loaded GenWarp.")
104
+
105
+ def load_models(self) -> None:
106
+ # VAE.
107
+ self.vae = AutoencoderKL.from_pretrained(
108
+ join(self.cfg.pretrained_model_path, 'sd-vae-ft-mse')
109
+ ).to(self.device, dtype=self.dtype)
110
+
111
+ # Image processor.
112
+ self.vae_scale_factor = \
113
+ 2 ** (len(self.vae.config.block_out_channels) - 1)
114
+ self.vae_image_processor = VaeImageProcessor(
115
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
116
+ )
117
+ self.clip_image_processor = CLIPImageProcessor()
118
+
119
+ # Image encoder.
120
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
121
+ join(self.cfg.pretrained_model_path, 'image_encoder')
122
+ ).to(self.device, dtype=self.dtype)
123
+
124
+ # Reference Unet.
125
+ self.reference_unet = UNet2DConditionModel.from_config(
126
+ UNet2DConditionModel.load_config(
127
+ join(self.model_path, 'config.json')
128
+ )).to(self.device, dtype=self.dtype)
129
+ self.reference_unet.load_state_dict(torch.load(
130
+ join(self.model_path, 'reference_unet.pth'),
131
+ map_location='cpu'),
132
+ )
133
+
134
+ # Denoising Unet.
135
+ self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
136
+ join(self.model_path, 'config.json'),
137
+ join(self.model_path, 'denoising_unet.pth')
138
+ ).to(self.device, dtype=self.dtype)
139
+ self.unet_in_channels = self.denoising_unet.config.in_channels
140
+
141
+ # Pose guider.
142
+ self.pose_guider = PoseGuider(
143
+ conditioning_embedding_channels=320,
144
+ conditioning_channels=11,
145
+ ).to(self.device, dtype=self.dtype)
146
+ self.pose_guider.load_state_dict(torch.load(
147
+ join(self.model_path, 'pose_guider.pth'),
148
+ map_location='cpu'),
149
+ )
150
+
151
+ # Noise scheduler
152
+ sched_kwargs = OmegaConf.to_container(OmegaConf.create({
153
+ 'num_train_timesteps': 1000,
154
+ 'beta_start': 0.00085,
155
+ 'beta_end': 0.012,
156
+ 'beta_schedule': 'scaled_linear',
157
+ 'steps_offset': 1,
158
+ 'clip_sample': False
159
+ }))
160
+ sched_kwargs.update(
161
+ rescale_betas_zero_snr=True,
162
+ timestep_spacing='trailing',
163
+ prediction_type='v_prediction',
164
+ )
165
+ self.scheduler = DDIMScheduler(**sched_kwargs)
166
+
167
+ self.vae.requires_grad_(False)
168
+ self.image_encoder.requires_grad_(False)
169
+ self.reference_unet.requires_grad_(False)
170
+ self.denoising_unet.requires_grad_(False)
171
+ self.pose_guider.requires_grad_(False)
172
+
173
+ # Coordinates embedding.
174
+ self.embedder = self.get_embedder(2)
175
+
176
+ def to(self, device: str):
177
+ self.device = device
178
+ self.viewport_mtx = self.viewport_mtx.to(device)
179
+ self.vae = self.vae.to(device)
180
+ self.image_encoder = self.image_encoder.to(device)
181
+ self.reference_unet = self.reference_unet.to(device)
182
+ self.denoising_unet = self.denoising_unet.to(device)
183
+ self.pose_guider = self.pose_guider.to(device)
184
+
185
+ return self
186
+
187
+ def get_embedder(self, multires):
188
+ embed_kwargs = {
189
+ 'include_input' : True,
190
+ 'input_dims' : 2,
191
+ 'max_freq_log2' : multires-1,
192
+ 'num_freqs' : multires,
193
+ 'log_sampling' : True,
194
+ 'periodic_fns' : [torch.sin, torch.cos],
195
+ }
196
+
197
+ embedder_obj = self.Embedder(**embed_kwargs)
198
+ embed = lambda x, eo=embedder_obj : eo.embed(x)
199
+ return embed
200
+
201
+ def __call__(
202
+ self,
203
+ src_image: Float[Tensor, 'B C H W'],
204
+ src_depth: Float[Tensor, 'B C H W'],
205
+ rel_view_mtx: Float[Tensor, 'B 4 4'],
206
+ src_proj_mtx: Float[Tensor, 'B 4 4'],
207
+ tar_proj_mtx: Float[Tensor, 'B 4 4'],
208
+ ) -> Dict[str, Tensor]:
209
+ """ Perform NVS.
210
+ """
211
+ batch_size = src_image.shape[0]
212
+
213
+ # Rearrange and resize.
214
+ src_image = self.preprocess_image(src_image)
215
+ src_depth = self.preprocess_image(src_depth)
216
+ viewport_mtx = repeat(
217
+ self.viewport_mtx, 'b h w -> (repeat b) h w',
218
+ repeat=batch_size)
219
+
220
+ pipe_args = dict(
221
+ src_image=src_image,
222
+ src_depth=src_depth,
223
+ rel_view_mtx=rel_view_mtx,
224
+ src_proj_mtx=src_proj_mtx,
225
+ tar_proj_mtx=tar_proj_mtx,
226
+ viewport_mtx=viewport_mtx
227
+ )
228
+
229
+ # Prepare inputs.
230
+ conditions, renders = self.prepare_conditions(**pipe_args)
231
+
232
+ # NVS.
233
+ latents_clean = self.perform_nvs(
234
+ **pipe_args,
235
+ **conditions
236
+ )
237
+
238
+ # Decode to images.
239
+ synthesized = self.decode_latents(latents_clean)
240
+
241
+ inference_out = {
242
+ 'synthesized': synthesized,
243
+ 'warped': renders['warped'],
244
+ 'mask': renders['mask'],
245
+ 'correspondence': conditions['correspondence']
246
+ }
247
+
248
+ return inference_out
249
+
250
+ def preprocess_image(
251
+ self,
252
+ image: Float[Tensor, 'B C H W']
253
+ ) -> Float[Tensor, 'B C H W']:
254
+ image = F.interpolate(
255
+ image, (self.cfg.height, self.cfg.width)
256
+ )
257
+ return image
258
+
259
+ def get_image_prompt(
260
+ self,
261
+ src_image: Float[Tensor, 'B C H W']
262
+ ) -> Float[Tensor, '2 B L']:
263
+ ref_image_for_clip = self.vae_image_processor.preprocess(
264
+ src_image, height=224, width=224
265
+ )
266
+ ref_image_for_clip = ref_image_for_clip * 0.5 + 0.5
267
+
268
+ clip_image = self.clip_image_processor.preprocess(
269
+ ref_image_for_clip, return_tensors='pt'
270
+ ).pixel_values
271
+
272
+ clip_image_embeds = self.image_encoder(
273
+ clip_image.to(self.device, dtype=self.image_encoder.dtype)
274
+ ).image_embeds
275
+
276
+ image_prompt_embeds = clip_image_embeds.unsqueeze(1)
277
+ uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
278
+
279
+ image_prompt_embeds = torch.cat(
280
+ [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
281
+ )
282
+
283
+ return image_prompt_embeds
284
+
285
+ def encode_images(
286
+ self,
287
+ rgb: Float[Tensor, 'B C H W']
288
+ ) -> Float[Tensor, 'B C H W']:
289
+ rgb = self.vae_image_processor.preprocess(rgb)
290
+ latents = self.vae.encode(rgb).latent_dist.mean
291
+ latents = latents * 0.18215
292
+ return latents
293
+
294
+ def decode_latents(
295
+ self,
296
+ latents: Float[Tensor, 'B C H W']
297
+ ) -> Float[Tensor, 'B C H W']:
298
+ latents = 1 / 0.18215 * latents
299
+ rgb = []
300
+ for frame_idx in range(latents.shape[0]):
301
+ rgb.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
302
+ rgb = torch.cat(rgb)
303
+ rgb = (rgb / 2 + 0.5).clamp(0, 1)
304
+ return rgb.squeeze(2)
305
+
306
+ def get_reference_controls(
307
+ self,
308
+ batch_size: int
309
+ ) -> Tuple[ReferenceAttentionControl, ReferenceAttentionControl]:
310
+ reader = ReferenceAttentionControl(
311
+ self.denoising_unet,
312
+ do_classifier_free_guidance=True,
313
+ mode='read',
314
+ batch_size=batch_size,
315
+ fusion_blocks='full',
316
+ feature_fusion_type='attention_full_sharing'
317
+ )
318
+ writer = ReferenceAttentionControl(
319
+ self.reference_unet,
320
+ do_classifier_free_guidance=True,
321
+ mode='write',
322
+ batch_size=batch_size,
323
+ fusion_blocks='full',
324
+ feature_fusion_type='attention_full_sharing'
325
+ )
326
+
327
+ return reader, writer
328
+
329
+ def prepare_extra_step_kwargs(
330
+ self,
331
+ generator,
332
+ eta
333
+ ) -> Dict[str, Any]:
334
+ accepts_eta = 'eta' in set(
335
+ inspect.signature(self.scheduler.step).parameters.keys()
336
+ )
337
+ extra_step_kwargs = {}
338
+ if accepts_eta:
339
+ extra_step_kwargs['eta'] = eta
340
+
341
+ # check if the scheduler accepts generator
342
+ accepts_generator = 'generator' in set(
343
+ inspect.signature(self.scheduler.step).parameters.keys()
344
+ )
345
+ if accepts_generator:
346
+ extra_step_kwargs['generator'] = generator
347
+ return extra_step_kwargs
348
+
349
+ def get_pose_features(
350
+ self,
351
+ src_embed: Float[Tensor, 'B C H W'],
352
+ trg_embed: Float[Tensor, 'B C H W'],
353
+ do_classifier_guidance: bool = True
354
+ ) -> Tuple[Tensor, Tensor]:
355
+ pose_cond_tensor = src_embed.unsqueeze(2)
356
+ pose_cond_tensor = pose_cond_tensor.to(
357
+ device=self.device, dtype=self.pose_guider.dtype
358
+ )
359
+ pose_cond_tensor_2 = trg_embed.unsqueeze(2)
360
+ pose_cond_tensor_2 = pose_cond_tensor_2.to(
361
+ device=self.device, dtype=self.pose_guider.dtype
362
+ )
363
+ pose_fea = self.pose_guider(pose_cond_tensor)
364
+ pose_fea_2 = self.pose_guider(pose_cond_tensor_2)
365
+
366
+ if do_classifier_guidance:
367
+ pose_fea = torch.cat([pose_fea] * 2)
368
+ pose_fea_2 = torch.cat([pose_fea_2] * 2)
369
+
370
+ return pose_fea, pose_fea_2
371
+
372
+ @torch.no_grad()
373
+ def prepare_conditions(
374
+ self,
375
+ src_image: Float[Tensor, 'B C H W'],
376
+ src_depth: Float[Tensor, 'B C H W'],
377
+ rel_view_mtx: Float[Tensor, 'B 4 4'],
378
+ src_proj_mtx: Float[Tensor, 'B 4 4'],
379
+ tar_proj_mtx: Float[Tensor, 'B 4 4'],
380
+ viewport_mtx: Float[Tensor, 'B 4 4']
381
+ ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
382
+ # Prepare inputs.
383
+ B = src_image.shape[0]
384
+ H, W = src_image.shape[2:4]
385
+ src_scr_mtx = (viewport_mtx @ src_proj_mtx).to(src_proj_mtx)
386
+ mvp_mtx = (tar_proj_mtx @ rel_view_mtx).to(rel_view_mtx)
387
+
388
+ # Coordinate grids.
389
+ grid: Float[Tensor, 'H W C'] = torch.stack(torch.meshgrid(
390
+ torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
391
+ ).to(self.device, dtype=self.dtype)
392
+
393
+ # Unproject depth.
394
+ screen = F.pad(grid, (0, 1), 'constant', 0) # z=0 (z doesn't matter)
395
+ screen = F.pad(screen, (0, 1), 'constant', 1) # w=1
396
+ screen = repeat(screen, 'h w c -> b h w c', b=B)
397
+ screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
398
+ # To eye coordinates.
399
+ eye = screen_flat @ torch.linalg.inv_ex(
400
+ src_scr_mtx.float()
401
+ )[0].mT.to(self.dtype)
402
+ # Overwrite depth.
403
+ eye = eye * rearrange(src_depth, 'b c h w -> b (h w) c')
404
+ eye[..., 3] = 1
405
+
406
+ # Coordinates embedding.
407
+ coords = torch.stack((grid[..., 0]/H, grid[..., 1]/W), dim=-1)
408
+ embed = repeat(self.embedder(coords), 'h w c -> b c h w', b=B)
409
+
410
+ # Warping.
411
+ input_image: Float[Tensor, 'B C H W'] = torch.cat(
412
+ [embed, src_image], dim=1
413
+ )
414
+ output_image = forward_warper(
415
+ input_image, screen_flat[..., :2], eye,
416
+ mvp_mtx=mvp_mtx, viewport_mtx=viewport_mtx
417
+ )
418
+ warped_embed = output_image['warped'][:, :embed.shape[1]]
419
+ warped_image = output_image['warped'][:, embed.shape[1]:]
420
+ # mask == 1 where there's no pixel
421
+ mask = output_image['mask']
422
+ correspondence = output_image['correspondence']
423
+
424
+ # Conditions.
425
+ src_coords_embed = torch.cat(
426
+ [embed, torch.zeros_like(mask, device=mask.device)], dim=1)
427
+ trg_coords_embed = torch.cat([warped_embed, mask], dim=1)
428
+
429
+ # Outputs.
430
+ conditions = dict(
431
+ src_coords_embed=src_coords_embed,
432
+ trg_coords_embed=trg_coords_embed,
433
+ correspondence=correspondence
434
+ )
435
+
436
+ renders = dict(
437
+ warped=warped_image,
438
+ mask=1 - mask # mask == 1 where there's a pixel
439
+ )
440
+
441
+ return conditions, renders
442
+
443
+ def perform_nvs(
444
+ self,
445
+ src_image,
446
+ src_coords_embed,
447
+ trg_coords_embed,
448
+ correspondence,
449
+ eta: float=0.0,
450
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]]=None,
451
+ **kwargs,
452
+ ) -> Float[Tensor, 'B C H W']:
453
+ batch_size = src_image.shape[0]
454
+
455
+ # For the cross attention.
456
+ reference_control_reader, reference_control_writer = \
457
+ self.get_reference_controls(batch_size)
458
+
459
+ # Prepare extra step kwargs.
460
+ extra_step_kwargs = self.prepare_extra_step_kwargs(
461
+ generator, eta
462
+ )
463
+
464
+ with torch.no_grad():
465
+ # Create fake inputs. It'll be replaced by pure noise.
466
+ latents = torch.randn(
467
+ batch_size,
468
+ self.unet_in_channels,
469
+ self.cfg.height // self.vae_scale_factor,
470
+ self.cfg.width // self.vae_scale_factor
471
+ ).to(self.device, dtype=src_image.dtype)
472
+ initial_t = torch.tensor(
473
+ [self.num_train_timesteps - 1] * batch_size
474
+ ).to(self.device, dtype=torch.long)
475
+
476
+ # Add noise.
477
+ noise = torch.randn_like(latents)
478
+ latents_noisy_start = self.scheduler.add_noise(
479
+ latents, noise, initial_t
480
+ )
481
+ latents_noisy_start = latents_noisy_start.unsqueeze(2)
482
+
483
+ # Prepare clip image embeds.
484
+ image_prompt_embeds = self.get_image_prompt(src_image)
485
+
486
+ # Prepare ref image latents.
487
+ ref_image_latents = self.encode_images(src_image)
488
+
489
+ # Prepare pose condition image.
490
+ pose_fea, pose_fea_2 = self.get_pose_features(
491
+ src_coords_embed, trg_coords_embed
492
+ )
493
+
494
+ pose_fea = pose_fea[:, :, 0, ...]
495
+
496
+ # Forward reference images.
497
+ self.reference_unet(
498
+ ref_image_latents.repeat(2, 1, 1, 1),
499
+ torch.zeros(batch_size * 2).to(ref_image_latents),
500
+ encoder_hidden_states=image_prompt_embeds,
501
+ pose_cond_fea=pose_fea,
502
+ return_dict=False,
503
+ )
504
+ # Update the denosing net with reference features.
505
+ reference_control_reader.update(
506
+ reference_control_writer,
507
+ correspondence=correspondence
508
+ )
509
+
510
+ timesteps = self.scheduler.timesteps
511
+ latents_noisy = latents_noisy_start
512
+ for t in timesteps:
513
+ # Prepare latents.
514
+ latent_model_input = torch.cat([latents_noisy] * 2)
515
+ latent_model_input = self.scheduler.scale_model_input(
516
+ latent_model_input, t
517
+ )
518
+
519
+ # Denoise.
520
+ noise_pred = self.denoising_unet(
521
+ latent_model_input,
522
+ t,
523
+ encoder_hidden_states=image_prompt_embeds,
524
+ pose_cond_fea=pose_fea_2,
525
+ return_dict=False,
526
+ )[0]
527
+
528
+ # CFG.
529
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
530
+ noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
531
+ noise_pred_text - noise_pred_uncond
532
+ )
533
+
534
+ # t -> t-1
535
+ latents_noisy = self.scheduler.step(
536
+ noise_pred, t, latents_noisy, **extra_step_kwargs,
537
+ return_dict=False
538
+ )[0]
539
+
540
+ # Noise disappears eventually
541
+ latents_clean = latents_noisy
542
+
543
+ reference_control_reader.clear()
544
+ reference_control_writer.clear()
545
+
546
+ return latents_clean.squeeze(2)
genwarp/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .GenWarp import GenWarp
genwarp/models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .pose_guider import PoseGuider
2
+ from .unet_2d_condition import UNet2DConditionModel
3
+ from .unet_3d import UNet3DConditionModel
4
+ from .mutual_self_attention import ReferenceAttentionControl
genwarp/models/attention.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # Diffusers
9
+ # Apache License, Version 2.0
10
+ # Copyright (c) Hugging Face Inc.
11
+ # https://github.com/huggingface/diffusers
12
+ # ==============================================================================
13
+
14
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
20
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
21
+ from einops import rearrange
22
+ from torch import nn
23
+
24
+
25
+ class BasicTransformerBlock(nn.Module):
26
+ r"""
27
+ A basic Transformer block.
28
+
29
+ Parameters:
30
+ dim (`int`): The number of channels in the input and output.
31
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
32
+ attention_head_dim (`int`): The number of channels in each head.
33
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
34
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
35
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
36
+ num_embeds_ada_norm (:
37
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
38
+ attention_bias (:
39
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
40
+ only_cross_attention (`bool`, *optional*):
41
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
42
+ double_self_attention (`bool`, *optional*):
43
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
44
+ upcast_attention (`bool`, *optional*):
45
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
46
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
47
+ Whether to use learnable elementwise affine parameters for normalization.
48
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
49
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
50
+ final_dropout (`bool` *optional*, defaults to False):
51
+ Whether to apply a final dropout after the last feed-forward layer.
52
+ attention_type (`str`, *optional*, defaults to `"default"`):
53
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
54
+ positional_embeddings (`str`, *optional*, defaults to `None`):
55
+ The type of positional embeddings to apply to.
56
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
57
+ The maximum number of positional embeddings to apply.
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ dim: int,
63
+ num_attention_heads: int,
64
+ attention_head_dim: int,
65
+ dropout=0.0,
66
+ cross_attention_dim: Optional[int] = None,
67
+ activation_fn: str = "geglu",
68
+ num_embeds_ada_norm: Optional[int] = None,
69
+ attention_bias: bool = False,
70
+ only_cross_attention: bool = False,
71
+ double_self_attention: bool = False,
72
+ upcast_attention: bool = False,
73
+ norm_elementwise_affine: bool = True,
74
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
75
+ norm_eps: float = 1e-5,
76
+ final_dropout: bool = False,
77
+ attention_type: str = "default",
78
+ positional_embeddings: Optional[str] = None,
79
+ num_positional_embeddings: Optional[int] = None,
80
+ ):
81
+ super().__init__()
82
+ self.only_cross_attention = only_cross_attention
83
+
84
+ self.use_ada_layer_norm_zero = (
85
+ num_embeds_ada_norm is not None
86
+ ) and norm_type == "ada_norm_zero"
87
+ self.use_ada_layer_norm = (
88
+ num_embeds_ada_norm is not None
89
+ ) and norm_type == "ada_norm"
90
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
91
+ self.use_layer_norm = norm_type == "layer_norm"
92
+
93
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
94
+ raise ValueError(
95
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
96
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
97
+ )
98
+
99
+ if positional_embeddings and (num_positional_embeddings is None):
100
+ raise ValueError(
101
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
102
+ )
103
+
104
+ if positional_embeddings == "sinusoidal":
105
+ self.pos_embed = SinusoidalPositionalEmbedding(
106
+ dim, max_seq_length=num_positional_embeddings
107
+ )
108
+ else:
109
+ self.pos_embed = None
110
+
111
+ # Define 3 blocks. Each block has its own normalization layer.
112
+ # 1. Self-Attn
113
+ if self.use_ada_layer_norm:
114
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
115
+ elif self.use_ada_layer_norm_zero:
116
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
117
+ else:
118
+ self.norm1 = nn.LayerNorm(
119
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
120
+ )
121
+
122
+ self.attn1 = Attention(
123
+ query_dim=dim,
124
+ heads=num_attention_heads,
125
+ dim_head=attention_head_dim,
126
+ dropout=dropout,
127
+ bias=attention_bias,
128
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
129
+ upcast_attention=upcast_attention,
130
+ )
131
+
132
+ # 2. Cross-Attn
133
+ if cross_attention_dim is not None or double_self_attention:
134
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
135
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
136
+ # the second cross attention block.
137
+ self.norm2 = (
138
+ AdaLayerNorm(dim, num_embeds_ada_norm)
139
+ if self.use_ada_layer_norm
140
+ else nn.LayerNorm(
141
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
142
+ )
143
+ )
144
+ self.attn2 = Attention(
145
+ query_dim=dim,
146
+ cross_attention_dim=cross_attention_dim
147
+ if not double_self_attention
148
+ else None,
149
+ heads=num_attention_heads,
150
+ dim_head=attention_head_dim,
151
+ dropout=dropout,
152
+ bias=attention_bias,
153
+ upcast_attention=upcast_attention,
154
+ ) # is self-attn if encoder_hidden_states is none
155
+ else:
156
+ self.norm2 = None
157
+ self.attn2 = None
158
+
159
+ # 3. Feed-forward
160
+ if not self.use_ada_layer_norm_single:
161
+ self.norm3 = nn.LayerNorm(
162
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
163
+ )
164
+
165
+ self.ff = FeedForward(
166
+ dim,
167
+ dropout=dropout,
168
+ activation_fn=activation_fn,
169
+ final_dropout=final_dropout,
170
+ )
171
+
172
+ # 4. Fuser
173
+ if attention_type == "gated" or attention_type == "gated-text-image":
174
+ self.fuser = GatedSelfAttentionDense(
175
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
176
+ )
177
+
178
+ # 5. Scale-shift for PixArt-Alpha.
179
+ if self.use_ada_layer_norm_single:
180
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
181
+
182
+ # let chunk size default to None
183
+ self._chunk_size = None
184
+ self._chunk_dim = 0
185
+
186
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
187
+ # Sets chunk feed-forward
188
+ self._chunk_size = chunk_size
189
+ self._chunk_dim = dim
190
+
191
+ def forward(
192
+ self,
193
+ hidden_states: torch.FloatTensor,
194
+ attention_mask: Optional[torch.FloatTensor] = None,
195
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
196
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
197
+ timestep: Optional[torch.LongTensor] = None,
198
+ cross_attention_kwargs: Dict[str, Any] = None,
199
+ class_labels: Optional[torch.LongTensor] = None,
200
+ ) -> torch.FloatTensor:
201
+ # Notice that normalization is always applied before the real computation in the following blocks.
202
+ # 0. Self-Attention
203
+ batch_size = hidden_states.shape[0]
204
+
205
+ if self.use_ada_layer_norm:
206
+ norm_hidden_states = self.norm1(hidden_states, timestep)
207
+ elif self.use_ada_layer_norm_zero:
208
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
209
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
210
+ )
211
+ elif self.use_layer_norm:
212
+ norm_hidden_states = self.norm1(hidden_states)
213
+ elif self.use_ada_layer_norm_single:
214
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
215
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
216
+ ).chunk(6, dim=1)
217
+ norm_hidden_states = self.norm1(hidden_states)
218
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
219
+ norm_hidden_states = norm_hidden_states.squeeze(1)
220
+ else:
221
+ raise ValueError("Incorrect norm used")
222
+
223
+ if self.pos_embed is not None:
224
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
225
+
226
+ # 1. Retrieve lora scale.
227
+ lora_scale = (
228
+ cross_attention_kwargs.get("scale", 1.0)
229
+ if cross_attention_kwargs is not None
230
+ else 1.0
231
+ )
232
+
233
+ # 2. Prepare GLIGEN inputs
234
+ cross_attention_kwargs = (
235
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
236
+ )
237
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
238
+
239
+ attn_output = self.attn1(
240
+ norm_hidden_states,
241
+ encoder_hidden_states=encoder_hidden_states
242
+ if self.only_cross_attention
243
+ else None,
244
+ attention_mask=attention_mask,
245
+ **cross_attention_kwargs,
246
+ )
247
+ if self.use_ada_layer_norm_zero:
248
+ attn_output = gate_msa.unsqueeze(1) * attn_output
249
+ elif self.use_ada_layer_norm_single:
250
+ attn_output = gate_msa * attn_output
251
+
252
+ hidden_states = attn_output + hidden_states
253
+ if hidden_states.ndim == 4:
254
+ hidden_states = hidden_states.squeeze(1)
255
+
256
+ # 2.5 GLIGEN Control
257
+ if gligen_kwargs is not None:
258
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
259
+
260
+ # 3. Cross-Attention
261
+ if self.attn2 is not None:
262
+ if self.use_ada_layer_norm:
263
+ norm_hidden_states = self.norm2(hidden_states, timestep)
264
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
265
+ norm_hidden_states = self.norm2(hidden_states)
266
+ elif self.use_ada_layer_norm_single:
267
+ # For PixArt norm2 isn't applied here:
268
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
269
+ norm_hidden_states = hidden_states
270
+ else:
271
+ raise ValueError("Incorrect norm")
272
+
273
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
274
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
275
+
276
+ attn_output = self.attn2(
277
+ norm_hidden_states,
278
+ encoder_hidden_states=encoder_hidden_states,
279
+ attention_mask=encoder_attention_mask,
280
+ **cross_attention_kwargs,
281
+ )
282
+ hidden_states = attn_output + hidden_states
283
+
284
+ # 4. Feed-forward
285
+ if not self.use_ada_layer_norm_single:
286
+ norm_hidden_states = self.norm3(hidden_states)
287
+
288
+ if self.use_ada_layer_norm_zero:
289
+ norm_hidden_states = (
290
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
291
+ )
292
+
293
+ if self.use_ada_layer_norm_single:
294
+ norm_hidden_states = self.norm2(hidden_states)
295
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
296
+
297
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
298
+
299
+ if self.use_ada_layer_norm_zero:
300
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
301
+ elif self.use_ada_layer_norm_single:
302
+ ff_output = gate_mlp * ff_output
303
+
304
+ hidden_states = ff_output + hidden_states
305
+ if hidden_states.ndim == 4:
306
+ hidden_states = hidden_states.squeeze(1)
307
+
308
+ return hidden_states
309
+
310
+
311
+ class WarpedFeatureInjector(nn.Module):
312
+ def __init__(self, dim: int):
313
+ super().__init__()
314
+
315
+ self.dim = dim
316
+ # Additional convolutional layers
317
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False)
318
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False)
319
+ self.conv3 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False)
320
+ # Initialize convolutional layers
321
+ nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
322
+ nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
323
+ nn.init.kaiming_normal_(self.conv3.weight, mode='fan_out', nonlinearity='relu')
324
+
325
+ # Zero convolution
326
+ self.out_conv = nn.Conv2d(
327
+ dim, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False
328
+ )
329
+ nn.init.zeros_(self.out_conv.weight.data)
330
+ def forward(self, x):
331
+ # Apply convolutional layers
332
+ x = self.conv1(x)
333
+ x = self.conv2(x)
334
+ x = self.conv3(x)
335
+
336
+ # Apply zero convolution
337
+ x = self.out_conv(x)
338
+
339
+ return x
340
+
341
+
342
+
343
+ class TemporalBasicTransformerBlock(nn.Module):
344
+ def __init__(
345
+ self,
346
+ dim: int,
347
+ num_attention_heads: int,
348
+ attention_head_dim: int,
349
+ dropout=0.0,
350
+ cross_attention_dim: Optional[int] = None,
351
+ activation_fn: str = "geglu",
352
+ num_embeds_ada_norm: Optional[int] = None,
353
+ attention_bias: bool = False,
354
+ only_cross_attention: bool = False,
355
+ upcast_attention: bool = False,
356
+ unet_use_cross_frame_attention=None,
357
+ unet_use_temporal_attention=None,
358
+ use_zero_convs=False,
359
+ ):
360
+ super().__init__()
361
+ self.only_cross_attention = only_cross_attention
362
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
363
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
364
+ self.unet_use_temporal_attention = unet_use_temporal_attention
365
+
366
+ if use_zero_convs:
367
+ # self.zero_conv = nn.Conv2d(
368
+ # dim, dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False
369
+ # )
370
+ # nn.init.zeros_(self.zero_conv.weight.data)
371
+ self.zero_conv = WarpedFeatureInjector(dim)
372
+
373
+ else:
374
+ self.zero_conv = None
375
+
376
+ # SC-Attn
377
+ self.attn1 = Attention(
378
+ query_dim=dim,
379
+ heads=num_attention_heads,
380
+ dim_head=attention_head_dim,
381
+ dropout=dropout,
382
+ bias=attention_bias,
383
+ upcast_attention=upcast_attention,
384
+ )
385
+ self.norm1 = (
386
+ AdaLayerNorm(dim, num_embeds_ada_norm)
387
+ if self.use_ada_layer_norm
388
+ else nn.LayerNorm(dim)
389
+ )
390
+
391
+ # Cross-Attn
392
+ if cross_attention_dim is not None:
393
+ self.attn2 = Attention(
394
+ query_dim=dim,
395
+ cross_attention_dim=cross_attention_dim,
396
+ heads=num_attention_heads,
397
+ dim_head=attention_head_dim,
398
+ dropout=dropout,
399
+ bias=attention_bias,
400
+ upcast_attention=upcast_attention,
401
+ )
402
+ else:
403
+ self.attn2 = None
404
+
405
+ if cross_attention_dim is not None:
406
+ self.norm2 = (
407
+ AdaLayerNorm(dim, num_embeds_ada_norm)
408
+ if self.use_ada_layer_norm
409
+ else nn.LayerNorm(dim)
410
+ )
411
+ else:
412
+ self.norm2 = None
413
+
414
+ # Feed-forward
415
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
416
+ self.norm3 = nn.LayerNorm(dim)
417
+ self.use_ada_layer_norm_zero = False
418
+
419
+ # Temp-Attn
420
+ assert unet_use_temporal_attention is not None
421
+ if unet_use_temporal_attention:
422
+ self.attn_temp = Attention(
423
+ query_dim=dim,
424
+ heads=num_attention_heads,
425
+ dim_head=attention_head_dim,
426
+ dropout=dropout,
427
+ bias=attention_bias,
428
+ upcast_attention=upcast_attention,
429
+ )
430
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
431
+ self.norm_temp = (
432
+ AdaLayerNorm(dim, num_embeds_ada_norm)
433
+ if self.use_ada_layer_norm
434
+ else nn.LayerNorm(dim)
435
+ )
436
+
437
+ def forward(
438
+ self,
439
+ hidden_states,
440
+ encoder_hidden_states=None,
441
+ timestep=None,
442
+ attention_mask=None,
443
+ video_length=None,
444
+ ):
445
+ norm_hidden_states = (
446
+ self.norm1(hidden_states, timestep)
447
+ if self.use_ada_layer_norm
448
+ else self.norm1(hidden_states)
449
+ )
450
+
451
+ if self.unet_use_cross_frame_attention:
452
+ hidden_states = (
453
+ self.attn1(
454
+ norm_hidden_states,
455
+ attention_mask=attention_mask,
456
+ video_length=video_length,
457
+ )
458
+ + hidden_states
459
+ )
460
+ else:
461
+ hidden_states = (
462
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
463
+ + hidden_states
464
+ )
465
+
466
+ if self.attn2 is not None:
467
+ # Cross-Attention
468
+ norm_hidden_states = (
469
+ self.norm2(hidden_states, timestep)
470
+ if self.use_ada_layer_norm
471
+ else self.norm2(hidden_states)
472
+ )
473
+ hidden_states = (
474
+ self.attn2(
475
+ norm_hidden_states,
476
+ encoder_hidden_states=encoder_hidden_states,
477
+ attention_mask=attention_mask,
478
+ )
479
+ + hidden_states
480
+ )
481
+
482
+ # Feed-forward
483
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
484
+
485
+ # Temporal-Attention
486
+ if self.unet_use_temporal_attention:
487
+ d = hidden_states.shape[1]
488
+ hidden_states = rearrange(
489
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
490
+ )
491
+ norm_hidden_states = (
492
+ self.norm_temp(hidden_states, timestep)
493
+ if self.use_ada_layer_norm
494
+ else self.norm_temp(hidden_states)
495
+ )
496
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
497
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
498
+
499
+ return hidden_states
genwarp/models/motion_module.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # AnimateDiff
9
+ # Apache License, Version 2.0G
10
+ # https://github.com/guoyww/AnimateDiff
11
+ # ==============================================================================
12
+
13
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
14
+ import math
15
+ from dataclasses import dataclass
16
+ from typing import Callable, Optional
17
+
18
+ import torch
19
+ from torch import nn
20
+ from diffusers.models.attention import FeedForward
21
+ from diffusers.models.attention_processor import Attention, AttnProcessor
22
+ from diffusers.utils import BaseOutput
23
+ from diffusers.utils.import_utils import is_xformers_available
24
+ from einops import rearrange, repeat
25
+
26
+ def zero_module(module):
27
+ # Zero out the parameters of a module and return it.
28
+ for p in module.parameters():
29
+ p.detach().zero_()
30
+ return module
31
+
32
+
33
+ @dataclass
34
+ class TemporalTransformer3DModelOutput(BaseOutput):
35
+ sample: torch.FloatTensor
36
+
37
+
38
+ if is_xformers_available():
39
+ import xformers
40
+ import xformers.ops
41
+ else:
42
+ xformers = None
43
+
44
+
45
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
46
+ if motion_module_type == "Vanilla":
47
+ return VanillaTemporalModule(
48
+ in_channels=in_channels,
49
+ **motion_module_kwargs,
50
+ )
51
+ else:
52
+ raise ValueError
53
+
54
+
55
+ class VanillaTemporalModule(nn.Module):
56
+ def __init__(
57
+ self,
58
+ in_channels,
59
+ num_attention_heads=8,
60
+ num_transformer_block=2,
61
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
62
+ cross_frame_attention_mode=None,
63
+ temporal_position_encoding=False,
64
+ temporal_position_encoding_max_len=24,
65
+ temporal_attention_dim_div=1,
66
+ zero_initialize=True,
67
+ ):
68
+ super().__init__()
69
+
70
+ self.temporal_transformer = TemporalTransformer3DModel(
71
+ in_channels=in_channels,
72
+ num_attention_heads=num_attention_heads,
73
+ attention_head_dim=in_channels
74
+ // num_attention_heads
75
+ // temporal_attention_dim_div,
76
+ num_layers=num_transformer_block,
77
+ attention_block_types=attention_block_types,
78
+ cross_frame_attention_mode=cross_frame_attention_mode,
79
+ temporal_position_encoding=temporal_position_encoding,
80
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
81
+ )
82
+
83
+ if zero_initialize:
84
+ self.temporal_transformer.proj_out = zero_module(
85
+ self.temporal_transformer.proj_out
86
+ )
87
+
88
+ def forward(
89
+ self,
90
+ input_tensor,
91
+ temb,
92
+ encoder_hidden_states,
93
+ attention_mask=None,
94
+ anchor_frame_idx=None,
95
+ ):
96
+ hidden_states = input_tensor
97
+ hidden_states = self.temporal_transformer(
98
+ hidden_states, encoder_hidden_states, attention_mask
99
+ )
100
+
101
+ output = hidden_states
102
+ return output
103
+
104
+
105
+ class TemporalTransformer3DModel(nn.Module):
106
+ def __init__(
107
+ self,
108
+ in_channels,
109
+ num_attention_heads,
110
+ attention_head_dim,
111
+ num_layers,
112
+ attention_block_types=(
113
+ "Temporal_Self",
114
+ "Temporal_Self",
115
+ ),
116
+ dropout=0.0,
117
+ norm_num_groups=32,
118
+ cross_attention_dim=768,
119
+ activation_fn="geglu",
120
+ attention_bias=False,
121
+ upcast_attention=False,
122
+ cross_frame_attention_mode=None,
123
+ temporal_position_encoding=False,
124
+ temporal_position_encoding_max_len=24,
125
+ ):
126
+ super().__init__()
127
+
128
+ inner_dim = num_attention_heads * attention_head_dim
129
+
130
+ self.norm = torch.nn.GroupNorm(
131
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
132
+ )
133
+ self.proj_in = nn.Linear(in_channels, inner_dim)
134
+
135
+ self.transformer_blocks = nn.ModuleList(
136
+ [
137
+ TemporalTransformerBlock(
138
+ dim=inner_dim,
139
+ num_attention_heads=num_attention_heads,
140
+ attention_head_dim=attention_head_dim,
141
+ attention_block_types=attention_block_types,
142
+ dropout=dropout,
143
+ norm_num_groups=norm_num_groups,
144
+ cross_attention_dim=cross_attention_dim,
145
+ activation_fn=activation_fn,
146
+ attention_bias=attention_bias,
147
+ upcast_attention=upcast_attention,
148
+ cross_frame_attention_mode=cross_frame_attention_mode,
149
+ temporal_position_encoding=temporal_position_encoding,
150
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
151
+ )
152
+ for d in range(num_layers)
153
+ ]
154
+ )
155
+ self.proj_out = nn.Linear(inner_dim, in_channels)
156
+
157
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
158
+ assert (
159
+ hidden_states.dim() == 5
160
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
161
+ video_length = hidden_states.shape[2]
162
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
163
+
164
+ batch, channel, height, weight = hidden_states.shape
165
+ residual = hidden_states
166
+
167
+ hidden_states = self.norm(hidden_states)
168
+ inner_dim = hidden_states.shape[1]
169
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
170
+ batch, height * weight, inner_dim
171
+ )
172
+ hidden_states = self.proj_in(hidden_states)
173
+
174
+ # Transformer Blocks
175
+ for block in self.transformer_blocks:
176
+ hidden_states = block(
177
+ hidden_states,
178
+ encoder_hidden_states=encoder_hidden_states,
179
+ video_length=video_length,
180
+ )
181
+
182
+ # output
183
+ hidden_states = self.proj_out(hidden_states)
184
+ hidden_states = (
185
+ hidden_states.reshape(batch, height, weight, inner_dim)
186
+ .permute(0, 3, 1, 2)
187
+ .contiguous()
188
+ )
189
+
190
+ output = hidden_states + residual
191
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
192
+
193
+ return output
194
+
195
+
196
+ class TemporalTransformerBlock(nn.Module):
197
+ def __init__(
198
+ self,
199
+ dim,
200
+ num_attention_heads,
201
+ attention_head_dim,
202
+ attention_block_types=(
203
+ "Temporal_Self",
204
+ "Temporal_Self",
205
+ ),
206
+ dropout=0.0,
207
+ norm_num_groups=32,
208
+ cross_attention_dim=768,
209
+ activation_fn="geglu",
210
+ attention_bias=False,
211
+ upcast_attention=False,
212
+ cross_frame_attention_mode=None,
213
+ temporal_position_encoding=False,
214
+ temporal_position_encoding_max_len=24,
215
+ ):
216
+ super().__init__()
217
+
218
+ attention_blocks = []
219
+ norms = []
220
+
221
+ for block_name in attention_block_types:
222
+ attention_blocks.append(
223
+ VersatileAttention(
224
+ attention_mode=block_name.split("_")[0],
225
+ cross_attention_dim=cross_attention_dim
226
+ if block_name.endswith("_Cross")
227
+ else None,
228
+ query_dim=dim,
229
+ heads=num_attention_heads,
230
+ dim_head=attention_head_dim,
231
+ dropout=dropout,
232
+ bias=attention_bias,
233
+ upcast_attention=upcast_attention,
234
+ cross_frame_attention_mode=cross_frame_attention_mode,
235
+ temporal_position_encoding=temporal_position_encoding,
236
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
237
+ )
238
+ )
239
+ norms.append(nn.LayerNorm(dim))
240
+
241
+ self.attention_blocks = nn.ModuleList(attention_blocks)
242
+ self.norms = nn.ModuleList(norms)
243
+
244
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
245
+ self.ff_norm = nn.LayerNorm(dim)
246
+
247
+ def forward(
248
+ self,
249
+ hidden_states,
250
+ encoder_hidden_states=None,
251
+ attention_mask=None,
252
+ video_length=None,
253
+ ):
254
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
255
+ norm_hidden_states = norm(hidden_states)
256
+ hidden_states = (
257
+ attention_block(
258
+ norm_hidden_states,
259
+ encoder_hidden_states=encoder_hidden_states
260
+ if attention_block.is_cross_attention
261
+ else None,
262
+ video_length=video_length,
263
+ )
264
+ + hidden_states
265
+ )
266
+
267
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
268
+
269
+ output = hidden_states
270
+ return output
271
+
272
+
273
+ class PositionalEncoding(nn.Module):
274
+ def __init__(self, d_model, dropout=0.0, max_len=24):
275
+ super().__init__()
276
+ self.dropout = nn.Dropout(p=dropout)
277
+ position = torch.arange(max_len).unsqueeze(1)
278
+ div_term = torch.exp(
279
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
280
+ )
281
+ pe = torch.zeros(1, max_len, d_model)
282
+ pe[0, :, 0::2] = torch.sin(position * div_term)
283
+ pe[0, :, 1::2] = torch.cos(position * div_term)
284
+ self.register_buffer("pe", pe)
285
+
286
+ def forward(self, x):
287
+ x = x + self.pe[:, : x.size(1)]
288
+ return self.dropout(x)
289
+
290
+
291
+ class VersatileAttention(Attention):
292
+ def __init__(
293
+ self,
294
+ attention_mode=None,
295
+ cross_frame_attention_mode=None,
296
+ temporal_position_encoding=False,
297
+ temporal_position_encoding_max_len=24,
298
+ *args,
299
+ **kwargs,
300
+ ):
301
+ super().__init__(*args, **kwargs)
302
+ assert attention_mode == "Temporal"
303
+
304
+ self.attention_mode = attention_mode
305
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
306
+
307
+ self.pos_encoder = (
308
+ PositionalEncoding(
309
+ kwargs["query_dim"],
310
+ dropout=0.0,
311
+ max_len=temporal_position_encoding_max_len,
312
+ )
313
+ if (temporal_position_encoding and attention_mode == "Temporal")
314
+ else None
315
+ )
316
+
317
+ def extra_repr(self):
318
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
319
+
320
+ def set_use_memory_efficient_attention_xformers(
321
+ self,
322
+ use_memory_efficient_attention_xformers: bool,
323
+ attention_op: Optional[Callable] = None,
324
+ ):
325
+ if use_memory_efficient_attention_xformers:
326
+ if not is_xformers_available():
327
+ raise ModuleNotFoundError(
328
+ (
329
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
330
+ " xformers"
331
+ ),
332
+ name="xformers",
333
+ )
334
+ elif not torch.cuda.is_available():
335
+ raise ValueError(
336
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
337
+ " only available for GPU "
338
+ )
339
+ else:
340
+ try:
341
+ # Make sure we can run the memory efficient attention
342
+ _ = xformers.ops.memory_efficient_attention(
343
+ torch.randn((1, 2, 40), device="cuda"),
344
+ torch.randn((1, 2, 40), device="cuda"),
345
+ torch.randn((1, 2, 40), device="cuda"),
346
+ )
347
+ except Exception as e:
348
+ raise e
349
+
350
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
351
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
352
+ # You don't need XFormersAttnProcessor here.
353
+ # processor = XFormersAttnProcessor(
354
+ # attention_op=attention_op,
355
+ # )
356
+ processor = AttnProcessor()
357
+ else:
358
+ processor = AttnProcessor()
359
+
360
+ self.set_processor(processor)
361
+
362
+ def forward(
363
+ self,
364
+ hidden_states,
365
+ encoder_hidden_states=None,
366
+ attention_mask=None,
367
+ video_length=None,
368
+ **cross_attention_kwargs,
369
+ ):
370
+ if self.attention_mode == "Temporal":
371
+ d = hidden_states.shape[1] # d means HxW
372
+ hidden_states = rearrange(
373
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
374
+ )
375
+
376
+ if self.pos_encoder is not None:
377
+ hidden_states = self.pos_encoder(hidden_states)
378
+
379
+ encoder_hidden_states = (
380
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
381
+ if encoder_hidden_states is not None
382
+ else encoder_hidden_states
383
+ )
384
+
385
+ else:
386
+ raise NotImplementedError
387
+
388
+ hidden_states = self.processor(
389
+ self,
390
+ hidden_states,
391
+ encoder_hidden_states=encoder_hidden_states,
392
+ attention_mask=attention_mask,
393
+ **cross_attention_kwargs,
394
+ )
395
+
396
+ if self.attention_mode == "Temporal":
397
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
398
+
399
+ return hidden_states
genwarp/models/mutual_self_attention.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # magic-animate
9
+ # BSD 3-Clause License
10
+ # Copyright (c) Bytedance Inc.
11
+ # https://github.com/magic-research/magic-animate
12
+ # ==============================================================================
13
+
14
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
15
+ from typing import Any, Dict, Optional
16
+ import math
17
+
18
+ import torch
19
+ from einops import rearrange
20
+
21
+ from .attention import TemporalBasicTransformerBlock
22
+ from .attention import BasicTransformerBlock
23
+
24
+ def torch_dfs(model: torch.nn.Module):
25
+ result = [model]
26
+ for child in model.children():
27
+ result += torch_dfs(child)
28
+ return result
29
+
30
+
31
+ class ReferenceAttentionControl:
32
+ def __init__(
33
+ self,
34
+ unet,
35
+ mode="write",
36
+ do_classifier_free_guidance=False,
37
+ attention_auto_machine_weight=float("inf"),
38
+ gn_auto_machine_weight=1.0,
39
+ style_fidelity=1.0,
40
+ reference_attn=True,
41
+ reference_adain=False,
42
+ fusion_blocks="midup",
43
+ batch_size=1,
44
+ feature_fusion_type=None,
45
+ ) -> None:
46
+ self.unet = unet
47
+ assert mode in ["read", "write"]
48
+ assert fusion_blocks in ["midup", "full"]
49
+ self.reference_attn = reference_attn
50
+ self.reference_adain = reference_adain
51
+ self.fusion_blocks = fusion_blocks
52
+ self.feature_fusion_type = feature_fusion_type
53
+
54
+ self.mode = mode
55
+ self.do_classifier_free_guidance = do_classifier_free_guidance
56
+ self.attention_auto_machine_weight = attention_auto_machine_weight
57
+ self.gn_auto_machine_weight = gn_auto_machine_weight
58
+ self.style_fidelity = style_fidelity
59
+ self.batch_size = batch_size
60
+
61
+ self.register_reference_hooks(
62
+ mode,
63
+ do_classifier_free_guidance,
64
+ attention_auto_machine_weight,
65
+ gn_auto_machine_weight,
66
+ style_fidelity,
67
+ reference_attn,
68
+ reference_adain,
69
+ fusion_blocks,
70
+ batch_size=batch_size,
71
+ )
72
+
73
+ def rehook(self):
74
+ self.register_reference_hooks(
75
+ self.mode,
76
+ self.do_classifier_free_guidance,
77
+ self.attention_auto_machine_weight,
78
+ self.gn_auto_machine_weight,
79
+ self.style_fidelity,
80
+ self.reference_attn,
81
+ self.reference_adain,
82
+ self.fusion_blocks,
83
+ self.batch_size,
84
+ )
85
+
86
+ def register_reference_hooks(
87
+ self,
88
+ mode,
89
+ do_classifier_free_guidance,
90
+ attention_auto_machine_weight,
91
+ gn_auto_machine_weight,
92
+ style_fidelity,
93
+ reference_attn,
94
+ reference_adain,
95
+ dtype=torch.float16,
96
+ batch_size=1,
97
+ num_images_per_prompt=1,
98
+ device=torch.device("cpu"),
99
+ fusion_blocks="midup",
100
+ ):
101
+ do_classifier_free_guidance = do_classifier_free_guidance
102
+ attention_auto_machine_weight = attention_auto_machine_weight
103
+ gn_auto_machine_weight = gn_auto_machine_weight
104
+ style_fidelity = style_fidelity
105
+ reference_attn = reference_attn
106
+ reference_adain = reference_adain
107
+ fusion_blocks = fusion_blocks
108
+ num_images_per_prompt = num_images_per_prompt
109
+ dtype = dtype
110
+ feature_fusion_type = self.feature_fusion_type
111
+
112
+ if do_classifier_free_guidance:
113
+ uc_mask = (
114
+ torch.Tensor(
115
+ [1] * batch_size * num_images_per_prompt * 16
116
+ + [0] * batch_size * num_images_per_prompt * 16
117
+ )
118
+ .to(device)
119
+ .bool()
120
+ )
121
+ else:
122
+ uc_mask = (
123
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
124
+ .to(device)
125
+ .bool()
126
+ )
127
+
128
+ def hacked_basic_transformer_inner_forward(
129
+ self,
130
+ hidden_states: torch.FloatTensor,
131
+ attention_mask: Optional[torch.FloatTensor] = None,
132
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
133
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
134
+ timestep: Optional[torch.LongTensor] = None,
135
+ cross_attention_kwargs: Dict[str, Any] = None,
136
+ class_labels: Optional[torch.LongTensor] = None,
137
+ video_length=None,
138
+ ):
139
+ if self.use_ada_layer_norm: # False
140
+ norm_hidden_states = self.norm1(hidden_states, timestep)
141
+ elif self.use_ada_layer_norm_zero:
142
+ (
143
+ norm_hidden_states,
144
+ gate_msa,
145
+ shift_mlp,
146
+ scale_mlp,
147
+ gate_mlp,
148
+ ) = self.norm1(
149
+ hidden_states,
150
+ timestep,
151
+ class_labels,
152
+ hidden_dtype=hidden_states.dtype,
153
+ )
154
+ else:
155
+ norm_hidden_states = self.norm1(hidden_states)
156
+
157
+ cross_attention_kwargs = (
158
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
159
+ )
160
+ if self.only_cross_attention:
161
+ attn_output = self.attn1(
162
+ norm_hidden_states,
163
+ encoder_hidden_states=encoder_hidden_states
164
+ if self.only_cross_attention
165
+ else None,
166
+ attention_mask=attention_mask,
167
+ **cross_attention_kwargs,
168
+ )
169
+ else:
170
+ if mode == "write":
171
+ self.bank.append(norm_hidden_states.clone())
172
+ self.bank_unnorm.append(hidden_states.clone())
173
+ attn_output = self.attn1(
174
+ norm_hidden_states,
175
+ encoder_hidden_states=encoder_hidden_states
176
+ if self.only_cross_attention
177
+ else None,
178
+ attention_mask=attention_mask,
179
+ **cross_attention_kwargs,
180
+ )
181
+ if mode == "read":
182
+ bank_fea = [
183
+ rearrange(
184
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
185
+ "b t l c -> (b t) l c",
186
+ )
187
+ for d in self.bank
188
+ ]
189
+
190
+ bank_fea_unnorm = [
191
+ rearrange(
192
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
193
+ "b t l c -> (b t) l c",
194
+ )
195
+ for d in self.bank_unnorm
196
+ ]
197
+
198
+
199
+ modify_norm_hidden_states = torch.cat(
200
+ [norm_hidden_states] + bank_fea, dim=1
201
+ )
202
+
203
+ if feature_fusion_type == 'attention_full_sharing':
204
+ # Full sharing for ablation exp.
205
+ hidden_states_uc = (
206
+ self.attn1(
207
+ norm_hidden_states,
208
+ encoder_hidden_states=modify_norm_hidden_states,
209
+ attention_mask=None,
210
+ )
211
+ + hidden_states
212
+ )
213
+ else:
214
+ raise ValueError("feature_fusion_type is not valid")
215
+
216
+ if do_classifier_free_guidance:
217
+ hidden_states_c = hidden_states_uc.clone()
218
+ _uc_mask = uc_mask.clone()
219
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
220
+ _uc_mask = (
221
+ torch.Tensor(
222
+ [1] * (hidden_states.shape[0] // 2)
223
+ + [0] * (hidden_states.shape[0] // 2)
224
+ )
225
+ .to(device)
226
+ .bool()
227
+ )
228
+ hidden_states_c[_uc_mask] = (
229
+ self.attn1(
230
+ norm_hidden_states[_uc_mask],
231
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
232
+ attention_mask=None,
233
+ )
234
+ + hidden_states[_uc_mask]
235
+ )
236
+ hidden_states = hidden_states_c.clone()
237
+ else:
238
+ hidden_states = hidden_states_uc
239
+
240
+ if self.attn2 is not None:
241
+ # Cross-Attention
242
+ norm_hidden_states = (
243
+ self.norm2(hidden_states, timestep)
244
+ if self.use_ada_layer_norm
245
+ else self.norm2(hidden_states)
246
+ )
247
+ hidden_states = (
248
+ self.attn2(
249
+ norm_hidden_states,
250
+ encoder_hidden_states=encoder_hidden_states,
251
+ attention_mask=None,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ # Feed-forward
257
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
258
+
259
+ # Temporal-Attention
260
+ if self.unet_use_temporal_attention:
261
+ d = hidden_states.shape[1]
262
+ hidden_states = rearrange(
263
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
264
+ )
265
+ norm_hidden_states = (
266
+ self.norm_temp(hidden_states, timestep)
267
+ if self.use_ada_layer_norm
268
+ else self.norm_temp(hidden_states)
269
+ )
270
+ hidden_states = (
271
+ self.attn_temp(norm_hidden_states) + hidden_states
272
+ )
273
+ hidden_states = rearrange(
274
+ hidden_states, "(b d) f c -> (b f) d c", d=d
275
+ )
276
+
277
+ return hidden_states
278
+
279
+ if self.use_ada_layer_norm_zero:
280
+ attn_output = gate_msa.unsqueeze(1) * attn_output
281
+ hidden_states = attn_output + hidden_states
282
+
283
+ if self.attn2 is not None:
284
+ norm_hidden_states = (
285
+ self.norm2(hidden_states, timestep)
286
+ if self.use_ada_layer_norm
287
+ else self.norm2(hidden_states)
288
+ )
289
+
290
+ attn_output = self.attn2(
291
+ norm_hidden_states,
292
+ encoder_hidden_states=encoder_hidden_states,
293
+ attention_mask=encoder_attention_mask,
294
+ **cross_attention_kwargs,
295
+ )
296
+ hidden_states = attn_output + hidden_states
297
+
298
+ norm_hidden_states = self.norm3(hidden_states)
299
+
300
+ if self.use_ada_layer_norm_zero:
301
+ norm_hidden_states = (
302
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
303
+ )
304
+
305
+ ff_output = self.ff(norm_hidden_states)
306
+
307
+ if self.use_ada_layer_norm_zero:
308
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
309
+
310
+ hidden_states = ff_output + hidden_states
311
+
312
+ return hidden_states
313
+
314
+ if self.reference_attn:
315
+ if self.fusion_blocks == "midup":
316
+ attn_modules = [
317
+ module
318
+ for module in (
319
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
320
+ )
321
+ if isinstance(module, BasicTransformerBlock)
322
+ or isinstance(module, TemporalBasicTransformerBlock)
323
+ ]
324
+ elif self.fusion_blocks == "full":
325
+ attn_modules = [
326
+ module
327
+ for module in torch_dfs(self.unet)
328
+ if isinstance(module, BasicTransformerBlock)
329
+ or isinstance(module, TemporalBasicTransformerBlock)
330
+ ]
331
+ attn_modules = sorted(
332
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333
+ )
334
+
335
+ for i, module in enumerate(attn_modules):
336
+ module._original_inner_forward = module.forward
337
+ if isinstance(module, BasicTransformerBlock):
338
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
339
+ module, BasicTransformerBlock
340
+ )
341
+ if isinstance(module, TemporalBasicTransformerBlock):
342
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
343
+ module, TemporalBasicTransformerBlock
344
+ )
345
+
346
+ module.bank = []
347
+ module.bank_unnorm = []
348
+ module.correspondence = None
349
+ module.attn_weight = float(i) / float(len(attn_modules))
350
+
351
+ def update(self, writer, correspondence=None, dtype=torch.float16):
352
+ if self.reference_attn:
353
+ if self.fusion_blocks == "midup":
354
+ reader_attn_modules = [
355
+ module
356
+ for module in (
357
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
358
+ )
359
+ if isinstance(module, TemporalBasicTransformerBlock)
360
+ ]
361
+ writer_attn_modules = [
362
+ module
363
+ for module in (
364
+ torch_dfs(writer.unet.mid_block)
365
+ + torch_dfs(writer.unet.up_blocks)
366
+ )
367
+ if isinstance(module, BasicTransformerBlock)
368
+ ]
369
+ elif self.fusion_blocks == "full":
370
+ reader_attn_modules = [
371
+ module
372
+ for module in torch_dfs(self.unet)
373
+ if isinstance(module, TemporalBasicTransformerBlock)
374
+ ]
375
+ writer_attn_modules = [
376
+ module
377
+ for module in torch_dfs(writer.unet)
378
+ if isinstance(module, BasicTransformerBlock)
379
+ ]
380
+ reader_attn_modules = sorted(
381
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
382
+ )
383
+ writer_attn_modules = sorted(
384
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
385
+ )
386
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
387
+ r.bank = [v.clone().to(dtype) for v in w.bank]
388
+ r.bank_unnorm = [v.clone().to(dtype) for v in w.bank_unnorm]
389
+ if correspondence is not None:
390
+ r.correspondence = [correspondence]
391
+ else:
392
+ r.correspondence = None
393
+
394
+ def clear(self):
395
+ if self.reference_attn:
396
+ if self.fusion_blocks == "midup":
397
+ reader_attn_modules = [
398
+ module
399
+ for module in (
400
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
401
+ )
402
+ if isinstance(module, BasicTransformerBlock)
403
+ or isinstance(module, TemporalBasicTransformerBlock)
404
+ ]
405
+ elif self.fusion_blocks == "full":
406
+ reader_attn_modules = [
407
+ module
408
+ for module in torch_dfs(self.unet)
409
+ if isinstance(module, BasicTransformerBlock)
410
+ or isinstance(module, TemporalBasicTransformerBlock)
411
+ ]
412
+ reader_attn_modules = sorted(
413
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
414
+ )
415
+ for r in reader_attn_modules:
416
+ r.bank.clear()
417
+ r.bank_unnorm.clear()
418
+ if r.correspondence is not None:
419
+ r.correspondence.clear()
420
+ r.correspondence = None
genwarp/models/pose_guider.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # ==============================================================================
8
+
9
+ from typing import Tuple
10
+
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+
15
+ from .motion_module import zero_module
16
+ from .resnet import InflatedConv3d
17
+
18
+ class PoseGuider(ModelMixin):
19
+ def __init__(
20
+ self,
21
+ conditioning_embedding_channels: int,
22
+ conditioning_channels: int = 3,
23
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
24
+ ):
25
+ super().__init__()
26
+ self.conv_in = InflatedConv3d(
27
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
28
+ )
29
+
30
+ self.blocks = nn.ModuleList([])
31
+
32
+ for i in range(len(block_out_channels) - 1):
33
+ channel_in = block_out_channels[i]
34
+ channel_out = block_out_channels[i + 1]
35
+ self.blocks.append(
36
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
37
+ )
38
+ self.blocks.append(
39
+ InflatedConv3d(
40
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
41
+ )
42
+ )
43
+
44
+ self.conv_out = zero_module(
45
+ InflatedConv3d(
46
+ block_out_channels[-1],
47
+ conditioning_embedding_channels,
48
+ kernel_size=3,
49
+ padding=1,
50
+ )
51
+ )
52
+
53
+ def forward(self, conditioning):
54
+ embedding = self.conv_in(conditioning)
55
+ embedding = F.silu(embedding)
56
+
57
+ for block in self.blocks:
58
+ embedding = block(embedding)
59
+ embedding = F.silu(embedding)
60
+
61
+ embedding = self.conv_out(embedding)
62
+
63
+ return embedding
genwarp/models/resnet.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # Diffusers
9
+ # Apache License, Version 2.0
10
+ # Copyright (c) Hugging Face Inc.
11
+ # https://github.com/huggingface/diffusers
12
+ # ==============================================================================
13
+
14
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from einops import rearrange
20
+
21
+
22
+ class InflatedConv3d(nn.Conv2d):
23
+ def forward(self, x):
24
+ video_length = x.shape[2]
25
+
26
+ x = rearrange(x, "b c f h w -> (b f) c h w")
27
+ x = super().forward(x)
28
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
29
+
30
+ return x
31
+
32
+
33
+ class InflatedGroupNorm(nn.GroupNorm):
34
+ def forward(self, x):
35
+ video_length = x.shape[2]
36
+
37
+ x = rearrange(x, "b c f h w -> (b f) c h w")
38
+ x = super().forward(x)
39
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
40
+
41
+ return x
42
+
43
+
44
+ class Upsample3D(nn.Module):
45
+ def __init__(
46
+ self,
47
+ channels,
48
+ use_conv=False,
49
+ use_conv_transpose=False,
50
+ out_channels=None,
51
+ name="conv",
52
+ ):
53
+ super().__init__()
54
+ self.channels = channels
55
+ self.out_channels = out_channels or channels
56
+ self.use_conv = use_conv
57
+ self.use_conv_transpose = use_conv_transpose
58
+ self.name = name
59
+
60
+ conv = None
61
+ if use_conv_transpose:
62
+ raise NotImplementedError
63
+ elif use_conv:
64
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
65
+
66
+ def forward(self, hidden_states, output_size=None):
67
+ assert hidden_states.shape[1] == self.channels
68
+
69
+ if self.use_conv_transpose:
70
+ raise NotImplementedError
71
+
72
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
73
+ dtype = hidden_states.dtype
74
+ if dtype == torch.bfloat16:
75
+ hidden_states = hidden_states.to(torch.float32)
76
+
77
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
78
+ if hidden_states.shape[0] >= 64:
79
+ hidden_states = hidden_states.contiguous()
80
+
81
+ # if `output_size` is passed we force the interpolation output
82
+ # size and do not make use of `scale_factor=2`
83
+ if output_size is None:
84
+ hidden_states = F.interpolate(
85
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
86
+ )
87
+ else:
88
+ hidden_states = F.interpolate(
89
+ hidden_states, size=output_size, mode="nearest"
90
+ )
91
+
92
+ # If the input is bfloat16, we cast back to bfloat16
93
+ if dtype == torch.bfloat16:
94
+ hidden_states = hidden_states.to(dtype)
95
+
96
+ # if self.use_conv:
97
+ # if self.name == "conv":
98
+ # hidden_states = self.conv(hidden_states)
99
+ # else:
100
+ # hidden_states = self.Conv2d_0(hidden_states)
101
+ hidden_states = self.conv(hidden_states)
102
+
103
+ return hidden_states
104
+
105
+
106
+ class Downsample3D(nn.Module):
107
+ def __init__(
108
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
109
+ ):
110
+ super().__init__()
111
+ self.channels = channels
112
+ self.out_channels = out_channels or channels
113
+ self.use_conv = use_conv
114
+ self.padding = padding
115
+ stride = 2
116
+ self.name = name
117
+
118
+ if use_conv:
119
+ self.conv = InflatedConv3d(
120
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
121
+ )
122
+ else:
123
+ raise NotImplementedError
124
+
125
+ def forward(self, hidden_states):
126
+ assert hidden_states.shape[1] == self.channels
127
+ if self.use_conv and self.padding == 0:
128
+ raise NotImplementedError
129
+
130
+ assert hidden_states.shape[1] == self.channels
131
+ hidden_states = self.conv(hidden_states)
132
+
133
+ return hidden_states
134
+
135
+
136
+ class ResnetBlock3D(nn.Module):
137
+ def __init__(
138
+ self,
139
+ *,
140
+ in_channels,
141
+ out_channels=None,
142
+ conv_shortcut=False,
143
+ dropout=0.0,
144
+ temb_channels=512,
145
+ groups=32,
146
+ groups_out=None,
147
+ pre_norm=True,
148
+ eps=1e-6,
149
+ non_linearity="swish",
150
+ time_embedding_norm="default",
151
+ output_scale_factor=1.0,
152
+ use_in_shortcut=None,
153
+ use_inflated_groupnorm=None,
154
+ ):
155
+ super().__init__()
156
+ self.pre_norm = pre_norm
157
+ self.pre_norm = True
158
+ self.in_channels = in_channels
159
+ out_channels = in_channels if out_channels is None else out_channels
160
+ self.out_channels = out_channels
161
+ self.use_conv_shortcut = conv_shortcut
162
+ self.time_embedding_norm = time_embedding_norm
163
+ self.output_scale_factor = output_scale_factor
164
+
165
+ if groups_out is None:
166
+ groups_out = groups
167
+
168
+ assert use_inflated_groupnorm != None
169
+ if use_inflated_groupnorm:
170
+ self.norm1 = InflatedGroupNorm(
171
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
172
+ )
173
+ else:
174
+ self.norm1 = torch.nn.GroupNorm(
175
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
176
+ )
177
+
178
+ self.conv1 = InflatedConv3d(
179
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
180
+ )
181
+
182
+ if temb_channels is not None:
183
+ if self.time_embedding_norm == "default":
184
+ time_emb_proj_out_channels = out_channels
185
+ elif self.time_embedding_norm == "scale_shift":
186
+ time_emb_proj_out_channels = out_channels * 2
187
+ else:
188
+ raise ValueError(
189
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
190
+ )
191
+
192
+ self.time_emb_proj = torch.nn.Linear(
193
+ temb_channels, time_emb_proj_out_channels
194
+ )
195
+ else:
196
+ self.time_emb_proj = None
197
+
198
+ if use_inflated_groupnorm:
199
+ self.norm2 = InflatedGroupNorm(
200
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
201
+ )
202
+ else:
203
+ self.norm2 = torch.nn.GroupNorm(
204
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
205
+ )
206
+ self.dropout = torch.nn.Dropout(dropout)
207
+ self.conv2 = InflatedConv3d(
208
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
209
+ )
210
+
211
+ if non_linearity == "swish":
212
+ self.nonlinearity = lambda x: F.silu(x)
213
+ elif non_linearity == "mish":
214
+ self.nonlinearity = Mish()
215
+ elif non_linearity == "silu":
216
+ self.nonlinearity = nn.SiLU()
217
+
218
+ self.use_in_shortcut = (
219
+ self.in_channels != self.out_channels
220
+ if use_in_shortcut is None
221
+ else use_in_shortcut
222
+ )
223
+
224
+ self.conv_shortcut = None
225
+ if self.use_in_shortcut:
226
+ self.conv_shortcut = InflatedConv3d(
227
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
228
+ )
229
+
230
+ def forward(self, input_tensor, temb):
231
+ hidden_states = input_tensor
232
+
233
+ hidden_states = self.norm1(hidden_states)
234
+ hidden_states = self.nonlinearity(hidden_states)
235
+
236
+ hidden_states = self.conv1(hidden_states)
237
+
238
+ if temb is not None:
239
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
240
+
241
+ if temb is not None and self.time_embedding_norm == "default":
242
+ hidden_states = hidden_states + temb
243
+
244
+ hidden_states = self.norm2(hidden_states)
245
+
246
+ if temb is not None and self.time_embedding_norm == "scale_shift":
247
+ scale, shift = torch.chunk(temb, 2, dim=1)
248
+ hidden_states = hidden_states * (1 + scale) + shift
249
+
250
+ hidden_states = self.nonlinearity(hidden_states)
251
+
252
+ hidden_states = self.dropout(hidden_states)
253
+ hidden_states = self.conv2(hidden_states)
254
+
255
+ if self.conv_shortcut is not None:
256
+ input_tensor = self.conv_shortcut(input_tensor)
257
+
258
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
259
+
260
+ return output_tensor
261
+
262
+
263
+ class Mish(torch.nn.Module):
264
+ def forward(self, hidden_states):
265
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
genwarp/models/transformer_2d.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # Diffusers
9
+ # Apache License, Version 2.0
10
+ # Copyright (c) Hugging Face Inc.
11
+ # https://github.com/huggingface/diffusers
12
+ # ==============================================================================
13
+
14
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ # from diffusers.models.embeddings import CaptionProjection
21
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.models.normalization import AdaLayerNormSingle
24
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
25
+ from torch import nn
26
+
27
+ from .attention import BasicTransformerBlock
28
+
29
+
30
+ @dataclass
31
+ class Transformer2DModelOutput(BaseOutput):
32
+ """
33
+ The output of [`Transformer2DModel`].
34
+
35
+ Args:
36
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
37
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
38
+ distributions for the unnoised latent pixels.
39
+ """
40
+
41
+ sample: torch.FloatTensor
42
+ ref_feature: torch.FloatTensor
43
+
44
+
45
+ class Transformer2DModel(ModelMixin, ConfigMixin):
46
+ """
47
+ A 2D Transformer model for image-like data.
48
+
49
+ Parameters:
50
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
51
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
52
+ in_channels (`int`, *optional*):
53
+ The number of channels in the input and output (specify if the input is **continuous**).
54
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
55
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
56
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
57
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
58
+ This is fixed during training since it is used to learn a number of position embeddings.
59
+ num_vector_embeds (`int`, *optional*):
60
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
61
+ Includes the class for the masked latent pixel.
62
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
63
+ num_embeds_ada_norm ( `int`, *optional*):
64
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
65
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
66
+ added to the hidden states.
67
+
68
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
69
+ attention_bias (`bool`, *optional*):
70
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
71
+ """
72
+
73
+ _supports_gradient_checkpointing = True
74
+
75
+ @register_to_config
76
+ def __init__(
77
+ self,
78
+ num_attention_heads: int = 16,
79
+ attention_head_dim: int = 88,
80
+ in_channels: Optional[int] = None,
81
+ out_channels: Optional[int] = None,
82
+ num_layers: int = 1,
83
+ dropout: float = 0.0,
84
+ norm_num_groups: int = 32,
85
+ cross_attention_dim: Optional[int] = None,
86
+ attention_bias: bool = False,
87
+ sample_size: Optional[int] = None,
88
+ num_vector_embeds: Optional[int] = None,
89
+ patch_size: Optional[int] = None,
90
+ activation_fn: str = "geglu",
91
+ num_embeds_ada_norm: Optional[int] = None,
92
+ use_linear_projection: bool = False,
93
+ only_cross_attention: bool = False,
94
+ double_self_attention: bool = False,
95
+ upcast_attention: bool = False,
96
+ norm_type: str = "layer_norm",
97
+ norm_elementwise_affine: bool = True,
98
+ norm_eps: float = 1e-5,
99
+ attention_type: str = "default",
100
+ caption_channels: int = None,
101
+ ):
102
+ super().__init__()
103
+ self.use_linear_projection = use_linear_projection
104
+ self.num_attention_heads = num_attention_heads
105
+ self.attention_head_dim = attention_head_dim
106
+ inner_dim = num_attention_heads * attention_head_dim
107
+
108
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
109
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
110
+
111
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
112
+ # Define whether input is continuous or discrete depending on configuration
113
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
114
+ self.is_input_vectorized = num_vector_embeds is not None
115
+ self.is_input_patches = in_channels is not None and patch_size is not None
116
+
117
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
118
+ deprecation_message = (
119
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
120
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
121
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
122
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
123
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
124
+ )
125
+ deprecate(
126
+ "norm_type!=num_embeds_ada_norm",
127
+ "1.0.0",
128
+ deprecation_message,
129
+ standard_warn=False,
130
+ )
131
+ norm_type = "ada_norm"
132
+
133
+ if self.is_input_continuous and self.is_input_vectorized:
134
+ raise ValueError(
135
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
136
+ " sure that either `in_channels` or `num_vector_embeds` is None."
137
+ )
138
+ elif self.is_input_vectorized and self.is_input_patches:
139
+ raise ValueError(
140
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
141
+ " sure that either `num_vector_embeds` or `num_patches` is None."
142
+ )
143
+ elif (
144
+ not self.is_input_continuous
145
+ and not self.is_input_vectorized
146
+ and not self.is_input_patches
147
+ ):
148
+ raise ValueError(
149
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
150
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
151
+ )
152
+
153
+ # 2. Define input layers
154
+ self.in_channels = in_channels
155
+
156
+ self.norm = torch.nn.GroupNorm(
157
+ num_groups=norm_num_groups,
158
+ num_channels=in_channels,
159
+ eps=1e-6,
160
+ affine=True,
161
+ )
162
+ if use_linear_projection:
163
+ self.proj_in = linear_cls(in_channels, inner_dim)
164
+ else:
165
+ self.proj_in = conv_cls(
166
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
167
+ )
168
+
169
+ # 3. Define transformers blocks
170
+ self.transformer_blocks = nn.ModuleList(
171
+ [
172
+ BasicTransformerBlock(
173
+ inner_dim,
174
+ num_attention_heads,
175
+ attention_head_dim,
176
+ dropout=dropout,
177
+ cross_attention_dim=cross_attention_dim,
178
+ activation_fn=activation_fn,
179
+ num_embeds_ada_norm=num_embeds_ada_norm,
180
+ attention_bias=attention_bias,
181
+ only_cross_attention=only_cross_attention,
182
+ double_self_attention=double_self_attention,
183
+ upcast_attention=upcast_attention,
184
+ norm_type=norm_type,
185
+ norm_elementwise_affine=norm_elementwise_affine,
186
+ norm_eps=norm_eps,
187
+ attention_type=attention_type,
188
+ )
189
+ for d in range(num_layers)
190
+ ]
191
+ )
192
+
193
+ # 4. Define output layers
194
+ self.out_channels = in_channels if out_channels is None else out_channels
195
+ # TODO: should use out_channels for continuous projections
196
+ if use_linear_projection:
197
+ self.proj_out = linear_cls(inner_dim, in_channels)
198
+ else:
199
+ self.proj_out = conv_cls(
200
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
201
+ )
202
+
203
+ # 5. PixArt-Alpha blocks.
204
+ self.adaln_single = None
205
+ self.use_additional_conditions = False
206
+ if norm_type == "ada_norm_single":
207
+ self.use_additional_conditions = self.config.sample_size == 128
208
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
209
+ # additional conditions until we find better name
210
+ self.adaln_single = AdaLayerNormSingle(
211
+ inner_dim, use_additional_conditions=self.use_additional_conditions
212
+ )
213
+
214
+ self.caption_projection = None
215
+ # if caption_channels is not None:
216
+ # self.caption_projection = CaptionProjection(
217
+ # in_features=caption_channels, hidden_size=inner_dim
218
+ # )
219
+
220
+ self.gradient_checkpointing = False
221
+
222
+ def _set_gradient_checkpointing(self, module, value=False):
223
+ if hasattr(module, "gradient_checkpointing"):
224
+ module.gradient_checkpointing = value
225
+
226
+ def forward(
227
+ self,
228
+ hidden_states: torch.Tensor,
229
+ encoder_hidden_states: Optional[torch.Tensor] = None,
230
+ timestep: Optional[torch.LongTensor] = None,
231
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
232
+ class_labels: Optional[torch.LongTensor] = None,
233
+ cross_attention_kwargs: Dict[str, Any] = None,
234
+ attention_mask: Optional[torch.Tensor] = None,
235
+ encoder_attention_mask: Optional[torch.Tensor] = None,
236
+ return_dict: bool = True,
237
+ ):
238
+ """
239
+ The [`Transformer2DModel`] forward method.
240
+
241
+ Args:
242
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
243
+ Input `hidden_states`.
244
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
245
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
246
+ self-attention.
247
+ timestep ( `torch.LongTensor`, *optional*):
248
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
249
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
250
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
251
+ `AdaLayerZeroNorm`.
252
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
253
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
254
+ `self.processor` in
255
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
256
+ attention_mask ( `torch.Tensor`, *optional*):
257
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
258
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
259
+ negative values to the attention scores corresponding to "discard" tokens.
260
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
261
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
262
+
263
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
264
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
265
+
266
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
267
+ above. This bias will be added to the cross-attention scores.
268
+ return_dict (`bool`, *optional*, defaults to `True`):
269
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
270
+ tuple.
271
+
272
+ Returns:
273
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
274
+ `tuple` where the first element is the sample tensor.
275
+ """
276
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
277
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
278
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
279
+ # expects mask of shape:
280
+ # [batch, key_tokens]
281
+ # adds singleton query_tokens dimension:
282
+ # [batch, 1, key_tokens]
283
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
284
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
285
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
286
+ if attention_mask is not None and attention_mask.ndim == 2:
287
+ # assume that mask is expressed as:
288
+ # (1 = keep, 0 = discard)
289
+ # convert mask into a bias that can be added to attention scores:
290
+ # (keep = +0, discard = -10000.0)
291
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
292
+ attention_mask = attention_mask.unsqueeze(1)
293
+
294
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
295
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
296
+ encoder_attention_mask = (
297
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
298
+ ) * -10000.0
299
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
300
+
301
+ # Retrieve lora scale.
302
+ lora_scale = (
303
+ cross_attention_kwargs.get("scale", 1.0)
304
+ if cross_attention_kwargs is not None
305
+ else 1.0
306
+ )
307
+
308
+ # 1. Input
309
+ batch, _, height, width = hidden_states.shape
310
+ residual = hidden_states
311
+
312
+ hidden_states = self.norm(hidden_states)
313
+ if not self.use_linear_projection:
314
+ hidden_states = (
315
+ self.proj_in(hidden_states, scale=lora_scale)
316
+ if not USE_PEFT_BACKEND
317
+ else self.proj_in(hidden_states)
318
+ )
319
+ inner_dim = hidden_states.shape[1]
320
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
321
+ batch, height * width, inner_dim
322
+ )
323
+ else:
324
+ inner_dim = hidden_states.shape[1]
325
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
326
+ batch, height * width, inner_dim
327
+ )
328
+ hidden_states = (
329
+ self.proj_in(hidden_states, scale=lora_scale)
330
+ if not USE_PEFT_BACKEND
331
+ else self.proj_in(hidden_states)
332
+ )
333
+
334
+ # 2. Blocks
335
+ if self.caption_projection is not None:
336
+ batch_size = hidden_states.shape[0]
337
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
338
+ encoder_hidden_states = encoder_hidden_states.view(
339
+ batch_size, -1, hidden_states.shape[-1]
340
+ )
341
+
342
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
343
+ for block in self.transformer_blocks:
344
+ if self.training and self.gradient_checkpointing:
345
+
346
+ def create_custom_forward(module, return_dict=None):
347
+ def custom_forward(*inputs):
348
+ if return_dict is not None:
349
+ return module(*inputs, return_dict=return_dict)
350
+ else:
351
+ return module(*inputs)
352
+
353
+ return custom_forward
354
+
355
+ ckpt_kwargs: Dict[str, Any] = (
356
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
357
+ )
358
+ hidden_states = torch.utils.checkpoint.checkpoint(
359
+ create_custom_forward(block),
360
+ hidden_states,
361
+ attention_mask,
362
+ encoder_hidden_states,
363
+ encoder_attention_mask,
364
+ timestep,
365
+ cross_attention_kwargs,
366
+ class_labels,
367
+ **ckpt_kwargs,
368
+ )
369
+ else:
370
+ hidden_states = block(
371
+ hidden_states,
372
+ attention_mask=attention_mask,
373
+ encoder_hidden_states=encoder_hidden_states,
374
+ encoder_attention_mask=encoder_attention_mask,
375
+ timestep=timestep,
376
+ cross_attention_kwargs=cross_attention_kwargs,
377
+ class_labels=class_labels,
378
+ )
379
+
380
+ # 3. Output
381
+ if self.is_input_continuous:
382
+ if not self.use_linear_projection:
383
+ hidden_states = (
384
+ hidden_states.reshape(batch, height, width, inner_dim)
385
+ .permute(0, 3, 1, 2)
386
+ .contiguous()
387
+ )
388
+ hidden_states = (
389
+ self.proj_out(hidden_states, scale=lora_scale)
390
+ if not USE_PEFT_BACKEND
391
+ else self.proj_out(hidden_states)
392
+ )
393
+ else:
394
+ hidden_states = (
395
+ self.proj_out(hidden_states, scale=lora_scale)
396
+ if not USE_PEFT_BACKEND
397
+ else self.proj_out(hidden_states)
398
+ )
399
+ hidden_states = (
400
+ hidden_states.reshape(batch, height, width, inner_dim)
401
+ .permute(0, 3, 1, 2)
402
+ .contiguous()
403
+ )
404
+
405
+ output = hidden_states + residual
406
+ if not return_dict:
407
+ return (output, ref_feature)
408
+
409
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
genwarp/models/transformer_3d.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # ==============================================================================
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.models import ModelMixin
15
+ from diffusers.utils import BaseOutput
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from einops import rearrange, repeat
18
+ from torch import nn
19
+
20
+ from .attention import TemporalBasicTransformerBlock
21
+
22
+
23
+ @dataclass
24
+ class Transformer3DModelOutput(BaseOutput):
25
+ sample: torch.FloatTensor
26
+
27
+
28
+ if is_xformers_available():
29
+ import xformers
30
+ import xformers.ops
31
+ else:
32
+ xformers = None
33
+
34
+
35
+ class Transformer3DModel(ModelMixin, ConfigMixin):
36
+ _supports_gradient_checkpointing = True
37
+
38
+ @register_to_config
39
+ def __init__(
40
+ self,
41
+ num_attention_heads: int = 16,
42
+ attention_head_dim: int = 88,
43
+ in_channels: Optional[int] = None,
44
+ num_layers: int = 1,
45
+ dropout: float = 0.0,
46
+ norm_num_groups: int = 32,
47
+ cross_attention_dim: Optional[int] = None,
48
+ attention_bias: bool = False,
49
+ activation_fn: str = "geglu",
50
+ num_embeds_ada_norm: Optional[int] = None,
51
+ use_linear_projection: bool = False,
52
+ only_cross_attention: bool = False,
53
+ upcast_attention: bool = False,
54
+ unet_use_cross_frame_attention=None,
55
+ unet_use_temporal_attention=None,
56
+ use_zero_convs=False,
57
+ ):
58
+ super().__init__()
59
+ self.use_linear_projection = use_linear_projection
60
+ self.num_attention_heads = num_attention_heads
61
+ self.attention_head_dim = attention_head_dim
62
+ inner_dim = num_attention_heads * attention_head_dim
63
+
64
+ # Define input layers
65
+ self.in_channels = in_channels
66
+
67
+ self.norm = torch.nn.GroupNorm(
68
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
69
+ )
70
+ if use_linear_projection:
71
+ self.proj_in = nn.Linear(in_channels, inner_dim)
72
+ else:
73
+ self.proj_in = nn.Conv2d(
74
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
75
+ )
76
+
77
+ # Define transformers blocks
78
+ self.transformer_blocks = nn.ModuleList(
79
+ [
80
+ TemporalBasicTransformerBlock(
81
+ inner_dim,
82
+ num_attention_heads,
83
+ attention_head_dim,
84
+ dropout=dropout,
85
+ cross_attention_dim=cross_attention_dim,
86
+ activation_fn=activation_fn,
87
+ num_embeds_ada_norm=num_embeds_ada_norm,
88
+ attention_bias=attention_bias,
89
+ only_cross_attention=only_cross_attention,
90
+ upcast_attention=upcast_attention,
91
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
92
+ unet_use_temporal_attention=unet_use_temporal_attention,
93
+ use_zero_convs=use_zero_convs,
94
+ )
95
+ for d in range(num_layers)
96
+ ]
97
+ )
98
+
99
+ # 4. Define output layers
100
+ if use_linear_projection:
101
+ self.proj_out = nn.Linear(in_channels, inner_dim)
102
+ else:
103
+ self.proj_out = nn.Conv2d(
104
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
105
+ )
106
+
107
+ self.gradient_checkpointing = False
108
+
109
+ def _set_gradient_checkpointing(self, module, value=False):
110
+ if hasattr(module, "gradient_checkpointing"):
111
+ module.gradient_checkpointing = value
112
+
113
+ def forward(
114
+ self,
115
+ hidden_states,
116
+ encoder_hidden_states=None,
117
+ timestep=None,
118
+ return_dict: bool = True,
119
+ ):
120
+ # Input
121
+ assert (
122
+ hidden_states.dim() == 5
123
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
124
+ video_length = hidden_states.shape[2]
125
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
126
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
127
+ encoder_hidden_states = repeat(
128
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
129
+ )
130
+
131
+ batch, channel, height, weight = hidden_states.shape
132
+ residual = hidden_states
133
+
134
+ hidden_states = self.norm(hidden_states)
135
+ if not self.use_linear_projection:
136
+ hidden_states = self.proj_in(hidden_states)
137
+ inner_dim = hidden_states.shape[1]
138
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
139
+ batch, height * weight, inner_dim
140
+ )
141
+ else:
142
+ inner_dim = hidden_states.shape[1]
143
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
144
+ batch, height * weight, inner_dim
145
+ )
146
+ hidden_states = self.proj_in(hidden_states)
147
+
148
+ # Blocks
149
+ for i, block in enumerate(self.transformer_blocks):
150
+ hidden_states = block(
151
+ hidden_states,
152
+ encoder_hidden_states=encoder_hidden_states,
153
+ timestep=timestep,
154
+ video_length=video_length,
155
+ )
156
+
157
+ # Output
158
+ if not self.use_linear_projection:
159
+ hidden_states = (
160
+ hidden_states.reshape(batch, height, weight, inner_dim)
161
+ .permute(0, 3, 1, 2)
162
+ .contiguous()
163
+ )
164
+ hidden_states = self.proj_out(hidden_states)
165
+ else:
166
+ hidden_states = self.proj_out(hidden_states)
167
+ hidden_states = (
168
+ hidden_states.reshape(batch, height, weight, inner_dim)
169
+ .permute(0, 3, 1, 2)
170
+ .contiguous()
171
+ )
172
+
173
+ output = hidden_states + residual
174
+
175
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
176
+ if not return_dict:
177
+ return (output,)
178
+
179
+ return Transformer3DModelOutput(sample=output)
genwarp/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1087 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # Diffusers
9
+ # Apache License, Version 2.0
10
+ # Copyright (c) Hugging Face Inc.
11
+ # https://github.com/huggingface/diffusers
12
+ # ==============================================================================
13
+
14
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
15
+ from typing import Any, Dict, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from diffusers.models.activations import get_activation
21
+ from diffusers.models.attention_processor import Attention
22
+ from diffusers.models import DualTransformer2DModel
23
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
24
+ from diffusers.utils import is_torch_version, logging
25
+ from diffusers.utils.torch_utils import apply_freeu
26
+ from torch import nn
27
+
28
+ from .transformer_2d import Transformer2DModel
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ def get_down_block(
34
+ down_block_type: str,
35
+ num_layers: int,
36
+ in_channels: int,
37
+ out_channels: int,
38
+ temb_channels: int,
39
+ add_downsample: bool,
40
+ resnet_eps: float,
41
+ resnet_act_fn: str,
42
+ transformer_layers_per_block: int = 1,
43
+ num_attention_heads: Optional[int] = None,
44
+ resnet_groups: Optional[int] = None,
45
+ cross_attention_dim: Optional[int] = None,
46
+ downsample_padding: Optional[int] = None,
47
+ dual_cross_attention: bool = False,
48
+ use_linear_projection: bool = False,
49
+ only_cross_attention: bool = False,
50
+ upcast_attention: bool = False,
51
+ resnet_time_scale_shift: str = "default",
52
+ attention_type: str = "default",
53
+ resnet_skip_time_act: bool = False,
54
+ resnet_out_scale_factor: float = 1.0,
55
+ cross_attention_norm: Optional[str] = None,
56
+ attention_head_dim: Optional[int] = None,
57
+ downsample_type: Optional[str] = None,
58
+ dropout: float = 0.0,
59
+ ):
60
+ # If attn head dim is not defined, we default it to the number of heads
61
+ if attention_head_dim is None:
62
+ logger.warn(
63
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
64
+ )
65
+ attention_head_dim = num_attention_heads
66
+
67
+ down_block_type = (
68
+ down_block_type[7:]
69
+ if down_block_type.startswith("UNetRes")
70
+ else down_block_type
71
+ )
72
+ if down_block_type == "DownBlock2D":
73
+ return DownBlock2D(
74
+ num_layers=num_layers,
75
+ in_channels=in_channels,
76
+ out_channels=out_channels,
77
+ temb_channels=temb_channels,
78
+ dropout=dropout,
79
+ add_downsample=add_downsample,
80
+ resnet_eps=resnet_eps,
81
+ resnet_act_fn=resnet_act_fn,
82
+ resnet_groups=resnet_groups,
83
+ downsample_padding=downsample_padding,
84
+ resnet_time_scale_shift=resnet_time_scale_shift,
85
+ )
86
+ elif down_block_type == "CrossAttnDownBlock2D":
87
+ if cross_attention_dim is None:
88
+ raise ValueError(
89
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
90
+ )
91
+ return CrossAttnDownBlock2D(
92
+ num_layers=num_layers,
93
+ transformer_layers_per_block=transformer_layers_per_block,
94
+ in_channels=in_channels,
95
+ out_channels=out_channels,
96
+ temb_channels=temb_channels,
97
+ dropout=dropout,
98
+ add_downsample=add_downsample,
99
+ resnet_eps=resnet_eps,
100
+ resnet_act_fn=resnet_act_fn,
101
+ resnet_groups=resnet_groups,
102
+ downsample_padding=downsample_padding,
103
+ cross_attention_dim=cross_attention_dim,
104
+ num_attention_heads=num_attention_heads,
105
+ dual_cross_attention=dual_cross_attention,
106
+ use_linear_projection=use_linear_projection,
107
+ only_cross_attention=only_cross_attention,
108
+ upcast_attention=upcast_attention,
109
+ resnet_time_scale_shift=resnet_time_scale_shift,
110
+ attention_type=attention_type,
111
+ )
112
+ raise ValueError(f"{down_block_type} does not exist.")
113
+
114
+
115
+ def get_up_block(
116
+ up_block_type: str,
117
+ num_layers: int,
118
+ in_channels: int,
119
+ out_channels: int,
120
+ prev_output_channel: int,
121
+ temb_channels: int,
122
+ add_upsample: bool,
123
+ resnet_eps: float,
124
+ resnet_act_fn: str,
125
+ resolution_idx: Optional[int] = None,
126
+ transformer_layers_per_block: int = 1,
127
+ num_attention_heads: Optional[int] = None,
128
+ resnet_groups: Optional[int] = None,
129
+ cross_attention_dim: Optional[int] = None,
130
+ dual_cross_attention: bool = False,
131
+ use_linear_projection: bool = False,
132
+ only_cross_attention: bool = False,
133
+ upcast_attention: bool = False,
134
+ resnet_time_scale_shift: str = "default",
135
+ attention_type: str = "default",
136
+ resnet_skip_time_act: bool = False,
137
+ resnet_out_scale_factor: float = 1.0,
138
+ cross_attention_norm: Optional[str] = None,
139
+ attention_head_dim: Optional[int] = None,
140
+ upsample_type: Optional[str] = None,
141
+ dropout: float = 0.0,
142
+ ) -> nn.Module:
143
+ # If attn head dim is not defined, we default it to the number of heads
144
+ if attention_head_dim is None:
145
+ logger.warn(
146
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
147
+ )
148
+ attention_head_dim = num_attention_heads
149
+
150
+ up_block_type = (
151
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
152
+ )
153
+ if up_block_type == "UpBlock2D":
154
+ return UpBlock2D(
155
+ num_layers=num_layers,
156
+ in_channels=in_channels,
157
+ out_channels=out_channels,
158
+ prev_output_channel=prev_output_channel,
159
+ temb_channels=temb_channels,
160
+ resolution_idx=resolution_idx,
161
+ dropout=dropout,
162
+ add_upsample=add_upsample,
163
+ resnet_eps=resnet_eps,
164
+ resnet_act_fn=resnet_act_fn,
165
+ resnet_groups=resnet_groups,
166
+ resnet_time_scale_shift=resnet_time_scale_shift,
167
+ )
168
+ elif up_block_type == "CrossAttnUpBlock2D":
169
+ if cross_attention_dim is None:
170
+ raise ValueError(
171
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
172
+ )
173
+ return CrossAttnUpBlock2D(
174
+ num_layers=num_layers,
175
+ transformer_layers_per_block=transformer_layers_per_block,
176
+ in_channels=in_channels,
177
+ out_channels=out_channels,
178
+ prev_output_channel=prev_output_channel,
179
+ temb_channels=temb_channels,
180
+ resolution_idx=resolution_idx,
181
+ dropout=dropout,
182
+ add_upsample=add_upsample,
183
+ resnet_eps=resnet_eps,
184
+ resnet_act_fn=resnet_act_fn,
185
+ resnet_groups=resnet_groups,
186
+ cross_attention_dim=cross_attention_dim,
187
+ num_attention_heads=num_attention_heads,
188
+ dual_cross_attention=dual_cross_attention,
189
+ use_linear_projection=use_linear_projection,
190
+ only_cross_attention=only_cross_attention,
191
+ upcast_attention=upcast_attention,
192
+ resnet_time_scale_shift=resnet_time_scale_shift,
193
+ attention_type=attention_type,
194
+ )
195
+
196
+ raise ValueError(f"{up_block_type} does not exist.")
197
+
198
+
199
+ class AutoencoderTinyBlock(nn.Module):
200
+ """
201
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
202
+ blocks.
203
+
204
+ Args:
205
+ in_channels (`int`): The number of input channels.
206
+ out_channels (`int`): The number of output channels.
207
+ act_fn (`str`):
208
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
209
+
210
+ Returns:
211
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
212
+ `out_channels`.
213
+ """
214
+
215
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
216
+ super().__init__()
217
+ act_fn = get_activation(act_fn)
218
+ self.conv = nn.Sequential(
219
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
220
+ act_fn,
221
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
222
+ act_fn,
223
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
224
+ )
225
+ self.skip = (
226
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
227
+ if in_channels != out_channels
228
+ else nn.Identity()
229
+ )
230
+ self.fuse = nn.ReLU()
231
+
232
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
233
+ return self.fuse(self.conv(x) + self.skip(x))
234
+
235
+
236
+ class UNetMidBlock2D(nn.Module):
237
+ """
238
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
239
+
240
+ Args:
241
+ in_channels (`int`): The number of input channels.
242
+ temb_channels (`int`): The number of temporal embedding channels.
243
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
244
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
245
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
246
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
247
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
248
+ model on tasks with long-range temporal dependencies.
249
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
250
+ resnet_groups (`int`, *optional*, defaults to 32):
251
+ The number of groups to use in the group normalization layers of the resnet blocks.
252
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
253
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
254
+ Whether to use pre-normalization for the resnet blocks.
255
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
256
+ attention_head_dim (`int`, *optional*, defaults to 1):
257
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
258
+ the number of input channels.
259
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
260
+
261
+ Returns:
262
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
263
+ in_channels, height, width)`.
264
+
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ in_channels: int,
270
+ temb_channels: int,
271
+ dropout: float = 0.0,
272
+ num_layers: int = 1,
273
+ resnet_eps: float = 1e-6,
274
+ resnet_time_scale_shift: str = "default", # default, spatial
275
+ resnet_act_fn: str = "swish",
276
+ resnet_groups: int = 32,
277
+ attn_groups: Optional[int] = None,
278
+ resnet_pre_norm: bool = True,
279
+ add_attention: bool = True,
280
+ attention_head_dim: int = 1,
281
+ output_scale_factor: float = 1.0,
282
+ ):
283
+ super().__init__()
284
+ resnet_groups = (
285
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
286
+ )
287
+ self.add_attention = add_attention
288
+
289
+ if attn_groups is None:
290
+ attn_groups = (
291
+ resnet_groups if resnet_time_scale_shift == "default" else None
292
+ )
293
+
294
+ # there is always at least one resnet
295
+ resnets = [
296
+ ResnetBlock2D(
297
+ in_channels=in_channels,
298
+ out_channels=in_channels,
299
+ temb_channels=temb_channels,
300
+ eps=resnet_eps,
301
+ groups=resnet_groups,
302
+ dropout=dropout,
303
+ time_embedding_norm=resnet_time_scale_shift,
304
+ non_linearity=resnet_act_fn,
305
+ output_scale_factor=output_scale_factor,
306
+ pre_norm=resnet_pre_norm,
307
+ )
308
+ ]
309
+ attentions = []
310
+
311
+ if attention_head_dim is None:
312
+ logger.warn(
313
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
314
+ )
315
+ attention_head_dim = in_channels
316
+
317
+ for _ in range(num_layers):
318
+ if self.add_attention:
319
+ attentions.append(
320
+ Attention(
321
+ in_channels,
322
+ heads=in_channels // attention_head_dim,
323
+ dim_head=attention_head_dim,
324
+ rescale_output_factor=output_scale_factor,
325
+ eps=resnet_eps,
326
+ norm_num_groups=attn_groups,
327
+ spatial_norm_dim=temb_channels
328
+ if resnet_time_scale_shift == "spatial"
329
+ else None,
330
+ residual_connection=True,
331
+ bias=True,
332
+ upcast_softmax=True,
333
+ _from_deprecated_attn_block=True,
334
+ )
335
+ )
336
+ else:
337
+ attentions.append(None)
338
+
339
+ resnets.append(
340
+ ResnetBlock2D(
341
+ in_channels=in_channels,
342
+ out_channels=in_channels,
343
+ temb_channels=temb_channels,
344
+ eps=resnet_eps,
345
+ groups=resnet_groups,
346
+ dropout=dropout,
347
+ time_embedding_norm=resnet_time_scale_shift,
348
+ non_linearity=resnet_act_fn,
349
+ output_scale_factor=output_scale_factor,
350
+ pre_norm=resnet_pre_norm,
351
+ )
352
+ )
353
+
354
+ self.attentions = nn.ModuleList(attentions)
355
+ self.resnets = nn.ModuleList(resnets)
356
+
357
+ def forward(
358
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
359
+ ) -> torch.FloatTensor:
360
+ hidden_states = self.resnets[0](hidden_states, temb)
361
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
362
+ if attn is not None:
363
+ hidden_states = attn(hidden_states, temb=temb)
364
+ hidden_states = resnet(hidden_states, temb)
365
+
366
+ return hidden_states
367
+
368
+
369
+ class UNetMidBlock2DCrossAttn(nn.Module):
370
+ def __init__(
371
+ self,
372
+ in_channels: int,
373
+ temb_channels: int,
374
+ dropout: float = 0.0,
375
+ num_layers: int = 1,
376
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
377
+ resnet_eps: float = 1e-6,
378
+ resnet_time_scale_shift: str = "default",
379
+ resnet_act_fn: str = "swish",
380
+ resnet_groups: int = 32,
381
+ resnet_pre_norm: bool = True,
382
+ num_attention_heads: int = 1,
383
+ output_scale_factor: float = 1.0,
384
+ cross_attention_dim: int = 1280,
385
+ dual_cross_attention: bool = False,
386
+ use_linear_projection: bool = False,
387
+ upcast_attention: bool = False,
388
+ attention_type: str = "default",
389
+ ):
390
+ super().__init__()
391
+
392
+ self.has_cross_attention = True
393
+ self.num_attention_heads = num_attention_heads
394
+ resnet_groups = (
395
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
396
+ )
397
+
398
+ # support for variable transformer layers per block
399
+ if isinstance(transformer_layers_per_block, int):
400
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
401
+
402
+ # there is always at least one resnet
403
+ resnets = [
404
+ ResnetBlock2D(
405
+ in_channels=in_channels,
406
+ out_channels=in_channels,
407
+ temb_channels=temb_channels,
408
+ eps=resnet_eps,
409
+ groups=resnet_groups,
410
+ dropout=dropout,
411
+ time_embedding_norm=resnet_time_scale_shift,
412
+ non_linearity=resnet_act_fn,
413
+ output_scale_factor=output_scale_factor,
414
+ pre_norm=resnet_pre_norm,
415
+ )
416
+ ]
417
+ attentions = []
418
+
419
+ for i in range(num_layers):
420
+ if not dual_cross_attention:
421
+ attentions.append(
422
+ Transformer2DModel(
423
+ num_attention_heads,
424
+ in_channels // num_attention_heads,
425
+ in_channels=in_channels,
426
+ num_layers=transformer_layers_per_block[i],
427
+ cross_attention_dim=cross_attention_dim,
428
+ norm_num_groups=resnet_groups,
429
+ use_linear_projection=use_linear_projection,
430
+ upcast_attention=upcast_attention,
431
+ attention_type=attention_type,
432
+ )
433
+ )
434
+ else:
435
+ attentions.append(
436
+ DualTransformer2DModel(
437
+ num_attention_heads,
438
+ in_channels // num_attention_heads,
439
+ in_channels=in_channels,
440
+ num_layers=1,
441
+ cross_attention_dim=cross_attention_dim,
442
+ norm_num_groups=resnet_groups,
443
+ )
444
+ )
445
+ resnets.append(
446
+ ResnetBlock2D(
447
+ in_channels=in_channels,
448
+ out_channels=in_channels,
449
+ temb_channels=temb_channels,
450
+ eps=resnet_eps,
451
+ groups=resnet_groups,
452
+ dropout=dropout,
453
+ time_embedding_norm=resnet_time_scale_shift,
454
+ non_linearity=resnet_act_fn,
455
+ output_scale_factor=output_scale_factor,
456
+ pre_norm=resnet_pre_norm,
457
+ )
458
+ )
459
+
460
+ self.attentions = nn.ModuleList(attentions)
461
+ self.resnets = nn.ModuleList(resnets)
462
+
463
+ self.gradient_checkpointing = False
464
+
465
+ def forward(
466
+ self,
467
+ hidden_states: torch.FloatTensor,
468
+ temb: Optional[torch.FloatTensor] = None,
469
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
470
+ attention_mask: Optional[torch.FloatTensor] = None,
471
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
472
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
473
+ ) -> torch.FloatTensor:
474
+ lora_scale = (
475
+ cross_attention_kwargs.get("scale", 1.0)
476
+ if cross_attention_kwargs is not None
477
+ else 1.0
478
+ )
479
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
480
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
481
+ if self.training and self.gradient_checkpointing:
482
+
483
+ def create_custom_forward(module, return_dict=None):
484
+ def custom_forward(*inputs):
485
+ if return_dict is not None:
486
+ return module(*inputs, return_dict=return_dict)
487
+ else:
488
+ return module(*inputs)
489
+
490
+ return custom_forward
491
+
492
+ ckpt_kwargs: Dict[str, Any] = (
493
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
494
+ )
495
+ hidden_states, ref_feature = attn(
496
+ hidden_states,
497
+ encoder_hidden_states=encoder_hidden_states,
498
+ cross_attention_kwargs=cross_attention_kwargs,
499
+ attention_mask=attention_mask,
500
+ encoder_attention_mask=encoder_attention_mask,
501
+ return_dict=False,
502
+ )
503
+ hidden_states = torch.utils.checkpoint.checkpoint(
504
+ create_custom_forward(resnet),
505
+ hidden_states,
506
+ temb,
507
+ **ckpt_kwargs,
508
+ )
509
+ else:
510
+ hidden_states, ref_feature = attn(
511
+ hidden_states,
512
+ encoder_hidden_states=encoder_hidden_states,
513
+ cross_attention_kwargs=cross_attention_kwargs,
514
+ attention_mask=attention_mask,
515
+ encoder_attention_mask=encoder_attention_mask,
516
+ return_dict=False,
517
+ )
518
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
519
+
520
+ return hidden_states
521
+
522
+
523
+ class CrossAttnDownBlock2D(nn.Module):
524
+ def __init__(
525
+ self,
526
+ in_channels: int,
527
+ out_channels: int,
528
+ temb_channels: int,
529
+ dropout: float = 0.0,
530
+ num_layers: int = 1,
531
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
532
+ resnet_eps: float = 1e-6,
533
+ resnet_time_scale_shift: str = "default",
534
+ resnet_act_fn: str = "swish",
535
+ resnet_groups: int = 32,
536
+ resnet_pre_norm: bool = True,
537
+ num_attention_heads: int = 1,
538
+ cross_attention_dim: int = 1280,
539
+ output_scale_factor: float = 1.0,
540
+ downsample_padding: int = 1,
541
+ add_downsample: bool = True,
542
+ dual_cross_attention: bool = False,
543
+ use_linear_projection: bool = False,
544
+ only_cross_attention: bool = False,
545
+ upcast_attention: bool = False,
546
+ attention_type: str = "default",
547
+ ):
548
+ super().__init__()
549
+ resnets = []
550
+ attentions = []
551
+
552
+ self.has_cross_attention = True
553
+ self.num_attention_heads = num_attention_heads
554
+ if isinstance(transformer_layers_per_block, int):
555
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
556
+
557
+ for i in range(num_layers):
558
+ in_channels = in_channels if i == 0 else out_channels
559
+ resnets.append(
560
+ ResnetBlock2D(
561
+ in_channels=in_channels,
562
+ out_channels=out_channels,
563
+ temb_channels=temb_channels,
564
+ eps=resnet_eps,
565
+ groups=resnet_groups,
566
+ dropout=dropout,
567
+ time_embedding_norm=resnet_time_scale_shift,
568
+ non_linearity=resnet_act_fn,
569
+ output_scale_factor=output_scale_factor,
570
+ pre_norm=resnet_pre_norm,
571
+ )
572
+ )
573
+ if not dual_cross_attention:
574
+ attentions.append(
575
+ Transformer2DModel(
576
+ num_attention_heads,
577
+ out_channels // num_attention_heads,
578
+ in_channels=out_channels,
579
+ num_layers=transformer_layers_per_block[i],
580
+ cross_attention_dim=cross_attention_dim,
581
+ norm_num_groups=resnet_groups,
582
+ use_linear_projection=use_linear_projection,
583
+ only_cross_attention=only_cross_attention,
584
+ upcast_attention=upcast_attention,
585
+ attention_type=attention_type,
586
+ )
587
+ )
588
+ else:
589
+ attentions.append(
590
+ DualTransformer2DModel(
591
+ num_attention_heads,
592
+ out_channels // num_attention_heads,
593
+ in_channels=out_channels,
594
+ num_layers=1,
595
+ cross_attention_dim=cross_attention_dim,
596
+ norm_num_groups=resnet_groups,
597
+ )
598
+ )
599
+ self.attentions = nn.ModuleList(attentions)
600
+ self.resnets = nn.ModuleList(resnets)
601
+
602
+ if add_downsample:
603
+ self.downsamplers = nn.ModuleList(
604
+ [
605
+ Downsample2D(
606
+ out_channels,
607
+ use_conv=True,
608
+ out_channels=out_channels,
609
+ padding=downsample_padding,
610
+ name="op",
611
+ )
612
+ ]
613
+ )
614
+ else:
615
+ self.downsamplers = None
616
+
617
+ self.gradient_checkpointing = False
618
+
619
+ def forward(
620
+ self,
621
+ hidden_states: torch.FloatTensor,
622
+ temb: Optional[torch.FloatTensor] = None,
623
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
624
+ attention_mask: Optional[torch.FloatTensor] = None,
625
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
626
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
627
+ additional_residuals: Optional[torch.FloatTensor] = None,
628
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
629
+ output_states = ()
630
+
631
+ lora_scale = (
632
+ cross_attention_kwargs.get("scale", 1.0)
633
+ if cross_attention_kwargs is not None
634
+ else 1.0
635
+ )
636
+
637
+ blocks = list(zip(self.resnets, self.attentions))
638
+
639
+ for i, (resnet, attn) in enumerate(blocks):
640
+ if self.training and self.gradient_checkpointing:
641
+
642
+ def create_custom_forward(module, return_dict=None):
643
+ def custom_forward(*inputs):
644
+ if return_dict is not None:
645
+ return module(*inputs, return_dict=return_dict)
646
+ else:
647
+ return module(*inputs)
648
+
649
+ return custom_forward
650
+
651
+ ckpt_kwargs: Dict[str, Any] = (
652
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
653
+ )
654
+ hidden_states = torch.utils.checkpoint.checkpoint(
655
+ create_custom_forward(resnet),
656
+ hidden_states,
657
+ temb,
658
+ **ckpt_kwargs,
659
+ )
660
+ hidden_states, ref_feature = attn(
661
+ hidden_states,
662
+ encoder_hidden_states=encoder_hidden_states,
663
+ cross_attention_kwargs=cross_attention_kwargs,
664
+ attention_mask=attention_mask,
665
+ encoder_attention_mask=encoder_attention_mask,
666
+ return_dict=False,
667
+ )
668
+ else:
669
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
670
+ hidden_states, ref_feature = attn(
671
+ hidden_states,
672
+ encoder_hidden_states=encoder_hidden_states,
673
+ cross_attention_kwargs=cross_attention_kwargs,
674
+ attention_mask=attention_mask,
675
+ encoder_attention_mask=encoder_attention_mask,
676
+ return_dict=False,
677
+ )
678
+
679
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
680
+ if i == len(blocks) - 1 and additional_residuals is not None:
681
+ hidden_states = hidden_states + additional_residuals
682
+
683
+ output_states = output_states + (hidden_states,)
684
+
685
+ if self.downsamplers is not None:
686
+ for downsampler in self.downsamplers:
687
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
688
+
689
+ output_states = output_states + (hidden_states,)
690
+
691
+ return hidden_states, output_states
692
+
693
+
694
+ class DownBlock2D(nn.Module):
695
+ def __init__(
696
+ self,
697
+ in_channels: int,
698
+ out_channels: int,
699
+ temb_channels: int,
700
+ dropout: float = 0.0,
701
+ num_layers: int = 1,
702
+ resnet_eps: float = 1e-6,
703
+ resnet_time_scale_shift: str = "default",
704
+ resnet_act_fn: str = "swish",
705
+ resnet_groups: int = 32,
706
+ resnet_pre_norm: bool = True,
707
+ output_scale_factor: float = 1.0,
708
+ add_downsample: bool = True,
709
+ downsample_padding: int = 1,
710
+ ):
711
+ super().__init__()
712
+ resnets = []
713
+
714
+ for i in range(num_layers):
715
+ in_channels = in_channels if i == 0 else out_channels
716
+ resnets.append(
717
+ ResnetBlock2D(
718
+ in_channels=in_channels,
719
+ out_channels=out_channels,
720
+ temb_channels=temb_channels,
721
+ eps=resnet_eps,
722
+ groups=resnet_groups,
723
+ dropout=dropout,
724
+ time_embedding_norm=resnet_time_scale_shift,
725
+ non_linearity=resnet_act_fn,
726
+ output_scale_factor=output_scale_factor,
727
+ pre_norm=resnet_pre_norm,
728
+ )
729
+ )
730
+
731
+ self.resnets = nn.ModuleList(resnets)
732
+
733
+ if add_downsample:
734
+ self.downsamplers = nn.ModuleList(
735
+ [
736
+ Downsample2D(
737
+ out_channels,
738
+ use_conv=True,
739
+ out_channels=out_channels,
740
+ padding=downsample_padding,
741
+ name="op",
742
+ )
743
+ ]
744
+ )
745
+ else:
746
+ self.downsamplers = None
747
+
748
+ self.gradient_checkpointing = False
749
+
750
+ def forward(
751
+ self,
752
+ hidden_states: torch.FloatTensor,
753
+ temb: Optional[torch.FloatTensor] = None,
754
+ scale: float = 1.0,
755
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
756
+ output_states = ()
757
+
758
+ for resnet in self.resnets:
759
+ if self.training and self.gradient_checkpointing:
760
+
761
+ def create_custom_forward(module):
762
+ def custom_forward(*inputs):
763
+ return module(*inputs)
764
+
765
+ return custom_forward
766
+
767
+ if is_torch_version(">=", "1.11.0"):
768
+ hidden_states = torch.utils.checkpoint.checkpoint(
769
+ create_custom_forward(resnet),
770
+ hidden_states,
771
+ temb,
772
+ use_reentrant=False,
773
+ )
774
+ else:
775
+ hidden_states = torch.utils.checkpoint.checkpoint(
776
+ create_custom_forward(resnet), hidden_states, temb
777
+ )
778
+ else:
779
+ hidden_states = resnet(hidden_states, temb, scale=scale)
780
+
781
+ output_states = output_states + (hidden_states,)
782
+
783
+ if self.downsamplers is not None:
784
+ for downsampler in self.downsamplers:
785
+ hidden_states = downsampler(hidden_states, scale=scale)
786
+
787
+ output_states = output_states + (hidden_states,)
788
+
789
+ return hidden_states, output_states
790
+
791
+
792
+ class CrossAttnUpBlock2D(nn.Module):
793
+ def __init__(
794
+ self,
795
+ in_channels: int,
796
+ out_channels: int,
797
+ prev_output_channel: int,
798
+ temb_channels: int,
799
+ resolution_idx: Optional[int] = None,
800
+ dropout: float = 0.0,
801
+ num_layers: int = 1,
802
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
803
+ resnet_eps: float = 1e-6,
804
+ resnet_time_scale_shift: str = "default",
805
+ resnet_act_fn: str = "swish",
806
+ resnet_groups: int = 32,
807
+ resnet_pre_norm: bool = True,
808
+ num_attention_heads: int = 1,
809
+ cross_attention_dim: int = 1280,
810
+ output_scale_factor: float = 1.0,
811
+ add_upsample: bool = True,
812
+ dual_cross_attention: bool = False,
813
+ use_linear_projection: bool = False,
814
+ only_cross_attention: bool = False,
815
+ upcast_attention: bool = False,
816
+ attention_type: str = "default",
817
+ ):
818
+ super().__init__()
819
+ resnets = []
820
+ attentions = []
821
+
822
+ self.has_cross_attention = True
823
+ self.num_attention_heads = num_attention_heads
824
+
825
+ if isinstance(transformer_layers_per_block, int):
826
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
827
+
828
+ for i in range(num_layers):
829
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
830
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
831
+
832
+ resnets.append(
833
+ ResnetBlock2D(
834
+ in_channels=resnet_in_channels + res_skip_channels,
835
+ out_channels=out_channels,
836
+ temb_channels=temb_channels,
837
+ eps=resnet_eps,
838
+ groups=resnet_groups,
839
+ dropout=dropout,
840
+ time_embedding_norm=resnet_time_scale_shift,
841
+ non_linearity=resnet_act_fn,
842
+ output_scale_factor=output_scale_factor,
843
+ pre_norm=resnet_pre_norm,
844
+ )
845
+ )
846
+ if not dual_cross_attention:
847
+ attentions.append(
848
+ Transformer2DModel(
849
+ num_attention_heads,
850
+ out_channels // num_attention_heads,
851
+ in_channels=out_channels,
852
+ num_layers=transformer_layers_per_block[i],
853
+ cross_attention_dim=cross_attention_dim,
854
+ norm_num_groups=resnet_groups,
855
+ use_linear_projection=use_linear_projection,
856
+ only_cross_attention=only_cross_attention,
857
+ upcast_attention=upcast_attention,
858
+ attention_type=attention_type,
859
+ )
860
+ )
861
+ else:
862
+ attentions.append(
863
+ DualTransformer2DModel(
864
+ num_attention_heads,
865
+ out_channels // num_attention_heads,
866
+ in_channels=out_channels,
867
+ num_layers=1,
868
+ cross_attention_dim=cross_attention_dim,
869
+ norm_num_groups=resnet_groups,
870
+ )
871
+ )
872
+ self.attentions = nn.ModuleList(attentions)
873
+ self.resnets = nn.ModuleList(resnets)
874
+
875
+ if add_upsample:
876
+ self.upsamplers = nn.ModuleList(
877
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
878
+ )
879
+ else:
880
+ self.upsamplers = None
881
+
882
+ self.gradient_checkpointing = False
883
+ self.resolution_idx = resolution_idx
884
+
885
+ def forward(
886
+ self,
887
+ hidden_states: torch.FloatTensor,
888
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
889
+ temb: Optional[torch.FloatTensor] = None,
890
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
891
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
892
+ upsample_size: Optional[int] = None,
893
+ attention_mask: Optional[torch.FloatTensor] = None,
894
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
895
+ ) -> torch.FloatTensor:
896
+ lora_scale = (
897
+ cross_attention_kwargs.get("scale", 1.0)
898
+ if cross_attention_kwargs is not None
899
+ else 1.0
900
+ )
901
+ is_freeu_enabled = (
902
+ getattr(self, "s1", None)
903
+ and getattr(self, "s2", None)
904
+ and getattr(self, "b1", None)
905
+ and getattr(self, "b2", None)
906
+ )
907
+
908
+ for resnet, attn in zip(self.resnets, self.attentions):
909
+ # pop res hidden states
910
+ res_hidden_states = res_hidden_states_tuple[-1]
911
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
912
+
913
+ # FreeU: Only operate on the first two stages
914
+ if is_freeu_enabled:
915
+ hidden_states, res_hidden_states = apply_freeu(
916
+ self.resolution_idx,
917
+ hidden_states,
918
+ res_hidden_states,
919
+ s1=self.s1,
920
+ s2=self.s2,
921
+ b1=self.b1,
922
+ b2=self.b2,
923
+ )
924
+
925
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
926
+
927
+ if self.training and self.gradient_checkpointing:
928
+
929
+ def create_custom_forward(module, return_dict=None):
930
+ def custom_forward(*inputs):
931
+ if return_dict is not None:
932
+ return module(*inputs, return_dict=return_dict)
933
+ else:
934
+ return module(*inputs)
935
+
936
+ return custom_forward
937
+
938
+ ckpt_kwargs: Dict[str, Any] = (
939
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
940
+ )
941
+ hidden_states = torch.utils.checkpoint.checkpoint(
942
+ create_custom_forward(resnet),
943
+ hidden_states,
944
+ temb,
945
+ **ckpt_kwargs,
946
+ )
947
+ hidden_states, ref_feature = attn(
948
+ hidden_states,
949
+ encoder_hidden_states=encoder_hidden_states,
950
+ cross_attention_kwargs=cross_attention_kwargs,
951
+ attention_mask=attention_mask,
952
+ encoder_attention_mask=encoder_attention_mask,
953
+ return_dict=False,
954
+ )
955
+ else:
956
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
957
+ hidden_states, ref_feature = attn(
958
+ hidden_states,
959
+ encoder_hidden_states=encoder_hidden_states,
960
+ cross_attention_kwargs=cross_attention_kwargs,
961
+ attention_mask=attention_mask,
962
+ encoder_attention_mask=encoder_attention_mask,
963
+ return_dict=False,
964
+ )
965
+
966
+ if self.upsamplers is not None:
967
+ for upsampler in self.upsamplers:
968
+ hidden_states = upsampler(
969
+ hidden_states, upsample_size, scale=lora_scale
970
+ )
971
+
972
+ return hidden_states
973
+
974
+
975
+ class UpBlock2D(nn.Module):
976
+ def __init__(
977
+ self,
978
+ in_channels: int,
979
+ prev_output_channel: int,
980
+ out_channels: int,
981
+ temb_channels: int,
982
+ resolution_idx: Optional[int] = None,
983
+ dropout: float = 0.0,
984
+ num_layers: int = 1,
985
+ resnet_eps: float = 1e-6,
986
+ resnet_time_scale_shift: str = "default",
987
+ resnet_act_fn: str = "swish",
988
+ resnet_groups: int = 32,
989
+ resnet_pre_norm: bool = True,
990
+ output_scale_factor: float = 1.0,
991
+ add_upsample: bool = True,
992
+ ):
993
+ super().__init__()
994
+ resnets = []
995
+
996
+ for i in range(num_layers):
997
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
998
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
999
+
1000
+ resnets.append(
1001
+ ResnetBlock2D(
1002
+ in_channels=resnet_in_channels + res_skip_channels,
1003
+ out_channels=out_channels,
1004
+ temb_channels=temb_channels,
1005
+ eps=resnet_eps,
1006
+ groups=resnet_groups,
1007
+ dropout=dropout,
1008
+ time_embedding_norm=resnet_time_scale_shift,
1009
+ non_linearity=resnet_act_fn,
1010
+ output_scale_factor=output_scale_factor,
1011
+ pre_norm=resnet_pre_norm,
1012
+ )
1013
+ )
1014
+
1015
+ self.resnets = nn.ModuleList(resnets)
1016
+
1017
+ if add_upsample:
1018
+ self.upsamplers = nn.ModuleList(
1019
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1020
+ )
1021
+ else:
1022
+ self.upsamplers = None
1023
+
1024
+ self.gradient_checkpointing = False
1025
+ self.resolution_idx = resolution_idx
1026
+
1027
+ def forward(
1028
+ self,
1029
+ hidden_states: torch.FloatTensor,
1030
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1031
+ temb: Optional[torch.FloatTensor] = None,
1032
+ upsample_size: Optional[int] = None,
1033
+ scale: float = 1.0,
1034
+ ) -> torch.FloatTensor:
1035
+ is_freeu_enabled = (
1036
+ getattr(self, "s1", None)
1037
+ and getattr(self, "s2", None)
1038
+ and getattr(self, "b1", None)
1039
+ and getattr(self, "b2", None)
1040
+ )
1041
+
1042
+ for resnet in self.resnets:
1043
+ # pop res hidden states
1044
+ res_hidden_states = res_hidden_states_tuple[-1]
1045
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1046
+
1047
+ # FreeU: Only operate on the first two stages
1048
+ if is_freeu_enabled:
1049
+ hidden_states, res_hidden_states = apply_freeu(
1050
+ self.resolution_idx,
1051
+ hidden_states,
1052
+ res_hidden_states,
1053
+ s1=self.s1,
1054
+ s2=self.s2,
1055
+ b1=self.b1,
1056
+ b2=self.b2,
1057
+ )
1058
+
1059
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1060
+
1061
+ if self.training and self.gradient_checkpointing:
1062
+
1063
+ def create_custom_forward(module):
1064
+ def custom_forward(*inputs):
1065
+ return module(*inputs)
1066
+
1067
+ return custom_forward
1068
+
1069
+ if is_torch_version(">=", "1.11.0"):
1070
+ hidden_states = torch.utils.checkpoint.checkpoint(
1071
+ create_custom_forward(resnet),
1072
+ hidden_states,
1073
+ temb,
1074
+ use_reentrant=False,
1075
+ )
1076
+ else:
1077
+ hidden_states = torch.utils.checkpoint.checkpoint(
1078
+ create_custom_forward(resnet), hidden_states, temb
1079
+ )
1080
+ else:
1081
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1082
+
1083
+ if self.upsamplers is not None:
1084
+ for upsampler in self.upsamplers:
1085
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1086
+
1087
+ return hidden_states
genwarp/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # Diffusers
9
+ # Apache License, Version 2.0
10
+ # Copyright (c) Hugging Face Inc.
11
+ # https://github.com/huggingface/diffusers
12
+ # ==============================================================================
13
+
14
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import UNet2DConditionLoadersMixin
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import (
25
+ ADDED_KV_ATTENTION_PROCESSORS,
26
+ CROSS_ATTENTION_PROCESSORS,
27
+ AttentionProcessor,
28
+ AttnAddedKVProcessor,
29
+ AttnProcessor,
30
+ )
31
+ from diffusers.models.embeddings import (
32
+ GaussianFourierProjection,
33
+ ImageHintTimeEmbedding,
34
+ ImageProjection,
35
+ ImageTimeEmbedding,
36
+ # PositionNet,
37
+ TextImageProjection,
38
+ TextImageTimeEmbedding,
39
+ TextTimeEmbedding,
40
+ TimestepEmbedding,
41
+ Timesteps,
42
+ )
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from diffusers.utils import (
45
+ USE_PEFT_BACKEND,
46
+ BaseOutput,
47
+ deprecate,
48
+ logging,
49
+ scale_lora_layers,
50
+ unscale_lora_layers,
51
+ )
52
+
53
+ from .unet_2d_blocks import (
54
+ UNetMidBlock2D,
55
+ UNetMidBlock2DCrossAttn,
56
+ get_down_block,
57
+ get_up_block,
58
+ )
59
+
60
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
61
+
62
+
63
+ @dataclass
64
+ class UNet2DConditionOutput(BaseOutput):
65
+ """
66
+ The output of [`UNet2DConditionModel`].
67
+
68
+ Args:
69
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
70
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
71
+ """
72
+
73
+ sample: torch.FloatTensor = None
74
+ ref_features: Tuple[torch.FloatTensor] = None
75
+
76
+
77
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
78
+ r"""
79
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
80
+ shaped output.
81
+
82
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
83
+ for all models (such as downloading or saving).
84
+
85
+ Parameters:
86
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
87
+ Height and width of input/output sample.
88
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
89
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
90
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
91
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
92
+ Whether to flip the sin to cos in the time embedding.
93
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
94
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
95
+ The tuple of downsample blocks to use.
96
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
97
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
98
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
99
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
100
+ The tuple of upsample blocks to use.
101
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
102
+ Whether to include self-attention in the basic transformer blocks, see
103
+ [`~models.attention.BasicTransformerBlock`].
104
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
105
+ The tuple of output channels for each block.
106
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
107
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
108
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
109
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
110
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
111
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
112
+ If `None`, normalization and activation layers is skipped in post-processing.
113
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
114
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
115
+ The dimension of the cross attention features.
116
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
117
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
118
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
119
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
120
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
121
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
122
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
123
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
124
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
125
+ encoder_hid_dim (`int`, *optional*, defaults to None):
126
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
127
+ dimension to `cross_attention_dim`.
128
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
129
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
130
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
131
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
132
+ num_attention_heads (`int`, *optional*):
133
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
134
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
135
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
136
+ class_embed_type (`str`, *optional*, defaults to `None`):
137
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
138
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
139
+ addition_embed_type (`str`, *optional*, defaults to `None`):
140
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
141
+ "text". "text" will use the `TextTimeEmbedding` layer.
142
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
143
+ Dimension for the timestep embeddings.
144
+ num_class_embeds (`int`, *optional*, defaults to `None`):
145
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
146
+ class conditioning with `class_embed_type` equal to `None`.
147
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
148
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
149
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
150
+ An optional override for the dimension of the projected time embedding.
151
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
152
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
153
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
154
+ timestep_post_act (`str`, *optional*, defaults to `None`):
155
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
156
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
157
+ The dimension of `cond_proj` layer in the timestep embedding.
158
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
159
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
160
+ *optional*): The dimension of the `class_labels` input when
161
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
162
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
163
+ embeddings with the class embeddings.
164
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
165
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
166
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
167
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
168
+ otherwise.
169
+ """
170
+
171
+ _supports_gradient_checkpointing = True
172
+
173
+ @register_to_config
174
+ def __init__(
175
+ self,
176
+ sample_size: Optional[int] = None,
177
+ in_channels: int = 4,
178
+ out_channels: int = 4,
179
+ center_input_sample: bool = False,
180
+ flip_sin_to_cos: bool = True,
181
+ freq_shift: int = 0,
182
+ down_block_types: Tuple[str] = (
183
+ "CrossAttnDownBlock2D",
184
+ "CrossAttnDownBlock2D",
185
+ "CrossAttnDownBlock2D",
186
+ "DownBlock2D",
187
+ ),
188
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
189
+ up_block_types: Tuple[str] = (
190
+ "UpBlock2D",
191
+ "CrossAttnUpBlock2D",
192
+ "CrossAttnUpBlock2D",
193
+ "CrossAttnUpBlock2D",
194
+ ),
195
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
196
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
197
+ layers_per_block: Union[int, Tuple[int]] = 2,
198
+ downsample_padding: int = 1,
199
+ mid_block_scale_factor: float = 1,
200
+ dropout: float = 0.0,
201
+ act_fn: str = "silu",
202
+ norm_num_groups: Optional[int] = 32,
203
+ norm_eps: float = 1e-5,
204
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
205
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
206
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
207
+ encoder_hid_dim: Optional[int] = None,
208
+ encoder_hid_dim_type: Optional[str] = None,
209
+ attention_head_dim: Union[int, Tuple[int]] = 8,
210
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
211
+ dual_cross_attention: bool = False,
212
+ use_linear_projection: bool = False,
213
+ class_embed_type: Optional[str] = None,
214
+ addition_embed_type: Optional[str] = None,
215
+ addition_time_embed_dim: Optional[int] = None,
216
+ num_class_embeds: Optional[int] = None,
217
+ upcast_attention: bool = False,
218
+ resnet_time_scale_shift: str = "default",
219
+ resnet_skip_time_act: bool = False,
220
+ resnet_out_scale_factor: int = 1.0,
221
+ time_embedding_type: str = "positional",
222
+ time_embedding_dim: Optional[int] = None,
223
+ time_embedding_act_fn: Optional[str] = None,
224
+ timestep_post_act: Optional[str] = None,
225
+ time_cond_proj_dim: Optional[int] = None,
226
+ conv_in_kernel: int = 3,
227
+ conv_out_kernel: int = 3,
228
+ projection_class_embeddings_input_dim: Optional[int] = None,
229
+ attention_type: str = "default",
230
+ class_embeddings_concat: bool = False,
231
+ mid_block_only_cross_attention: Optional[bool] = None,
232
+ cross_attention_norm: Optional[str] = None,
233
+ addition_embed_type_num_heads=64,
234
+ ):
235
+ super().__init__()
236
+
237
+ self.sample_size = sample_size
238
+
239
+ if num_attention_heads is not None:
240
+ raise ValueError(
241
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
242
+ )
243
+
244
+ # If `num_attention_heads` is not defined (which is the case for most models)
245
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
246
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
247
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
248
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
249
+ # which is why we correct for the naming here.
250
+ num_attention_heads = num_attention_heads or attention_head_dim
251
+
252
+ # Check inputs
253
+ if len(down_block_types) != len(up_block_types):
254
+ raise ValueError(
255
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
256
+ )
257
+
258
+ if len(block_out_channels) != len(down_block_types):
259
+ raise ValueError(
260
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
261
+ )
262
+
263
+ if not isinstance(only_cross_attention, bool) and len(
264
+ only_cross_attention
265
+ ) != len(down_block_types):
266
+ raise ValueError(
267
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
268
+ )
269
+
270
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
271
+ down_block_types
272
+ ):
273
+ raise ValueError(
274
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
278
+ down_block_types
279
+ ):
280
+ raise ValueError(
281
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
282
+ )
283
+
284
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
285
+ down_block_types
286
+ ):
287
+ raise ValueError(
288
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
289
+ )
290
+
291
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
292
+ down_block_types
293
+ ):
294
+ raise ValueError(
295
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
296
+ )
297
+ if (
298
+ isinstance(transformer_layers_per_block, list)
299
+ and reverse_transformer_layers_per_block is None
300
+ ):
301
+ for layer_number_per_block in transformer_layers_per_block:
302
+ if isinstance(layer_number_per_block, list):
303
+ raise ValueError(
304
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
305
+ )
306
+
307
+ # input
308
+ conv_in_padding = (conv_in_kernel - 1) // 2
309
+ self.conv_in = nn.Conv2d(
310
+ in_channels,
311
+ block_out_channels[0],
312
+ kernel_size=conv_in_kernel,
313
+ padding=conv_in_padding,
314
+ )
315
+
316
+ # time
317
+ if time_embedding_type == "fourier":
318
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
319
+ if time_embed_dim % 2 != 0:
320
+ raise ValueError(
321
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
322
+ )
323
+ self.time_proj = GaussianFourierProjection(
324
+ time_embed_dim // 2,
325
+ set_W_to_weight=False,
326
+ log=False,
327
+ flip_sin_to_cos=flip_sin_to_cos,
328
+ )
329
+ timestep_input_dim = time_embed_dim
330
+ elif time_embedding_type == "positional":
331
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
332
+
333
+ self.time_proj = Timesteps(
334
+ block_out_channels[0], flip_sin_to_cos, freq_shift
335
+ )
336
+ timestep_input_dim = block_out_channels[0]
337
+ else:
338
+ raise ValueError(
339
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
340
+ )
341
+
342
+ self.time_embedding = TimestepEmbedding(
343
+ timestep_input_dim,
344
+ time_embed_dim,
345
+ act_fn=act_fn,
346
+ post_act_fn=timestep_post_act,
347
+ cond_proj_dim=time_cond_proj_dim,
348
+ )
349
+
350
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
351
+ encoder_hid_dim_type = "text_proj"
352
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
353
+ logger.info(
354
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
355
+ )
356
+
357
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
358
+ raise ValueError(
359
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
360
+ )
361
+
362
+ if encoder_hid_dim_type == "text_proj":
363
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
364
+ elif encoder_hid_dim_type == "text_image_proj":
365
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
366
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
367
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
368
+ self.encoder_hid_proj = TextImageProjection(
369
+ text_embed_dim=encoder_hid_dim,
370
+ image_embed_dim=cross_attention_dim,
371
+ cross_attention_dim=cross_attention_dim,
372
+ )
373
+ elif encoder_hid_dim_type == "image_proj":
374
+ # Kandinsky 2.2
375
+ self.encoder_hid_proj = ImageProjection(
376
+ image_embed_dim=encoder_hid_dim,
377
+ cross_attention_dim=cross_attention_dim,
378
+ )
379
+ elif encoder_hid_dim_type is not None:
380
+ raise ValueError(
381
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
382
+ )
383
+ else:
384
+ self.encoder_hid_proj = None
385
+
386
+ # class embedding
387
+ if class_embed_type is None and num_class_embeds is not None:
388
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
389
+ elif class_embed_type == "timestep":
390
+ self.class_embedding = TimestepEmbedding(
391
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
392
+ )
393
+ elif class_embed_type == "identity":
394
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
395
+ elif class_embed_type == "projection":
396
+ if projection_class_embeddings_input_dim is None:
397
+ raise ValueError(
398
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
399
+ )
400
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
401
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
402
+ # 2. it projects from an arbitrary input dimension.
403
+ #
404
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
405
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
406
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
407
+ self.class_embedding = TimestepEmbedding(
408
+ projection_class_embeddings_input_dim, time_embed_dim
409
+ )
410
+ elif class_embed_type == "simple_projection":
411
+ if projection_class_embeddings_input_dim is None:
412
+ raise ValueError(
413
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
414
+ )
415
+ self.class_embedding = nn.Linear(
416
+ projection_class_embeddings_input_dim, time_embed_dim
417
+ )
418
+ else:
419
+ self.class_embedding = None
420
+
421
+ if addition_embed_type == "text":
422
+ if encoder_hid_dim is not None:
423
+ text_time_embedding_from_dim = encoder_hid_dim
424
+ else:
425
+ text_time_embedding_from_dim = cross_attention_dim
426
+
427
+ self.add_embedding = TextTimeEmbedding(
428
+ text_time_embedding_from_dim,
429
+ time_embed_dim,
430
+ num_heads=addition_embed_type_num_heads,
431
+ )
432
+ elif addition_embed_type == "text_image":
433
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
434
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
435
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
436
+ self.add_embedding = TextImageTimeEmbedding(
437
+ text_embed_dim=cross_attention_dim,
438
+ image_embed_dim=cross_attention_dim,
439
+ time_embed_dim=time_embed_dim,
440
+ )
441
+ elif addition_embed_type == "text_time":
442
+ self.add_time_proj = Timesteps(
443
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
444
+ )
445
+ self.add_embedding = TimestepEmbedding(
446
+ projection_class_embeddings_input_dim, time_embed_dim
447
+ )
448
+ elif addition_embed_type == "image":
449
+ # Kandinsky 2.2
450
+ self.add_embedding = ImageTimeEmbedding(
451
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
452
+ )
453
+ elif addition_embed_type == "image_hint":
454
+ # Kandinsky 2.2 ControlNet
455
+ self.add_embedding = ImageHintTimeEmbedding(
456
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
457
+ )
458
+ elif addition_embed_type is not None:
459
+ raise ValueError(
460
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
461
+ )
462
+
463
+ if time_embedding_act_fn is None:
464
+ self.time_embed_act = None
465
+ else:
466
+ self.time_embed_act = get_activation(time_embedding_act_fn)
467
+
468
+ self.down_blocks = nn.ModuleList([])
469
+ self.up_blocks = nn.ModuleList([])
470
+
471
+ if isinstance(only_cross_attention, bool):
472
+ if mid_block_only_cross_attention is None:
473
+ mid_block_only_cross_attention = only_cross_attention
474
+
475
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
476
+
477
+ if mid_block_only_cross_attention is None:
478
+ mid_block_only_cross_attention = False
479
+
480
+ if isinstance(num_attention_heads, int):
481
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
482
+
483
+ if isinstance(attention_head_dim, int):
484
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
485
+
486
+ if isinstance(cross_attention_dim, int):
487
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
488
+
489
+ if isinstance(layers_per_block, int):
490
+ layers_per_block = [layers_per_block] * len(down_block_types)
491
+
492
+ if isinstance(transformer_layers_per_block, int):
493
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
494
+ down_block_types
495
+ )
496
+
497
+ if class_embeddings_concat:
498
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
499
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
500
+ # regular time embeddings
501
+ blocks_time_embed_dim = time_embed_dim * 2
502
+ else:
503
+ blocks_time_embed_dim = time_embed_dim
504
+
505
+ # down
506
+ output_channel = block_out_channels[0]
507
+ for i, down_block_type in enumerate(down_block_types):
508
+ input_channel = output_channel
509
+ output_channel = block_out_channels[i]
510
+ is_final_block = i == len(block_out_channels) - 1
511
+
512
+ down_block = get_down_block(
513
+ down_block_type,
514
+ num_layers=layers_per_block[i],
515
+ transformer_layers_per_block=transformer_layers_per_block[i],
516
+ in_channels=input_channel,
517
+ out_channels=output_channel,
518
+ temb_channels=blocks_time_embed_dim,
519
+ add_downsample=not is_final_block,
520
+ resnet_eps=norm_eps,
521
+ resnet_act_fn=act_fn,
522
+ resnet_groups=norm_num_groups,
523
+ cross_attention_dim=cross_attention_dim[i],
524
+ num_attention_heads=num_attention_heads[i],
525
+ downsample_padding=downsample_padding,
526
+ dual_cross_attention=dual_cross_attention,
527
+ use_linear_projection=use_linear_projection,
528
+ only_cross_attention=only_cross_attention[i],
529
+ upcast_attention=upcast_attention,
530
+ resnet_time_scale_shift=resnet_time_scale_shift,
531
+ attention_type=attention_type,
532
+ resnet_skip_time_act=resnet_skip_time_act,
533
+ resnet_out_scale_factor=resnet_out_scale_factor,
534
+ cross_attention_norm=cross_attention_norm,
535
+ attention_head_dim=attention_head_dim[i]
536
+ if attention_head_dim[i] is not None
537
+ else output_channel,
538
+ dropout=dropout,
539
+ )
540
+ self.down_blocks.append(down_block)
541
+
542
+ # mid
543
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
544
+ self.mid_block = UNetMidBlock2DCrossAttn(
545
+ transformer_layers_per_block=transformer_layers_per_block[-1],
546
+ in_channels=block_out_channels[-1],
547
+ temb_channels=blocks_time_embed_dim,
548
+ dropout=dropout,
549
+ resnet_eps=norm_eps,
550
+ resnet_act_fn=act_fn,
551
+ output_scale_factor=mid_block_scale_factor,
552
+ resnet_time_scale_shift=resnet_time_scale_shift,
553
+ cross_attention_dim=cross_attention_dim[-1],
554
+ num_attention_heads=num_attention_heads[-1],
555
+ resnet_groups=norm_num_groups,
556
+ dual_cross_attention=dual_cross_attention,
557
+ use_linear_projection=use_linear_projection,
558
+ upcast_attention=upcast_attention,
559
+ attention_type=attention_type,
560
+ )
561
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
562
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
563
+ elif mid_block_type == "UNetMidBlock2D":
564
+ self.mid_block = UNetMidBlock2D(
565
+ in_channels=block_out_channels[-1],
566
+ temb_channels=blocks_time_embed_dim,
567
+ dropout=dropout,
568
+ num_layers=0,
569
+ resnet_eps=norm_eps,
570
+ resnet_act_fn=act_fn,
571
+ output_scale_factor=mid_block_scale_factor,
572
+ resnet_groups=norm_num_groups,
573
+ resnet_time_scale_shift=resnet_time_scale_shift,
574
+ add_attention=False,
575
+ )
576
+ elif mid_block_type is None:
577
+ self.mid_block = None
578
+ else:
579
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
580
+
581
+ # count how many layers upsample the images
582
+ self.num_upsamplers = 0
583
+
584
+ # up
585
+ reversed_block_out_channels = list(reversed(block_out_channels))
586
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
587
+ reversed_layers_per_block = list(reversed(layers_per_block))
588
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
589
+ reversed_transformer_layers_per_block = (
590
+ list(reversed(transformer_layers_per_block))
591
+ if reverse_transformer_layers_per_block is None
592
+ else reverse_transformer_layers_per_block
593
+ )
594
+ only_cross_attention = list(reversed(only_cross_attention))
595
+
596
+ output_channel = reversed_block_out_channels[0]
597
+ for i, up_block_type in enumerate(up_block_types):
598
+ is_final_block = i == len(block_out_channels) - 1
599
+
600
+ prev_output_channel = output_channel
601
+ output_channel = reversed_block_out_channels[i]
602
+ input_channel = reversed_block_out_channels[
603
+ min(i + 1, len(block_out_channels) - 1)
604
+ ]
605
+
606
+ # add upsample block for all BUT final layer
607
+ if not is_final_block:
608
+ add_upsample = True
609
+ self.num_upsamplers += 1
610
+ else:
611
+ add_upsample = False
612
+
613
+ up_block = get_up_block(
614
+ up_block_type,
615
+ num_layers=reversed_layers_per_block[i] + 1,
616
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
617
+ in_channels=input_channel,
618
+ out_channels=output_channel,
619
+ prev_output_channel=prev_output_channel,
620
+ temb_channels=blocks_time_embed_dim,
621
+ add_upsample=add_upsample,
622
+ resnet_eps=norm_eps,
623
+ resnet_act_fn=act_fn,
624
+ resolution_idx=i,
625
+ resnet_groups=norm_num_groups,
626
+ cross_attention_dim=reversed_cross_attention_dim[i],
627
+ num_attention_heads=reversed_num_attention_heads[i],
628
+ dual_cross_attention=dual_cross_attention,
629
+ use_linear_projection=use_linear_projection,
630
+ only_cross_attention=only_cross_attention[i],
631
+ upcast_attention=upcast_attention,
632
+ resnet_time_scale_shift=resnet_time_scale_shift,
633
+ attention_type=attention_type,
634
+ resnet_skip_time_act=resnet_skip_time_act,
635
+ resnet_out_scale_factor=resnet_out_scale_factor,
636
+ cross_attention_norm=cross_attention_norm,
637
+ attention_head_dim=attention_head_dim[i]
638
+ if attention_head_dim[i] is not None
639
+ else output_channel,
640
+ dropout=dropout,
641
+ )
642
+ self.up_blocks.append(up_block)
643
+ prev_output_channel = output_channel
644
+
645
+ # out
646
+ if norm_num_groups is not None:
647
+ self.conv_norm_out = nn.GroupNorm(
648
+ num_channels=block_out_channels[0],
649
+ num_groups=norm_num_groups,
650
+ eps=norm_eps,
651
+ )
652
+
653
+ self.conv_act = get_activation(act_fn)
654
+
655
+ else:
656
+ self.conv_norm_out = None
657
+ self.conv_act = None
658
+ self.conv_norm_out = None
659
+
660
+ conv_out_padding = (conv_out_kernel - 1) // 2
661
+ # self.conv_out = nn.Conv2d(
662
+ # block_out_channels[0],
663
+ # out_channels,
664
+ # kernel_size=conv_out_kernel,
665
+ # padding=conv_out_padding,
666
+ # )
667
+
668
+ if attention_type in ["gated", "gated-text-image"]:
669
+ positive_len = 768
670
+ if isinstance(cross_attention_dim, int):
671
+ positive_len = cross_attention_dim
672
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
673
+ cross_attention_dim, list
674
+ ):
675
+ positive_len = cross_attention_dim[0]
676
+
677
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
678
+ # self.position_net = PositionNet(
679
+ # positive_len=positive_len,
680
+ # out_dim=cross_attention_dim,
681
+ # feature_type=feature_type,
682
+ # )
683
+
684
+ @property
685
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
686
+ r"""
687
+ Returns:
688
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
689
+ indexed by its weight name.
690
+ """
691
+ # set recursively
692
+ processors = {}
693
+
694
+ def fn_recursive_add_processors(
695
+ name: str,
696
+ module: torch.nn.Module,
697
+ processors: Dict[str, AttentionProcessor],
698
+ ):
699
+ if hasattr(module, "get_processor"):
700
+ processors[f"{name}.processor"] = module.get_processor(
701
+ return_deprecated_lora=True
702
+ )
703
+
704
+ for sub_name, child in module.named_children():
705
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
706
+
707
+ return processors
708
+
709
+ for name, module in self.named_children():
710
+ fn_recursive_add_processors(name, module, processors)
711
+
712
+ return processors
713
+
714
+ def set_attn_processor(
715
+ self,
716
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
717
+ _remove_lora=False,
718
+ ):
719
+ r"""
720
+ Sets the attention processor to use to compute attention.
721
+
722
+ Parameters:
723
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
724
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
725
+ for **all** `Attention` layers.
726
+
727
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
728
+ processor. This is strongly recommended when setting trainable attention processors.
729
+
730
+ """
731
+ count = len(self.attn_processors.keys())
732
+
733
+ if isinstance(processor, dict) and len(processor) != count:
734
+ raise ValueError(
735
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
736
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
737
+ )
738
+
739
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
740
+ if hasattr(module, "set_processor"):
741
+ if not isinstance(processor, dict):
742
+ module.set_processor(processor, _remove_lora=_remove_lora)
743
+ else:
744
+ module.set_processor(
745
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
746
+ )
747
+
748
+ for sub_name, child in module.named_children():
749
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
750
+
751
+ for name, module in self.named_children():
752
+ fn_recursive_attn_processor(name, module, processor)
753
+
754
+ def set_default_attn_processor(self):
755
+ """
756
+ Disables custom attention processors and sets the default attention implementation.
757
+ """
758
+ if all(
759
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
760
+ for proc in self.attn_processors.values()
761
+ ):
762
+ processor = AttnAddedKVProcessor()
763
+ elif all(
764
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
765
+ for proc in self.attn_processors.values()
766
+ ):
767
+ processor = AttnProcessor()
768
+ else:
769
+ raise ValueError(
770
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
771
+ )
772
+
773
+ self.set_attn_processor(processor, _remove_lora=True)
774
+
775
+ def set_attention_slice(self, slice_size):
776
+ r"""
777
+ Enable sliced attention computation.
778
+
779
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
780
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
781
+
782
+ Args:
783
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
784
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
785
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
786
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
787
+ must be a multiple of `slice_size`.
788
+ """
789
+ sliceable_head_dims = []
790
+
791
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
792
+ if hasattr(module, "set_attention_slice"):
793
+ sliceable_head_dims.append(module.sliceable_head_dim)
794
+
795
+ for child in module.children():
796
+ fn_recursive_retrieve_sliceable_dims(child)
797
+
798
+ # retrieve number of attention layers
799
+ for module in self.children():
800
+ fn_recursive_retrieve_sliceable_dims(module)
801
+
802
+ num_sliceable_layers = len(sliceable_head_dims)
803
+
804
+ if slice_size == "auto":
805
+ # half the attention head size is usually a good trade-off between
806
+ # speed and memory
807
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
808
+ elif slice_size == "max":
809
+ # make smallest slice possible
810
+ slice_size = num_sliceable_layers * [1]
811
+
812
+ slice_size = (
813
+ num_sliceable_layers * [slice_size]
814
+ if not isinstance(slice_size, list)
815
+ else slice_size
816
+ )
817
+
818
+ if len(slice_size) != len(sliceable_head_dims):
819
+ raise ValueError(
820
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
821
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
822
+ )
823
+
824
+ for i in range(len(slice_size)):
825
+ size = slice_size[i]
826
+ dim = sliceable_head_dims[i]
827
+ if size is not None and size > dim:
828
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
829
+
830
+ # Recursively walk through all the children.
831
+ # Any children which exposes the set_attention_slice method
832
+ # gets the message
833
+ def fn_recursive_set_attention_slice(
834
+ module: torch.nn.Module, slice_size: List[int]
835
+ ):
836
+ if hasattr(module, "set_attention_slice"):
837
+ module.set_attention_slice(slice_size.pop())
838
+
839
+ for child in module.children():
840
+ fn_recursive_set_attention_slice(child, slice_size)
841
+
842
+ reversed_slice_size = list(reversed(slice_size))
843
+ for module in self.children():
844
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
845
+
846
+ def _set_gradient_checkpointing(self, module, value=False):
847
+ if hasattr(module, "gradient_checkpointing"):
848
+ module.gradient_checkpointing = value
849
+
850
+ def enable_freeu(self, s1, s2, b1, b2):
851
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
852
+
853
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
854
+
855
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
856
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
857
+
858
+ Args:
859
+ s1 (`float`):
860
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
861
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
862
+ s2 (`float`):
863
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
864
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
865
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
866
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
867
+ """
868
+ for i, upsample_block in enumerate(self.up_blocks):
869
+ setattr(upsample_block, "s1", s1)
870
+ setattr(upsample_block, "s2", s2)
871
+ setattr(upsample_block, "b1", b1)
872
+ setattr(upsample_block, "b2", b2)
873
+
874
+ def disable_freeu(self):
875
+ """Disables the FreeU mechanism."""
876
+ freeu_keys = {"s1", "s2", "b1", "b2"}
877
+ for i, upsample_block in enumerate(self.up_blocks):
878
+ for k in freeu_keys:
879
+ if (
880
+ hasattr(upsample_block, k)
881
+ or getattr(upsample_block, k, None) is not None
882
+ ):
883
+ setattr(upsample_block, k, None)
884
+
885
+ def forward(
886
+ self,
887
+ sample: torch.FloatTensor,
888
+ timestep: Union[torch.Tensor, float, int],
889
+ encoder_hidden_states: torch.Tensor,
890
+ class_labels: Optional[torch.Tensor] = None,
891
+ timestep_cond: Optional[torch.Tensor] = None,
892
+ attention_mask: Optional[torch.Tensor] = None,
893
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
894
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
895
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
896
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
897
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
898
+ encoder_attention_mask: Optional[torch.Tensor] = None,
899
+ pose_cond_fea: Optional[torch.Tensor] = None,
900
+ return_dict: bool = True,
901
+ ) -> Union[UNet2DConditionOutput, Tuple]:
902
+ r"""
903
+ The [`UNet2DConditionModel`] forward method.
904
+
905
+ Args:
906
+ sample (`torch.FloatTensor`):
907
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
908
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
909
+ encoder_hidden_states (`torch.FloatTensor`):
910
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
911
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
912
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
913
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
914
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
915
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
916
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
917
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
918
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
919
+ negative values to the attention scores corresponding to "discard" tokens.
920
+ cross_attention_kwargs (`dict`, *optional*):
921
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
922
+ `self.processor` in
923
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
924
+ added_cond_kwargs: (`dict`, *optional*):
925
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
926
+ are passed along to the UNet blocks.
927
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
928
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
929
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
930
+ A tensor that if specified is added to the residual of the middle unet block.
931
+ encoder_attention_mask (`torch.Tensor`):
932
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
933
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
934
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
935
+ return_dict (`bool`, *optional*, defaults to `True`):
936
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
937
+ tuple.
938
+ cross_attention_kwargs (`dict`, *optional*):
939
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
940
+ added_cond_kwargs: (`dict`, *optional*):
941
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
942
+ are passed along to the UNet blocks.
943
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
944
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
945
+ example from ControlNet side model(s)
946
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
947
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
948
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
949
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
950
+
951
+ Returns:
952
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
953
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
954
+ a `tuple` is returned where the first element is the sample tensor.
955
+ """
956
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
957
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
958
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
959
+ # on the fly if necessary.
960
+ default_overall_up_factor = 2**self.num_upsamplers
961
+
962
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
963
+ forward_upsample_size = False
964
+ upsample_size = None
965
+
966
+ for dim in sample.shape[-2:]:
967
+ if dim % default_overall_up_factor != 0:
968
+ # Forward upsample size to force interpolation output size.
969
+ forward_upsample_size = True
970
+ break
971
+
972
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
973
+ # expects mask of shape:
974
+ # [batch, key_tokens]
975
+ # adds singleton query_tokens dimension:
976
+ # [batch, 1, key_tokens]
977
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
978
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
979
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
980
+ if attention_mask is not None:
981
+ # assume that mask is expressed as:
982
+ # (1 = keep, 0 = discard)
983
+ # convert mask into a bias that can be added to attention scores:
984
+ # (keep = +0, discard = -10000.0)
985
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
986
+ attention_mask = attention_mask.unsqueeze(1)
987
+
988
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
989
+ if encoder_attention_mask is not None:
990
+ encoder_attention_mask = (
991
+ 1 - encoder_attention_mask.to(sample.dtype)
992
+ ) * -10000.0
993
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
994
+
995
+ # 0. center input if necessary
996
+ if self.config.center_input_sample:
997
+ sample = 2 * sample - 1.0
998
+
999
+ # 1. time
1000
+ timesteps = timestep
1001
+ if not torch.is_tensor(timesteps):
1002
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1003
+ # This would be a good case for the `match` statement (Python 3.10+)
1004
+ is_mps = sample.device.type == "mps"
1005
+ if isinstance(timestep, float):
1006
+ dtype = torch.float32 if is_mps else torch.float64
1007
+ else:
1008
+ dtype = torch.int32 if is_mps else torch.int64
1009
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1010
+ elif len(timesteps.shape) == 0:
1011
+ timesteps = timesteps[None].to(sample.device)
1012
+
1013
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1014
+ timesteps = timesteps.expand(sample.shape[0])
1015
+
1016
+ t_emb = self.time_proj(timesteps)
1017
+
1018
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1019
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1020
+ # there might be better ways to encapsulate this.
1021
+ t_emb = t_emb.to(dtype=sample.dtype)
1022
+
1023
+ emb = self.time_embedding(t_emb, timestep_cond)
1024
+ aug_emb = None
1025
+
1026
+ if self.class_embedding is not None:
1027
+ if class_labels is None:
1028
+ raise ValueError(
1029
+ "class_labels should be provided when num_class_embeds > 0"
1030
+ )
1031
+
1032
+ if self.config.class_embed_type == "timestep":
1033
+ class_labels = self.time_proj(class_labels)
1034
+
1035
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1036
+ # there might be better ways to encapsulate this.
1037
+ class_labels = class_labels.to(dtype=sample.dtype)
1038
+
1039
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1040
+
1041
+ if self.config.class_embeddings_concat:
1042
+ emb = torch.cat([emb, class_emb], dim=-1)
1043
+ else:
1044
+ emb = emb + class_emb
1045
+
1046
+ if self.config.addition_embed_type == "text":
1047
+ aug_emb = self.add_embedding(encoder_hidden_states)
1048
+ elif self.config.addition_embed_type == "text_image":
1049
+ # Kandinsky 2.1 - style
1050
+ if "image_embeds" not in added_cond_kwargs:
1051
+ raise ValueError(
1052
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1053
+ )
1054
+
1055
+ image_embs = added_cond_kwargs.get("image_embeds")
1056
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1057
+ aug_emb = self.add_embedding(text_embs, image_embs)
1058
+ elif self.config.addition_embed_type == "text_time":
1059
+ # SDXL - style
1060
+ if "text_embeds" not in added_cond_kwargs:
1061
+ raise ValueError(
1062
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1063
+ )
1064
+ text_embeds = added_cond_kwargs.get("text_embeds")
1065
+ if "time_ids" not in added_cond_kwargs:
1066
+ raise ValueError(
1067
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1068
+ )
1069
+ time_ids = added_cond_kwargs.get("time_ids")
1070
+ time_embeds = self.add_time_proj(time_ids.flatten())
1071
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1072
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1073
+ add_embeds = add_embeds.to(emb.dtype)
1074
+ aug_emb = self.add_embedding(add_embeds)
1075
+ elif self.config.addition_embed_type == "image":
1076
+ # Kandinsky 2.2 - style
1077
+ if "image_embeds" not in added_cond_kwargs:
1078
+ raise ValueError(
1079
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1080
+ )
1081
+ image_embs = added_cond_kwargs.get("image_embeds")
1082
+ aug_emb = self.add_embedding(image_embs)
1083
+ elif self.config.addition_embed_type == "image_hint":
1084
+ # Kandinsky 2.2 - style
1085
+ if (
1086
+ "image_embeds" not in added_cond_kwargs
1087
+ or "hint" not in added_cond_kwargs
1088
+ ):
1089
+ raise ValueError(
1090
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1091
+ )
1092
+ image_embs = added_cond_kwargs.get("image_embeds")
1093
+ hint = added_cond_kwargs.get("hint")
1094
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1095
+ sample = torch.cat([sample, hint], dim=1)
1096
+
1097
+ emb = emb + aug_emb if aug_emb is not None else emb
1098
+
1099
+ if self.time_embed_act is not None:
1100
+ emb = self.time_embed_act(emb)
1101
+
1102
+ if (
1103
+ self.encoder_hid_proj is not None
1104
+ and self.config.encoder_hid_dim_type == "text_proj"
1105
+ ):
1106
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1107
+ elif (
1108
+ self.encoder_hid_proj is not None
1109
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1110
+ ):
1111
+ # Kadinsky 2.1 - style
1112
+ if "image_embeds" not in added_cond_kwargs:
1113
+ raise ValueError(
1114
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1115
+ )
1116
+
1117
+ image_embeds = added_cond_kwargs.get("image_embeds")
1118
+ encoder_hidden_states = self.encoder_hid_proj(
1119
+ encoder_hidden_states, image_embeds
1120
+ )
1121
+ elif (
1122
+ self.encoder_hid_proj is not None
1123
+ and self.config.encoder_hid_dim_type == "image_proj"
1124
+ ):
1125
+ # Kandinsky 2.2 - style
1126
+ if "image_embeds" not in added_cond_kwargs:
1127
+ raise ValueError(
1128
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1129
+ )
1130
+ image_embeds = added_cond_kwargs.get("image_embeds")
1131
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1132
+ elif (
1133
+ self.encoder_hid_proj is not None
1134
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1135
+ ):
1136
+ if "image_embeds" not in added_cond_kwargs:
1137
+ raise ValueError(
1138
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1139
+ )
1140
+ image_embeds = added_cond_kwargs.get("image_embeds")
1141
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1142
+ encoder_hidden_states.dtype
1143
+ )
1144
+ encoder_hidden_states = torch.cat(
1145
+ [encoder_hidden_states, image_embeds], dim=1
1146
+ )
1147
+
1148
+ # 2. pre-process
1149
+ sample = self.conv_in(sample)
1150
+ if pose_cond_fea is not None:
1151
+ sample = sample + pose_cond_fea
1152
+
1153
+ # 2.5 GLIGEN position net
1154
+ # if (
1155
+ # cross_attention_kwargs is not None
1156
+ # and cross_attention_kwargs.get("gligen", None) is not None
1157
+ # ):
1158
+ # cross_attention_kwargs = cross_attention_kwargs.copy()
1159
+ # gligen_args = cross_attention_kwargs.pop("gligen")
1160
+ # cross_attention_kwargs["gligen"] = {
1161
+ # "objs": self.position_net(**gligen_args)
1162
+ # }
1163
+
1164
+ # 3. down
1165
+ lora_scale = (
1166
+ cross_attention_kwargs.get("scale", 1.0)
1167
+ if cross_attention_kwargs is not None
1168
+ else 1.0
1169
+ )
1170
+ if USE_PEFT_BACKEND:
1171
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1172
+ scale_lora_layers(self, lora_scale)
1173
+
1174
+ is_controlnet = (
1175
+ mid_block_additional_residual is not None
1176
+ and down_block_additional_residuals is not None
1177
+ )
1178
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1179
+ is_adapter = down_intrablock_additional_residuals is not None
1180
+ # maintain backward compatibility for legacy usage, where
1181
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1182
+ # but can only use one or the other
1183
+ if (
1184
+ not is_adapter
1185
+ and mid_block_additional_residual is None
1186
+ and down_block_additional_residuals is not None
1187
+ ):
1188
+ deprecate(
1189
+ "T2I should not use down_block_additional_residuals",
1190
+ "1.3.0",
1191
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1192
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1193
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1194
+ standard_warn=False,
1195
+ )
1196
+ down_intrablock_additional_residuals = down_block_additional_residuals
1197
+ is_adapter = True
1198
+
1199
+ down_block_res_samples = (sample,)
1200
+ tot_referece_features = ()
1201
+ for downsample_block in self.down_blocks:
1202
+ if (
1203
+ hasattr(downsample_block, "has_cross_attention")
1204
+ and downsample_block.has_cross_attention
1205
+ ):
1206
+ # For t2i-adapter CrossAttnDownBlock2D
1207
+ additional_residuals = {}
1208
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1209
+ additional_residuals[
1210
+ "additional_residuals"
1211
+ ] = down_intrablock_additional_residuals.pop(0)
1212
+
1213
+ sample, res_samples = downsample_block(
1214
+ hidden_states=sample,
1215
+ temb=emb,
1216
+ encoder_hidden_states=encoder_hidden_states,
1217
+ attention_mask=attention_mask,
1218
+ cross_attention_kwargs=cross_attention_kwargs,
1219
+ encoder_attention_mask=encoder_attention_mask,
1220
+ **additional_residuals,
1221
+ )
1222
+ else:
1223
+ sample, res_samples = downsample_block(
1224
+ hidden_states=sample, temb=emb, scale=lora_scale
1225
+ )
1226
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1227
+ sample += down_intrablock_additional_residuals.pop(0)
1228
+
1229
+ down_block_res_samples += res_samples
1230
+
1231
+ if is_controlnet:
1232
+ new_down_block_res_samples = ()
1233
+
1234
+ for down_block_res_sample, down_block_additional_residual in zip(
1235
+ down_block_res_samples, down_block_additional_residuals
1236
+ ):
1237
+ down_block_res_sample = (
1238
+ down_block_res_sample + down_block_additional_residual
1239
+ )
1240
+ new_down_block_res_samples = new_down_block_res_samples + (
1241
+ down_block_res_sample,
1242
+ )
1243
+
1244
+ down_block_res_samples = new_down_block_res_samples
1245
+
1246
+ # 4. mid
1247
+ if self.mid_block is not None:
1248
+ if (
1249
+ hasattr(self.mid_block, "has_cross_attention")
1250
+ and self.mid_block.has_cross_attention
1251
+ ):
1252
+ sample = self.mid_block(
1253
+ sample,
1254
+ emb,
1255
+ encoder_hidden_states=encoder_hidden_states,
1256
+ attention_mask=attention_mask,
1257
+ cross_attention_kwargs=cross_attention_kwargs,
1258
+ encoder_attention_mask=encoder_attention_mask,
1259
+ )
1260
+ else:
1261
+ sample = self.mid_block(sample, emb)
1262
+
1263
+ # To support T2I-Adapter-XL
1264
+ if (
1265
+ is_adapter
1266
+ and len(down_intrablock_additional_residuals) > 0
1267
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1268
+ ):
1269
+ sample += down_intrablock_additional_residuals.pop(0)
1270
+
1271
+ if is_controlnet:
1272
+ sample = sample + mid_block_additional_residual
1273
+
1274
+ # 5. up
1275
+ for i, upsample_block in enumerate(self.up_blocks):
1276
+ is_final_block = i == len(self.up_blocks) - 1
1277
+
1278
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1279
+ down_block_res_samples = down_block_res_samples[
1280
+ : -len(upsample_block.resnets)
1281
+ ]
1282
+
1283
+ # if we have not reached the final block and need to forward the
1284
+ # upsample size, we do it here
1285
+ if not is_final_block and forward_upsample_size:
1286
+ upsample_size = down_block_res_samples[-1].shape[2:]
1287
+
1288
+ if (
1289
+ hasattr(upsample_block, "has_cross_attention")
1290
+ and upsample_block.has_cross_attention
1291
+ ):
1292
+ sample = upsample_block(
1293
+ hidden_states=sample,
1294
+ temb=emb,
1295
+ res_hidden_states_tuple=res_samples,
1296
+ encoder_hidden_states=encoder_hidden_states,
1297
+ cross_attention_kwargs=cross_attention_kwargs,
1298
+ upsample_size=upsample_size,
1299
+ attention_mask=attention_mask,
1300
+ encoder_attention_mask=encoder_attention_mask,
1301
+ )
1302
+ else:
1303
+ sample = upsample_block(
1304
+ hidden_states=sample,
1305
+ temb=emb,
1306
+ res_hidden_states_tuple=res_samples,
1307
+ upsample_size=upsample_size,
1308
+ scale=lora_scale,
1309
+ )
1310
+
1311
+ # 6. post-process
1312
+ # if self.conv_norm_out:
1313
+ # sample = self.conv_norm_out(sample)
1314
+ # sample = self.conv_act(sample)
1315
+ # sample = self.conv_out(sample)
1316
+
1317
+ if USE_PEFT_BACKEND:
1318
+ # remove `lora_scale` from each PEFT layer
1319
+ unscale_lora_layers(self, lora_scale)
1320
+
1321
+ if not return_dict:
1322
+ return (sample,)
1323
+
1324
+ return UNet2DConditionOutput(sample=sample)
genwarp/models/unet_3d.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # AnimateDiff
9
+ # Apache License, Version 2.0
10
+ # https://github.com/guoyww/AnimateDiff
11
+ # ==============================================================================
12
+
13
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
14
+
15
+ from collections import OrderedDict
16
+ from dataclasses import dataclass
17
+ from os import PathLike
18
+ from pathlib import Path
19
+ from typing import Dict, List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.utils.checkpoint
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.attention_processor import AttentionProcessor
26
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
29
+ from safetensors.torch import load_file
30
+
31
+ from .resnet import InflatedConv3d, InflatedGroupNorm
32
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ @dataclass
38
+ class UNet3DConditionOutput(BaseOutput):
39
+ sample: torch.FloatTensor
40
+
41
+
42
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
43
+ _supports_gradient_checkpointing = True
44
+
45
+ @register_to_config
46
+ def __init__(
47
+ self,
48
+ sample_size: Optional[int] = None,
49
+ in_channels: int = 4,
50
+ out_channels: int = 4,
51
+ center_input_sample: bool = False,
52
+ flip_sin_to_cos: bool = True,
53
+ freq_shift: int = 0,
54
+ down_block_types: Tuple[str] = (
55
+ "CrossAttnDownBlock3D",
56
+ "CrossAttnDownBlock3D",
57
+ "CrossAttnDownBlock3D",
58
+ "DownBlock3D",
59
+ ),
60
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
61
+ up_block_types: Tuple[str] = (
62
+ "UpBlock3D",
63
+ "CrossAttnUpBlock3D",
64
+ "CrossAttnUpBlock3D",
65
+ "CrossAttnUpBlock3D",
66
+ ),
67
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
68
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
69
+ layers_per_block: int = 2,
70
+ downsample_padding: int = 1,
71
+ mid_block_scale_factor: float = 1,
72
+ act_fn: str = "silu",
73
+ norm_num_groups: int = 32,
74
+ norm_eps: float = 1e-5,
75
+ cross_attention_dim: int = 1280,
76
+ attention_head_dim: Union[int, Tuple[int]] = 8,
77
+ dual_cross_attention: bool = False,
78
+ use_linear_projection: bool = False,
79
+ class_embed_type: Optional[str] = None,
80
+ num_class_embeds: Optional[int] = None,
81
+ upcast_attention: bool = False,
82
+ resnet_time_scale_shift: str = "default",
83
+ use_inflated_groupnorm=False,
84
+ # Additional
85
+ use_motion_module=False,
86
+ motion_module_resolutions=(1, 2, 4, 8),
87
+ motion_module_mid_block=False,
88
+ motion_module_decoder_only=False,
89
+ motion_module_type=None,
90
+ motion_module_kwargs={},
91
+ unet_use_cross_frame_attention=None,
92
+ unet_use_temporal_attention=None,
93
+ use_zero_convs=False,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.sample_size = sample_size
98
+ time_embed_dim = block_out_channels[0] * 4
99
+
100
+ # input
101
+ self.conv_in = InflatedConv3d(
102
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
103
+ )
104
+
105
+ # time
106
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
107
+ timestep_input_dim = block_out_channels[0]
108
+
109
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
110
+
111
+ # class embedding
112
+ if class_embed_type is None and num_class_embeds is not None:
113
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
114
+ elif class_embed_type == "timestep":
115
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
116
+ elif class_embed_type == "identity":
117
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
118
+ else:
119
+ self.class_embedding = None
120
+
121
+ self.down_blocks = nn.ModuleList([])
122
+ self.mid_block = None
123
+ self.up_blocks = nn.ModuleList([])
124
+
125
+ if isinstance(only_cross_attention, bool):
126
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
127
+
128
+ if isinstance(attention_head_dim, int):
129
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
130
+
131
+ # down
132
+ output_channel = block_out_channels[0]
133
+ for i, down_block_type in enumerate(down_block_types):
134
+ res = 2**i
135
+ input_channel = output_channel
136
+ output_channel = block_out_channels[i]
137
+ is_final_block = i == len(block_out_channels) - 1
138
+
139
+ down_block = get_down_block(
140
+ down_block_type,
141
+ num_layers=layers_per_block,
142
+ in_channels=input_channel,
143
+ out_channels=output_channel,
144
+ temb_channels=time_embed_dim,
145
+ add_downsample=not is_final_block,
146
+ resnet_eps=norm_eps,
147
+ resnet_act_fn=act_fn,
148
+ resnet_groups=norm_num_groups,
149
+ cross_attention_dim=cross_attention_dim,
150
+ attn_num_head_channels=attention_head_dim[i],
151
+ downsample_padding=downsample_padding,
152
+ dual_cross_attention=dual_cross_attention,
153
+ use_linear_projection=use_linear_projection,
154
+ only_cross_attention=only_cross_attention[i],
155
+ upcast_attention=upcast_attention,
156
+ resnet_time_scale_shift=resnet_time_scale_shift,
157
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
158
+ unet_use_temporal_attention=unet_use_temporal_attention,
159
+ use_inflated_groupnorm=use_inflated_groupnorm,
160
+ use_motion_module=use_motion_module
161
+ and (res in motion_module_resolutions)
162
+ and (not motion_module_decoder_only),
163
+ motion_module_type=motion_module_type,
164
+ motion_module_kwargs=motion_module_kwargs,
165
+ use_zero_convs=use_zero_convs,
166
+ )
167
+ self.down_blocks.append(down_block)
168
+
169
+ # mid
170
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
171
+ self.mid_block = UNetMidBlock3DCrossAttn(
172
+ in_channels=block_out_channels[-1],
173
+ temb_channels=time_embed_dim,
174
+ resnet_eps=norm_eps,
175
+ resnet_act_fn=act_fn,
176
+ output_scale_factor=mid_block_scale_factor,
177
+ resnet_time_scale_shift=resnet_time_scale_shift,
178
+ cross_attention_dim=cross_attention_dim,
179
+ attn_num_head_channels=attention_head_dim[-1],
180
+ resnet_groups=norm_num_groups,
181
+ dual_cross_attention=dual_cross_attention,
182
+ use_linear_projection=use_linear_projection,
183
+ upcast_attention=upcast_attention,
184
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
185
+ unet_use_temporal_attention=unet_use_temporal_attention,
186
+ use_inflated_groupnorm=use_inflated_groupnorm,
187
+ use_motion_module=use_motion_module and motion_module_mid_block,
188
+ motion_module_type=motion_module_type,
189
+ motion_module_kwargs=motion_module_kwargs,
190
+ use_zero_convs=use_zero_convs,
191
+ )
192
+ else:
193
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
194
+
195
+ # count how many layers upsample the videos
196
+ self.num_upsamplers = 0
197
+
198
+ # up
199
+ reversed_block_out_channels = list(reversed(block_out_channels))
200
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
201
+ only_cross_attention = list(reversed(only_cross_attention))
202
+ output_channel = reversed_block_out_channels[0]
203
+ for i, up_block_type in enumerate(up_block_types):
204
+ res = 2 ** (3 - i)
205
+ is_final_block = i == len(block_out_channels) - 1
206
+
207
+ prev_output_channel = output_channel
208
+ output_channel = reversed_block_out_channels[i]
209
+ input_channel = reversed_block_out_channels[
210
+ min(i + 1, len(block_out_channels) - 1)
211
+ ]
212
+
213
+ # add upsample block for all BUT final layer
214
+ if not is_final_block:
215
+ add_upsample = True
216
+ self.num_upsamplers += 1
217
+ else:
218
+ add_upsample = False
219
+
220
+ up_block = get_up_block(
221
+ up_block_type,
222
+ num_layers=layers_per_block + 1,
223
+ in_channels=input_channel,
224
+ out_channels=output_channel,
225
+ prev_output_channel=prev_output_channel,
226
+ temb_channels=time_embed_dim,
227
+ add_upsample=add_upsample,
228
+ resnet_eps=norm_eps,
229
+ resnet_act_fn=act_fn,
230
+ resnet_groups=norm_num_groups,
231
+ cross_attention_dim=cross_attention_dim,
232
+ attn_num_head_channels=reversed_attention_head_dim[i],
233
+ dual_cross_attention=dual_cross_attention,
234
+ use_linear_projection=use_linear_projection,
235
+ only_cross_attention=only_cross_attention[i],
236
+ upcast_attention=upcast_attention,
237
+ resnet_time_scale_shift=resnet_time_scale_shift,
238
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
239
+ unet_use_temporal_attention=unet_use_temporal_attention,
240
+ use_inflated_groupnorm=use_inflated_groupnorm,
241
+ use_motion_module=use_motion_module
242
+ and (res in motion_module_resolutions),
243
+ motion_module_type=motion_module_type,
244
+ motion_module_kwargs=motion_module_kwargs,
245
+ use_zero_convs=use_zero_convs,
246
+ )
247
+ self.up_blocks.append(up_block)
248
+ prev_output_channel = output_channel
249
+
250
+ # out
251
+ if use_inflated_groupnorm:
252
+ self.conv_norm_out = InflatedGroupNorm(
253
+ num_channels=block_out_channels[0],
254
+ num_groups=norm_num_groups,
255
+ eps=norm_eps,
256
+ )
257
+ else:
258
+ self.conv_norm_out = nn.GroupNorm(
259
+ num_channels=block_out_channels[0],
260
+ num_groups=norm_num_groups,
261
+ eps=norm_eps,
262
+ )
263
+ self.conv_act = nn.SiLU()
264
+ self.conv_out = InflatedConv3d(
265
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
266
+ )
267
+
268
+
269
+
270
+ @property
271
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
272
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
273
+ r"""
274
+ Returns:
275
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
276
+ indexed by its weight name.
277
+ """
278
+ # set recursively
279
+ processors = {}
280
+
281
+ def fn_recursive_add_processors(
282
+ name: str,
283
+ module: torch.nn.Module,
284
+ processors: Dict[str, AttentionProcessor],
285
+ ):
286
+ if hasattr(module, "set_processor"):
287
+ processors[f"{name}.processor"] = module.processor
288
+
289
+ for sub_name, child in module.named_children():
290
+ if "temporal_transformer" not in sub_name:
291
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
292
+
293
+ return processors
294
+
295
+ for name, module in self.named_children():
296
+ if "temporal_transformer" not in name:
297
+ fn_recursive_add_processors(name, module, processors)
298
+
299
+ return processors
300
+
301
+ def set_attention_slice(self, slice_size):
302
+ r"""
303
+ Enable sliced attention computation.
304
+
305
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
306
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
307
+
308
+ Args:
309
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
310
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
311
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
312
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
313
+ must be a multiple of `slice_size`.
314
+ """
315
+ sliceable_head_dims = []
316
+
317
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
318
+ if hasattr(module, "set_attention_slice"):
319
+ sliceable_head_dims.append(module.sliceable_head_dim)
320
+
321
+ for child in module.children():
322
+ fn_recursive_retrieve_slicable_dims(child)
323
+
324
+ # retrieve number of attention layers
325
+ for module in self.children():
326
+ fn_recursive_retrieve_slicable_dims(module)
327
+
328
+ num_slicable_layers = len(sliceable_head_dims)
329
+
330
+ if slice_size == "auto":
331
+ # half the attention head size is usually a good trade-off between
332
+ # speed and memory
333
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
334
+ elif slice_size == "max":
335
+ # make smallest slice possible
336
+ slice_size = num_slicable_layers * [1]
337
+
338
+ slice_size = (
339
+ num_slicable_layers * [slice_size]
340
+ if not isinstance(slice_size, list)
341
+ else slice_size
342
+ )
343
+
344
+ if len(slice_size) != len(sliceable_head_dims):
345
+ raise ValueError(
346
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
347
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
348
+ )
349
+
350
+ for i in range(len(slice_size)):
351
+ size = slice_size[i]
352
+ dim = sliceable_head_dims[i]
353
+ if size is not None and size > dim:
354
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
355
+
356
+ # Recursively walk through all the children.
357
+ # Any children which exposes the set_attention_slice method
358
+ # gets the message
359
+ def fn_recursive_set_attention_slice(
360
+ module: torch.nn.Module, slice_size: List[int]
361
+ ):
362
+ if hasattr(module, "set_attention_slice"):
363
+ module.set_attention_slice(slice_size.pop())
364
+
365
+ for child in module.children():
366
+ fn_recursive_set_attention_slice(child, slice_size)
367
+
368
+ reversed_slice_size = list(reversed(slice_size))
369
+ for module in self.children():
370
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
371
+
372
+ def _set_gradient_checkpointing(self, module, value=False):
373
+ if hasattr(module, "gradient_checkpointing"):
374
+ module.gradient_checkpointing = value
375
+
376
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
377
+ def set_attn_processor(
378
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
379
+ ):
380
+ r"""
381
+ Sets the attention processor to use to compute attention.
382
+
383
+ Parameters:
384
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
385
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
386
+ for **all** `Attention` layers.
387
+
388
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
389
+ processor. This is strongly recommended when setting trainable attention processors.
390
+
391
+ """
392
+ count = len(self.attn_processors.keys())
393
+
394
+ if isinstance(processor, dict) and len(processor) != count:
395
+ raise ValueError(
396
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
397
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
398
+ )
399
+
400
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
401
+ if hasattr(module, "set_processor"):
402
+ if not isinstance(processor, dict):
403
+ module.set_processor(processor)
404
+ else:
405
+ module.set_processor(processor.pop(f"{name}.processor"))
406
+
407
+ for sub_name, child in module.named_children():
408
+ if "temporal_transformer" not in sub_name:
409
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
410
+
411
+ for name, module in self.named_children():
412
+ if "temporal_transformer" not in name:
413
+ fn_recursive_attn_processor(name, module, processor)
414
+
415
+ def forward(
416
+ self,
417
+ sample: torch.FloatTensor,
418
+ timestep: Union[torch.Tensor, float, int],
419
+ encoder_hidden_states: torch.Tensor,
420
+ class_labels: Optional[torch.Tensor] = None,
421
+ pose_cond_fea: Optional[torch.Tensor] = None,
422
+ attention_mask: Optional[torch.Tensor] = None,
423
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
424
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
425
+ return_dict: bool = True,
426
+ ) -> Union[UNet3DConditionOutput, Tuple]:
427
+ r"""
428
+ Args:
429
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
430
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
431
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
432
+ return_dict (`bool`, *optional*, defaults to `True`):
433
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
434
+
435
+ Returns:
436
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
437
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
438
+ returning a tuple, the first element is the sample tensor.
439
+ """
440
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
441
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
442
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
443
+ # on the fly if necessary.
444
+ default_overall_up_factor = 2**self.num_upsamplers
445
+
446
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
447
+ forward_upsample_size = False
448
+ upsample_size = None
449
+
450
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
451
+ logger.info("Forward upsample size to force interpolation output size.")
452
+ forward_upsample_size = True
453
+
454
+ # prepare attention_mask
455
+ if attention_mask is not None:
456
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
457
+ attention_mask = attention_mask.unsqueeze(1)
458
+
459
+ # center input if necessary
460
+ if self.config.center_input_sample:
461
+ sample = 2 * sample - 1.0
462
+
463
+ # time
464
+ timesteps = timestep
465
+ if not torch.is_tensor(timesteps):
466
+ # This would be a good case for the `match` statement (Python 3.10+)
467
+ is_mps = sample.device.type == "mps"
468
+ if isinstance(timestep, float):
469
+ dtype = torch.float32 if is_mps else torch.float64
470
+ else:
471
+ dtype = torch.int32 if is_mps else torch.int64
472
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
473
+ elif len(timesteps.shape) == 0:
474
+ timesteps = timesteps[None].to(sample.device)
475
+
476
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
477
+ timesteps = timesteps.expand(sample.shape[0])
478
+
479
+ t_emb = self.time_proj(timesteps)
480
+
481
+ # timesteps does not contain any weights and will always return f32 tensors
482
+ # but time_embedding might actually be running in fp16. so we need to cast here.
483
+ # there might be better ways to encapsulate this.
484
+ t_emb = t_emb.to(dtype=self.dtype)
485
+ emb = self.time_embedding(t_emb)
486
+
487
+ if self.class_embedding is not None:
488
+ if class_labels is None:
489
+ raise ValueError(
490
+ "class_labels should be provided when num_class_embeds > 0"
491
+ )
492
+
493
+ if self.config.class_embed_type == "timestep":
494
+ class_labels = self.time_proj(class_labels)
495
+
496
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
497
+ emb = emb + class_emb
498
+
499
+ # pre-process
500
+ sample = self.conv_in(sample)
501
+ if pose_cond_fea is not None:
502
+ sample = sample + pose_cond_fea
503
+
504
+ # down
505
+ down_block_res_samples = (sample,)
506
+ for downsample_block in self.down_blocks:
507
+ if (
508
+ hasattr(downsample_block, "has_cross_attention")
509
+ and downsample_block.has_cross_attention
510
+ ):
511
+ sample, res_samples = downsample_block(
512
+ hidden_states=sample,
513
+ temb=emb,
514
+ encoder_hidden_states=encoder_hidden_states,
515
+ attention_mask=attention_mask,
516
+ )
517
+ else:
518
+ sample, res_samples = downsample_block(
519
+ hidden_states=sample,
520
+ temb=emb,
521
+ encoder_hidden_states=encoder_hidden_states,
522
+ )
523
+
524
+ down_block_res_samples += res_samples
525
+
526
+ if down_block_additional_residuals is not None:
527
+ new_down_block_res_samples = ()
528
+
529
+ for down_block_res_sample, down_block_additional_residual in zip(
530
+ down_block_res_samples, down_block_additional_residuals
531
+ ):
532
+ down_block_res_sample = (
533
+ down_block_res_sample + down_block_additional_residual
534
+ )
535
+ new_down_block_res_samples += (down_block_res_sample,)
536
+
537
+ down_block_res_samples = new_down_block_res_samples
538
+
539
+ # mid
540
+ sample = self.mid_block(
541
+ sample,
542
+ emb,
543
+ encoder_hidden_states=encoder_hidden_states,
544
+ attention_mask=attention_mask,
545
+ )
546
+
547
+ if mid_block_additional_residual is not None:
548
+ sample = sample + mid_block_additional_residual
549
+
550
+ # up
551
+ for i, upsample_block in enumerate(self.up_blocks):
552
+ is_final_block = i == len(self.up_blocks) - 1
553
+
554
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
555
+ down_block_res_samples = down_block_res_samples[
556
+ : -len(upsample_block.resnets)
557
+ ]
558
+
559
+ # if we have not reached the final block and need to forward the
560
+ # upsample size, we do it here
561
+ if not is_final_block and forward_upsample_size:
562
+ upsample_size = down_block_res_samples[-1].shape[2:]
563
+
564
+ if (
565
+ hasattr(upsample_block, "has_cross_attention")
566
+ and upsample_block.has_cross_attention
567
+ ):
568
+ sample = upsample_block(
569
+ hidden_states=sample,
570
+ temb=emb,
571
+ res_hidden_states_tuple=res_samples,
572
+ encoder_hidden_states=encoder_hidden_states,
573
+ upsample_size=upsample_size,
574
+ attention_mask=attention_mask,
575
+ )
576
+ else:
577
+ sample = upsample_block(
578
+ hidden_states=sample,
579
+ temb=emb,
580
+ res_hidden_states_tuple=res_samples,
581
+ upsample_size=upsample_size,
582
+ encoder_hidden_states=encoder_hidden_states,
583
+ )
584
+
585
+ # post-process
586
+ sample = self.conv_norm_out(sample)
587
+ sample = self.conv_act(sample)
588
+ sample = self.conv_out(sample)
589
+
590
+ if not return_dict:
591
+ return (sample,)
592
+
593
+ return UNet3DConditionOutput(sample=sample)
594
+
595
+ @classmethod
596
+ def from_pretrained_2d(
597
+ cls,
598
+ config_file: PathLike,
599
+ ckpt_file: PathLike
600
+ ):
601
+ unet_additional_kwargs={
602
+ "use_motion_module": False,
603
+ "unet_use_temporal_attention": False,
604
+ "use_zero_convs": False
605
+ }
606
+
607
+ config_file = Path(config_file)
608
+ ckpt_file = Path(ckpt_file)
609
+
610
+ if not (config_file.exists() and config_file.is_file()):
611
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
612
+ if not (config_file.exists() and config_file.is_file()):
613
+ raise RuntimeError(f"{ckpt_file} does not exist or is not a file")
614
+
615
+ unet_config = cls.load_config(config_file)
616
+ unet_config["_class_name"] = cls.__name__
617
+ unet_config["down_block_types"] = [
618
+ "CrossAttnDownBlock3D",
619
+ "CrossAttnDownBlock3D",
620
+ "CrossAttnDownBlock3D",
621
+ "DownBlock3D",
622
+ ]
623
+ unet_config["up_block_types"] = [
624
+ "UpBlock3D",
625
+ "CrossAttnUpBlock3D",
626
+ "CrossAttnUpBlock3D",
627
+ "CrossAttnUpBlock3D",
628
+ ]
629
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
630
+
631
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
632
+ state_dict = torch.load(
633
+ ckpt_file, map_location="cpu", weights_only=True,
634
+ )
635
+
636
+ # load the weights into the model
637
+ m, u = model.load_state_dict(state_dict, strict=False)
638
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
639
+
640
+ params = [
641
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
642
+ ]
643
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
644
+
645
+ return model
genwarp/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is adapted from below and then modified.
2
+ # -----------------------------------------------------------------------------
3
+ # Moore-AnimateAnyone
4
+ # Apache License, Version 2.0
5
+ # Copyright @2023-2024 Moore Threads Technology Co., Ltd.
6
+ # https://github.com/MooreThreads/Moore-AnimateAnyone
7
+ # -----------------------------------------------------------------------------
8
+ # Diffusers
9
+ # Apache License, Version 2.0
10
+ # Copyright (c) Hugging Face Inc.
11
+ # https://github.com/huggingface/diffusers
12
+ # ==============================================================================
13
+
14
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
15
+
16
+ import pdb
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from .motion_module import get_motion_module
22
+
23
+ # from .motion_module import get_motion_module
24
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
25
+ from .transformer_3d import Transformer3DModel
26
+
27
+
28
+ def get_down_block(
29
+ down_block_type,
30
+ num_layers,
31
+ in_channels,
32
+ out_channels,
33
+ temb_channels,
34
+ add_downsample,
35
+ resnet_eps,
36
+ resnet_act_fn,
37
+ attn_num_head_channels,
38
+ resnet_groups=None,
39
+ cross_attention_dim=None,
40
+ downsample_padding=None,
41
+ dual_cross_attention=False,
42
+ use_linear_projection=False,
43
+ only_cross_attention=False,
44
+ upcast_attention=False,
45
+ resnet_time_scale_shift="default",
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ use_inflated_groupnorm=None,
49
+ use_motion_module=None,
50
+ motion_module_type=None,
51
+ motion_module_kwargs=None,
52
+ use_zero_convs=False,
53
+ ):
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock3D":
60
+ return DownBlock3D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ add_downsample=add_downsample,
66
+ resnet_eps=resnet_eps,
67
+ resnet_act_fn=resnet_act_fn,
68
+ resnet_groups=resnet_groups,
69
+ downsample_padding=downsample_padding,
70
+ resnet_time_scale_shift=resnet_time_scale_shift,
71
+ use_inflated_groupnorm=use_inflated_groupnorm,
72
+ use_motion_module=use_motion_module,
73
+ motion_module_type=motion_module_type,
74
+ motion_module_kwargs=motion_module_kwargs,
75
+ )
76
+ elif down_block_type == "CrossAttnDownBlock3D":
77
+ if cross_attention_dim is None:
78
+ raise ValueError(
79
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
80
+ )
81
+ return CrossAttnDownBlock3D(
82
+ num_layers=num_layers,
83
+ in_channels=in_channels,
84
+ out_channels=out_channels,
85
+ temb_channels=temb_channels,
86
+ add_downsample=add_downsample,
87
+ resnet_eps=resnet_eps,
88
+ resnet_act_fn=resnet_act_fn,
89
+ resnet_groups=resnet_groups,
90
+ downsample_padding=downsample_padding,
91
+ cross_attention_dim=cross_attention_dim,
92
+ attn_num_head_channels=attn_num_head_channels,
93
+ dual_cross_attention=dual_cross_attention,
94
+ use_linear_projection=use_linear_projection,
95
+ only_cross_attention=only_cross_attention,
96
+ upcast_attention=upcast_attention,
97
+ resnet_time_scale_shift=resnet_time_scale_shift,
98
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
99
+ unet_use_temporal_attention=unet_use_temporal_attention,
100
+ use_inflated_groupnorm=use_inflated_groupnorm,
101
+ use_motion_module=use_motion_module,
102
+ motion_module_type=motion_module_type,
103
+ motion_module_kwargs=motion_module_kwargs,
104
+ use_zero_convs=use_zero_convs,
105
+ )
106
+ raise ValueError(f"{down_block_type} does not exist.")
107
+
108
+
109
+ def get_up_block(
110
+ up_block_type,
111
+ num_layers,
112
+ in_channels,
113
+ out_channels,
114
+ prev_output_channel,
115
+ temb_channels,
116
+ add_upsample,
117
+ resnet_eps,
118
+ resnet_act_fn,
119
+ attn_num_head_channels,
120
+ resnet_groups=None,
121
+ cross_attention_dim=None,
122
+ dual_cross_attention=False,
123
+ use_linear_projection=False,
124
+ only_cross_attention=False,
125
+ upcast_attention=False,
126
+ resnet_time_scale_shift="default",
127
+ unet_use_cross_frame_attention=None,
128
+ unet_use_temporal_attention=None,
129
+ use_inflated_groupnorm=None,
130
+ use_motion_module=None,
131
+ motion_module_type=None,
132
+ motion_module_kwargs=None,
133
+ use_zero_convs=False,
134
+ ):
135
+ up_block_type = (
136
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
137
+ )
138
+ if up_block_type == "UpBlock3D":
139
+ return UpBlock3D(
140
+ num_layers=num_layers,
141
+ in_channels=in_channels,
142
+ out_channels=out_channels,
143
+ prev_output_channel=prev_output_channel,
144
+ temb_channels=temb_channels,
145
+ add_upsample=add_upsample,
146
+ resnet_eps=resnet_eps,
147
+ resnet_act_fn=resnet_act_fn,
148
+ resnet_groups=resnet_groups,
149
+ resnet_time_scale_shift=resnet_time_scale_shift,
150
+ use_inflated_groupnorm=use_inflated_groupnorm,
151
+ use_motion_module=use_motion_module,
152
+ motion_module_type=motion_module_type,
153
+ motion_module_kwargs=motion_module_kwargs,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock3D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
159
+ )
160
+ return CrossAttnUpBlock3D(
161
+ num_layers=num_layers,
162
+ in_channels=in_channels,
163
+ out_channels=out_channels,
164
+ prev_output_channel=prev_output_channel,
165
+ temb_channels=temb_channels,
166
+ add_upsample=add_upsample,
167
+ resnet_eps=resnet_eps,
168
+ resnet_act_fn=resnet_act_fn,
169
+ resnet_groups=resnet_groups,
170
+ cross_attention_dim=cross_attention_dim,
171
+ attn_num_head_channels=attn_num_head_channels,
172
+ dual_cross_attention=dual_cross_attention,
173
+ use_linear_projection=use_linear_projection,
174
+ only_cross_attention=only_cross_attention,
175
+ upcast_attention=upcast_attention,
176
+ resnet_time_scale_shift=resnet_time_scale_shift,
177
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
178
+ unet_use_temporal_attention=unet_use_temporal_attention,
179
+ use_inflated_groupnorm=use_inflated_groupnorm,
180
+ use_motion_module=use_motion_module,
181
+ motion_module_type=motion_module_type,
182
+ motion_module_kwargs=motion_module_kwargs,
183
+ use_zero_convs=use_zero_convs,
184
+ )
185
+ raise ValueError(f"{up_block_type} does not exist.")
186
+
187
+
188
+ class UNetMidBlock3DCrossAttn(nn.Module):
189
+ def __init__(
190
+ self,
191
+ in_channels: int,
192
+ temb_channels: int,
193
+ dropout: float = 0.0,
194
+ num_layers: int = 1,
195
+ resnet_eps: float = 1e-6,
196
+ resnet_time_scale_shift: str = "default",
197
+ resnet_act_fn: str = "swish",
198
+ resnet_groups: int = 32,
199
+ resnet_pre_norm: bool = True,
200
+ attn_num_head_channels=1,
201
+ output_scale_factor=1.0,
202
+ cross_attention_dim=1280,
203
+ dual_cross_attention=False,
204
+ use_linear_projection=False,
205
+ upcast_attention=False,
206
+ unet_use_cross_frame_attention=None,
207
+ unet_use_temporal_attention=None,
208
+ use_inflated_groupnorm=None,
209
+ use_motion_module=None,
210
+ motion_module_type=None,
211
+ motion_module_kwargs=None,
212
+ use_zero_convs=False,
213
+ ):
214
+ super().__init__()
215
+
216
+ self.has_cross_attention = True
217
+ self.attn_num_head_channels = attn_num_head_channels
218
+ resnet_groups = (
219
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
220
+ )
221
+
222
+ # there is always at least one resnet
223
+ resnets = [
224
+ ResnetBlock3D(
225
+ in_channels=in_channels,
226
+ out_channels=in_channels,
227
+ temb_channels=temb_channels,
228
+ eps=resnet_eps,
229
+ groups=resnet_groups,
230
+ dropout=dropout,
231
+ time_embedding_norm=resnet_time_scale_shift,
232
+ non_linearity=resnet_act_fn,
233
+ output_scale_factor=output_scale_factor,
234
+ pre_norm=resnet_pre_norm,
235
+ use_inflated_groupnorm=use_inflated_groupnorm,
236
+ )
237
+ ]
238
+ attentions = []
239
+ motion_modules = []
240
+
241
+ for _ in range(num_layers):
242
+ if dual_cross_attention:
243
+ raise NotImplementedError
244
+ attentions.append(
245
+ Transformer3DModel(
246
+ attn_num_head_channels,
247
+ in_channels // attn_num_head_channels,
248
+ in_channels=in_channels,
249
+ num_layers=1,
250
+ cross_attention_dim=cross_attention_dim,
251
+ norm_num_groups=resnet_groups,
252
+ use_linear_projection=use_linear_projection,
253
+ upcast_attention=upcast_attention,
254
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
255
+ unet_use_temporal_attention=unet_use_temporal_attention,
256
+ use_zero_convs=use_zero_convs,
257
+ )
258
+ )
259
+ motion_modules.append(
260
+ get_motion_module(
261
+ in_channels=in_channels,
262
+ motion_module_type=motion_module_type,
263
+ motion_module_kwargs=motion_module_kwargs,
264
+ )
265
+ if use_motion_module
266
+ else None
267
+ )
268
+ resnets.append(
269
+ ResnetBlock3D(
270
+ in_channels=in_channels,
271
+ out_channels=in_channels,
272
+ temb_channels=temb_channels,
273
+ eps=resnet_eps,
274
+ groups=resnet_groups,
275
+ dropout=dropout,
276
+ time_embedding_norm=resnet_time_scale_shift,
277
+ non_linearity=resnet_act_fn,
278
+ output_scale_factor=output_scale_factor,
279
+ pre_norm=resnet_pre_norm,
280
+ use_inflated_groupnorm=use_inflated_groupnorm,
281
+ )
282
+ )
283
+
284
+ self.attentions = nn.ModuleList(attentions)
285
+ self.resnets = nn.ModuleList(resnets)
286
+ self.motion_modules = nn.ModuleList(motion_modules)
287
+
288
+ def forward(
289
+ self,
290
+ hidden_states,
291
+ temb=None,
292
+ encoder_hidden_states=None,
293
+ attention_mask=None,
294
+ ):
295
+ hidden_states = self.resnets[0](hidden_states, temb)
296
+ for attn, resnet, motion_module in zip(
297
+ self.attentions, self.resnets[1:], self.motion_modules
298
+ ):
299
+ hidden_states = attn(
300
+ hidden_states,
301
+ encoder_hidden_states=encoder_hidden_states,
302
+ ).sample
303
+ hidden_states = (
304
+ motion_module(
305
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
306
+ )
307
+ if motion_module is not None
308
+ else hidden_states
309
+ )
310
+ hidden_states = resnet(hidden_states, temb)
311
+
312
+ return hidden_states
313
+
314
+
315
+ class CrossAttnDownBlock3D(nn.Module):
316
+ def __init__(
317
+ self,
318
+ in_channels: int,
319
+ out_channels: int,
320
+ temb_channels: int,
321
+ dropout: float = 0.0,
322
+ num_layers: int = 1,
323
+ resnet_eps: float = 1e-6,
324
+ resnet_time_scale_shift: str = "default",
325
+ resnet_act_fn: str = "swish",
326
+ resnet_groups: int = 32,
327
+ resnet_pre_norm: bool = True,
328
+ attn_num_head_channels=1,
329
+ cross_attention_dim=1280,
330
+ output_scale_factor=1.0,
331
+ downsample_padding=1,
332
+ add_downsample=True,
333
+ dual_cross_attention=False,
334
+ use_linear_projection=False,
335
+ only_cross_attention=False,
336
+ upcast_attention=False,
337
+ unet_use_cross_frame_attention=None,
338
+ unet_use_temporal_attention=None,
339
+ use_inflated_groupnorm=None,
340
+ use_motion_module=None,
341
+ motion_module_type=None,
342
+ motion_module_kwargs=None,
343
+ use_zero_convs=False,
344
+ ):
345
+ super().__init__()
346
+ resnets = []
347
+ attentions = []
348
+ motion_modules = []
349
+
350
+ self.has_cross_attention = True
351
+ self.attn_num_head_channels = attn_num_head_channels
352
+
353
+ for i in range(num_layers):
354
+ in_channels = in_channels if i == 0 else out_channels
355
+ resnets.append(
356
+ ResnetBlock3D(
357
+ in_channels=in_channels,
358
+ out_channels=out_channels,
359
+ temb_channels=temb_channels,
360
+ eps=resnet_eps,
361
+ groups=resnet_groups,
362
+ dropout=dropout,
363
+ time_embedding_norm=resnet_time_scale_shift,
364
+ non_linearity=resnet_act_fn,
365
+ output_scale_factor=output_scale_factor,
366
+ pre_norm=resnet_pre_norm,
367
+ use_inflated_groupnorm=use_inflated_groupnorm,
368
+ )
369
+ )
370
+ if dual_cross_attention:
371
+ raise NotImplementedError
372
+ attentions.append(
373
+ Transformer3DModel(
374
+ attn_num_head_channels,
375
+ out_channels // attn_num_head_channels,
376
+ in_channels=out_channels,
377
+ num_layers=1,
378
+ cross_attention_dim=cross_attention_dim,
379
+ norm_num_groups=resnet_groups,
380
+ use_linear_projection=use_linear_projection,
381
+ only_cross_attention=only_cross_attention,
382
+ upcast_attention=upcast_attention,
383
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
384
+ unet_use_temporal_attention=unet_use_temporal_attention,
385
+ use_zero_convs=use_zero_convs,
386
+ )
387
+ )
388
+ motion_modules.append(
389
+ get_motion_module(
390
+ in_channels=out_channels,
391
+ motion_module_type=motion_module_type,
392
+ motion_module_kwargs=motion_module_kwargs,
393
+ )
394
+ if use_motion_module
395
+ else None
396
+ )
397
+
398
+ self.attentions = nn.ModuleList(attentions)
399
+ self.resnets = nn.ModuleList(resnets)
400
+ self.motion_modules = nn.ModuleList(motion_modules)
401
+
402
+ if add_downsample:
403
+ self.downsamplers = nn.ModuleList(
404
+ [
405
+ Downsample3D(
406
+ out_channels,
407
+ use_conv=True,
408
+ out_channels=out_channels,
409
+ padding=downsample_padding,
410
+ name="op",
411
+ )
412
+ ]
413
+ )
414
+ else:
415
+ self.downsamplers = None
416
+
417
+ self.gradient_checkpointing = False
418
+
419
+ def forward(
420
+ self,
421
+ hidden_states,
422
+ temb=None,
423
+ encoder_hidden_states=None,
424
+ attention_mask=None,
425
+ ):
426
+ output_states = ()
427
+
428
+ for i, (resnet, attn, motion_module) in enumerate(
429
+ zip(self.resnets, self.attentions, self.motion_modules)
430
+ ):
431
+ # self.gradient_checkpointing = False
432
+ if self.training and self.gradient_checkpointing:
433
+
434
+ def create_custom_forward(module, return_dict=None):
435
+ def custom_forward(*inputs):
436
+ if return_dict is not None:
437
+ return module(*inputs, return_dict=return_dict)
438
+ else:
439
+ return module(*inputs)
440
+
441
+ return custom_forward
442
+
443
+ hidden_states = torch.utils.checkpoint.checkpoint(
444
+ create_custom_forward(resnet), hidden_states, temb
445
+ )
446
+ hidden_states = torch.utils.checkpoint.checkpoint(
447
+ create_custom_forward(attn, return_dict=False),
448
+ hidden_states,
449
+ encoder_hidden_states,
450
+ )[0]
451
+
452
+ # add motion module
453
+ hidden_states = (
454
+ motion_module(
455
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
456
+ )
457
+ if motion_module is not None
458
+ else hidden_states
459
+ )
460
+
461
+ else:
462
+ hidden_states = resnet(hidden_states, temb)
463
+ hidden_states = attn(
464
+ hidden_states,
465
+ encoder_hidden_states=encoder_hidden_states,
466
+ ).sample
467
+
468
+ # add motion module
469
+ hidden_states = (
470
+ motion_module(
471
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
472
+ )
473
+ if motion_module is not None
474
+ else hidden_states
475
+ )
476
+
477
+ output_states += (hidden_states,)
478
+
479
+ if self.downsamplers is not None:
480
+ for downsampler in self.downsamplers:
481
+ hidden_states = downsampler(hidden_states)
482
+
483
+ output_states += (hidden_states,)
484
+
485
+ return hidden_states, output_states
486
+
487
+
488
+ class DownBlock3D(nn.Module):
489
+ def __init__(
490
+ self,
491
+ in_channels: int,
492
+ out_channels: int,
493
+ temb_channels: int,
494
+ dropout: float = 0.0,
495
+ num_layers: int = 1,
496
+ resnet_eps: float = 1e-6,
497
+ resnet_time_scale_shift: str = "default",
498
+ resnet_act_fn: str = "swish",
499
+ resnet_groups: int = 32,
500
+ resnet_pre_norm: bool = True,
501
+ output_scale_factor=1.0,
502
+ add_downsample=True,
503
+ downsample_padding=1,
504
+ use_inflated_groupnorm=None,
505
+ use_motion_module=None,
506
+ motion_module_type=None,
507
+ motion_module_kwargs=None,
508
+ ):
509
+ super().__init__()
510
+ resnets = []
511
+ motion_modules = []
512
+
513
+ # use_motion_module = False
514
+ for i in range(num_layers):
515
+ in_channels = in_channels if i == 0 else out_channels
516
+ resnets.append(
517
+ ResnetBlock3D(
518
+ in_channels=in_channels,
519
+ out_channels=out_channels,
520
+ temb_channels=temb_channels,
521
+ eps=resnet_eps,
522
+ groups=resnet_groups,
523
+ dropout=dropout,
524
+ time_embedding_norm=resnet_time_scale_shift,
525
+ non_linearity=resnet_act_fn,
526
+ output_scale_factor=output_scale_factor,
527
+ pre_norm=resnet_pre_norm,
528
+ use_inflated_groupnorm=use_inflated_groupnorm,
529
+ )
530
+ )
531
+ motion_modules.append(
532
+ get_motion_module(
533
+ in_channels=out_channels,
534
+ motion_module_type=motion_module_type,
535
+ motion_module_kwargs=motion_module_kwargs,
536
+ )
537
+ if use_motion_module
538
+ else None
539
+ )
540
+
541
+ self.resnets = nn.ModuleList(resnets)
542
+ self.motion_modules = nn.ModuleList(motion_modules)
543
+
544
+ if add_downsample:
545
+ self.downsamplers = nn.ModuleList(
546
+ [
547
+ Downsample3D(
548
+ out_channels,
549
+ use_conv=True,
550
+ out_channels=out_channels,
551
+ padding=downsample_padding,
552
+ name="op",
553
+ )
554
+ ]
555
+ )
556
+ else:
557
+ self.downsamplers = None
558
+
559
+ self.gradient_checkpointing = False
560
+
561
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
562
+ output_states = ()
563
+
564
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
565
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
566
+ if self.training and self.gradient_checkpointing:
567
+
568
+ def create_custom_forward(module):
569
+ def custom_forward(*inputs):
570
+ return module(*inputs)
571
+
572
+ return custom_forward
573
+
574
+ hidden_states = torch.utils.checkpoint.checkpoint(
575
+ create_custom_forward(resnet), hidden_states, temb
576
+ )
577
+ if motion_module is not None:
578
+ hidden_states = torch.utils.checkpoint.checkpoint(
579
+ create_custom_forward(motion_module),
580
+ hidden_states.requires_grad_(),
581
+ temb,
582
+ encoder_hidden_states,
583
+ )
584
+ else:
585
+ hidden_states = resnet(hidden_states, temb)
586
+
587
+ # add motion module
588
+ hidden_states = (
589
+ motion_module(
590
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
591
+ )
592
+ if motion_module is not None
593
+ else hidden_states
594
+ )
595
+
596
+ output_states += (hidden_states,)
597
+
598
+ if self.downsamplers is not None:
599
+ for downsampler in self.downsamplers:
600
+ hidden_states = downsampler(hidden_states)
601
+
602
+ output_states += (hidden_states,)
603
+
604
+ return hidden_states, output_states
605
+
606
+
607
+ class CrossAttnUpBlock3D(nn.Module):
608
+ def __init__(
609
+ self,
610
+ in_channels: int,
611
+ out_channels: int,
612
+ prev_output_channel: int,
613
+ temb_channels: int,
614
+ dropout: float = 0.0,
615
+ num_layers: int = 1,
616
+ resnet_eps: float = 1e-6,
617
+ resnet_time_scale_shift: str = "default",
618
+ resnet_act_fn: str = "swish",
619
+ resnet_groups: int = 32,
620
+ resnet_pre_norm: bool = True,
621
+ attn_num_head_channels=1,
622
+ cross_attention_dim=1280,
623
+ output_scale_factor=1.0,
624
+ add_upsample=True,
625
+ dual_cross_attention=False,
626
+ use_linear_projection=False,
627
+ only_cross_attention=False,
628
+ upcast_attention=False,
629
+ unet_use_cross_frame_attention=None,
630
+ unet_use_temporal_attention=None,
631
+ use_motion_module=None,
632
+ use_inflated_groupnorm=None,
633
+ motion_module_type=None,
634
+ motion_module_kwargs=None,
635
+ use_zero_convs=False,
636
+ ):
637
+ super().__init__()
638
+ resnets = []
639
+ attentions = []
640
+ motion_modules = []
641
+
642
+ self.has_cross_attention = True
643
+ self.attn_num_head_channels = attn_num_head_channels
644
+
645
+ for i in range(num_layers):
646
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
647
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
648
+
649
+ resnets.append(
650
+ ResnetBlock3D(
651
+ in_channels=resnet_in_channels + res_skip_channels,
652
+ out_channels=out_channels,
653
+ temb_channels=temb_channels,
654
+ eps=resnet_eps,
655
+ groups=resnet_groups,
656
+ dropout=dropout,
657
+ time_embedding_norm=resnet_time_scale_shift,
658
+ non_linearity=resnet_act_fn,
659
+ output_scale_factor=output_scale_factor,
660
+ pre_norm=resnet_pre_norm,
661
+ use_inflated_groupnorm=use_inflated_groupnorm,
662
+ )
663
+ )
664
+ if dual_cross_attention:
665
+ raise NotImplementedError
666
+ attentions.append(
667
+ Transformer3DModel(
668
+ attn_num_head_channels,
669
+ out_channels // attn_num_head_channels,
670
+ in_channels=out_channels,
671
+ num_layers=1,
672
+ cross_attention_dim=cross_attention_dim,
673
+ norm_num_groups=resnet_groups,
674
+ use_linear_projection=use_linear_projection,
675
+ only_cross_attention=only_cross_attention,
676
+ upcast_attention=upcast_attention,
677
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
678
+ unet_use_temporal_attention=unet_use_temporal_attention,
679
+ use_zero_convs=use_zero_convs,
680
+ )
681
+ )
682
+ motion_modules.append(
683
+ get_motion_module(
684
+ in_channels=out_channels,
685
+ motion_module_type=motion_module_type,
686
+ motion_module_kwargs=motion_module_kwargs,
687
+ )
688
+ if use_motion_module
689
+ else None
690
+ )
691
+
692
+ self.attentions = nn.ModuleList(attentions)
693
+ self.resnets = nn.ModuleList(resnets)
694
+ self.motion_modules = nn.ModuleList(motion_modules)
695
+
696
+ if add_upsample:
697
+ self.upsamplers = nn.ModuleList(
698
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
699
+ )
700
+ else:
701
+ self.upsamplers = None
702
+
703
+ self.gradient_checkpointing = False
704
+
705
+ def forward(
706
+ self,
707
+ hidden_states,
708
+ res_hidden_states_tuple,
709
+ temb=None,
710
+ encoder_hidden_states=None,
711
+ upsample_size=None,
712
+ attention_mask=None,
713
+ ):
714
+ for i, (resnet, attn, motion_module) in enumerate(
715
+ zip(self.resnets, self.attentions, self.motion_modules)
716
+ ):
717
+ # pop res hidden states
718
+ res_hidden_states = res_hidden_states_tuple[-1]
719
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
720
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
721
+
722
+ if self.training and self.gradient_checkpointing:
723
+
724
+ def create_custom_forward(module, return_dict=None):
725
+ def custom_forward(*inputs):
726
+ if return_dict is not None:
727
+ return module(*inputs, return_dict=return_dict)
728
+ else:
729
+ return module(*inputs)
730
+
731
+ return custom_forward
732
+
733
+ hidden_states = torch.utils.checkpoint.checkpoint(
734
+ create_custom_forward(resnet), hidden_states, temb
735
+ )
736
+ hidden_states = attn(
737
+ hidden_states,
738
+ encoder_hidden_states=encoder_hidden_states,
739
+ ).sample
740
+ if motion_module is not None:
741
+ hidden_states = torch.utils.checkpoint.checkpoint(
742
+ create_custom_forward(motion_module),
743
+ hidden_states.requires_grad_(),
744
+ temb,
745
+ encoder_hidden_states,
746
+ )
747
+
748
+ else:
749
+ hidden_states = resnet(hidden_states, temb)
750
+ hidden_states = attn(
751
+ hidden_states,
752
+ encoder_hidden_states=encoder_hidden_states,
753
+ ).sample
754
+
755
+ # add motion module
756
+ hidden_states = (
757
+ motion_module(
758
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
759
+ )
760
+ if motion_module is not None
761
+ else hidden_states
762
+ )
763
+
764
+ if self.upsamplers is not None:
765
+ for upsampler in self.upsamplers:
766
+ hidden_states = upsampler(hidden_states, upsample_size)
767
+
768
+ return hidden_states
769
+
770
+
771
+ class UpBlock3D(nn.Module):
772
+ def __init__(
773
+ self,
774
+ in_channels: int,
775
+ prev_output_channel: int,
776
+ out_channels: int,
777
+ temb_channels: int,
778
+ dropout: float = 0.0,
779
+ num_layers: int = 1,
780
+ resnet_eps: float = 1e-6,
781
+ resnet_time_scale_shift: str = "default",
782
+ resnet_act_fn: str = "swish",
783
+ resnet_groups: int = 32,
784
+ resnet_pre_norm: bool = True,
785
+ output_scale_factor=1.0,
786
+ add_upsample=True,
787
+ use_inflated_groupnorm=None,
788
+ use_motion_module=None,
789
+ motion_module_type=None,
790
+ motion_module_kwargs=None,
791
+ ):
792
+ super().__init__()
793
+ resnets = []
794
+ motion_modules = []
795
+
796
+ # use_motion_module = False
797
+ for i in range(num_layers):
798
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
799
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
800
+
801
+ resnets.append(
802
+ ResnetBlock3D(
803
+ in_channels=resnet_in_channels + res_skip_channels,
804
+ out_channels=out_channels,
805
+ temb_channels=temb_channels,
806
+ eps=resnet_eps,
807
+ groups=resnet_groups,
808
+ dropout=dropout,
809
+ time_embedding_norm=resnet_time_scale_shift,
810
+ non_linearity=resnet_act_fn,
811
+ output_scale_factor=output_scale_factor,
812
+ pre_norm=resnet_pre_norm,
813
+ use_inflated_groupnorm=use_inflated_groupnorm,
814
+ )
815
+ )
816
+ motion_modules.append(
817
+ get_motion_module(
818
+ in_channels=out_channels,
819
+ motion_module_type=motion_module_type,
820
+ motion_module_kwargs=motion_module_kwargs,
821
+ )
822
+ if use_motion_module
823
+ else None
824
+ )
825
+
826
+ self.resnets = nn.ModuleList(resnets)
827
+ self.motion_modules = nn.ModuleList(motion_modules)
828
+
829
+ if add_upsample:
830
+ self.upsamplers = nn.ModuleList(
831
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
832
+ )
833
+ else:
834
+ self.upsamplers = None
835
+
836
+ self.gradient_checkpointing = False
837
+
838
+ def forward(
839
+ self,
840
+ hidden_states,
841
+ res_hidden_states_tuple,
842
+ temb=None,
843
+ upsample_size=None,
844
+ encoder_hidden_states=None,
845
+ ):
846
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
847
+ # pop res hidden states
848
+ res_hidden_states = res_hidden_states_tuple[-1]
849
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
850
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
851
+
852
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
853
+ if self.training and self.gradient_checkpointing:
854
+
855
+ def create_custom_forward(module):
856
+ def custom_forward(*inputs):
857
+ return module(*inputs)
858
+
859
+ return custom_forward
860
+
861
+ hidden_states = torch.utils.checkpoint.checkpoint(
862
+ create_custom_forward(resnet), hidden_states, temb
863
+ )
864
+ if motion_module is not None:
865
+ hidden_states = torch.utils.checkpoint.checkpoint(
866
+ create_custom_forward(motion_module),
867
+ hidden_states.requires_grad_(),
868
+ temb,
869
+ encoder_hidden_states,
870
+ )
871
+ else:
872
+ hidden_states = resnet(hidden_states, temb)
873
+ hidden_states = (
874
+ motion_module(
875
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
876
+ )
877
+ if motion_module is not None
878
+ else hidden_states
879
+ )
880
+
881
+ if self.upsamplers is not None:
882
+ for upsampler in self.upsamplers:
883
+ hidden_states = upsampler(hidden_states, upsample_size)
884
+
885
+ return hidden_states
genwarp/ops.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from jaxtyping import Float
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn.functional as F
8
+
9
+ from einops import rearrange
10
+ from splatting import splatting_function
11
+
12
+ def sph2cart(
13
+ azi: Float[Tensor, 'B'],
14
+ ele: Float[Tensor, 'B'],
15
+ r: Float[Tensor, 'B']
16
+ ) -> Float[Tensor, 'B 3']:
17
+ # z-up, y-right, x-back
18
+ rcos = r * torch.cos(ele)
19
+ pos_cart = torch.stack([
20
+ rcos * torch.cos(azi),
21
+ rcos * torch.sin(azi),
22
+ r * torch.sin(ele)
23
+ ], dim=1)
24
+
25
+ return pos_cart
26
+
27
+ def get_viewport_matrix(
28
+ width: int,
29
+ height: int,
30
+ batch_size: int=1,
31
+ device: torch.device=None,
32
+ ) -> Float[Tensor, 'B 4 4']:
33
+ N = torch.tensor(
34
+ [[width/2, 0, 0, width/2],
35
+ [0, height/2, 0, height/2],
36
+ [0, 0, 1/2, 1/2],
37
+ [0, 0, 0, 1]],
38
+ dtype=torch.float32,
39
+ device=device
40
+ )[None].repeat(batch_size, 1, 1)
41
+ return N
42
+
43
+ def get_projection_matrix(
44
+ fovy: Float[Tensor, 'B'],
45
+ aspect_wh: float,
46
+ near: float,
47
+ far: float
48
+ ) -> Float[Tensor, 'B 4 4']:
49
+ batch_size = fovy.shape[0]
50
+ proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
51
+ proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
52
+ proj_mtx[:, 1, 1] = -1.0 / torch.tan(
53
+ fovy / 2.0
54
+ ) # add a negative sign here as the y axis is flipped in nvdiffrast output
55
+ proj_mtx[:, 2, 2] = -(far + near) / (far - near)
56
+ proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
57
+ proj_mtx[:, 3, 2] = -1.0
58
+ return proj_mtx
59
+
60
+ def camera_lookat(
61
+ eye: Float[Tensor, 'B 3'],
62
+ target: Float[Tensor, 'B 3'],
63
+ up: Float[Tensor, 'B 3']
64
+ ) -> Float[Tensor, 'B 4 4']:
65
+ B = eye.shape[0]
66
+ f = F.normalize(eye - target)
67
+ l = F.normalize(torch.linalg.cross(up, f))
68
+ u = F.normalize(torch.linalg.cross(f, l))
69
+
70
+ R = torch.stack((l, u, f), dim=1) # B 3 3
71
+ M_R = torch.eye(4, dtype=torch.float32)[None].repeat((B, 1, 1))
72
+ M_R[..., :3, :3] = R
73
+
74
+ T = - eye
75
+ M_T = torch.eye(4, dtype=torch.float32)[None].repeat((B, 1, 1))
76
+ M_T[..., :3, 3] = T
77
+
78
+ return (M_R @ M_T).to(dtype=torch.float32)
79
+
80
+ def focal_length_to_fov(
81
+ focal_length: float,
82
+ censor_length: float = 24.
83
+ ) -> float:
84
+ return 2 * np.arctan(censor_length / focal_length / 2.)
85
+
86
+ def forward_warper(
87
+ image: Float[Tensor, 'B C H W'],
88
+ screen: Float[Tensor, 'B (H W) 2'],
89
+ pcd: Float[Tensor, 'B (H W) 4'],
90
+ mvp_mtx: Float[Tensor, 'B 4 4'],
91
+ viewport_mtx: Float[Tensor, 'B 4 4'],
92
+ alpha: float = 0.5
93
+ ) -> Dict[str, Tensor]:
94
+ H, W = image.shape[2:4]
95
+
96
+ # Projection.
97
+ points_c = pcd @ mvp_mtx.mT
98
+ points_ndc = points_c / points_c[..., 3:4]
99
+ # To screen.
100
+ coords_new = points_ndc @ viewport_mtx.mT
101
+
102
+ # Masking invalid pixels.
103
+ invalid = coords_new[..., 2] <= 0
104
+ coords_new[invalid] = -1000000 if coords_new.dtype == torch.float32 else -1e+4
105
+
106
+ # Calculate flow and importance for splatting.
107
+ new_z = points_c[..., 2:3]
108
+ flow = coords_new[..., :2] - screen[..., :2]
109
+ ## Importance.
110
+ importance = alpha / new_z
111
+ importance -= importance.amin((1, 2), keepdim=True)
112
+ importance /= importance.amax((1, 2), keepdim=True) + 1e-6
113
+ importance = importance * 10 - 10
114
+ ## Rearrange.
115
+ importance = rearrange(importance, 'b (h w) c -> b c h w', h=H, w=W)
116
+ flow = rearrange(flow, 'b (h w) c -> b c h w', h=H, w=W)
117
+
118
+ # Splatting.
119
+ warped = splatting_function('softmax', image, flow, importance, eps=1e-6)
120
+ ## mask is 1 where there is no splat
121
+ mask = (warped == 0.).all(dim=1, keepdim=True).to(image.dtype)
122
+ flow2 = rearrange(coords_new[..., :2], 'b (h w) c -> b c h w', h=H, w=W)
123
+
124
+ output = dict(
125
+ warped=warped,
126
+ mask=mask,
127
+ correspondence=flow2
128
+ )
129
+
130
+ return output
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ diffusers
4
+ accelerate
5
+ transformers
6
+ scipy
7
+ opencv-python
8
+ omegaconf
9
+ einops
10
+ roma
11
+ jaxtyping
12
+ timm==0.6.7
13
+ matplotlib==3.6.2
14
+ gradio_model3dgscamera