Yiyuan commited on
Commit
96a9519
1 Parent(s): 3d95114

Upload 98 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +11 -0
  2. LICENSE +201 -0
  3. README.md +5 -5
  4. demo/configs/i2v_config.yaml +57 -0
  5. demo/draw_utils.py +124 -0
  6. demo/main_gradio.py +942 -0
  7. models/animatediff/models/__init__.py +0 -0
  8. models/animatediff/models/attention.py +559 -0
  9. models/animatediff/models/motion_module.py +572 -0
  10. models/animatediff/models/resnet.py +197 -0
  11. models/animatediff/models/unet.py +572 -0
  12. models/animatediff/models/unet_blocks.py +733 -0
  13. models/animatediff/pipelines/__init__.py +3 -0
  14. models/animatediff/pipelines/i2v_pipeline.py +729 -0
  15. models/animatediff/utils/convert_from_ckpt.py +964 -0
  16. models/animatediff/utils/convert_lora_safetensor_to_diffusers.py +208 -0
  17. models/animatediff/utils/util.py +334 -0
  18. models/draggan/dnnlib/__init__.py +9 -0
  19. models/draggan/dnnlib/util.py +491 -0
  20. models/draggan/gan_inv/__init__.py +9 -0
  21. models/draggan/gan_inv/inversion.py +277 -0
  22. models/draggan/gan_inv/lpips/__init__.py +5 -0
  23. models/draggan/gan_inv/lpips/base_model.py +58 -0
  24. models/draggan/gan_inv/lpips/dist_model.py +314 -0
  25. models/draggan/gan_inv/lpips/networks_basic.py +188 -0
  26. models/draggan/gan_inv/lpips/pretrained_networks.py +181 -0
  27. models/draggan/gan_inv/lpips/util.py +160 -0
  28. models/draggan/legacy.py +325 -0
  29. models/draggan/torch_utils/__init__.py +9 -0
  30. models/draggan/torch_utils/custom_ops.py +157 -0
  31. models/draggan/torch_utils/misc.py +266 -0
  32. models/draggan/torch_utils/ops/__init__.py +9 -0
  33. models/draggan/torch_utils/ops/bias_act.cpp +99 -0
  34. models/draggan/torch_utils/ops/bias_act.cu +173 -0
  35. models/draggan/torch_utils/ops/bias_act.h +38 -0
  36. models/draggan/torch_utils/ops/bias_act.py +209 -0
  37. models/draggan/torch_utils/ops/conv2d_gradfix.py +198 -0
  38. models/draggan/torch_utils/ops/conv2d_resample.py +143 -0
  39. models/draggan/torch_utils/ops/filtered_lrelu.cpp +300 -0
  40. models/draggan/torch_utils/ops/filtered_lrelu.cu +1284 -0
  41. models/draggan/torch_utils/ops/filtered_lrelu.h +90 -0
  42. models/draggan/torch_utils/ops/filtered_lrelu.py +274 -0
  43. models/draggan/torch_utils/ops/filtered_lrelu_ns.cu +27 -0
  44. models/draggan/torch_utils/ops/filtered_lrelu_rd.cu +27 -0
  45. models/draggan/torch_utils/ops/filtered_lrelu_wr.cu +27 -0
  46. models/draggan/torch_utils/ops/fma.py +60 -0
  47. models/draggan/torch_utils/ops/grid_sample_gradfix.py +77 -0
  48. models/draggan/torch_utils/ops/upfirdn2d.cpp +107 -0
  49. models/draggan/torch_utils/ops/upfirdn2d.cu +384 -0
  50. models/draggan/torch_utils/ops/upfirdn2d.h +59 -0
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/__pycache__
2
+ results/
3
+ results/images/
4
+ results/videos/
5
+ checkpoints/
6
+ examples/ui/saving_test/
7
+ examples/ui/checkpoints/
8
+ checkpoints/
9
+ dustbin/
10
+ ssh.txt
11
+ env.yaml
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
  title: InteractiveVideo
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.16.0
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
1
  ---
2
  title: InteractiveVideo
3
+ emoji: 👀
4
+ colorFrom: purple
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.44.0
8
+ app_file: demo/main_gradio.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
demo/configs/i2v_config.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompts:
2
+ - - lightning, lighthouse
3
+ # - sun rising, lighthouse
4
+ # - fireworks, lighthouse
5
+
6
+ n_prompt:
7
+ - 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'
8
+
9
+ generate:
10
+ model_path: "checkpoints/i2v/unet/pia.ckpt"
11
+ use_image: true
12
+ use_video: false
13
+ sample_width: 512
14
+ sample_height: 512
15
+ video_length: 16
16
+ use_lora: false
17
+ use_db: true
18
+ global_seed: 5658137986800322011
19
+ lora_path: ""
20
+ db_path: "checkpoints/i2v/dreambooth/rcnzCartoon3d_v10.safetensors"
21
+ lora_alpha: 0.8
22
+
23
+ validation_data:
24
+ # mask_sim_range: [0, 1]
25
+ mask_sim_range: [0]
26
+ cond_frame: 0
27
+ num_inference_steps: 25
28
+
29
+ img_mask: ''
30
+ input_name: 'lighthouse'
31
+ validation_input_path: 'img'
32
+ save_path: 'result'
33
+
34
+ noise_scheduler_kwargs:
35
+ num_train_timesteps: 1000
36
+ beta_start: 0.00085
37
+ beta_end: 0.012
38
+ beta_schedule: "linear"
39
+ steps_offset: 1
40
+ clip_sample: false
41
+
42
+ pretrained_model_path: "checkpoints/diffusion_body/stable-diffusion-v1-5"
43
+ unet_additional_kwargs:
44
+ use_motion_module : true
45
+ motion_module_resolutions : [ 1,2,4,8 ]
46
+ unet_use_cross_frame_attention : false
47
+ unet_use_temporal_attention : false
48
+
49
+ motion_module_type: Vanilla
50
+ motion_module_kwargs:
51
+ num_attention_heads : 8
52
+ num_transformer_block : 1
53
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
54
+ temporal_position_encoding : true
55
+ temporal_position_encoding_max_len : 32
56
+ temporal_attention_dim_div : 1
57
+ zero_initialize : true
demo/draw_utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from PIL import ImageDraw
3
+ import numpy as np
4
+
5
+
6
+ def draw_points_on_image(image,
7
+ points,
8
+ curr_point=None,
9
+ highlight_all=True,
10
+ radius_scale=0.01):
11
+ overlay_rgba = Image.new("RGBA", image.size, 0)
12
+ overlay_draw = ImageDraw.Draw(overlay_rgba)
13
+ for point_key, point in points.items():
14
+ if ((curr_point is not None and curr_point == point_key)
15
+ or highlight_all):
16
+ p_color = (255, 0, 0)
17
+ t_color = (0, 0, 255)
18
+
19
+ else:
20
+ p_color = (255, 0, 0, 35)
21
+ t_color = (0, 0, 255, 35)
22
+
23
+ rad_draw = int(image.size[0] * radius_scale)
24
+
25
+ p_start = point.get("start_temp", point["start"])
26
+ p_target = point["target"]
27
+
28
+ if p_start is not None and p_target is not None:
29
+ p_draw = int(p_start[0]), int(p_start[1])
30
+ t_draw = int(p_target[0]), int(p_target[1])
31
+
32
+ overlay_draw.line(
33
+ (p_draw[0], p_draw[1], t_draw[0], t_draw[1]),
34
+ fill=(255, 255, 0),
35
+ width=2,
36
+ )
37
+
38
+ if p_start is not None:
39
+ p_draw = int(p_start[0]), int(p_start[1])
40
+ overlay_draw.ellipse(
41
+ (
42
+ p_draw[0] - rad_draw,
43
+ p_draw[1] - rad_draw,
44
+ p_draw[0] + rad_draw,
45
+ p_draw[1] + rad_draw,
46
+ ),
47
+ fill=p_color,
48
+ )
49
+
50
+ if curr_point is not None and curr_point == point_key:
51
+ # overlay_draw.text(p_draw, "p", font=font, align="center", fill=(0, 0, 0))
52
+ overlay_draw.text(p_draw, "p", align="center", fill=(0, 0, 0))
53
+
54
+ if p_target is not None:
55
+ t_draw = int(p_target[0]), int(p_target[1])
56
+ overlay_draw.ellipse(
57
+ (
58
+ t_draw[0] - rad_draw,
59
+ t_draw[1] - rad_draw,
60
+ t_draw[0] + rad_draw,
61
+ t_draw[1] + rad_draw,
62
+ ),
63
+ fill=t_color,
64
+ )
65
+
66
+ if curr_point is not None and curr_point == point_key:
67
+ # overlay_draw.text(t_draw, "t", font=font, align="center", fill=(0, 0, 0))
68
+ overlay_draw.text(t_draw, "t", align="center", fill=(0, 0, 0))
69
+
70
+ return Image.alpha_composite(image.convert("RGBA"),
71
+ overlay_rgba).convert("RGB")
72
+
73
+
74
+ def draw_mask_on_image(image, mask):
75
+ if mask is None:
76
+ mask = np.ones((image.height, image.width), dtype=np.uint8)
77
+
78
+ im_mask = np.uint8(mask * 255)
79
+ im_mask_rgba = np.concatenate(
80
+ (
81
+ np.tile(im_mask[..., None], [1, 1, 3]),
82
+ 45 * np.ones(
83
+ (im_mask.shape[0], im_mask.shape[1], 1), dtype=np.uint8),
84
+ ),
85
+ axis=-1,
86
+ )
87
+ im_mask_rgba = Image.fromarray(im_mask_rgba).convert("RGBA")
88
+
89
+ return Image.alpha_composite(image.convert("RGBA"),
90
+ im_mask_rgba).convert("RGB")
91
+
92
+
93
+ def draw_circle_on_mask(mask, x, y, radius, mode='add', inv=False):
94
+ H, W = mask.shape
95
+ J = np.arange(W, dtype=np.int32)
96
+ I = np.arange(H, dtype=np.int32)
97
+ I, J = np.meshgrid(I, J, indexing='ij')
98
+ dis = (I - y)**2 + (J - x)**2
99
+ if inv:
100
+ new_mask = dis > radius**2
101
+ else:
102
+ new_mask = dis <= radius**2
103
+ if mode == 'add':
104
+ return (mask + new_mask).clip(0, 1)
105
+ elif mode == 'mul':
106
+ return mask * new_mask
107
+ return (mask + new_mask).clip(0, 1) # default add mode
108
+
109
+
110
+ def draw_circle_on_image(image, x, y, radius, color=(255, 0, 0)):
111
+ H, W, C = image.shape
112
+ J = np.arange(W, dtype=np.int32)
113
+ I = np.arange(H, dtype=np.int32)
114
+ I, J = np.meshgrid(I, J, indexing='ij')
115
+ dis = (I - y)**2 + (J - x)**2
116
+ mask = dis <= radius**2
117
+ i_color = np.array(color, dtype=np.int32)
118
+ i_color = np.expand_dims(i_color, axis=[0, 1])
119
+ i_mask = mask.astype(np.int32)
120
+ i_mask = np.expand_dims(i_mask, axis=[2])
121
+ i_image = image.astype(np.int32)
122
+ i_image = image + i_mask * i_color
123
+ i_image = np.clip(i_image, 0, 255)
124
+ return i_image.astype(np.uint8)
demo/main_gradio.py ADDED
@@ -0,0 +1,942 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time, os, sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
5
+
6
+ os.system('python scripts/download_models.py')
7
+
8
+ import gradio as gr
9
+ from PIL import Image
10
+ import numpy as np
11
+ import torch
12
+ from typing import List, Literal, Dict, Optional
13
+ from draw_utils import draw_points_on_image, draw_mask_on_image
14
+ import cv2
15
+
16
+
17
+ from models.streamdiffusion.wrapper import StreamDiffusionWrapper
18
+
19
+ from models.animatediff.pipelines import I2VPipeline
20
+ from omegaconf import OmegaConf
21
+
22
+ from models.draggan.viz.renderer import Renderer
23
+ from models.draggan.gan_inv.lpips.util import PerceptualLoss
24
+ import models.draggan.dnnlib as dnnlib
25
+ from models.draggan.gan_inv.inversion import PTI
26
+
27
+ import imageio
28
+ import torchvision
29
+ from einops import rearrange
30
+
31
+ # =========================== Model Implementation Start ===================================
32
+
33
+ def save_videos_grid_255(videos: torch.Tensor, path: str, n_rows=6, fps=8):
34
+ videos = rearrange(videos, "b c t h w -> t b c h w")
35
+ outputs = []
36
+ for x in videos:
37
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
38
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
39
+ x = x.numpy().astype(np.uint8)
40
+ outputs.append(x)
41
+
42
+ os.makedirs(os.path.dirname(path), exist_ok=True)
43
+ imageio.mimsave(path, outputs, fps=fps)
44
+
45
+ def reverse_point_pairs(points):
46
+ new_points = []
47
+ for p in points:
48
+ new_points.append([p[1], p[0]])
49
+ return new_points
50
+
51
+ def render_view_image(img, drag_markers, show_mask=False):
52
+ img = draw_points_on_image(img, drag_markers['points'])
53
+ if show_mask:
54
+ img = draw_mask_on_image(img, drag_markers['mask'])
55
+ img = np.array(img).astype(np.uint8)
56
+ img = np.concatenate([
57
+ img,
58
+ 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=img.dtype)
59
+ ], axis=2)
60
+ return Image.fromarray(img)
61
+
62
+
63
+ def update_state_image(state):
64
+ state['generated_image_show'] = render_view_image(
65
+ state['generated_image'],
66
+ state['drag_markers'][0],
67
+ state['is_show_mask'],
68
+ )
69
+ return state['generated_image_show']
70
+
71
+
72
+ class GeneratePipeline:
73
+ def __init__(
74
+ self,
75
+ i2i_body_ckpt: str = "checkpoints/diffusion_body/kohaku-v2.1",
76
+ # i2i_body_ckpt: str = "checkpoints/diffusion_body/stable-diffusion-v1-5",
77
+ i2i_lora_dict: Optional[Dict[str, float]] = {'checkpoints/i2i/lora/lcm-lora-sdv1-5.safetensors': 1.0},
78
+ prompt: str = "",
79
+ negative_prompt: str = "low quality, bad quality, blurry, low resolution",
80
+ frame_buffer_size: int = 1,
81
+ width: int = 512,
82
+ height: int = 512,
83
+ acceleration: Literal["none", "xformers", "tensorrt"] = "xformers",
84
+ use_denoising_batch: bool = True,
85
+ seed: int = 2,
86
+ cfg_type: Literal["none", "full", "self", "initialize"] = "self",
87
+ guidance_scale: float = 1.4,
88
+ delta: float = 0.5,
89
+ do_add_noise: bool = False,
90
+ enable_similar_image_filter: bool = True,
91
+ similar_image_filter_threshold: float = 0.99,
92
+ similar_image_filter_max_skip_frame: float = 10,
93
+ ):
94
+ super(GeneratePipeline, self).__init__()
95
+ if not torch.cuda.is_available():
96
+ acceleration = None
97
+
98
+ self.img2img_model = None
99
+ self.img2video_model = None
100
+ self.img2video_generator = None
101
+ self.sim_ranges = None
102
+
103
+ # set parameters
104
+ self.i2i_body_ckpt = i2i_body_ckpt
105
+ self.i2i_lora_dict = i2i_lora_dict
106
+ self.prompt = prompt
107
+ self.negative_prompt = negative_prompt
108
+ self.frame_buffer_size = frame_buffer_size
109
+ self.width = width
110
+ self.height = height
111
+ self.acceleration = acceleration
112
+ self.use_denoising_batch = use_denoising_batch
113
+ self.seed = seed
114
+ self.cfg_type = cfg_type
115
+ self.guidance_scale = guidance_scale
116
+ self.delta = delta
117
+ self.do_add_noise = do_add_noise
118
+ self.enable_similar_image_filter = enable_similar_image_filter
119
+ self.similar_image_filter_threshold = similar_image_filter_threshold
120
+ self.similar_image_filter_max_skip_frame = similar_image_filter_max_skip_frame
121
+
122
+ self.i2v_config = OmegaConf.load('demo/configs/i2v_config.yaml')
123
+ self.i2v_body_ckpt = self.i2v_config.pretrained_model_path
124
+ self.i2v_unet_path = self.i2v_config.generate.model_path
125
+ self.i2v_dreambooth_ckpt = self.i2v_config.generate.db_path
126
+
127
+ self.lora_alpha = 0
128
+
129
+ assert self.frame_buffer_size == 1
130
+
131
+ def init_model(self):
132
+ # StreamDiffusion
133
+ self.img2img_model = StreamDiffusionWrapper(
134
+ model_id_or_path=self.i2i_body_ckpt,
135
+ lora_dict=self.i2i_lora_dict,
136
+ t_index_list=[32, 45],
137
+ frame_buffer_size=self.frame_buffer_size,
138
+ width=self.width,
139
+ height=self.height,
140
+ warmup=10,
141
+ acceleration=self.acceleration,
142
+ do_add_noise=self.do_add_noise,
143
+ enable_similar_image_filter=self.enable_similar_image_filter,
144
+ similar_image_filter_threshold=self.similar_image_filter_threshold,
145
+ similar_image_filter_max_skip_frame=self.similar_image_filter_max_skip_frame,
146
+ mode="img2img",
147
+ use_denoising_batch=self.use_denoising_batch,
148
+ cfg_type=self.cfg_type,
149
+ seed=self.seed,
150
+ use_lcm_lora=False,
151
+ )
152
+ self.img2img_model.prepare(
153
+ prompt=self.prompt,
154
+ negative_prompt=self.negative_prompt,
155
+ num_inference_steps=50,
156
+ guidance_scale=self.guidance_scale,
157
+ delta=self.delta,
158
+ )
159
+
160
+ # PIA
161
+ self.img2video_model = I2VPipeline.build_pipeline(
162
+ self.i2v_config,
163
+ self.i2v_body_ckpt,
164
+ self.i2v_unet_path,
165
+ self.i2v_dreambooth_ckpt,
166
+ None, # lora path
167
+ self.lora_alpha,
168
+ )
169
+ if torch.cuda.is_available():
170
+ device = 'cuda'
171
+ else:
172
+ device = 'cpu'
173
+ self.img2video_generator = torch.Generator(device=device)
174
+ self.img2video_generator.manual_seed(self.i2v_config.generate.global_seed)
175
+ self.sim_ranges = self.i2v_config.validation_data.mask_sim_range
176
+
177
+ # Drag GAN
178
+ self.drag_model = Renderer(disable_timing=True)
179
+
180
+ def generate_image(self, image, text, start_time=None):
181
+ if text is not None:
182
+ pos_prompt, neg_prompt = text
183
+ self.img2img_model.prepare(
184
+ prompt=pos_prompt,
185
+ negative_prompt=neg_prompt,
186
+ num_inference_steps=50,
187
+ guidance_scale=self.guidance_scale,
188
+ delta=self.delta,
189
+ )
190
+ sampled_inputs = [image]
191
+ input_batch = torch.cat(sampled_inputs)
192
+ output_images = self.img2img_model.stream(
193
+ input_batch.to(device=self.img2img_model.device, dtype=self.img2img_model.dtype)
194
+ )
195
+ # if start_time is not None:
196
+ # print('Generate Done: {}'.format(time.perf_counter() - start_time))
197
+ output_images = output_images.cpu()
198
+ # if start_time is not None:
199
+ # print('Move Done: {}'.format(time.perf_counter() - start_time))
200
+ return output_images
201
+
202
+ def generate_video(self, image, text, height=None, width=None):
203
+ pos_prompt, neg_prompt = text
204
+ sim_range = self.sim_ranges[0]
205
+ print(f"using sim_range : {sim_range}")
206
+ self.i2v_config.validation_data.mask_sim_range = sim_range
207
+ sample = self.img2video_model(
208
+ image = image,
209
+ prompt = pos_prompt,
210
+ generator = self.img2video_generator,
211
+ video_length = self.i2v_config.generate.video_length,
212
+ height = height if height is not None else self.i2v_config.generate.sample_height,
213
+ width = width if width is not None else self.i2v_config.generate.sample_width,
214
+ negative_prompt = neg_prompt,
215
+ mask_sim_template_idx = self.i2v_config.validation_data.mask_sim_range,
216
+ **self.i2v_config.validation_data,
217
+ ).videos
218
+ return sample
219
+
220
+ def prepare_drag_model(
221
+ self,
222
+ custom_image: Image,
223
+ latent_space = 'w+',
224
+ trunc_psi = 0.7,
225
+ trunc_cutoff = None,
226
+ seed = 0,
227
+ lr = 0.001,
228
+ generator_params = dnnlib.EasyDict(),
229
+ pretrained_weight = 'stylegan2_lions_512_pytorch',
230
+ ):
231
+ self.drag_model.init_network(
232
+ generator_params, # res
233
+ pretrained_weight, # pkl
234
+ seed, # w0_seed,
235
+ None, # w_load
236
+ latent_space == 'w+', # w_plus
237
+ 'const',
238
+ trunc_psi, # trunc_psi,
239
+ trunc_cutoff, # trunc_cutoff,
240
+ None, # input_transform
241
+ lr # lr,
242
+ )
243
+
244
+ if torch.cuda.is_available():
245
+ percept = PerceptualLoss(model="net-lin", net="vgg", use_gpu=True)
246
+ else:
247
+ percept = PerceptualLoss(model="net-lin", net="vgg", use_gpu=False)
248
+
249
+ pti = PTI(self.drag_model.G, percept, max_pti_step=400)
250
+ inversed_img, w_pivot = pti.train(custom_image, latent_space == 'w+')
251
+ inversed_img = (inversed_img[0] * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0)
252
+ inversed_img = inversed_img.cpu().numpy()
253
+ inversed_img = Image.fromarray(inversed_img)
254
+ mask = np.ones((inversed_img.height, inversed_img.width),
255
+ dtype=np.uint8)
256
+ generator_params.image = inversed_img
257
+ generator_params.w = w_pivot.detach().cpu().numpy()
258
+ self.drag_model.set_latent(w_pivot, trunc_psi, trunc_cutoff)
259
+
260
+ del percept
261
+ del pti
262
+ print('inverse end')
263
+
264
+ return generator_params, mask
265
+
266
+ def drag_image(
267
+ self,
268
+ points,
269
+ mask,
270
+ motion_lambda = 20,
271
+ r1_in_pixels = 3,
272
+ r2_in_pixels = 12,
273
+ trunc_psi = 0.7,
274
+ draw_interval = 1,
275
+ generator_params = dnnlib.EasyDict(),
276
+ ):
277
+ p_in_pixels = []
278
+ t_in_pixels = []
279
+ valid_points = []
280
+ # Transform the points into torch tensors
281
+ for key_point, point in points.items():
282
+ try:
283
+ p_start = point.get("start_temp", point["start"])
284
+ p_end = point["target"]
285
+
286
+ if p_start is None or p_end is None:
287
+ continue
288
+
289
+ except KeyError:
290
+ continue
291
+
292
+ p_in_pixels.append(p_start)
293
+ t_in_pixels.append(p_end)
294
+ valid_points.append(key_point)
295
+
296
+ mask = torch.tensor(mask).float()
297
+ drag_mask = 1 - mask
298
+
299
+ # reverse points order
300
+ p_to_opt = reverse_point_pairs(p_in_pixels)
301
+ t_to_opt = reverse_point_pairs(t_in_pixels)
302
+ step_idx = 0
303
+
304
+ self.drag_model._render_drag_impl(
305
+ generator_params,
306
+ p_to_opt, # point
307
+ t_to_opt, # target
308
+ drag_mask, # mask,
309
+ motion_lambda, # lambda_mask
310
+ reg = 0,
311
+ feature_idx = 5, # NOTE: do not support change for now
312
+ r1 = r1_in_pixels, # r1
313
+ r2 = r2_in_pixels, # r2
314
+ # random_seed = 0,
315
+ # noise_mode = 'const',
316
+ trunc_psi = trunc_psi,
317
+ # force_fp32 = False,
318
+ # layer_name = None,
319
+ # sel_channels = 3,
320
+ # base_channel = 0,
321
+ # img_scale_db = 0,
322
+ # img_normalize = False,
323
+ # untransform = False,
324
+ is_drag=True,
325
+ to_pil=True
326
+ )
327
+
328
+
329
+ points_upd = points
330
+ if step_idx % draw_interval == 0:
331
+ for key_point, p_i, t_i in zip(valid_points, p_to_opt,
332
+ t_to_opt):
333
+ points_upd[key_point]["start_temp"] = [
334
+ p_i[1],
335
+ p_i[0],
336
+ ]
337
+ points_upd[key_point]["target"] = [
338
+ t_i[1],
339
+ t_i[0],
340
+ ]
341
+ start_temp = points_upd[key_point][
342
+ "start_temp"]
343
+
344
+ image_result = generator_params['image']
345
+
346
+ return image_result
347
+
348
+ # ============================= Model Implementation ENd ===================================
349
+
350
+
351
+ parser = argparse.ArgumentParser()
352
+ parser.add_argument('--share', action='store_true',default='True')
353
+ parser.add_argument('--cache-dir', type=str, default='./checkpoints')
354
+ parser.add_argument(
355
+ "--listen",
356
+ action="store_true",
357
+ help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests",
358
+ )
359
+ args = parser.parse_args()
360
+
361
+
362
+ class CustomImageMask(gr.Image):
363
+ is_template = True
364
+ def __init__(
365
+ self,
366
+ source='upload',
367
+ tool='sketch',
368
+ elem_id="image_upload",
369
+ label='Generated Image',
370
+ type="pil",
371
+ mask_opacity=0.5,
372
+ brush_color='#FFFFFF',
373
+ height=400,
374
+ interactive=True,
375
+ **kwargs
376
+ ):
377
+ super(CustomImageMask, self).__init__(
378
+ source=source,
379
+ tool=tool,
380
+ elem_id=elem_id,
381
+ label=label,
382
+ type=type,
383
+ mask_opacity=mask_opacity,
384
+ brush_color=brush_color,
385
+ height=height,
386
+ interactive=interactive,
387
+ **kwargs
388
+ )
389
+
390
+ def preprocess(self, x):
391
+ if x is None:
392
+ return x
393
+ if self.tool == 'sketch' and self.source in ['upload', 'webcam'] and type(x) != dict:
394
+ decode_image = gr.processing_utils.decode_base64_to_image(x)
395
+ width, height = decode_image.size
396
+ mask = np.ones((height, width, 4), dtype=np.uint8)
397
+ mask[..., -1] = 255
398
+ mask = self.postprocess(mask)
399
+ x = {'image': x, 'mask': mask}
400
+ return super().preprocess(x)
401
+
402
+
403
+ draggan_ckpts = os.listdir('checkpoints/drag')
404
+ draggan_ckpts.sort()
405
+
406
+
407
+ generate_pipeline = GeneratePipeline()
408
+ generate_pipeline.init_model()
409
+
410
+
411
+ with gr.Blocks() as demo:
412
+ global_state = gr.State(
413
+ {
414
+ 'is_image_generation': True,
415
+ 'is_image_text_prompt_up-to-date': True,
416
+ 'is_show_mask': False,
417
+ 'is_dragging': False,
418
+ 'generated_image': None,
419
+ 'generated_image_show': None,
420
+ 'drag_markers': [
421
+ {
422
+ 'points': {},
423
+ 'mask': None
424
+ }
425
+ ],
426
+ 'generator_params': dnnlib.EasyDict(),
427
+ 'default_image_text_prompts': ('', 'low quality, bad quality, blurry, low resolution'),
428
+ 'default_video_text_prompts': ('', 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'),
429
+ 'image_text_prompts': ('', 'low quality, bad quality, blurry, low resolution'),
430
+ 'video_text_prompts': ('', 'wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted, disfigured, bad eyes, wrong lips,weird mouth, bad teeth, mutated hands and fingers, bad anatomy,wrong anatomy, amputation, extra limb, missing limb, floating,limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg'),
431
+ 'params': {
432
+ 'seed': 0,
433
+ 'motion_lambda': 20,
434
+ 'r1_in_pixels': 3,
435
+ 'r2_in_pixels': 12,
436
+ 'magnitude_direction_in_pixels': 1.0,
437
+ 'latent_space': 'w+',
438
+ 'trunc_psi': 0.7,
439
+ 'trunc_cutoff': None,
440
+ 'lr': 0.001,
441
+ },
442
+ 'device': None, # device,
443
+ 'draw_interval': 1,
444
+ 'points': {},
445
+ 'curr_point': None,
446
+ 'curr_type_point': 'start',
447
+ 'editing_state': 'add_points',
448
+ 'pretrained_weight': draggan_ckpts[0],
449
+ 'video_preview_resolution': '512 x 512',
450
+ 'viewer_height': 300,
451
+ 'viewer_width': 300
452
+ }
453
+ )
454
+
455
+ with gr.Column():
456
+ with gr.Row():
457
+ with gr.Column(scale=8, min_width=10):
458
+ with gr.Tab('Image Text Prompts'):
459
+ image_pos_text_prompt_editor = gr.Textbox(placeholder='Positive Prompts', label='Positive', min_width=10)
460
+ image_neg_text_prompt_editor = gr.Textbox(placeholder='Negative Prompts', label='Negative', min_width=10)
461
+ with gr.Tab('Video Text Prompts'):
462
+ video_pos_text_prompt_editor = gr.Textbox(placeholder='Positive Prompts', label='Positive', min_width=10)
463
+ video_neg_text_prompt_editor = gr.Textbox(placeholder='Negative Prompts', label='Negative', min_width=10)
464
+ with gr.Tab('Drag Image'):
465
+ with gr.Row():
466
+ with gr.Column(scale=1, min_width=10):
467
+ drag_mode_on_button = gr.Button('Drag Mode On', size='sm', min_width=10)
468
+ drag_mode_off_button = gr.Button('Drag Mode Off', size='sm', min_width=10)
469
+ drag_checkpoint_dropdown = gr.Dropdown(choices=draggan_ckpts, value=draggan_ckpts[0], label='checkpoint', min_width=10)
470
+ with gr.Column(scale=1, min_width=10):
471
+ with gr.Row():
472
+ drag_start_button = gr.Button('start', size='sm', min_width=10)
473
+ drag_stop_button = gr.Button('stop', size='sm', min_width=10)
474
+ with gr.Row():
475
+ add_point_button = gr.Button('add point', size='sm', min_width=10)
476
+ reset_point_button = gr.Button('reset point', size='sm', min_width=10)
477
+ with gr.Row():
478
+ steps_number = gr.Number(0, label='steps', interactive=False)
479
+ with gr.Column(scale=1, min_width=10):
480
+ with gr.Row():
481
+ draw_mask_button = gr.Button('draw mask', size='sm', min_width=10)
482
+ reset_mask_button = gr.Button('reset mask', size='sm', min_width=10)
483
+ with gr.Row():
484
+ show_mask_checkbox = gr.Checkbox(value=False, label='show mask', min_width=10, interactive=True)
485
+ with gr.Row():
486
+ motion_lambda_number = gr.Number(20, label='Motion Lambda', minimum=1, maximum=100, step=1, interactive=True)
487
+ with gr.Tab('More'):
488
+ with gr.Row():
489
+ with gr.Column(scale=2, min_width=10):
490
+ video_preview_resolution_dropdown = gr.Dropdown(choices=['256 x 256', '512 x 512'], value='512 x 512', label='Video Preview Resolution', min_width=10)
491
+ sample_image_dropdown = gr.Dropdown(choices=['samples/canvas.jpg'] + ['samples/sample{:>02d}.jpg'.format(i) for i in range(1, 8)], value=None, label='Choose A Sample Image', min_width=10)
492
+ with gr.Column(scale=1, min_width=10):
493
+ confirm_text_button = gr.Button('Confirm Text', size='sm', min_width=10)
494
+ generate_video_button = gr.Button('Generate Video', size='sm', min_width=10)
495
+ clear_video_button = gr.Button('Clear Video', size='sm', min_width=10)
496
+ with gr.Row():
497
+ captured_image_viewer = gr.Image(source='upload', tool='color-sketch', type='pil', label='Image Drawer', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=True, shape=(global_state.value['viewer_width'], global_state.value['viewer_height'])) #
498
+ generated_image_viewer = CustomImageMask(source='upload', tool='sketch', elem_id="image_upload", label='Generated Image', type="pil", mask_opacity=0.5, brush_color='#FFFFFF', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=True)
499
+ generated_video_viewer = gr.Video(source='upload', label='Generated Video', height=global_state.value['viewer_height'], width=global_state.value['viewer_width'], interactive=False)
500
+
501
+ gr.Markdown(
502
+ """
503
+ ## Quick Start
504
+
505
+ 1. Select one sample image in `More` tab.
506
+ 2. Draw to edit the sample image in the left most image viewer.
507
+ 3. Click `Generate Video` and enjoy it!
508
+
509
+ ## Note
510
+ Due to the limitation of gradio implementation, the image-to-image generation might have a large latency after the model generation is done.
511
+ We command you to enjoy a better experience with our local demo at [github](https://github.com/invictus717/InteractiveVideo).
512
+
513
+ ## Advance Usage
514
+
515
+ 1. **Try different text prompts.** Enter positive or negative prompts for image / video generation, and
516
+ click `Confirm Text` to enable your prompts.
517
+ 2. **Drag images.** Go to `Drag Image` tab, choose a suitable checkpoint and click `Drag Mode On`.
518
+ It might take a minute to prepare. Properly add points and use masks, then click `start` to
519
+ start dragging. Once you think it's ok, click `stop` button.
520
+ 3. **Adjust video resolution** in the `More` tab.
521
+ 4. **Draw from scratch** by choosing `canvas.jpg` in `More` tab and enjoy yourself!
522
+ """
523
+ )
524
+
525
+ # ========================= Main Function Start =============================
526
+ def on_captured_image_viewer_update(state, image):
527
+ if image is None:
528
+ return state, gr.Image.update(None)
529
+ if state['is_image_text_prompt_up-to-date']:
530
+ text_prompts = None
531
+ else:
532
+ text_prompts = state['image_text_prompts']
533
+ state['is_image_text_prompt_up-to-date'] = True
534
+
535
+ # start_time = time.perf_counter()
536
+
537
+ input_image = np.array(image).astype(np.float32)
538
+ input_image = (input_image / 255 - 0.5) * 2
539
+ input_image = torch.tensor(input_image).permute([2, 0, 1])
540
+ noisy_image = torch.randn_like(input_image)
541
+
542
+ # print('preprocess done: {}'.format(time.perf_counter() - start_time))
543
+
544
+ output_image = generate_pipeline.generate_image(
545
+ input_image,
546
+ text_prompts,
547
+ # start_time,
548
+ )[0]
549
+ output_image = generate_pipeline.generate_image(
550
+ noisy_image,
551
+ None,
552
+ # start_time,
553
+ )[0] # TODO: is there more elegant way?
554
+ output_image = output_image.permute([1, 2, 0])
555
+ output_image = (output_image / 2 + 0.5).clamp(0, 1) * 255
556
+
557
+ output_image = output_image.to(torch.uint8).cpu().numpy()
558
+ output_image = Image.fromarray(output_image)
559
+
560
+ # print('postprocess done: {}'.format(time.perf_counter() - start_time))
561
+
562
+ # output_image = image
563
+ state['generated_image'] = output_image
564
+ output_image = update_state_image(state)
565
+
566
+ # print('draw done: {}'.format(time.perf_counter() - start_time))
567
+ return state, gr.Image.update(output_image, interactive=False)
568
+
569
+ captured_image_viewer.change(
570
+ fn=on_captured_image_viewer_update,
571
+ inputs=[global_state, captured_image_viewer],
572
+ outputs=[global_state, generated_image_viewer]
573
+ )
574
+
575
+ def on_generated_image_viewer_edit(state, data_dict):
576
+ mask = data_dict['mask']
577
+ state['drag_markers'][0]['mask'] = np.array(mask)[:, :, 0] // 255
578
+ image = update_state_image(state)
579
+ return state, image
580
+
581
+ generated_image_viewer.edit(
582
+ fn=on_generated_image_viewer_edit,
583
+ inputs=[global_state, generated_image_viewer],
584
+ outputs=[global_state, generated_image_viewer]
585
+ )
586
+
587
+ def on_generate_video_click(state):
588
+ input_image = np.array(state['generated_image'])
589
+ text_prompts = state['video_text_prompts']
590
+ video_preview_resolution = state['video_preview_resolution'].split('x')
591
+ height = int(video_preview_resolution[0].strip(' '))
592
+ width = int(video_preview_resolution[1].strip(' '))
593
+ output_video = generate_pipeline.generate_video(
594
+ input_image,
595
+ text_prompts,
596
+ height = height,
597
+ width = width
598
+ )[0]
599
+ output_video = output_video.clamp(0, 1) * 255
600
+ output_video = output_video.to(torch.uint8)
601
+ # 3 T H W
602
+ print('[video generation done]')
603
+
604
+ fps = 5 # frames per second
605
+ video_size = (height, width)
606
+ fourcc = cv2.VideoWriter.fourcc(*'mp4v')
607
+ if not os.access('results', os.F_OK):
608
+ os.makedirs('results')
609
+ video_writer = cv2.VideoWriter('results/gradio_temp.mp4', fourcc, fps, video_size) # Create VideoWriter object
610
+ for i in range(output_video.shape[1]):
611
+ frame = output_video[:, i, :, :].permute([1, 2, 0]).cpu().numpy()
612
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
613
+ video_writer.write(frame)
614
+
615
+ video_writer.release()
616
+ return state, gr.Video.update('results/gradio_temp.mp4')
617
+
618
+ generate_video_button.click(
619
+ fn=on_generate_video_click,
620
+ inputs=[global_state],
621
+ outputs=[global_state, generated_video_viewer]
622
+ )
623
+
624
+ def on_clear_video_click(state):
625
+ return state, gr.Video.update(None)
626
+
627
+ clear_video_button.click(
628
+ fn=on_clear_video_click,
629
+ inputs=[global_state],
630
+ outputs=[global_state, generated_video_viewer]
631
+ )
632
+
633
+ def on_drag_mode_on_click(state):
634
+ # prepare DragGAN for custom image
635
+ custom_image = state['generated_image']
636
+ current_ckpt_name = state['pretrained_weight']
637
+ generate_pipeline.prepare_drag_model(
638
+ custom_image,
639
+ generator_params = state['generator_params'],
640
+ pretrained_weight = os.path.join('checkpoints/drag/', current_ckpt_name),
641
+ )
642
+ state['generated_image'] = state['generator_params'].image
643
+ view_image = update_state_image(state)
644
+ return state, gr.Image.update(view_image, interactive=True)
645
+
646
+ drag_mode_on_button.click(
647
+ fn=on_drag_mode_on_click,
648
+ inputs=[global_state],
649
+ outputs=[global_state, generated_image_viewer]
650
+ )
651
+
652
+ def on_drag_mode_off_click(state, image):
653
+ return on_captured_image_viewer_update(state, image)
654
+
655
+ drag_mode_off_button.click(
656
+ fn=on_drag_mode_off_click,
657
+ inputs=[global_state, captured_image_viewer],
658
+ outputs=[global_state, generated_image_viewer]
659
+ )
660
+
661
+ def on_drag_start_click(state):
662
+ state['is_dragging'] = True
663
+ points = state['drag_markers'][0]['points']
664
+ if state['drag_markers'][0]['mask'] is None:
665
+ mask = np.ones((state['generator_params'].image.height, state['generator_params'].image.width), dtype=np.uint8)
666
+ else:
667
+ mask = state['drag_markers'][0]['mask']
668
+ cur_step = 0
669
+ while True:
670
+ if not state['is_dragging']:
671
+ break
672
+ generated_image = generate_pipeline.drag_image(
673
+ points,
674
+ mask,
675
+ motion_lambda = state['params']['motion_lambda'],
676
+ generator_params = state['generator_params']
677
+ )
678
+ state['drag_markers'] = [{'points': points, 'mask': mask}]
679
+ state['generated_image'] = generated_image
680
+ cur_step += 1
681
+ view_image = update_state_image(state)
682
+ if cur_step % 50 == 0:
683
+ print('[{} / {}]'.format(cur_step, 'inf'))
684
+ yield (
685
+ state,
686
+ gr.Image.update(view_image, interactive=False), # generated image viewer
687
+ gr.Number.update(cur_step), # step
688
+ )
689
+
690
+ view_image = update_state_image(state)
691
+ return (
692
+ state,
693
+ gr.Image.update(view_image, interactive=True),
694
+ gr.Number.update(cur_step),
695
+ )
696
+
697
+ drag_start_button.click(
698
+ fn=on_drag_start_click,
699
+ inputs=[global_state],
700
+ outputs=[global_state, generated_image_viewer, steps_number]
701
+ )
702
+
703
+ def on_drag_stop_click(state):
704
+ state['is_dragging'] = False
705
+ return state
706
+
707
+ drag_stop_button.click(
708
+ fn=on_drag_stop_click,
709
+ inputs=[global_state],
710
+ outputs=[global_state]
711
+ )
712
+
713
+ # ========================= Main Function End =============================
714
+
715
+ # ====================== Update Text Prompts Start ====================
716
+ def on_image_pos_text_prompt_editor_submit(state, text):
717
+ if len(text) == 0:
718
+ temp = state['image_text_prompts']
719
+ state['image_text_prompts'] = (state['default_image_text_prompts'][0], temp[1])
720
+ else:
721
+ temp = state['image_text_prompts']
722
+ state['image_text_prompts'] = (text, temp[1])
723
+ state['is_image_text_prompt_up-to-date'] = False
724
+ return state
725
+
726
+ image_pos_text_prompt_editor.submit(
727
+ fn=on_image_pos_text_prompt_editor_submit,
728
+ inputs=[global_state, image_pos_text_prompt_editor],
729
+ outputs=None
730
+ )
731
+
732
+ def on_image_neg_text_prompt_editor_submit(state, text):
733
+ if len(text) == 0:
734
+ temp = state['image_text_prompts']
735
+ state['image_text_prompts'] = (temp[0], state['default_image_text_prompts'][1])
736
+ else:
737
+ temp = state['image_text_prompts']
738
+ state['image_text_prompts'] = (temp[0], text)
739
+ state['is_image_text_prompt_up-to-date'] = False
740
+ return state
741
+
742
+ image_neg_text_prompt_editor.submit(
743
+ fn=on_image_neg_text_prompt_editor_submit,
744
+ inputs=[global_state, image_neg_text_prompt_editor],
745
+ outputs=None
746
+ )
747
+
748
+ def on_video_pos_text_prompt_editor_submit(state, text):
749
+ if len(text) == 0:
750
+ temp = state['video_text_prompts']
751
+ state['video_text_prompts'] = (state['default_video_text_prompts'][0], temp[1])
752
+ else:
753
+ temp = state['video_text_prompts']
754
+ state['video_text_prompts'] = (text, temp[1])
755
+ return state
756
+
757
+ video_pos_text_prompt_editor.submit(
758
+ fn=on_video_pos_text_prompt_editor_submit,
759
+ inputs=[global_state, video_pos_text_prompt_editor],
760
+ outputs=None
761
+ )
762
+
763
+ def on_video_neg_text_prompt_editor_submit(state, text):
764
+ if len(text) == 0:
765
+ temp = state['video_text_prompts']
766
+ state['video_text_prompts'] = (temp[0], state['default_video_text_prompts'][1])
767
+ else:
768
+ temp = state['video_text_prompts']
769
+ state['video_text_prompts'] = (temp[0], text)
770
+ return state
771
+
772
+ video_neg_text_prompt_editor.submit(
773
+ fn=on_video_neg_text_prompt_editor_submit,
774
+ inputs=[global_state, video_neg_text_prompt_editor],
775
+ outputs=None
776
+ )
777
+
778
+ def on_confirm_text_click(state, image, img_pos_t, img_neg_t, vid_pos_t, vid_neg_t):
779
+ state = on_image_pos_text_prompt_editor_submit(state, img_pos_t)
780
+ state = on_image_neg_text_prompt_editor_submit(state, img_neg_t)
781
+ state = on_video_pos_text_prompt_editor_submit(state, vid_pos_t)
782
+ state = on_video_neg_text_prompt_editor_submit(state, vid_neg_t)
783
+ return on_captured_image_viewer_update(state, image)
784
+
785
+ confirm_text_button.click(
786
+ fn=on_confirm_text_click,
787
+ inputs=[global_state, captured_image_viewer, image_pos_text_prompt_editor, image_neg_text_prompt_editor,
788
+ video_pos_text_prompt_editor, video_neg_text_prompt_editor],
789
+ outputs=[global_state, generated_image_viewer]
790
+ )
791
+
792
+ # ====================== Update Text Prompts End ====================
793
+
794
+ # ======================= Drag Point Edit Start =========================
795
+
796
+ def on_image_clicked(state, evt: gr.SelectData):
797
+ """
798
+ This function only support click for point selection
799
+ """
800
+ pos_x, pos_y = evt.index
801
+ drag_markers = state['drag_markers']
802
+ key_points = list(drag_markers[0]['points'].keys())
803
+ key_points.sort(reverse=False)
804
+ if len(key_points) == 0: # no point pairs, add a new point pair
805
+ drag_markers[0]['points'][0] = {
806
+ 'start_temp': [pos_x, pos_y],
807
+ 'start': [pos_x, pos_y],
808
+ 'target': None,
809
+ }
810
+ else:
811
+ largest_id = key_points[-1]
812
+ if drag_markers[0]['points'][largest_id]['target'] is None: # target is not set
813
+ drag_markers[0]['points'][largest_id]['target'] = [pos_x, pos_y]
814
+ else: # target is set, add a new point pair
815
+ drag_markers[0]['points'][largest_id + 1] = {
816
+ 'start_temp': [pos_x, pos_y],
817
+ 'start': [pos_x, pos_y],
818
+ 'target': None,
819
+ }
820
+ state['drag_markers'] = drag_markers
821
+ image = update_state_image(state)
822
+ return state, gr.Image.update(image, interactive=False)
823
+
824
+ generated_image_viewer.select(
825
+ fn=on_image_clicked,
826
+ inputs=[global_state],
827
+ outputs=[global_state, generated_image_viewer],
828
+ )
829
+
830
+ def on_add_point_click(state):
831
+ return gr.Image.update(state['generated_image_show'], interactive=False)
832
+
833
+ add_point_button.click(
834
+ fn=on_add_point_click,
835
+ inputs=[global_state],
836
+ outputs=[generated_image_viewer]
837
+ )
838
+
839
+ def on_reset_point_click(state):
840
+ drag_markers = state['drag_markers']
841
+ drag_markers[0]['points'] = {}
842
+ state['drag_markers'] = drag_markers
843
+ image = update_state_image(state)
844
+ return state, gr.Image.update(image)
845
+
846
+ reset_point_button.click(
847
+ fn=on_reset_point_click,
848
+ inputs=[global_state],
849
+ outputs=[global_state, generated_image_viewer]
850
+ )
851
+
852
+ # ======================= Drag Point Edit End =========================
853
+
854
+ # ======================= Drag Mask Edit Start =========================
855
+
856
+ def on_draw_mask_click(state):
857
+ return gr.Image.update(state['generated_image_show'], interactive=True)
858
+
859
+ draw_mask_button.click(
860
+ fn=on_draw_mask_click,
861
+ inputs=[global_state],
862
+ outputs=[generated_image_viewer]
863
+ )
864
+
865
+ def on_reset_mask_click(state):
866
+ drag_markers = state['drag_markers']
867
+ drag_markers[0]['mask'] = np.ones_like(drag_markers[0]['mask'])
868
+ state['drag_markers'] = drag_markers
869
+ image = update_state_image(state)
870
+ return state, gr.Image.update(image)
871
+
872
+ reset_mask_button.click(
873
+ fn=on_reset_mask_click,
874
+ inputs=[global_state],
875
+ outputs=[global_state, generated_image_viewer]
876
+ )
877
+
878
+ def on_show_mask_click(state, evt: gr.SelectData):
879
+ state['is_show_mask'] = evt.selected
880
+ image = update_state_image(state)
881
+ return state, image
882
+
883
+ show_mask_checkbox.select(
884
+ fn=on_show_mask_click,
885
+ inputs=[global_state],
886
+ outputs=[global_state, generated_image_viewer]
887
+ )
888
+
889
+ # ======================= Drag Mask Edit End =========================
890
+
891
+ # ======================= Drag Setting Start =========================
892
+
893
+ def on_motion_lambda_change(state, number):
894
+ state['params']['number'] = number
895
+ return state
896
+
897
+ motion_lambda_number.input(
898
+ fn=on_motion_lambda_change,
899
+ inputs=[global_state, motion_lambda_number],
900
+ outputs=[global_state]
901
+ )
902
+
903
+ def on_drag_checkpoint_change(state, checkpoint):
904
+ state['pretrained_weight'] = checkpoint
905
+ print(type(checkpoint), checkpoint)
906
+ return state
907
+
908
+ drag_checkpoint_dropdown.change(
909
+ fn=on_drag_checkpoint_change,
910
+ inputs=[global_state, drag_checkpoint_dropdown],
911
+ outputs=[global_state]
912
+ )
913
+
914
+ # ======================= Drag Setting End =========================
915
+
916
+ # ======================= General Setting Start =========================
917
+
918
+ def on_video_preview_resolution_change(state, resolution):
919
+ state['video_preview_resolution'] = resolution
920
+ return state
921
+
922
+ video_preview_resolution_dropdown.change(
923
+ fn=on_video_preview_resolution_change,
924
+ inputs=[global_state, video_preview_resolution_dropdown],
925
+ outputs=[global_state]
926
+ )
927
+
928
+ def on_sample_image_change(state, image):
929
+ return state, gr.Image.update(image)
930
+
931
+ sample_image_dropdown.change(
932
+ fn=on_sample_image_change,
933
+ inputs=[global_state, sample_image_dropdown],
934
+ outputs=[global_state, captured_image_viewer]
935
+ )
936
+
937
+ # ======================= General Setting End =========================
938
+
939
+
940
+ demo.queue(concurrency_count=3, max_size=20)
941
+ # demo.launch(share=False, server_name="0.0.0.0" if args.listen else "127.0.0.1")
942
+ demo.launch()
models/animatediff/models/__init__.py ADDED
File without changes
models/animatediff/models/attention.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
11
+ from diffusers.models import ModelMixin
12
+ from diffusers.models.attention import Attention
13
+ from diffusers.utils import BaseOutput
14
+ from diffusers.utils.import_utils import is_xformers_available
15
+ from diffusers.models.attention import FeedForward, AdaLayerNorm
16
+
17
+ from einops import rearrange, repeat
18
+ import pdb
19
+
20
+ @dataclass
21
+ class Transformer3DModelOutput(BaseOutput):
22
+ sample: torch.FloatTensor
23
+
24
+
25
+ if is_xformers_available():
26
+ import xformers
27
+ import xformers.ops
28
+ else:
29
+ xformers = None
30
+
31
+
32
+ class Transformer3DModel(ModelMixin, ConfigMixin):
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ num_attention_heads: int = 16,
37
+ attention_head_dim: int = 88,
38
+ in_channels: Optional[int] = None,
39
+ num_layers: int = 1,
40
+ dropout: float = 0.0,
41
+ norm_num_groups: int = 32,
42
+ cross_attention_dim: Optional[int] = None,
43
+ attention_bias: bool = False,
44
+ activation_fn: str = "geglu",
45
+ num_embeds_ada_norm: Optional[int] = None,
46
+ use_linear_projection: bool = False,
47
+ only_cross_attention: bool = False,
48
+ upcast_attention: bool = False,
49
+
50
+ unet_use_cross_frame_attention=None,
51
+ unet_use_temporal_attention=None,
52
+ ):
53
+ super().__init__()
54
+ self.use_linear_projection = use_linear_projection
55
+ self.num_attention_heads = num_attention_heads
56
+ self.attention_head_dim = attention_head_dim
57
+ inner_dim = num_attention_heads * attention_head_dim
58
+
59
+ # Define input layers
60
+ self.in_channels = in_channels
61
+
62
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
63
+ if use_linear_projection:
64
+ self.proj_in = nn.Linear(in_channels, inner_dim)
65
+ else:
66
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ BasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+
83
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
84
+ unet_use_temporal_attention=unet_use_temporal_attention,
85
+ )
86
+ for d in range(num_layers)
87
+ ]
88
+ )
89
+
90
+ # 4. Define output layers
91
+ if use_linear_projection:
92
+ self.proj_out = nn.Linear(in_channels, inner_dim)
93
+ else:
94
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
95
+
96
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
97
+ # Input
98
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
99
+ video_length = hidden_states.shape[2]
100
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
101
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
102
+
103
+ batch, channel, height, weight = hidden_states.shape
104
+ residual = hidden_states
105
+
106
+ hidden_states = self.norm(hidden_states)
107
+ if not self.use_linear_projection:
108
+ hidden_states = self.proj_in(hidden_states)
109
+ inner_dim = hidden_states.shape[1]
110
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
111
+ else:
112
+ inner_dim = hidden_states.shape[1]
113
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
114
+ hidden_states = self.proj_in(hidden_states)
115
+
116
+ # Blocks
117
+ for block in self.transformer_blocks:
118
+ hidden_states = block(
119
+ hidden_states,
120
+ encoder_hidden_states=encoder_hidden_states,
121
+ timestep=timestep,
122
+ video_length=video_length
123
+ )
124
+
125
+ # Output
126
+ if not self.use_linear_projection:
127
+ hidden_states = (
128
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
129
+ )
130
+ hidden_states = self.proj_out(hidden_states)
131
+ else:
132
+ hidden_states = self.proj_out(hidden_states)
133
+ hidden_states = (
134
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
135
+ )
136
+
137
+ output = hidden_states + residual
138
+
139
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
140
+ if not return_dict:
141
+ return (output,)
142
+
143
+ return Transformer3DModelOutput(sample=output)
144
+
145
+
146
+ class BasicTransformerBlock(nn.Module):
147
+ def __init__(
148
+ self,
149
+ dim: int,
150
+ num_attention_heads: int,
151
+ attention_head_dim: int,
152
+ dropout=0.0,
153
+ cross_attention_dim: Optional[int] = None,
154
+ activation_fn: str = "geglu",
155
+ num_embeds_ada_norm: Optional[int] = None,
156
+ attention_bias: bool = False,
157
+ only_cross_attention: bool = False,
158
+ upcast_attention: bool = False,
159
+
160
+ unet_use_cross_frame_attention = None,
161
+ unet_use_temporal_attention = None,
162
+ ):
163
+ super().__init__()
164
+ self.only_cross_attention = only_cross_attention
165
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
166
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
167
+ self.unet_use_temporal_attention = unet_use_temporal_attention
168
+
169
+ # SC-Attn
170
+ assert unet_use_cross_frame_attention is not None
171
+ if unet_use_cross_frame_attention:
172
+ self.attn1 = SparseCausalAttention(
173
+ query_dim=dim,
174
+ heads=num_attention_heads,
175
+ dim_head=attention_head_dim,
176
+ dropout=dropout,
177
+ bias=attention_bias,
178
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
179
+ upcast_attention=upcast_attention,
180
+ )
181
+ else:
182
+ self.attn1 = Attention(
183
+ query_dim=dim,
184
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
185
+ heads=num_attention_heads,
186
+ dim_head=attention_head_dim,
187
+ dropout=dropout,
188
+ bias=attention_bias,
189
+ upcast_attention=upcast_attention,
190
+ )
191
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
192
+
193
+ # Cross-Attn
194
+ if cross_attention_dim is not None:
195
+ self.attn2 = Attention(
196
+ query_dim=dim,
197
+ cross_attention_dim=cross_attention_dim,
198
+ heads=num_attention_heads,
199
+ dim_head=attention_head_dim,
200
+ dropout=dropout,
201
+ bias=attention_bias,
202
+ upcast_attention=upcast_attention,
203
+ )
204
+ else:
205
+ self.attn2 = None
206
+
207
+ if cross_attention_dim is not None:
208
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
209
+ else:
210
+ self.norm2 = None
211
+
212
+ # Feed-forward
213
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
214
+ self.norm3 = nn.LayerNorm(dim)
215
+
216
+ # Temp-Attn
217
+ assert unet_use_temporal_attention is not None
218
+ if unet_use_temporal_attention:
219
+ self.attn_temp = Attention(
220
+ query_dim=dim,
221
+ heads=num_attention_heads,
222
+ dim_head=attention_head_dim,
223
+ dropout=dropout,
224
+ bias=attention_bias,
225
+ upcast_attention=upcast_attention,
226
+ )
227
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
228
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
229
+
230
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
231
+ # SparseCausal-Attention
232
+ norm_hidden_states = (
233
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
234
+ )
235
+
236
+ # if self.only_cross_attention:
237
+ # hidden_states = (
238
+ # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
239
+ # )
240
+ # else:
241
+ # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
242
+
243
+ # pdb.set_trace()
244
+ if self.unet_use_cross_frame_attention:
245
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
246
+ else:
247
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
248
+
249
+ if self.attn2 is not None:
250
+ # Cross-Attention
251
+ norm_hidden_states = (
252
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
253
+ )
254
+ hidden_states = (
255
+ self.attn2(
256
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
257
+ )
258
+ + hidden_states
259
+ )
260
+
261
+ # Feed-forward
262
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
263
+
264
+ # Temporal-Attention
265
+ if self.unet_use_temporal_attention:
266
+ d = hidden_states.shape[1]
267
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
268
+ norm_hidden_states = (
269
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
270
+ )
271
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
272
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
273
+
274
+ return hidden_states
275
+
276
+ class CrossAttention(nn.Module):
277
+ r"""
278
+ A cross attention layer.
279
+
280
+ Parameters:
281
+ query_dim (`int`): The number of channels in the query.
282
+ cross_attention_dim (`int`, *optional*):
283
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
284
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
285
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
286
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
287
+ bias (`bool`, *optional*, defaults to False):
288
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
289
+ """
290
+
291
+ def __init__(
292
+ self,
293
+ query_dim: int,
294
+ cross_attention_dim: Optional[int] = None,
295
+ heads: int = 8,
296
+ dim_head: int = 64,
297
+ dropout: float = 0.0,
298
+ bias=False,
299
+ upcast_attention: bool = False,
300
+ upcast_softmax: bool = False,
301
+ added_kv_proj_dim: Optional[int] = None,
302
+ norm_num_groups: Optional[int] = None,
303
+ ):
304
+ super().__init__()
305
+ inner_dim = dim_head * heads
306
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
307
+ self.upcast_attention = upcast_attention
308
+ self.upcast_softmax = upcast_softmax
309
+
310
+ self.scale = dim_head**-0.5
311
+
312
+ self.heads = heads
313
+ # for slice_size > 0 the attention score computation
314
+ # is split across the batch axis to save memory
315
+ # You can set slice_size with `set_attention_slice`
316
+ self.sliceable_head_dim = heads
317
+ self._slice_size = None
318
+ self._use_memory_efficient_attention_xformers = False
319
+ self.added_kv_proj_dim = added_kv_proj_dim
320
+
321
+ if norm_num_groups is not None:
322
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
323
+ else:
324
+ self.group_norm = None
325
+
326
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
327
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
328
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
329
+
330
+ if self.added_kv_proj_dim is not None:
331
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
332
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
333
+
334
+ self.to_out = nn.ModuleList([])
335
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
336
+ self.to_out.append(nn.Dropout(dropout))
337
+
338
+ def reshape_heads_to_batch_dim(self, tensor):
339
+ batch_size, seq_len, dim = tensor.shape
340
+ head_size = self.heads
341
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
342
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
343
+ return tensor
344
+
345
+ def reshape_batch_dim_to_heads(self, tensor):
346
+ batch_size, seq_len, dim = tensor.shape
347
+ head_size = self.heads
348
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
349
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
350
+ return tensor
351
+
352
+ def set_attention_slice(self, slice_size):
353
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
354
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
355
+
356
+ self._slice_size = slice_size
357
+
358
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
359
+ batch_size, sequence_length, _ = hidden_states.shape
360
+
361
+ encoder_hidden_states = encoder_hidden_states
362
+
363
+ if self.group_norm is not None:
364
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
365
+
366
+ query = self.to_q(hidden_states)
367
+ dim = query.shape[-1]
368
+ query = self.reshape_heads_to_batch_dim(query)
369
+
370
+ if self.added_kv_proj_dim is not None:
371
+ key = self.to_k(hidden_states)
372
+ value = self.to_v(hidden_states)
373
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
374
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
375
+
376
+ key = self.reshape_heads_to_batch_dim(key)
377
+ value = self.reshape_heads_to_batch_dim(value)
378
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
379
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
380
+
381
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
382
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
383
+ else:
384
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
385
+ key = self.to_k(encoder_hidden_states)
386
+ value = self.to_v(encoder_hidden_states)
387
+
388
+ key = self.reshape_heads_to_batch_dim(key)
389
+ value = self.reshape_heads_to_batch_dim(value)
390
+
391
+ if attention_mask is not None:
392
+ if attention_mask.shape[-1] != query.shape[1]:
393
+ target_length = query.shape[1]
394
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
395
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
396
+
397
+ # attention, what we cannot get enough of
398
+ if self._use_memory_efficient_attention_xformers:
399
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
400
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
401
+ hidden_states = hidden_states.to(query.dtype)
402
+ else:
403
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
404
+ hidden_states = self._attention(query, key, value, attention_mask)
405
+ else:
406
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
407
+
408
+ # linear proj
409
+ hidden_states = self.to_out[0](hidden_states)
410
+
411
+ # dropout
412
+ hidden_states = self.to_out[1](hidden_states)
413
+ return hidden_states
414
+
415
+ def _attention(self, query, key, value, attention_mask=None):
416
+ if self.upcast_attention:
417
+ query = query.float()
418
+ key = key.float()
419
+
420
+ attention_scores = torch.baddbmm(
421
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
422
+ query,
423
+ key.transpose(-1, -2),
424
+ beta=0,
425
+ alpha=self.scale,
426
+ )
427
+
428
+ if attention_mask is not None:
429
+ attention_scores = attention_scores + attention_mask
430
+
431
+ if self.upcast_softmax:
432
+ attention_scores = attention_scores.float()
433
+
434
+ attention_probs = attention_scores.softmax(dim=-1)
435
+
436
+ # cast back to the original dtype
437
+ attention_probs = attention_probs.to(value.dtype)
438
+
439
+ # compute attention output
440
+ hidden_states = torch.bmm(attention_probs, value)
441
+
442
+ # reshape hidden_states
443
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
444
+ return hidden_states
445
+
446
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
447
+ batch_size_attention = query.shape[0]
448
+ hidden_states = torch.zeros(
449
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
450
+ )
451
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
452
+ for i in range(hidden_states.shape[0] // slice_size):
453
+ start_idx = i * slice_size
454
+ end_idx = (i + 1) * slice_size
455
+
456
+ query_slice = query[start_idx:end_idx]
457
+ key_slice = key[start_idx:end_idx]
458
+
459
+ if self.upcast_attention:
460
+ query_slice = query_slice.float()
461
+ key_slice = key_slice.float()
462
+
463
+ attn_slice = torch.baddbmm(
464
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
465
+ query_slice,
466
+ key_slice.transpose(-1, -2),
467
+ beta=0,
468
+ alpha=self.scale,
469
+ )
470
+
471
+ if attention_mask is not None:
472
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
473
+
474
+ if self.upcast_softmax:
475
+ attn_slice = attn_slice.float()
476
+
477
+ attn_slice = attn_slice.softmax(dim=-1)
478
+
479
+ # cast back to the original dtype
480
+ attn_slice = attn_slice.to(value.dtype)
481
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
482
+
483
+ hidden_states[start_idx:end_idx] = attn_slice
484
+
485
+ # reshape hidden_states
486
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
487
+ return hidden_states
488
+
489
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
490
+ # TODO attention_mask
491
+ query = query.contiguous()
492
+ key = key.contiguous()
493
+ value = value.contiguous()
494
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
495
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
496
+ return hidden_states
497
+
498
+
499
+
500
+ class SparseCausalAttention(CrossAttention):
501
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
502
+ batch_size, sequence_length, _ = hidden_states.shape
503
+
504
+ encoder_hidden_states = encoder_hidden_states
505
+
506
+ if self.group_norm is not None:
507
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
508
+
509
+ query = self.to_q(hidden_states)
510
+ dim = query.shape[-1]
511
+ query = self.reshape_heads_to_batch_dim(query)
512
+
513
+ if self.added_kv_proj_dim is not None:
514
+ raise NotImplementedError
515
+
516
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
517
+ key = self.to_k(encoder_hidden_states)
518
+ value = self.to_v(encoder_hidden_states)
519
+
520
+ former_frame_index = torch.arange(video_length) - 1
521
+ former_frame_index[0] = 0
522
+
523
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
524
+ #key = torch.cat([key[:, [0] * video_length], key[:, [0] * video_length]], dim=2)
525
+ key = key[:, [0] * video_length]
526
+ key = rearrange(key, "b f d c -> (b f) d c")
527
+
528
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
529
+ #value = torch.cat([value[:, [0] * video_length], value[:, [0] * video_length]], dim=2)
530
+ #value = value[:, former_frame_index]
531
+ value = rearrange(value, "b f d c -> (b f) d c")
532
+
533
+ key = self.reshape_heads_to_batch_dim(key)
534
+ value = self.reshape_heads_to_batch_dim(value)
535
+
536
+ if attention_mask is not None:
537
+ if attention_mask.shape[-1] != query.shape[1]:
538
+ target_length = query.shape[1]
539
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
540
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
541
+
542
+ # attention, what we cannot get enough of
543
+ if self._use_memory_efficient_attention_xformers:
544
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
545
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
546
+ hidden_states = hidden_states.to(query.dtype)
547
+ else:
548
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
549
+ hidden_states = self._attention(query, key, value, attention_mask)
550
+ else:
551
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
552
+
553
+ # linear proj
554
+ hidden_states = self.to_out[0](hidden_states)
555
+
556
+ # dropout
557
+ hidden_states = self.to_out[1](hidden_states)
558
+ return hidden_states
559
+
models/animatediff/models/motion_module.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+ import torchvision
10
+
11
+ from diffusers.utils import BaseOutput
12
+ from diffusers.utils.import_utils import is_xformers_available
13
+ from diffusers.models.attention import FeedForward
14
+
15
+ from einops import rearrange, repeat
16
+ import math
17
+
18
+
19
+ def zero_module(module):
20
+ # Zero out the parameters of a module and return it.
21
+ for p in module.parameters():
22
+ p.detach().zero_()
23
+ return module
24
+
25
+
26
+ @dataclass
27
+ class TemporalTransformer3DModelOutput(BaseOutput):
28
+ sample: torch.FloatTensor
29
+
30
+
31
+ if is_xformers_available():
32
+ import xformers
33
+ import xformers.ops
34
+ else:
35
+ xformers = None
36
+
37
+
38
+ def get_motion_module(
39
+ in_channels,
40
+ motion_module_type: str,
41
+ motion_module_kwargs: dict
42
+ ):
43
+ if motion_module_type == "Vanilla":
44
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
45
+ else:
46
+ raise ValueError
47
+
48
+
49
+ class VanillaTemporalModule(nn.Module):
50
+ def __init__(
51
+ self,
52
+ in_channels,
53
+ num_attention_heads = 8,
54
+ num_transformer_block = 2,
55
+ attention_block_types =( "Temporal_Self", "Temporal_Self" ),
56
+ cross_frame_attention_mode = None,
57
+ temporal_position_encoding = False,
58
+ temporal_position_encoding_max_len = 32,
59
+ temporal_attention_dim_div = 1,
60
+ zero_initialize = True,
61
+ ):
62
+ super().__init__()
63
+
64
+ self.temporal_transformer = TemporalTransformer3DModel(
65
+ in_channels=in_channels,
66
+ num_attention_heads=num_attention_heads,
67
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
68
+ num_layers=num_transformer_block,
69
+ attention_block_types=attention_block_types,
70
+ cross_frame_attention_mode=cross_frame_attention_mode,
71
+ temporal_position_encoding=temporal_position_encoding,
72
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
73
+ )
74
+
75
+ if zero_initialize:
76
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
77
+
78
+ def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
79
+ hidden_states = input_tensor
80
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
81
+
82
+ output = hidden_states
83
+ return output
84
+
85
+
86
+ class TemporalTransformer3DModel(nn.Module):
87
+ def __init__(
88
+ self,
89
+ in_channels,
90
+ num_attention_heads,
91
+ attention_head_dim,
92
+
93
+ num_layers,
94
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
95
+ dropout = 0.0,
96
+ norm_num_groups = 32,
97
+ cross_attention_dim = 1280,
98
+ activation_fn = "geglu",
99
+ attention_bias = False,
100
+ upcast_attention = False,
101
+
102
+ cross_frame_attention_mode = None,
103
+ temporal_position_encoding = False,
104
+ temporal_position_encoding_max_len = 32,
105
+ ):
106
+ super().__init__()
107
+
108
+ inner_dim = num_attention_heads * attention_head_dim
109
+
110
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
111
+ self.proj_in = nn.Linear(in_channels, inner_dim)
112
+
113
+ self.transformer_blocks = nn.ModuleList(
114
+ [
115
+ TemporalTransformerBlock(
116
+ dim=inner_dim,
117
+ num_attention_heads=num_attention_heads,
118
+ attention_head_dim=attention_head_dim,
119
+ attention_block_types=attention_block_types,
120
+ dropout=dropout,
121
+ norm_num_groups=norm_num_groups,
122
+ cross_attention_dim=cross_attention_dim,
123
+ activation_fn=activation_fn,
124
+ attention_bias=attention_bias,
125
+ upcast_attention=upcast_attention,
126
+ cross_frame_attention_mode=cross_frame_attention_mode,
127
+ temporal_position_encoding=temporal_position_encoding,
128
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
129
+ )
130
+ for d in range(num_layers)
131
+ ]
132
+ )
133
+ self.proj_out = nn.Linear(inner_dim, in_channels)
134
+
135
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
136
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
137
+ video_length = hidden_states.shape[2]
138
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
139
+
140
+ batch, channel, height, weight = hidden_states.shape
141
+ residual = hidden_states
142
+
143
+ hidden_states = self.norm(hidden_states)
144
+ inner_dim = hidden_states.shape[1]
145
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
146
+ hidden_states = self.proj_in(hidden_states)
147
+
148
+ # Transformer Blocks
149
+ for block in self.transformer_blocks:
150
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
151
+
152
+ # output
153
+ hidden_states = self.proj_out(hidden_states)
154
+ hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
155
+
156
+ output = hidden_states + residual
157
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
158
+
159
+ return output
160
+
161
+
162
+ class TemporalTransformerBlock(nn.Module):
163
+ def __init__(
164
+ self,
165
+ dim,
166
+ num_attention_heads,
167
+ attention_head_dim,
168
+ attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
169
+ dropout = 0.0,
170
+ norm_num_groups = 32,
171
+ cross_attention_dim = 768,
172
+ activation_fn = "geglu",
173
+ attention_bias = False,
174
+ upcast_attention = False,
175
+ cross_frame_attention_mode = None,
176
+ temporal_position_encoding = False,
177
+ temporal_position_encoding_max_len = 32,
178
+ ):
179
+ super().__init__()
180
+
181
+ attention_blocks = []
182
+ norms = []
183
+
184
+ for block_name in attention_block_types:
185
+ attention_blocks.append(
186
+ VersatileAttention(
187
+ attention_mode=block_name.split("_")[0],
188
+ cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
189
+
190
+ query_dim=dim,
191
+ heads=num_attention_heads,
192
+ dim_head=attention_head_dim,
193
+ dropout=dropout,
194
+ bias=attention_bias,
195
+ upcast_attention=upcast_attention,
196
+
197
+ cross_frame_attention_mode=cross_frame_attention_mode,
198
+ temporal_position_encoding=temporal_position_encoding,
199
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
200
+ )
201
+ )
202
+ norms.append(nn.LayerNorm(dim))
203
+
204
+ self.attention_blocks = nn.ModuleList(attention_blocks)
205
+ self.norms = nn.ModuleList(norms)
206
+
207
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
208
+ self.ff_norm = nn.LayerNorm(dim)
209
+
210
+
211
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
212
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
213
+ norm_hidden_states = norm(hidden_states)
214
+ hidden_states = attention_block(
215
+ norm_hidden_states,
216
+ encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
217
+ video_length=video_length,
218
+ ) + hidden_states
219
+
220
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
221
+
222
+ output = hidden_states
223
+ return output
224
+
225
+
226
+ class PositionalEncoding(nn.Module):
227
+ def __init__(
228
+ self,
229
+ d_model,
230
+ dropout = 0.,
231
+ max_len = 32
232
+ ):
233
+ super().__init__()
234
+ self.dropout = nn.Dropout(p=dropout)
235
+ position = torch.arange(max_len).unsqueeze(1)
236
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
237
+ pe = torch.zeros(1, max_len, d_model)
238
+ pe[0, :, 0::2] = torch.sin(position * div_term)
239
+ pe[0, :, 1::2] = torch.cos(position * div_term)
240
+ self.register_buffer('pe', pe)
241
+
242
+ def forward(self, x):
243
+ x = x + self.pe[:, :x.size(1)]
244
+ return self.dropout(x)
245
+
246
+
247
+
248
+ class CrossAttention(nn.Module):
249
+ r"""
250
+ A cross attention layer.
251
+
252
+ Parameters:
253
+ query_dim (`int`): The number of channels in the query.
254
+ cross_attention_dim (`int`, *optional*):
255
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
256
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
257
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
258
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
259
+ bias (`bool`, *optional*, defaults to False):
260
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ query_dim: int,
266
+ cross_attention_dim: Optional[int] = None,
267
+ heads: int = 8,
268
+ dim_head: int = 64,
269
+ dropout: float = 0.0,
270
+ bias=False,
271
+ upcast_attention: bool = False,
272
+ upcast_softmax: bool = False,
273
+ added_kv_proj_dim: Optional[int] = None,
274
+ norm_num_groups: Optional[int] = None,
275
+ ):
276
+ super().__init__()
277
+ inner_dim = dim_head * heads
278
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
279
+ self.upcast_attention = upcast_attention
280
+ self.upcast_softmax = upcast_softmax
281
+
282
+ self.scale = dim_head**-0.5
283
+
284
+ self.heads = heads
285
+ # for slice_size > 0 the attention score computation
286
+ # is split across the batch axis to save memory
287
+ # You can set slice_size with `set_attention_slice`
288
+ self.sliceable_head_dim = heads
289
+ self._slice_size = None
290
+ self._use_memory_efficient_attention_xformers = False
291
+ self.added_kv_proj_dim = added_kv_proj_dim
292
+
293
+ if norm_num_groups is not None:
294
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
295
+ else:
296
+ self.group_norm = None
297
+
298
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
299
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
300
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
301
+
302
+ if self.added_kv_proj_dim is not None:
303
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
304
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
305
+
306
+ self.to_out = nn.ModuleList([])
307
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
308
+ self.to_out.append(nn.Dropout(dropout))
309
+
310
+ def reshape_heads_to_batch_dim(self, tensor):
311
+ batch_size, seq_len, dim = tensor.shape
312
+ head_size = self.heads
313
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
314
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
315
+ return tensor
316
+
317
+ def reshape_batch_dim_to_heads(self, tensor):
318
+ batch_size, seq_len, dim = tensor.shape
319
+ head_size = self.heads
320
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
321
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
322
+ return tensor
323
+
324
+ def set_attention_slice(self, slice_size):
325
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
326
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
327
+
328
+ self._slice_size = slice_size
329
+
330
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
331
+ batch_size, sequence_length, _ = hidden_states.shape
332
+
333
+ encoder_hidden_states = encoder_hidden_states
334
+
335
+ if self.group_norm is not None:
336
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
337
+
338
+ query = self.to_q(hidden_states)
339
+ dim = query.shape[-1]
340
+ query = self.reshape_heads_to_batch_dim(query)
341
+
342
+ if self.added_kv_proj_dim is not None:
343
+ key = self.to_k(hidden_states)
344
+ value = self.to_v(hidden_states)
345
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
346
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
347
+
348
+ key = self.reshape_heads_to_batch_dim(key)
349
+ value = self.reshape_heads_to_batch_dim(value)
350
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
351
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
352
+
353
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
354
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
355
+ else:
356
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
357
+ key = self.to_k(encoder_hidden_states)
358
+ value = self.to_v(encoder_hidden_states)
359
+
360
+ key = self.reshape_heads_to_batch_dim(key)
361
+ value = self.reshape_heads_to_batch_dim(value)
362
+
363
+ if attention_mask is not None:
364
+ if attention_mask.shape[-1] != query.shape[1]:
365
+ target_length = query.shape[1]
366
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
367
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
368
+
369
+ # attention, what we cannot get enough of
370
+ if self._use_memory_efficient_attention_xformers:
371
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
372
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
373
+ hidden_states = hidden_states.to(query.dtype)
374
+ else:
375
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
376
+ hidden_states = self._attention(query, key, value, attention_mask)
377
+ else:
378
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
379
+
380
+ # linear proj
381
+ hidden_states = self.to_out[0](hidden_states)
382
+
383
+ # dropout
384
+ hidden_states = self.to_out[1](hidden_states)
385
+ return hidden_states
386
+
387
+ def _attention(self, query, key, value, attention_mask=None):
388
+ if self.upcast_attention:
389
+ query = query.float()
390
+ key = key.float()
391
+
392
+ attention_scores = torch.baddbmm(
393
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
394
+ query,
395
+ key.transpose(-1, -2),
396
+ beta=0,
397
+ alpha=self.scale,
398
+ )
399
+
400
+ if attention_mask is not None:
401
+ attention_scores = attention_scores + attention_mask
402
+
403
+ if self.upcast_softmax:
404
+ attention_scores = attention_scores.float()
405
+
406
+ attention_probs = attention_scores.softmax(dim=-1)
407
+
408
+ # cast back to the original dtype
409
+ attention_probs = attention_probs.to(value.dtype)
410
+
411
+ # compute attention output
412
+ hidden_states = torch.bmm(attention_probs, value)
413
+
414
+ # reshape hidden_states
415
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
416
+ return hidden_states
417
+
418
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
419
+ batch_size_attention = query.shape[0]
420
+ hidden_states = torch.zeros(
421
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
422
+ )
423
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
424
+ for i in range(hidden_states.shape[0] // slice_size):
425
+ start_idx = i * slice_size
426
+ end_idx = (i + 1) * slice_size
427
+
428
+ query_slice = query[start_idx:end_idx]
429
+ key_slice = key[start_idx:end_idx]
430
+
431
+ if self.upcast_attention:
432
+ query_slice = query_slice.float()
433
+ key_slice = key_slice.float()
434
+
435
+ attn_slice = torch.baddbmm(
436
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
437
+ query_slice,
438
+ key_slice.transpose(-1, -2),
439
+ beta=0,
440
+ alpha=self.scale,
441
+ )
442
+
443
+ if attention_mask is not None:
444
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
445
+
446
+ if self.upcast_softmax:
447
+ attn_slice = attn_slice.float()
448
+
449
+ attn_slice = attn_slice.softmax(dim=-1)
450
+
451
+ # cast back to the original dtype
452
+ attn_slice = attn_slice.to(value.dtype)
453
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
454
+
455
+ hidden_states[start_idx:end_idx] = attn_slice
456
+
457
+ # reshape hidden_states
458
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
459
+ return hidden_states
460
+
461
+ def set_use_memory_efficient_attention_xformers(self, *args, **kwargs):
462
+ print('Set Xformers for MotionModule\'s Attention.')
463
+ self._use_memory_efficient_attention_xformers = True
464
+
465
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
466
+ # TODO attention_mask
467
+ query = query.contiguous()
468
+ key = key.contiguous()
469
+ value = value.contiguous()
470
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
471
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
472
+ return hidden_states
473
+
474
+ def _memory_efficient_attention_pt20(self, query, key, value, attention_mask):
475
+ query = query.contiguous()
476
+ key = key.contiguous()
477
+ value = value.contiguous()
478
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0, is_causal=False)
479
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
480
+ return hidden_states
481
+
482
+
483
+ class VersatileAttention(CrossAttention):
484
+ def __init__(
485
+ self,
486
+ attention_mode = None,
487
+ cross_frame_attention_mode = None,
488
+ temporal_position_encoding = False,
489
+ temporal_position_encoding_max_len = 32,
490
+ *args, **kwargs
491
+ ):
492
+ super().__init__(*args, **kwargs)
493
+ assert attention_mode == "Temporal"
494
+
495
+ self.attention_mode = attention_mode
496
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
497
+
498
+ self.pos_encoder = PositionalEncoding(
499
+ kwargs["query_dim"],
500
+ dropout=0.,
501
+ max_len=temporal_position_encoding_max_len
502
+ ) if (temporal_position_encoding and attention_mode == "Temporal") else None
503
+
504
+ def extra_repr(self):
505
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
506
+
507
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
508
+ batch_size, sequence_length, _ = hidden_states.shape
509
+
510
+ if self.attention_mode == "Temporal":
511
+ d = hidden_states.shape[1]
512
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
513
+
514
+ if self.pos_encoder is not None:
515
+ hidden_states = self.pos_encoder(hidden_states)
516
+
517
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
518
+ else:
519
+ raise NotImplementedError
520
+
521
+ encoder_hidden_states = encoder_hidden_states
522
+
523
+ if self.group_norm is not None:
524
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
525
+
526
+ query = self.to_q(hidden_states)
527
+ dim = query.shape[-1]
528
+ query = self.reshape_heads_to_batch_dim(query)
529
+
530
+ if self.added_kv_proj_dim is not None:
531
+ raise NotImplementedError
532
+
533
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
534
+ key = self.to_k(encoder_hidden_states)
535
+ value = self.to_v(encoder_hidden_states)
536
+
537
+ key = self.reshape_heads_to_batch_dim(key)
538
+ value = self.reshape_heads_to_batch_dim(value)
539
+
540
+ if attention_mask is not None:
541
+ if attention_mask.shape[-1] != query.shape[1]:
542
+ target_length = query.shape[1]
543
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
544
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
545
+
546
+ # attention, what we cannot get enough of
547
+ if hasattr(F, 'scaled_dot_product_attention'):
548
+ # NOTE: pt20's scaled_dot_product_attention seems more memory efficient than
549
+ # xformers' memory_efficient_attention, set it as the first class citizen
550
+ hidden_states = self._memory_efficient_attention_pt20(query, key, value, attention_mask)
551
+ hidden_states = hidden_states.to(query.dtype)
552
+ elif self._use_memory_efficient_attention_xformers:
553
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
554
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
555
+ hidden_states = hidden_states.to(query.dtype)
556
+ else:
557
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
558
+ hidden_states = self._attention(query, key, value, attention_mask)
559
+ else:
560
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
561
+
562
+ # linear proj
563
+ hidden_states = self.to_out[0](hidden_states)
564
+
565
+ # dropout
566
+ hidden_states = self.to_out[1](hidden_states)
567
+
568
+ if self.attention_mode == "Temporal":
569
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
570
+
571
+ return hidden_states
572
+
models/animatediff/models/resnet.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from einops import rearrange
8
+
9
+
10
+ class InflatedConv3d(nn.Conv2d):
11
+ def forward(self, x):
12
+ video_length = x.shape[2]
13
+
14
+ x = rearrange(x, "b c f h w -> (b f) c h w")
15
+ x = super().forward(x)
16
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
17
+
18
+ return x
19
+
20
+
21
+ class Upsample3D(nn.Module):
22
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.out_channels = out_channels or channels
26
+ self.use_conv = use_conv
27
+ self.use_conv_transpose = use_conv_transpose
28
+ self.name = name
29
+
30
+ conv = None
31
+ if use_conv_transpose:
32
+ raise NotImplementedError
33
+ elif use_conv:
34
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
35
+
36
+ def forward(self, hidden_states, output_size=None):
37
+ assert hidden_states.shape[1] == self.channels
38
+
39
+ if self.use_conv_transpose:
40
+ raise NotImplementedError
41
+
42
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
43
+ dtype = hidden_states.dtype
44
+ if dtype == torch.bfloat16:
45
+ hidden_states = hidden_states.to(torch.float32)
46
+
47
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
48
+ if hidden_states.shape[0] >= 64:
49
+ hidden_states = hidden_states.contiguous()
50
+
51
+ # if `output_size` is passed we force the interpolation output
52
+ # size and do not make use of `scale_factor=2`
53
+ if output_size is None:
54
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
55
+ else:
56
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
57
+
58
+ # If the input is bfloat16, we cast back to bfloat16
59
+ if dtype == torch.bfloat16:
60
+ hidden_states = hidden_states.to(dtype)
61
+
62
+ # if self.use_conv:
63
+ # if self.name == "conv":
64
+ # hidden_states = self.conv(hidden_states)
65
+ # else:
66
+ # hidden_states = self.Conv2d_0(hidden_states)
67
+ hidden_states = self.conv(hidden_states)
68
+
69
+ return hidden_states
70
+
71
+
72
+ class Downsample3D(nn.Module):
73
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.out_channels = out_channels or channels
77
+ self.use_conv = use_conv
78
+ self.padding = padding
79
+ stride = 2
80
+ self.name = name
81
+
82
+ if use_conv:
83
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ def forward(self, hidden_states):
88
+ assert hidden_states.shape[1] == self.channels
89
+ if self.use_conv and self.padding == 0:
90
+ raise NotImplementedError
91
+
92
+ assert hidden_states.shape[1] == self.channels
93
+ hidden_states = self.conv(hidden_states)
94
+
95
+ return hidden_states
96
+
97
+
98
+ class ResnetBlock3D(nn.Module):
99
+ def __init__(
100
+ self,
101
+ *,
102
+ in_channels,
103
+ out_channels=None,
104
+ conv_shortcut=False,
105
+ dropout=0.0,
106
+ temb_channels=512,
107
+ groups=32,
108
+ groups_out=None,
109
+ pre_norm=True,
110
+ eps=1e-6,
111
+ non_linearity="swish",
112
+ time_embedding_norm="default",
113
+ output_scale_factor=1.0,
114
+ use_in_shortcut=None,
115
+ ):
116
+ super().__init__()
117
+ self.pre_norm = pre_norm
118
+ self.pre_norm = True
119
+ self.in_channels = in_channels
120
+ out_channels = in_channels if out_channels is None else out_channels
121
+ self.out_channels = out_channels
122
+ self.use_conv_shortcut = conv_shortcut
123
+ self.time_embedding_norm = time_embedding_norm
124
+ self.output_scale_factor = output_scale_factor
125
+
126
+ if groups_out is None:
127
+ groups_out = groups
128
+
129
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
130
+
131
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
132
+
133
+ if temb_channels is not None:
134
+ if self.time_embedding_norm == "default":
135
+ time_emb_proj_out_channels = out_channels
136
+ elif self.time_embedding_norm == "scale_shift":
137
+ time_emb_proj_out_channels = out_channels * 2
138
+ else:
139
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
140
+
141
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
142
+ else:
143
+ self.time_emb_proj = None
144
+
145
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
146
+ self.dropout = torch.nn.Dropout(dropout)
147
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
148
+
149
+ if non_linearity == "swish":
150
+ self.nonlinearity = lambda x: F.silu(x)
151
+ elif non_linearity == "mish":
152
+ self.nonlinearity = Mish()
153
+ elif non_linearity == "silu":
154
+ self.nonlinearity = nn.SiLU()
155
+
156
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
157
+
158
+ self.conv_shortcut = None
159
+ if self.use_in_shortcut:
160
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
161
+
162
+ def forward(self, input_tensor, temb):
163
+ hidden_states = input_tensor
164
+
165
+ hidden_states = self.norm1(hidden_states)
166
+ hidden_states = self.nonlinearity(hidden_states)
167
+
168
+ hidden_states = self.conv1(hidden_states)
169
+
170
+ if temb is not None:
171
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
172
+
173
+ if temb is not None and self.time_embedding_norm == "default":
174
+ hidden_states = hidden_states + temb
175
+
176
+ hidden_states = self.norm2(hidden_states)
177
+
178
+ if temb is not None and self.time_embedding_norm == "scale_shift":
179
+ scale, shift = torch.chunk(temb, 2, dim=1)
180
+ hidden_states = hidden_states * (1 + scale) + shift
181
+
182
+ hidden_states = self.nonlinearity(hidden_states)
183
+
184
+ hidden_states = self.dropout(hidden_states)
185
+ hidden_states = self.conv2(hidden_states)
186
+
187
+ if self.conv_shortcut is not None:
188
+ input_tensor = self.conv_shortcut(input_tensor)
189
+
190
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
191
+
192
+ return output_tensor
193
+
194
+
195
+ class Mish(torch.nn.Module):
196
+ def forward(self, hidden_states):
197
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
models/animatediff/models/unet.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import os
7
+ import json
8
+ import pdb
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.checkpoint
13
+ try:
14
+ from diffusers.models.cross_attention import AttnProcessor
15
+ except:
16
+ from diffusers.models.attention_processor import AttnProcessor
17
+ from typing import Dict
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.models import ModelMixin
21
+ from diffusers.loaders import UNet2DConditionLoadersMixin
22
+ from diffusers.utils import BaseOutput, logging
23
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
24
+ from .unet_blocks import (
25
+ CrossAttnDownBlock3D,
26
+ CrossAttnUpBlock3D,
27
+ DownBlock3D,
28
+ UNetMidBlock3DCrossAttn,
29
+ UpBlock3D,
30
+ get_down_block,
31
+ get_up_block,
32
+ )
33
+ from .resnet import InflatedConv3d
34
+ from .motion_module import VersatileAttention
35
+ def zero_module(module):
36
+ # Zero out the parameters of a module and return it.
37
+ for p in module.parameters():
38
+ p.detach().zero_()
39
+ return module
40
+
41
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
42
+
43
+
44
+ @dataclass
45
+ class UNet3DConditionOutput(BaseOutput):
46
+ sample: torch.FloatTensor
47
+
48
+
49
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
50
+ _supports_gradient_checkpointing = True
51
+
52
+ @register_to_config
53
+ def __init__(
54
+ self,
55
+ sample_size: Optional[int] = None,
56
+ in_channels: int = 4,
57
+ out_channels: int = 4,
58
+ center_input_sample: bool = False,
59
+ flip_sin_to_cos: bool = True,
60
+ freq_shift: int = 0,
61
+ down_block_types: Tuple[str] = (
62
+ "CrossAttnDownBlock3D",
63
+ "CrossAttnDownBlock3D",
64
+ "CrossAttnDownBlock3D",
65
+ "DownBlock3D",
66
+ ),
67
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
68
+ up_block_types: Tuple[str] = (
69
+ "UpBlock3D",
70
+ "CrossAttnUpBlock3D",
71
+ "CrossAttnUpBlock3D",
72
+ "CrossAttnUpBlock3D"
73
+ ),
74
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
75
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
76
+ layers_per_block: int = 2,
77
+ downsample_padding: int = 1,
78
+ mid_block_scale_factor: float = 1,
79
+ act_fn: str = "silu",
80
+ norm_num_groups: int = 32,
81
+ norm_eps: float = 1e-5,
82
+ cross_attention_dim: int = 1280,
83
+ attention_head_dim: Union[int, Tuple[int]] = 8,
84
+ dual_cross_attention: bool = False,
85
+ use_linear_projection: bool = False,
86
+ class_embed_type: Optional[str] = None,
87
+ num_class_embeds: Optional[int] = None,
88
+ upcast_attention: bool = False,
89
+ resnet_time_scale_shift: str = "default",
90
+
91
+ # Additional
92
+ use_motion_module = True,
93
+ motion_module_resolutions = ( 1,2,4,8 ),
94
+ motion_module_mid_block = False,
95
+ motion_module_decoder_only = False,
96
+ motion_module_type = None,
97
+ motion_module_kwargs = {},
98
+ unet_use_cross_frame_attention = None,
99
+ unet_use_temporal_attention = None,
100
+
101
+ ):
102
+ super().__init__()
103
+
104
+ self.sample_size = sample_size
105
+ time_embed_dim = block_out_channels[0] * 4
106
+
107
+ # Image to Video Conv
108
+ # input
109
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
110
+
111
+ # time
112
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
113
+ timestep_input_dim = block_out_channels[0]
114
+
115
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
116
+
117
+ # class embedding
118
+ if class_embed_type is None and num_class_embeds is not None:
119
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
120
+ elif class_embed_type == "timestep":
121
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
122
+ elif class_embed_type == "identity":
123
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
124
+ else:
125
+ self.class_embedding = None
126
+
127
+ self.down_blocks = nn.ModuleList([])
128
+ self.mid_block = None
129
+ self.up_blocks = nn.ModuleList([])
130
+
131
+ if isinstance(only_cross_attention, bool):
132
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
133
+
134
+ if isinstance(attention_head_dim, int):
135
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
136
+
137
+ # down
138
+ output_channel = block_out_channels[0]
139
+ for i, down_block_type in enumerate(down_block_types):
140
+ res = 2 ** i
141
+ input_channel = output_channel
142
+ output_channel = block_out_channels[i]
143
+ is_final_block = i == len(block_out_channels) - 1
144
+
145
+ down_block = get_down_block(
146
+ down_block_type,
147
+ num_layers=layers_per_block,
148
+ in_channels=input_channel,
149
+ out_channels=output_channel,
150
+ temb_channels=time_embed_dim,
151
+ add_downsample=not is_final_block,
152
+ resnet_eps=norm_eps,
153
+ resnet_act_fn=act_fn,
154
+ resnet_groups=norm_num_groups,
155
+ cross_attention_dim=cross_attention_dim,
156
+ attn_num_head_channels=attention_head_dim[i],
157
+ downsample_padding=downsample_padding,
158
+ dual_cross_attention=dual_cross_attention,
159
+ use_linear_projection=use_linear_projection,
160
+ only_cross_attention=only_cross_attention[i],
161
+ upcast_attention=upcast_attention,
162
+ resnet_time_scale_shift=resnet_time_scale_shift,
163
+
164
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
165
+ unet_use_temporal_attention=unet_use_temporal_attention,
166
+
167
+ use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
168
+ motion_module_type=motion_module_type,
169
+ motion_module_kwargs=motion_module_kwargs,
170
+ )
171
+ self.down_blocks.append(down_block)
172
+
173
+ # mid
174
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
175
+ self.mid_block = UNetMidBlock3DCrossAttn(
176
+ in_channels=block_out_channels[-1],
177
+ temb_channels=time_embed_dim,
178
+ resnet_eps=norm_eps,
179
+ resnet_act_fn=act_fn,
180
+ output_scale_factor=mid_block_scale_factor,
181
+ resnet_time_scale_shift=resnet_time_scale_shift,
182
+ cross_attention_dim=cross_attention_dim,
183
+ attn_num_head_channels=attention_head_dim[-1],
184
+ resnet_groups=norm_num_groups,
185
+ dual_cross_attention=dual_cross_attention,
186
+ use_linear_projection=use_linear_projection,
187
+ upcast_attention=upcast_attention,
188
+
189
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
190
+ unet_use_temporal_attention=unet_use_temporal_attention,
191
+
192
+ use_motion_module=use_motion_module and motion_module_mid_block,
193
+ motion_module_type=motion_module_type,
194
+ motion_module_kwargs=motion_module_kwargs,
195
+ )
196
+ else:
197
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
198
+
199
+ # count how many layers upsample the videos
200
+ self.num_upsamplers = 0
201
+
202
+ # up
203
+ reversed_block_out_channels = list(reversed(block_out_channels))
204
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
205
+ only_cross_attention = list(reversed(only_cross_attention))
206
+ output_channel = reversed_block_out_channels[0]
207
+ for i, up_block_type in enumerate(up_block_types):
208
+ res = 2 ** (3 - i)
209
+ is_final_block = i == len(block_out_channels) - 1
210
+
211
+ prev_output_channel = output_channel
212
+ output_channel = reversed_block_out_channels[i]
213
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
214
+
215
+ # add upsample block for all BUT final layer
216
+ if not is_final_block:
217
+ add_upsample = True
218
+ self.num_upsamplers += 1
219
+ else:
220
+ add_upsample = False
221
+
222
+ up_block = get_up_block(
223
+ up_block_type,
224
+ num_layers=layers_per_block + 1,
225
+ in_channels=input_channel,
226
+ out_channels=output_channel,
227
+ prev_output_channel=prev_output_channel,
228
+ temb_channels=time_embed_dim,
229
+ add_upsample=add_upsample,
230
+ resnet_eps=norm_eps,
231
+ resnet_act_fn=act_fn,
232
+ resnet_groups=norm_num_groups,
233
+ cross_attention_dim=cross_attention_dim,
234
+ attn_num_head_channels=reversed_attention_head_dim[i],
235
+ dual_cross_attention=dual_cross_attention,
236
+ use_linear_projection=use_linear_projection,
237
+ only_cross_attention=only_cross_attention[i],
238
+ upcast_attention=upcast_attention,
239
+ resnet_time_scale_shift=resnet_time_scale_shift,
240
+
241
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
242
+ unet_use_temporal_attention=unet_use_temporal_attention,
243
+
244
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
245
+ motion_module_type=motion_module_type,
246
+ motion_module_kwargs=motion_module_kwargs,
247
+ )
248
+ self.up_blocks.append(up_block)
249
+ prev_output_channel = output_channel
250
+
251
+ # out
252
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
253
+ self.conv_act = nn.SiLU()
254
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
255
+
256
+ @property
257
+ def attn_processors(self) -> Dict[str, AttnProcessor]:
258
+ r"""
259
+ Returns:
260
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
261
+ indexed by its weight name.
262
+ """
263
+ # set recursively
264
+ processors = {}
265
+
266
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
267
+ if hasattr(module, "set_processor"):
268
+ processors[f"{name}.processor"] = module.processor
269
+
270
+ for sub_name, child in module.named_children():
271
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
272
+
273
+ return processors
274
+
275
+ for name, module in self.named_children():
276
+ fn_recursive_add_processors(name, module, processors)
277
+
278
+ return processors
279
+
280
+ def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
281
+ r"""
282
+ Parameters:
283
+ `processor (`dict` of `AttnProcessor` or `AttnProcessor`):
284
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
285
+ of **all** `CrossAttention` layers.
286
+ In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
287
+
288
+ """
289
+ count = len(self.attn_processors.keys())
290
+
291
+ if isinstance(processor, dict) and len(processor) != count:
292
+ raise ValueError(
293
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
294
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
295
+ )
296
+
297
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
298
+ if hasattr(module, "set_processor"):
299
+ if not isinstance(processor, dict):
300
+ print(f'Set {module}')
301
+ module.set_processor(processor)
302
+ else:
303
+ print(f'Set {module}')
304
+ module.set_processor(processor.pop(f"{name}.processor"))
305
+
306
+ for sub_name, child in module.named_children():
307
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
308
+
309
+ for name, module in self.named_children():
310
+ fn_recursive_attn_processor(name, module, processor)
311
+
312
+ def set_attention_slice(self, slice_size):
313
+ r"""
314
+ Enable sliced attention computation.
315
+
316
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
317
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
318
+
319
+ Args:
320
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
321
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
322
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
323
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
324
+ must be a multiple of `slice_size`.
325
+ """
326
+ sliceable_head_dims = []
327
+
328
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
329
+ if hasattr(module, "set_attention_slice"):
330
+ sliceable_head_dims.append(module.sliceable_head_dim)
331
+
332
+ for child in module.children():
333
+ fn_recursive_retrieve_slicable_dims(child)
334
+
335
+ # retrieve number of attention layers
336
+ for module in self.children():
337
+ fn_recursive_retrieve_slicable_dims(module)
338
+
339
+ num_slicable_layers = len(sliceable_head_dims)
340
+
341
+ if slice_size == "auto":
342
+ # half the attention head size is usually a good trade-off between
343
+ # speed and memory
344
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
345
+ elif slice_size == "max":
346
+ # make smallest slice possible
347
+ slice_size = num_slicable_layers * [1]
348
+
349
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
350
+
351
+ if len(slice_size) != len(sliceable_head_dims):
352
+ raise ValueError(
353
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
354
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
355
+ )
356
+
357
+ for i in range(len(slice_size)):
358
+ size = slice_size[i]
359
+ dim = sliceable_head_dims[i]
360
+ if size is not None and size > dim:
361
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
362
+
363
+ # Recursively walk through all the children.
364
+ # Any children which exposes the set_attention_slice method
365
+ # gets the message
366
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
367
+ if hasattr(module, "set_attention_slice"):
368
+ module.set_attention_slice(slice_size.pop())
369
+
370
+ for child in module.children():
371
+ fn_recursive_set_attention_slice(child, slice_size)
372
+
373
+ reversed_slice_size = list(reversed(slice_size))
374
+ for module in self.children():
375
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
376
+
377
+ def _set_gradient_checkpointing(self, module, value=False):
378
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
379
+ module.gradient_checkpointing = value
380
+
381
+ def forward(
382
+ self,
383
+ sample: torch.FloatTensor,
384
+ mask_sample: torch.FloatTensor,
385
+ masked_sample: torch.FloatTensor,
386
+ timestep: Union[torch.Tensor, float, int],
387
+ encoder_hidden_states: torch.Tensor,
388
+ class_labels: Optional[torch.Tensor] = None,
389
+ attention_mask: Optional[torch.Tensor] = None,
390
+ image_embeds: Optional[torch.Tensor] = None,
391
+ return_dict: bool = True,
392
+ ) -> Union[UNet3DConditionOutput, Tuple]:
393
+ r"""
394
+ Args:
395
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
396
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
397
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
398
+ return_dict (`bool`, *optional*, defaults to `True`):
399
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
400
+
401
+ Returns:
402
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
403
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
404
+ returning a tuple, the first element is the sample tensor.
405
+ """
406
+ # image to video b c f h w
407
+ sample = torch.cat([sample, mask_sample, masked_sample], dim=1).to(sample.device)
408
+
409
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
410
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
411
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
412
+ # on the fly if necessary.
413
+
414
+ default_overall_up_factor = 2**self.num_upsamplers
415
+
416
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
417
+ forward_upsample_size = False
418
+ upsample_size = None
419
+
420
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
421
+ logger.info("Forward upsample size to force interpolation output size.")
422
+ forward_upsample_size = True
423
+
424
+ # prepare attention_mask
425
+ if attention_mask is not None:
426
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * - 10000.0
427
+ attention_mask = attention_mask.unsqueeze(1)
428
+
429
+ # center input if necessary
430
+ if self.config.center_input_sample:
431
+ sample = 2 * sample - 1.0
432
+
433
+ # time
434
+ timesteps = timestep
435
+ if not torch.is_tensor(timesteps):
436
+ # This would be a good case for the `match` statement (Python 3.10+)
437
+ is_mps = sample.device.type == "mps"
438
+ if isinstance(timestep, float):
439
+ dtype = torch.float32 if is_mps else torch.float64
440
+ else:
441
+ dtype = torch.int32 if is_mps else torch.int64
442
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
443
+ elif len(timesteps.shape) == 0:
444
+ timesteps = timesteps[None].to(sample.device)
445
+
446
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
447
+ timesteps = timesteps.expand(sample.shape[0])
448
+
449
+ t_emb = self.time_proj(timesteps)
450
+
451
+ # timesteps does not contain any weights and will always return f32 tensors
452
+ # but time_embedding might actually be running in fp16. so we need to cast here.
453
+ # there might be better ways to encapsulate this.
454
+ t_emb = t_emb.to(dtype=self.dtype)
455
+ emb = self.time_embedding(t_emb)
456
+
457
+ if self.class_embedding is not None:
458
+ if class_labels is None:
459
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
460
+
461
+ if self.config.class_embed_type == "timestep":
462
+ class_labels = self.time_proj(class_labels)
463
+
464
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
465
+ emb = emb + class_emb
466
+
467
+ # prepare for ip-adapter
468
+ if image_embeds is not None:
469
+ image_embeds = self.encoder_hid_proj(
470
+ image_embeds).to(encoder_hidden_states.dtype)
471
+ encoder_hidden_states = torch.cat(
472
+ [encoder_hidden_states, image_embeds], dim=1)
473
+
474
+ # pre-process
475
+ # b c f h w
476
+ # 2 4 16 64 64
477
+ sample = self.conv_in(sample)
478
+ # down
479
+ down_block_res_samples = (sample,)
480
+ for downsample_block in self.down_blocks:
481
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
482
+ sample, res_samples = downsample_block(
483
+ hidden_states=sample,
484
+ temb=emb,
485
+ encoder_hidden_states=encoder_hidden_states,
486
+ attention_mask=attention_mask,
487
+ )
488
+ else:
489
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
490
+ down_block_res_samples += res_samples
491
+
492
+ # mid
493
+ sample = self.mid_block(
494
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
495
+ )
496
+
497
+ # up
498
+ for i, upsample_block in enumerate(self.up_blocks):
499
+ is_final_block = i == len(self.up_blocks) - 1
500
+
501
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
502
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
503
+
504
+ # if we have not reached the final block and need to forward the
505
+ # upsample size, we do it here
506
+ if not is_final_block and forward_upsample_size:
507
+ upsample_size = down_block_res_samples[-1].shape[2:]
508
+
509
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
510
+ sample = upsample_block(
511
+ hidden_states=sample,
512
+ temb=emb,
513
+ res_hidden_states_tuple=res_samples,
514
+ encoder_hidden_states=encoder_hidden_states,
515
+ upsample_size=upsample_size,
516
+ attention_mask=attention_mask,
517
+ )
518
+ else:
519
+ sample = upsample_block(
520
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
521
+ )
522
+
523
+ # post-process
524
+ sample = self.conv_norm_out(sample)
525
+ sample = self.conv_act(sample)
526
+ sample = self.conv_out(sample)
527
+
528
+ if not return_dict:
529
+ return (sample,)
530
+
531
+ return UNet3DConditionOutput(sample=sample)
532
+
533
+ @classmethod
534
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
535
+ if subfolder is not None:
536
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
537
+ print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
538
+
539
+ config_file = os.path.join(pretrained_model_path, 'config.json')
540
+ if not os.path.isfile(config_file):
541
+ raise RuntimeError(f"{config_file} does not exist")
542
+ with open(config_file, "r") as f:
543
+ config = json.load(f)
544
+ config["_class_name"] = cls.__name__
545
+ config["down_block_types"] = [
546
+ "CrossAttnDownBlock3D",
547
+ "CrossAttnDownBlock3D",
548
+ "CrossAttnDownBlock3D",
549
+ "DownBlock3D"
550
+ ]
551
+ config["up_block_types"] = [
552
+ "UpBlock3D",
553
+ "CrossAttnUpBlock3D",
554
+ "CrossAttnUpBlock3D",
555
+ "CrossAttnUpBlock3D"
556
+ ]
557
+
558
+ from diffusers.utils import WEIGHTS_NAME
559
+ model = cls.from_config(config, **unet_additional_kwargs)
560
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
561
+ if not os.path.isfile(model_file):
562
+ raise RuntimeError(f"{model_file} does not exist")
563
+ state_dict = torch.load(model_file, map_location="cpu")
564
+
565
+ m, u = model.load_state_dict(state_dict, strict=False)
566
+ print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
567
+ # print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
568
+
569
+ params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
570
+ print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
571
+
572
+ return model
models/animatediff/models/unet_blocks.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ from .attention import Transformer3DModel
7
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
8
+ from .motion_module import get_motion_module
9
+
10
+ import pdb
11
+
12
+ def get_down_block(
13
+ down_block_type,
14
+ num_layers,
15
+ in_channels,
16
+ out_channels,
17
+ temb_channels,
18
+ add_downsample,
19
+ resnet_eps,
20
+ resnet_act_fn,
21
+ attn_num_head_channels,
22
+ resnet_groups=None,
23
+ cross_attention_dim=None,
24
+ downsample_padding=None,
25
+ dual_cross_attention=False,
26
+ use_linear_projection=False,
27
+ only_cross_attention=False,
28
+ upcast_attention=False,
29
+ resnet_time_scale_shift="default",
30
+
31
+ unet_use_cross_frame_attention=None,
32
+ unet_use_temporal_attention=None,
33
+
34
+ use_motion_module=None,
35
+
36
+ motion_module_type=None,
37
+ motion_module_kwargs=None,
38
+ ):
39
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
40
+ if down_block_type == "DownBlock3D":
41
+ return DownBlock3D(
42
+ num_layers=num_layers,
43
+ in_channels=in_channels,
44
+ out_channels=out_channels,
45
+ temb_channels=temb_channels,
46
+ add_downsample=add_downsample,
47
+ resnet_eps=resnet_eps,
48
+ resnet_act_fn=resnet_act_fn,
49
+ resnet_groups=resnet_groups,
50
+ downsample_padding=downsample_padding,
51
+ resnet_time_scale_shift=resnet_time_scale_shift,
52
+
53
+ use_motion_module=use_motion_module,
54
+ motion_module_type=motion_module_type,
55
+ motion_module_kwargs=motion_module_kwargs,
56
+ )
57
+ elif down_block_type == "CrossAttnDownBlock3D":
58
+ if cross_attention_dim is None:
59
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
60
+ return CrossAttnDownBlock3D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ add_downsample=add_downsample,
66
+ resnet_eps=resnet_eps,
67
+ resnet_act_fn=resnet_act_fn,
68
+ resnet_groups=resnet_groups,
69
+ downsample_padding=downsample_padding,
70
+ cross_attention_dim=cross_attention_dim,
71
+ attn_num_head_channels=attn_num_head_channels,
72
+ dual_cross_attention=dual_cross_attention,
73
+ use_linear_projection=use_linear_projection,
74
+ only_cross_attention=only_cross_attention,
75
+ upcast_attention=upcast_attention,
76
+ resnet_time_scale_shift=resnet_time_scale_shift,
77
+
78
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
79
+ unet_use_temporal_attention=unet_use_temporal_attention,
80
+
81
+ use_motion_module=use_motion_module,
82
+ motion_module_type=motion_module_type,
83
+ motion_module_kwargs=motion_module_kwargs,
84
+ )
85
+ raise ValueError(f"{down_block_type} does not exist.")
86
+
87
+
88
+ def get_up_block(
89
+ up_block_type,
90
+ num_layers,
91
+ in_channels,
92
+ out_channels,
93
+ prev_output_channel,
94
+ temb_channels,
95
+ add_upsample,
96
+ resnet_eps,
97
+ resnet_act_fn,
98
+ attn_num_head_channels,
99
+ resnet_groups=None,
100
+ cross_attention_dim=None,
101
+ dual_cross_attention=False,
102
+ use_linear_projection=False,
103
+ only_cross_attention=False,
104
+ upcast_attention=False,
105
+ resnet_time_scale_shift="default",
106
+
107
+ unet_use_cross_frame_attention=None,
108
+ unet_use_temporal_attention=None,
109
+
110
+ use_motion_module=None,
111
+ motion_module_type=None,
112
+ motion_module_kwargs=None,
113
+ ):
114
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
115
+ if up_block_type == "UpBlock3D":
116
+ return UpBlock3D(
117
+ num_layers=num_layers,
118
+ in_channels=in_channels,
119
+ out_channels=out_channels,
120
+ prev_output_channel=prev_output_channel,
121
+ temb_channels=temb_channels,
122
+ add_upsample=add_upsample,
123
+ resnet_eps=resnet_eps,
124
+ resnet_act_fn=resnet_act_fn,
125
+ resnet_groups=resnet_groups,
126
+ resnet_time_scale_shift=resnet_time_scale_shift,
127
+
128
+ use_motion_module=use_motion_module,
129
+ motion_module_type=motion_module_type,
130
+ motion_module_kwargs=motion_module_kwargs,
131
+ )
132
+ elif up_block_type == "CrossAttnUpBlock3D":
133
+ if cross_attention_dim is None:
134
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
135
+ return CrossAttnUpBlock3D(
136
+ num_layers=num_layers,
137
+ in_channels=in_channels,
138
+ out_channels=out_channels,
139
+ prev_output_channel=prev_output_channel,
140
+ temb_channels=temb_channels,
141
+ add_upsample=add_upsample,
142
+ resnet_eps=resnet_eps,
143
+ resnet_act_fn=resnet_act_fn,
144
+ resnet_groups=resnet_groups,
145
+ cross_attention_dim=cross_attention_dim,
146
+ attn_num_head_channels=attn_num_head_channels,
147
+ dual_cross_attention=dual_cross_attention,
148
+ use_linear_projection=use_linear_projection,
149
+ only_cross_attention=only_cross_attention,
150
+ upcast_attention=upcast_attention,
151
+ resnet_time_scale_shift=resnet_time_scale_shift,
152
+
153
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
154
+ unet_use_temporal_attention=unet_use_temporal_attention,
155
+
156
+ use_motion_module=use_motion_module,
157
+ motion_module_type=motion_module_type,
158
+ motion_module_kwargs=motion_module_kwargs,
159
+ )
160
+ raise ValueError(f"{up_block_type} does not exist.")
161
+
162
+
163
+ class UNetMidBlock3DCrossAttn(nn.Module):
164
+ def __init__(
165
+ self,
166
+ in_channels: int,
167
+ temb_channels: int,
168
+ dropout: float = 0.0,
169
+ num_layers: int = 1,
170
+ resnet_eps: float = 1e-6,
171
+ resnet_time_scale_shift: str = "default",
172
+ resnet_act_fn: str = "swish",
173
+ resnet_groups: int = 32,
174
+ resnet_pre_norm: bool = True,
175
+ attn_num_head_channels=1,
176
+ output_scale_factor=1.0,
177
+ cross_attention_dim=1280,
178
+ dual_cross_attention=False,
179
+ use_linear_projection=False,
180
+ upcast_attention=False,
181
+
182
+ unet_use_cross_frame_attention=None,
183
+ unet_use_temporal_attention=None,
184
+
185
+ use_motion_module=None,
186
+
187
+ motion_module_type=None,
188
+ motion_module_kwargs=None,
189
+ ):
190
+ super().__init__()
191
+
192
+ self.has_cross_attention = True
193
+ self.attn_num_head_channels = attn_num_head_channels
194
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
195
+
196
+ # there is always at least one resnet
197
+ resnets = [
198
+ ResnetBlock3D(
199
+ in_channels=in_channels,
200
+ out_channels=in_channels,
201
+ temb_channels=temb_channels,
202
+ eps=resnet_eps,
203
+ groups=resnet_groups,
204
+ dropout=dropout,
205
+ time_embedding_norm=resnet_time_scale_shift,
206
+ non_linearity=resnet_act_fn,
207
+ output_scale_factor=output_scale_factor,
208
+ pre_norm=resnet_pre_norm,
209
+ )
210
+ ]
211
+ attentions = []
212
+ motion_modules = []
213
+
214
+ for _ in range(num_layers):
215
+ if dual_cross_attention:
216
+ raise NotImplementedError
217
+ attentions.append(
218
+ Transformer3DModel(
219
+ attn_num_head_channels,
220
+ in_channels // attn_num_head_channels,
221
+ in_channels=in_channels,
222
+ num_layers=1,
223
+ cross_attention_dim=cross_attention_dim,
224
+ norm_num_groups=resnet_groups,
225
+ use_linear_projection=use_linear_projection,
226
+ upcast_attention=upcast_attention,
227
+
228
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
229
+ unet_use_temporal_attention=unet_use_temporal_attention,
230
+ )
231
+ )
232
+ motion_modules.append(
233
+ get_motion_module(
234
+ in_channels=in_channels,
235
+ motion_module_type=motion_module_type,
236
+ motion_module_kwargs=motion_module_kwargs,
237
+ ) if use_motion_module else None
238
+ )
239
+ resnets.append(
240
+ ResnetBlock3D(
241
+ in_channels=in_channels,
242
+ out_channels=in_channels,
243
+ temb_channels=temb_channels,
244
+ eps=resnet_eps,
245
+ groups=resnet_groups,
246
+ dropout=dropout,
247
+ time_embedding_norm=resnet_time_scale_shift,
248
+ non_linearity=resnet_act_fn,
249
+ output_scale_factor=output_scale_factor,
250
+ pre_norm=resnet_pre_norm,
251
+ )
252
+ )
253
+
254
+ self.attentions = nn.ModuleList(attentions)
255
+ self.resnets = nn.ModuleList(resnets)
256
+ self.motion_modules = nn.ModuleList(motion_modules)
257
+
258
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
259
+ hidden_states = self.resnets[0](hidden_states, temb)
260
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
261
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
262
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
263
+ hidden_states = resnet(hidden_states, temb)
264
+
265
+ return hidden_states
266
+
267
+
268
+ class CrossAttnDownBlock3D(nn.Module):
269
+ def __init__(
270
+ self,
271
+ in_channels: int,
272
+ out_channels: int,
273
+ temb_channels: int,
274
+ dropout: float = 0.0,
275
+ num_layers: int = 1,
276
+ resnet_eps: float = 1e-6,
277
+ resnet_time_scale_shift: str = "default",
278
+ resnet_act_fn: str = "swish",
279
+ resnet_groups: int = 32,
280
+ resnet_pre_norm: bool = True,
281
+ attn_num_head_channels=1,
282
+ cross_attention_dim=1280,
283
+ output_scale_factor=1.0,
284
+ downsample_padding=1,
285
+ add_downsample=True,
286
+ dual_cross_attention=False,
287
+ use_linear_projection=False,
288
+ only_cross_attention=False,
289
+ upcast_attention=False,
290
+
291
+ unet_use_cross_frame_attention=None,
292
+ unet_use_temporal_attention=None,
293
+
294
+ use_motion_module=None,
295
+
296
+ motion_module_type=None,
297
+ motion_module_kwargs=None,
298
+ ):
299
+ super().__init__()
300
+ resnets = []
301
+ attentions = []
302
+ motion_modules = []
303
+
304
+ self.has_cross_attention = True
305
+ self.attn_num_head_channels = attn_num_head_channels
306
+
307
+ for i in range(num_layers):
308
+ in_channels = in_channels if i == 0 else out_channels
309
+ resnets.append(
310
+ ResnetBlock3D(
311
+ in_channels=in_channels,
312
+ out_channels=out_channels,
313
+ temb_channels=temb_channels,
314
+ eps=resnet_eps,
315
+ groups=resnet_groups,
316
+ dropout=dropout,
317
+ time_embedding_norm=resnet_time_scale_shift,
318
+ non_linearity=resnet_act_fn,
319
+ output_scale_factor=output_scale_factor,
320
+ pre_norm=resnet_pre_norm,
321
+ )
322
+ )
323
+ if dual_cross_attention:
324
+ raise NotImplementedError
325
+ attentions.append(
326
+ Transformer3DModel(
327
+ attn_num_head_channels,
328
+ out_channels // attn_num_head_channels,
329
+ in_channels=out_channels,
330
+ num_layers=1,
331
+ cross_attention_dim=cross_attention_dim,
332
+ norm_num_groups=resnet_groups,
333
+ use_linear_projection=use_linear_projection,
334
+ only_cross_attention=only_cross_attention,
335
+ upcast_attention=upcast_attention,
336
+
337
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
338
+ unet_use_temporal_attention=unet_use_temporal_attention,
339
+ )
340
+ )
341
+ motion_modules.append(
342
+ get_motion_module(
343
+ in_channels=out_channels,
344
+ motion_module_type=motion_module_type,
345
+ motion_module_kwargs=motion_module_kwargs,
346
+ ) if use_motion_module else None
347
+ )
348
+
349
+ self.attentions = nn.ModuleList(attentions)
350
+ self.resnets = nn.ModuleList(resnets)
351
+ self.motion_modules = nn.ModuleList(motion_modules)
352
+
353
+ if add_downsample:
354
+ self.downsamplers = nn.ModuleList(
355
+ [
356
+ Downsample3D(
357
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
358
+ )
359
+ ]
360
+ )
361
+ else:
362
+ self.downsamplers = None
363
+
364
+ self.gradient_checkpointing = False
365
+
366
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
367
+ output_states = ()
368
+
369
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
370
+ if self.training and self.gradient_checkpointing:
371
+
372
+ def create_custom_forward(module, return_dict=None):
373
+ def custom_forward(*inputs):
374
+ if return_dict is not None:
375
+ return module(*inputs, return_dict=return_dict)
376
+ else:
377
+ return module(*inputs)
378
+
379
+ return custom_forward
380
+
381
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
382
+ hidden_states = torch.utils.checkpoint.checkpoint(
383
+ create_custom_forward(attn, return_dict=False),
384
+ hidden_states,
385
+ encoder_hidden_states,
386
+ )[0]
387
+ if motion_module is not None:
388
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
389
+
390
+ else:
391
+ hidden_states = resnet(hidden_states, temb)
392
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
393
+
394
+ # add motion module
395
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
396
+
397
+ output_states += (hidden_states,)
398
+
399
+ if self.downsamplers is not None:
400
+ for downsampler in self.downsamplers:
401
+ hidden_states = downsampler(hidden_states)
402
+
403
+ output_states += (hidden_states,)
404
+
405
+ return hidden_states, output_states
406
+
407
+
408
+ class DownBlock3D(nn.Module):
409
+ def __init__(
410
+ self,
411
+ in_channels: int,
412
+ out_channels: int,
413
+ temb_channels: int,
414
+ dropout: float = 0.0,
415
+ num_layers: int = 1,
416
+ resnet_eps: float = 1e-6,
417
+ resnet_time_scale_shift: str = "default",
418
+ resnet_act_fn: str = "swish",
419
+ resnet_groups: int = 32,
420
+ resnet_pre_norm: bool = True,
421
+ output_scale_factor=1.0,
422
+ add_downsample=True,
423
+ downsample_padding=1,
424
+
425
+ use_motion_module=None,
426
+ motion_module_type=None,
427
+ motion_module_kwargs=None,
428
+ ):
429
+ super().__init__()
430
+ resnets = []
431
+ motion_modules = []
432
+
433
+ for i in range(num_layers):
434
+ in_channels = in_channels if i == 0 else out_channels
435
+ resnets.append(
436
+ ResnetBlock3D(
437
+ in_channels=in_channels,
438
+ out_channels=out_channels,
439
+ temb_channels=temb_channels,
440
+ eps=resnet_eps,
441
+ groups=resnet_groups,
442
+ dropout=dropout,
443
+ time_embedding_norm=resnet_time_scale_shift,
444
+ non_linearity=resnet_act_fn,
445
+ output_scale_factor=output_scale_factor,
446
+ pre_norm=resnet_pre_norm,
447
+ )
448
+ )
449
+ motion_modules.append(
450
+ get_motion_module(
451
+ in_channels=out_channels,
452
+ motion_module_type=motion_module_type,
453
+ motion_module_kwargs=motion_module_kwargs,
454
+ ) if use_motion_module else None
455
+ )
456
+
457
+ self.resnets = nn.ModuleList(resnets)
458
+ self.motion_modules = nn.ModuleList(motion_modules)
459
+
460
+ if add_downsample:
461
+ self.downsamplers = nn.ModuleList(
462
+ [
463
+ Downsample3D(
464
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
465
+ )
466
+ ]
467
+ )
468
+ else:
469
+ self.downsamplers = None
470
+
471
+ self.gradient_checkpointing = False
472
+
473
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
474
+ output_states = ()
475
+
476
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
477
+ if self.training and self.gradient_checkpointing:
478
+ def create_custom_forward(module):
479
+ def custom_forward(*inputs):
480
+ return module(*inputs)
481
+
482
+ return custom_forward
483
+
484
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
485
+ if motion_module is not None:
486
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
487
+ else:
488
+ hidden_states = resnet(hidden_states, temb)
489
+
490
+ # add motion module
491
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
492
+
493
+ output_states += (hidden_states,)
494
+
495
+ if self.downsamplers is not None:
496
+ for downsampler in self.downsamplers:
497
+ hidden_states = downsampler(hidden_states)
498
+
499
+ output_states += (hidden_states,)
500
+
501
+ return hidden_states, output_states
502
+
503
+
504
+ class CrossAttnUpBlock3D(nn.Module):
505
+ def __init__(
506
+ self,
507
+ in_channels: int,
508
+ out_channels: int,
509
+ prev_output_channel: int,
510
+ temb_channels: int,
511
+ dropout: float = 0.0,
512
+ num_layers: int = 1,
513
+ resnet_eps: float = 1e-6,
514
+ resnet_time_scale_shift: str = "default",
515
+ resnet_act_fn: str = "swish",
516
+ resnet_groups: int = 32,
517
+ resnet_pre_norm: bool = True,
518
+ attn_num_head_channels=1,
519
+ cross_attention_dim=1280,
520
+ output_scale_factor=1.0,
521
+ add_upsample=True,
522
+ dual_cross_attention=False,
523
+ use_linear_projection=False,
524
+ only_cross_attention=False,
525
+ upcast_attention=False,
526
+
527
+ unet_use_cross_frame_attention=None,
528
+ unet_use_temporal_attention=None,
529
+
530
+ use_motion_module=None,
531
+
532
+ motion_module_type=None,
533
+ motion_module_kwargs=None,
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+ motion_modules = []
539
+
540
+ self.has_cross_attention = True
541
+ self.attn_num_head_channels = attn_num_head_channels
542
+
543
+ for i in range(num_layers):
544
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
545
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
546
+
547
+ resnets.append(
548
+ ResnetBlock3D(
549
+ in_channels=resnet_in_channels + res_skip_channels,
550
+ out_channels=out_channels,
551
+ temb_channels=temb_channels,
552
+ eps=resnet_eps,
553
+ groups=resnet_groups,
554
+ dropout=dropout,
555
+ time_embedding_norm=resnet_time_scale_shift,
556
+ non_linearity=resnet_act_fn,
557
+ output_scale_factor=output_scale_factor,
558
+ pre_norm=resnet_pre_norm,
559
+ )
560
+ )
561
+ if dual_cross_attention:
562
+ raise NotImplementedError
563
+ attentions.append(
564
+ Transformer3DModel(
565
+ attn_num_head_channels,
566
+ out_channels // attn_num_head_channels,
567
+ in_channels=out_channels,
568
+ num_layers=1,
569
+ cross_attention_dim=cross_attention_dim,
570
+ norm_num_groups=resnet_groups,
571
+ use_linear_projection=use_linear_projection,
572
+ only_cross_attention=only_cross_attention,
573
+ upcast_attention=upcast_attention,
574
+
575
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
576
+ unet_use_temporal_attention=unet_use_temporal_attention,
577
+ )
578
+ )
579
+ motion_modules.append(
580
+ get_motion_module(
581
+ in_channels=out_channels,
582
+ motion_module_type=motion_module_type,
583
+ motion_module_kwargs=motion_module_kwargs,
584
+ ) if use_motion_module else None
585
+ )
586
+
587
+ self.attentions = nn.ModuleList(attentions)
588
+ self.resnets = nn.ModuleList(resnets)
589
+ self.motion_modules = nn.ModuleList(motion_modules)
590
+
591
+ if add_upsample:
592
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
593
+ else:
594
+ self.upsamplers = None
595
+
596
+ self.gradient_checkpointing = False
597
+
598
+ def forward(
599
+ self,
600
+ hidden_states,
601
+ res_hidden_states_tuple,
602
+ temb=None,
603
+ encoder_hidden_states=None,
604
+ upsample_size=None,
605
+ attention_mask=None,
606
+ ):
607
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
608
+ # pop res hidden states
609
+ res_hidden_states = res_hidden_states_tuple[-1]
610
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
611
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
612
+
613
+ if self.training and self.gradient_checkpointing:
614
+
615
+ def create_custom_forward(module, return_dict=None):
616
+ def custom_forward(*inputs):
617
+ if return_dict is not None:
618
+ return module(*inputs, return_dict=return_dict)
619
+ else:
620
+ return module(*inputs)
621
+
622
+ return custom_forward
623
+
624
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
625
+ hidden_states = torch.utils.checkpoint.checkpoint(
626
+ create_custom_forward(attn, return_dict=False),
627
+ hidden_states,
628
+ encoder_hidden_states,
629
+ )[0]
630
+ if motion_module is not None:
631
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
632
+
633
+ else:
634
+ hidden_states = resnet(hidden_states, temb)
635
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
636
+
637
+ # add motion module
638
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
639
+
640
+ if self.upsamplers is not None:
641
+ for upsampler in self.upsamplers:
642
+ hidden_states = upsampler(hidden_states, upsample_size)
643
+
644
+ return hidden_states
645
+
646
+
647
+ class UpBlock3D(nn.Module):
648
+ def __init__(
649
+ self,
650
+ in_channels: int,
651
+ prev_output_channel: int,
652
+ out_channels: int,
653
+ temb_channels: int,
654
+ dropout: float = 0.0,
655
+ num_layers: int = 1,
656
+ resnet_eps: float = 1e-6,
657
+ resnet_time_scale_shift: str = "default",
658
+ resnet_act_fn: str = "swish",
659
+ resnet_groups: int = 32,
660
+ resnet_pre_norm: bool = True,
661
+ output_scale_factor=1.0,
662
+ add_upsample=True,
663
+
664
+ use_motion_module=None,
665
+ motion_module_type=None,
666
+ motion_module_kwargs=None,
667
+ ):
668
+ super().__init__()
669
+ resnets = []
670
+ motion_modules = []
671
+
672
+ for i in range(num_layers):
673
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
674
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
675
+
676
+ resnets.append(
677
+ ResnetBlock3D(
678
+ in_channels=resnet_in_channels + res_skip_channels,
679
+ out_channels=out_channels,
680
+ temb_channels=temb_channels,
681
+ eps=resnet_eps,
682
+ groups=resnet_groups,
683
+ dropout=dropout,
684
+ time_embedding_norm=resnet_time_scale_shift,
685
+ non_linearity=resnet_act_fn,
686
+ output_scale_factor=output_scale_factor,
687
+ pre_norm=resnet_pre_norm,
688
+ )
689
+ )
690
+ motion_modules.append(
691
+ get_motion_module(
692
+ in_channels=out_channels,
693
+ motion_module_type=motion_module_type,
694
+ motion_module_kwargs=motion_module_kwargs,
695
+ ) if use_motion_module else None
696
+ )
697
+
698
+ self.resnets = nn.ModuleList(resnets)
699
+ self.motion_modules = nn.ModuleList(motion_modules)
700
+
701
+ if add_upsample:
702
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
703
+ else:
704
+ self.upsamplers = None
705
+
706
+ self.gradient_checkpointing = False
707
+
708
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
709
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
710
+ # pop res hidden states
711
+ res_hidden_states = res_hidden_states_tuple[-1]
712
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
713
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
714
+
715
+ if self.training and self.gradient_checkpointing:
716
+ def create_custom_forward(module):
717
+ def custom_forward(*inputs):
718
+ return module(*inputs)
719
+
720
+ return custom_forward
721
+
722
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
723
+ if motion_module is not None:
724
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
725
+ else:
726
+ hidden_states = resnet(hidden_states, temb)
727
+ hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
728
+
729
+ if self.upsamplers is not None:
730
+ for upsampler in self.upsamplers:
731
+ hidden_states = upsampler(hidden_states, upsample_size)
732
+
733
+ return hidden_states
models/animatediff/pipelines/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .i2v_pipeline import I2VPipeline
2
+
3
+ __all__ = ['I2VPipeline']
models/animatediff/pipelines/i2v_pipeline.py ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+ import inspect
3
+ import os.path as osp
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers.configuration_utils import FrozenDict
10
+ from diffusers.loaders import IPAdapterMixin, TextualInversionLoaderMixin
11
+ from diffusers.models import AutoencoderKL
12
+ from diffusers.pipelines import DiffusionPipeline
13
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
14
+ EulerAncestralDiscreteScheduler,
15
+ EulerDiscreteScheduler, LMSDiscreteScheduler,
16
+ PNDMScheduler)
17
+ from diffusers.utils import (BaseOutput, deprecate, is_accelerate_available,
18
+ logging)
19
+ from diffusers.utils.import_utils import is_xformers_available
20
+ from einops import rearrange
21
+ from omegaconf import OmegaConf
22
+ from packaging import version
23
+ from safetensors import safe_open
24
+ from tqdm import tqdm
25
+ from transformers import (CLIPImageProcessor, CLIPTextModel, CLIPTokenizer,
26
+ CLIPVisionModelWithProjection)
27
+
28
+ from animatediff.models.resnet import InflatedConv3d
29
+ from animatediff.models.unet import UNet3DConditionModel
30
+ from animatediff.utils.convert_from_ckpt import (convert_ldm_clip_checkpoint,
31
+ convert_ldm_unet_checkpoint,
32
+ convert_ldm_vae_checkpoint)
33
+ from animatediff.utils.convert_lora_safetensor_to_diffusers import \
34
+ convert_lora_model_level
35
+ from animatediff.utils.util import prepare_mask_coef_by_statistics
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+
40
+ DEFAULT_N_PROMPT = ('wrong white balance, dark, sketches,worst quality,'
41
+ 'low quality, deformed, distorted, disfigured, bad eyes, '
42
+ 'wrong lips,weird mouth, bad teeth, mutated hands and fingers, '
43
+ 'bad anatomy,wrong anatomy, amputation, extra limb, '
44
+ 'missing limb, floating,limbs, disconnected limbs, mutation, '
45
+ 'ugly, disgusting, bad_pictures, negative_hand-neg')
46
+
47
+
48
+ @dataclass
49
+ class AnimationPipelineOutput(BaseOutput):
50
+ videos: Union[torch.Tensor, np.ndarray]
51
+
52
+
53
+ class I2VPipeline(DiffusionPipeline, IPAdapterMixin, TextualInversionLoaderMixin):
54
+ _optional_components = []
55
+
56
+ def __init__(
57
+ self,
58
+ vae: AutoencoderKL,
59
+ text_encoder: CLIPTextModel,
60
+ tokenizer: CLIPTokenizer,
61
+ unet: UNet3DConditionModel,
62
+ scheduler: Union[
63
+ DDIMScheduler,
64
+ PNDMScheduler,
65
+ LMSDiscreteScheduler,
66
+ EulerDiscreteScheduler,
67
+ EulerAncestralDiscreteScheduler,
68
+ DPMSolverMultistepScheduler,
69
+ ],
70
+ # memory_format: torch.memory_format,
71
+ feature_extractor: CLIPImageProcessor = None,
72
+ image_encoder: CLIPVisionModelWithProjection = None,
73
+ ):
74
+ super().__init__()
75
+
76
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
77
+ deprecation_message = (
78
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
79
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
80
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
81
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
82
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
83
+ " file"
84
+ )
85
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
86
+ new_config = dict(scheduler.config)
87
+ new_config["steps_offset"] = 1
88
+ scheduler._internal_dict = FrozenDict(new_config)
89
+
90
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
91
+ deprecation_message = (
92
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
93
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
94
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
95
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
96
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
97
+ )
98
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
99
+ new_config = dict(scheduler.config)
100
+ new_config["clip_sample"] = False
101
+ scheduler._internal_dict = FrozenDict(new_config)
102
+
103
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
104
+ version.parse(unet.config._diffusers_version).base_version
105
+ ) < version.parse("0.9.0.dev0")
106
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
107
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
108
+ deprecation_message = (
109
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
110
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
111
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
112
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
113
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
114
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
115
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
116
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
117
+ " the `unet/config.json` file"
118
+ )
119
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
120
+ new_config = dict(unet.config)
121
+ new_config["sample_size"] = 64
122
+ unet._internal_dict = FrozenDict(new_config)
123
+
124
+ self.register_modules(
125
+ vae=vae,
126
+ text_encoder=text_encoder,
127
+ tokenizer=tokenizer,
128
+ unet=unet,
129
+ image_encoder=image_encoder,
130
+ feature_extractor=feature_extractor,
131
+ scheduler=scheduler,
132
+ )
133
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
134
+ # self.memory_format = memory_format
135
+ self.use_ip_adapter = False
136
+
137
+ @classmethod
138
+ def build_pipeline(cls,
139
+ base_cfg,
140
+ base_model: str,
141
+ unet_path: str,
142
+ dreambooth_path: Optional[str] = None,
143
+ lora_path: Optional[str] = None,
144
+ lora_alpha: int = 0,
145
+ vae_path: Optional[str] = None,
146
+ ip_adapter_path: Optional[str] = None,
147
+ ip_adapter_scale: float = 0.0,
148
+ only_load_vae_decoder: bool = False,
149
+ only_load_vae_encoder: bool = False) -> 'I2VPipeline':
150
+ """Method to build pipeline in a faster way~
151
+ Args:
152
+ base_cfg: The config to build model
153
+ base_mode: The model id to initialize StableDiffusion
154
+ unet_path: Path for i2v unet
155
+
156
+ dreambooth_path: path for dreambooth model
157
+ lora_path: path for lora model
158
+ lora_alpha: value for lora scale
159
+
160
+ only_load_vae_decoder: Only load VAE decoder from dreambooth / VAE ckpt
161
+ and maitain encoder as original.
162
+
163
+ """
164
+ # build unet
165
+ unet = UNet3DConditionModel.from_pretrained_2d(
166
+ base_model, subfolder="unet",
167
+ unet_additional_kwargs=OmegaConf.to_container(
168
+ base_cfg.unet_additional_kwargs))
169
+
170
+ old_weights = unet.conv_in.weight
171
+ old_bias = unet.conv_in.bias
172
+ new_conv1 = InflatedConv3d(
173
+ 9, old_weights.shape[0],
174
+ kernel_size=unet.conv_in.kernel_size,
175
+ stride=unet.conv_in.stride,
176
+ padding=unet.conv_in.padding,
177
+ bias=True if old_bias is not None else False)
178
+ param = torch.zeros((320,5,3,3),requires_grad=True)
179
+ new_conv1.weight = torch.nn.Parameter(torch.cat((old_weights,param),dim=1))
180
+ if old_bias is not None:
181
+ new_conv1.bias = old_bias
182
+ unet.conv_in = new_conv1
183
+ unet.config["in_channels"] = 9
184
+
185
+ unet_ckpt = torch.load(unet_path, map_location='cpu')
186
+ unet.load_state_dict(unet_ckpt, strict=False)
187
+ # NOTE: only load temporal layers and condition module
188
+ # for key, value in unet_ckpt.items():
189
+ # if 'motion' in key or 'conv_in' in key:
190
+ # unet.state_dict()[key].copy_(value)
191
+
192
+ # load vae, tokenizer, text encoder
193
+ vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae")
194
+ tokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer")
195
+ text_encoder = CLIPTextModel.from_pretrained(base_model, subfolder="text_encoder")
196
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(base_cfg.noise_scheduler_kwargs))
197
+
198
+ if dreambooth_path:
199
+
200
+ print(" >>> Begin loading DreamBooth >>>")
201
+ base_model_state_dict = {}
202
+ with safe_open(dreambooth_path, framework="pt", device="cpu") as f:
203
+ for key in f.keys():
204
+ base_model_state_dict[key] = f.get_tensor(key)
205
+
206
+ # load unet
207
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, unet.config)
208
+
209
+ old_value = converted_unet_checkpoint['conv_in.weight']
210
+ new_param = unet_ckpt['conv_in.weight'][:,4:,:,:].clone().cpu()
211
+ new_value = torch.nn.Parameter(torch.cat((old_value, new_param), dim=1))
212
+ converted_unet_checkpoint['conv_in.weight'] = new_value
213
+ unet.load_state_dict(converted_unet_checkpoint, strict=False)
214
+
215
+ # load vae
216
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
217
+ base_model_state_dict, vae.config,
218
+ only_decoder=only_load_vae_decoder,
219
+ only_encoder=only_load_vae_encoder,)
220
+ need_strict = not (only_load_vae_decoder or only_load_vae_encoder)
221
+ vae.load_state_dict(converted_vae_checkpoint, strict=need_strict)
222
+ print('Prefix in loaded VAE checkpoint: ')
223
+ print(set([k.split('.')[0] for k in converted_vae_checkpoint.keys()]))
224
+
225
+ # load text encoder
226
+ text_encoder_checkpoint = convert_ldm_clip_checkpoint(base_model_state_dict)
227
+ del text_encoder_checkpoint['text_model.embeddings.position_ids']
228
+ if text_encoder_checkpoint:
229
+ text_encoder.load_state_dict(text_encoder_checkpoint)
230
+
231
+ print(" <<< Loaded DreamBooth <<<")
232
+
233
+ if vae_path:
234
+ print(' >>> Begin loading VAE >>>')
235
+ vae_state_dict = {}
236
+ if vae_path.endswith('safetensors'):
237
+ with safe_open(vae_path, framework="pt", device="cpu") as f:
238
+ for key in f.keys():
239
+ vae_state_dict[key] = f.get_tensor(key)
240
+ elif vae_path.endswith('ckpt') or vae_path.endswith('pt'):
241
+ vae_state_dict = torch.load(vae_path, map_location='cpu')
242
+ if 'state_dict' in vae_state_dict:
243
+ vae_state_dict = vae_state_dict['state_dict']
244
+
245
+ vae_state_dict = {f'first_stage_model.{k}': v for k, v in vae_state_dict.items()}
246
+
247
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(
248
+ vae_state_dict, vae.config,
249
+ only_decoder=only_load_vae_decoder,
250
+ only_encoder=only_load_vae_encoder,)
251
+ print('Prefix in loaded VAE checkpoint: ')
252
+ print(set([k.split('.')[0] for k in converted_vae_checkpoint.keys()]))
253
+ need_strict = not (only_load_vae_decoder or only_load_vae_encoder)
254
+ vae.load_state_dict(converted_vae_checkpoint, strict=need_strict)
255
+ print(" <<< Loaded VAE <<<")
256
+
257
+ if lora_path:
258
+
259
+ print(" >>> Begin loading LoRA >>>")
260
+
261
+ lora_dict = {}
262
+ print("lora_path:",lora_path)
263
+ # exit()
264
+ with safe_open(lora_path, framework='pt', device='cpu') as file:
265
+ for k in file.keys():
266
+ lora_dict[k] = file.get_tensor(k)
267
+ unet, text_encoder = convert_lora_model_level(
268
+ lora_dict, unet, text_encoder, alpha=lora_alpha)
269
+
270
+ print(" <<< Loaded LoRA <<<")
271
+
272
+ # move model to device
273
+ if not torch.cuda.is_available():
274
+ device = torch.device('cpu')
275
+ unet_dtype = torch.float32
276
+ tenc_dtype = torch.float32
277
+ vae_dtype = torch.float32
278
+ else:
279
+ device = torch.device('cuda')
280
+ unet_dtype = torch.float16
281
+ tenc_dtype = torch.float16
282
+ vae_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
283
+
284
+ unet = unet.to(device=device, dtype=unet_dtype)
285
+ text_encoder = text_encoder.to(device=device, dtype=tenc_dtype)
286
+ vae = vae.to(device=device, dtype=vae_dtype)
287
+ print(f'Set Unet to {unet_dtype}')
288
+ print(f'Set text encoder to {tenc_dtype}')
289
+ print(f'Set vae to {vae_dtype}')
290
+
291
+ if torch.cuda.is_available() and is_xformers_available():
292
+ unet.enable_xformers_memory_efficient_attention()
293
+
294
+ pipeline = cls(unet=unet,
295
+ vae=vae,
296
+ tokenizer=tokenizer,
297
+ text_encoder=text_encoder,
298
+ scheduler=noise_scheduler)
299
+
300
+ # ip_adapter_path = 'h94/IP-Adapter'
301
+ if ip_adapter_path and ip_adapter_scale > 0:
302
+ ip_adapter_name = 'ip-adapter_sd15.bin'
303
+ # only online repo need subfolder
304
+ if not osp.isdir(ip_adapter_path):
305
+ subfolder = 'models'
306
+ else:
307
+ subfolder = ''
308
+ pipeline.load_ip_adapter(ip_adapter_path, subfolder, ip_adapter_name)
309
+ pipeline.set_ip_adapter_scale(ip_adapter_scale)
310
+ pipeline.use_ip_adapter = True
311
+ print(f'Load IP-Adapter, scale: {ip_adapter_scale}')
312
+
313
+ # text_inversion_path = './models/TextualInversion/easynegative.safetensors'
314
+ # if text_inversion_path:
315
+ # pipeline.load_textual_inversion(text_inversion_path, 'easynegative')
316
+
317
+ return pipeline
318
+
319
+ def enable_vae_slicing(self):
320
+ self.vae.enable_slicing()
321
+
322
+ def disable_vae_slicing(self):
323
+ self.vae.disable_slicing()
324
+
325
+ def enable_sequential_cpu_offload(self, gpu_id=0):
326
+ if is_accelerate_available():
327
+ from accelerate import cpu_offload
328
+ else:
329
+ raise ImportError("Please install accelerate via `pip install accelerate`")
330
+
331
+ if not torch.cuda.is_available():
332
+ device = torch.device('cpu')
333
+ else:
334
+ device = torch.device(f"cuda:{gpu_id}")
335
+
336
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
337
+ if cpu_offloaded_model is not None:
338
+ cpu_offload(cpu_offloaded_model, device)
339
+
340
+ @property
341
+ def _execution_device(self):
342
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
343
+ return self.device
344
+ for module in self.unet.modules():
345
+ if (
346
+ hasattr(module, "_hf_hook")
347
+ and hasattr(module._hf_hook, "execution_device")
348
+ and module._hf_hook.execution_device is not None
349
+ ):
350
+ return torch.device(module._hf_hook.execution_device)
351
+ return self.device
352
+
353
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
354
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
355
+
356
+ text_inputs = self.tokenizer(
357
+ prompt,
358
+ padding="max_length",
359
+ max_length=self.tokenizer.model_max_length,
360
+ truncation=True,
361
+ return_tensors="pt",
362
+ )
363
+ text_input_ids = text_inputs.input_ids
364
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
365
+
366
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
367
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
368
+ logger.warning(
369
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
370
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
371
+ )
372
+
373
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
374
+ attention_mask = text_inputs.attention_mask.to(device)
375
+ else:
376
+ attention_mask = None
377
+
378
+ text_embeddings = self.text_encoder(
379
+ text_input_ids.to(device),
380
+ attention_mask=attention_mask,
381
+ )
382
+ text_embeddings = text_embeddings[0]
383
+
384
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
385
+ bs_embed, seq_len, _ = text_embeddings.shape
386
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
387
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
388
+
389
+ # get unconditional embeddings for classifier free guidance
390
+ if do_classifier_free_guidance:
391
+ uncond_tokens: List[str]
392
+ if negative_prompt is None:
393
+ uncond_tokens = [""] * batch_size
394
+ elif type(prompt) is not type(negative_prompt):
395
+ raise TypeError(
396
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
397
+ f" {type(prompt)}."
398
+ )
399
+ elif isinstance(negative_prompt, str):
400
+ uncond_tokens = [negative_prompt]
401
+ elif batch_size != len(negative_prompt):
402
+ raise ValueError(
403
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
404
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
405
+ " the batch size of `prompt`."
406
+ )
407
+ else:
408
+ uncond_tokens = negative_prompt
409
+
410
+ max_length = text_input_ids.shape[-1]
411
+ uncond_input = self.tokenizer(
412
+ uncond_tokens,
413
+ padding="max_length",
414
+ max_length=max_length,
415
+ truncation=True,
416
+ return_tensors="pt",
417
+ )
418
+
419
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
420
+ attention_mask = uncond_input.attention_mask.to(device)
421
+ else:
422
+ attention_mask = None
423
+
424
+ uncond_embeddings = self.text_encoder(
425
+ uncond_input.input_ids.to(device),
426
+ attention_mask=attention_mask,
427
+ )
428
+ uncond_embeddings = uncond_embeddings[0]
429
+
430
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
431
+ seq_len = uncond_embeddings.shape[1]
432
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
433
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
434
+
435
+ # For classifier free guidance, we need to do two forward passes.
436
+ # Here we concatenate the unconditional and text embeddings into a single batch
437
+ # to avoid doing two forward passes
438
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
439
+
440
+ return text_embeddings
441
+
442
+ def decode_latents(self, latents):
443
+ video_length = latents.shape[2]
444
+ latents = 1 / 0.18215 * latents
445
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
446
+ # video = self.vae.decode(latents).sample
447
+ video = []
448
+ for frame_idx in tqdm(range(latents.shape[0])):
449
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
450
+ video = torch.cat(video)
451
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
452
+ video = (video / 2 + 0.5).clamp(0, 1)
453
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
454
+ video = video.cpu().float().numpy()
455
+ return video
456
+
457
+ def prepare_extra_step_kwargs(self, generator, eta):
458
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
459
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
460
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
461
+ # and should be between [0, 1]
462
+
463
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
464
+ extra_step_kwargs = {}
465
+ if accepts_eta:
466
+ extra_step_kwargs["eta"] = eta
467
+
468
+ # check if the scheduler accepts generator
469
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
470
+ if accepts_generator:
471
+ extra_step_kwargs["generator"] = generator
472
+ return extra_step_kwargs
473
+
474
+ def check_inputs(self, prompt, height, width, callback_steps):
475
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
476
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
477
+
478
+ if height % 8 != 0 or width % 8 != 0:
479
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
480
+
481
+ if (callback_steps is None) or (
482
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
483
+ ):
484
+ raise ValueError(
485
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
486
+ f" {type(callback_steps)}."
487
+ )
488
+
489
+ def get_timesteps(self, num_inference_steps, strength, device):
490
+ # get the original timestep using init_timestep
491
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
492
+
493
+ t_start = max(num_inference_steps - init_timestep, 0)
494
+ timesteps = self.scheduler.timesteps[t_start:]
495
+
496
+ return timesteps, num_inference_steps - t_start
497
+
498
+ def prepare_latents(self, add_noise_time_step, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
499
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
500
+
501
+ if isinstance(generator, list) and len(generator) != batch_size:
502
+ raise ValueError(
503
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
504
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
505
+ )
506
+ if latents is None:
507
+ rand_device = "cpu" if device.type == "mps" else device
508
+
509
+ if isinstance(generator, list):
510
+ shape = shape
511
+ # shape = (1,) + shape[1:]
512
+ latents = [
513
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
514
+ for i in range(batch_size)
515
+ ]
516
+ latents = torch.cat(latents, dim=0).to(device)
517
+ else:
518
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
519
+ else:
520
+ if latents.shape != shape:
521
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
522
+ latents = latents.to(device)
523
+
524
+ return latents
525
+
526
+ def encode_image(self, image, device, num_images_per_prompt):
527
+ """Encode image for ip-adapter. Copied from
528
+ https://github.com/huggingface/diffusers/blob/f9487783228cd500a21555da3346db40e8f05992/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L492-L514 # noqa
529
+ """
530
+ dtype = next(self.image_encoder.parameters()).dtype
531
+
532
+ if not isinstance(image, torch.Tensor):
533
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
534
+
535
+ image = image.to(device=device, dtype=dtype)
536
+ image_embeds = self.image_encoder(image).image_embeds
537
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
538
+
539
+ uncond_image_embeds = torch.zeros_like(image_embeds)
540
+ return image_embeds, uncond_image_embeds
541
+
542
+ @torch.no_grad()
543
+ def __call__(
544
+ self,
545
+ image: np.ndarray,
546
+ prompt: Union[str, List[str]],
547
+ video_length: Optional[int],
548
+ height: Optional[int] = None,
549
+ width: Optional[int] = None,
550
+ global_inf_num: int = 0,
551
+ num_inference_steps: int = 50,
552
+ guidance_scale: float = 7.5,
553
+ negative_prompt: Optional[Union[str, List[str]]] = None,
554
+ num_videos_per_prompt: Optional[int] = 1,
555
+ eta: float = 0.0,
556
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
557
+ latents: Optional[torch.FloatTensor] = None,
558
+ output_type: Optional[str] = "tensor",
559
+ return_dict: bool = True,
560
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
561
+ callback_steps: Optional[int] = 1,
562
+
563
+ cond_frame: int = 0,
564
+ mask_sim_template_idx: int = 0,
565
+ ip_adapter_scale: float = 0,
566
+ strength: float = 1,
567
+ progress_fn=None,
568
+ **kwargs,
569
+ ):
570
+ # Default height and width to unet
571
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
572
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
573
+
574
+ assert strength > 0 and strength <= 1, (
575
+ f'"strength" for img2vid must in (0, 1]. But receive {strength}.')
576
+
577
+ # Check inputs. Raise error if not correct
578
+ self.check_inputs(prompt, height, width, callback_steps)
579
+
580
+ # Define call parameters
581
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
582
+ batch_size = 1
583
+ if latents is not None:
584
+ batch_size = latents.shape[0]
585
+ if isinstance(prompt, list):
586
+ batch_size = len(prompt)
587
+
588
+ device = self._execution_device
589
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
590
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
591
+ # corresponds to doing no classifier free guidance.
592
+ do_classifier_free_guidance = guidance_scale > 1.0
593
+
594
+ # Encode input prompt
595
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
596
+
597
+ if negative_prompt is None:
598
+ negative_prompt = DEFAULT_N_PROMPT
599
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
600
+ text_embeddings = self._encode_prompt(
601
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
602
+ )
603
+
604
+ # Prepare timesteps
605
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
606
+ #timesteps = self.scheduler.timesteps
607
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
608
+ latent_timestep = timesteps[:1].repeat(batch_size)
609
+
610
+ # Prepare latent variables
611
+ num_channels_latents = self.unet.in_channels
612
+ latents = self.prepare_latents(
613
+ latent_timestep,
614
+ batch_size * num_videos_per_prompt,
615
+ 4,
616
+ video_length,
617
+ height,
618
+ width,
619
+ text_embeddings.dtype,
620
+ device,
621
+ generator,
622
+ latents,
623
+ )
624
+ # print("latents_1:",latents.shape) # (1,4,16,64,64)
625
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
626
+
627
+ raw_image = image.copy()
628
+ image = torch.from_numpy(image)[None, ...].permute(0, 3, 1, 2)
629
+ image = image / 255 # [0, 1]
630
+ image = image * 2 - 1 # [-1, 1]
631
+ image = image.to(device=device, dtype=self.vae.dtype)
632
+
633
+ if isinstance(generator, list):
634
+ image_latent = [
635
+ self.vae.encode(image[k : k + 1]).latent_dist.sample(generator[k]) for k in range(batch_size)
636
+ ]
637
+ image_latent = torch.cat(image_latent, dim=0)
638
+ else:
639
+ image_latent = self.vae.encode(image).latent_dist.sample(generator)
640
+
641
+ image_latent = image_latent.to(device=device, dtype=self.unet.dtype)
642
+ image_latent = torch.nn.functional.interpolate(image_latent, size=[shape[-2], shape[-1]])
643
+ image_latent_padding = image_latent.clone() * 0.18215
644
+ mask = torch.zeros((shape[0], 1, shape[2], shape[3], shape[4])).to(device=device, dtype=self.unet.dtype)
645
+
646
+ # prepare mask
647
+ mask_coef = prepare_mask_coef_by_statistics(video_length, cond_frame, mask_sim_template_idx)
648
+
649
+ masked_image = torch.zeros(shape[0], 4, shape[2], shape[3], shape[4]).to(device=device, dtype=self.unet.dtype)
650
+ for f in range(video_length):
651
+ mask[:,:,f,:,:] = mask_coef[f]
652
+ masked_image[:,:,f,:,:] = image_latent_padding.clone()
653
+
654
+ # Prepare extra step kwargs.
655
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
656
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
657
+ masked_image = torch.cat([masked_image] * 2) if do_classifier_free_guidance else masked_image
658
+ # Denoising loop
659
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
660
+
661
+ # prepare for ip-adapter
662
+ if self.use_ip_adapter:
663
+ image_embeds, neg_image_embeds = self.encode_image(raw_image, device, num_videos_per_prompt)
664
+ image_embeds = torch.cat([neg_image_embeds, image_embeds])
665
+ image_embeds = image_embeds.to(device=device, dtype=self.unet.dtype)
666
+
667
+ self.set_ip_adapter_scale(ip_adapter_scale)
668
+ print(f'Set IP-Adapter Scale as {ip_adapter_scale}')
669
+
670
+ else:
671
+
672
+ image_embeds = None
673
+
674
+ # prepare for latents if strength < 1, add convert gaussian latent to masked_img and add noise
675
+ if strength < 1:
676
+ noise = torch.randn_like(latents)
677
+ latents = self.scheduler.add_noise(masked_image[0], noise, timesteps[0])
678
+ # print(latents.shape)
679
+
680
+ if progress_fn is None:
681
+ progress_bar = tqdm(timesteps)
682
+ terminal_pbar = None
683
+ else:
684
+ progress_bar = progress_fn.tqdm(timesteps)
685
+ terminal_pbar = tqdm(total=len(timesteps))
686
+
687
+ # with self.progress_bar(total=num_inference_steps) as progress_bar:
688
+ for i, t in enumerate(progress_bar):
689
+ # expand the latents if we are doing classifier free guidance
690
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
691
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
692
+
693
+ # predict the noise residual
694
+ noise_pred = self.unet(
695
+ latent_model_input,
696
+ mask,
697
+ masked_image,
698
+ t,
699
+ encoder_hidden_states=text_embeddings,
700
+ image_embeds=image_embeds
701
+ )['sample']
702
+
703
+ # perform guidance
704
+ if do_classifier_free_guidance:
705
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
706
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
707
+
708
+ # compute the previous noisy sample x_t -> x_t-1
709
+ # print("latents_2:",latents.shape)
710
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
711
+
712
+ # call the callback, if provided
713
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
714
+ if callback is not None and i % callback_steps == 0:
715
+ callback(i, t, latents)
716
+ if terminal_pbar is not None:
717
+ terminal_pbar.update(1)
718
+
719
+ # Post-processing
720
+ video = self.decode_latents(latents.to(device, dtype=self.vae.dtype))
721
+
722
+ # Convert to tensor
723
+ if output_type == "tensor":
724
+ video = torch.from_numpy(video)
725
+
726
+ if not return_dict:
727
+ return video
728
+
729
+ return AnimationPipelineOutput(videos=video)
models/animatediff/utils/convert_from_ckpt.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from io import BytesIO
19
+ from typing import Optional
20
+
21
+ import requests
22
+ import torch
23
+ from transformers import (
24
+ AutoFeatureExtractor,
25
+ BertTokenizerFast,
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTextModelWithProjection,
29
+ CLIPTokenizer,
30
+ CLIPVisionConfig,
31
+ CLIPVisionModelWithProjection,
32
+ )
33
+
34
+ from diffusers.models import (
35
+ AutoencoderKL,
36
+ PriorTransformer,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.schedulers import (
40
+ DDIMScheduler,
41
+ DDPMScheduler,
42
+ DPMSolverMultistepScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ EulerDiscreteScheduler,
45
+ HeunDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ UnCLIPScheduler,
49
+ )
50
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
51
+
52
+
53
+ def shave_segments(path, n_shave_prefix_segments=1):
54
+ """
55
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
56
+ """
57
+ if n_shave_prefix_segments >= 0:
58
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
59
+ else:
60
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
61
+
62
+
63
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
64
+ """
65
+ Updates paths inside resnets to the new naming scheme (local renaming)
66
+ """
67
+ mapping = []
68
+ for old_item in old_list:
69
+ new_item = old_item.replace("in_layers.0", "norm1")
70
+ new_item = new_item.replace("in_layers.2", "conv1")
71
+
72
+ new_item = new_item.replace("out_layers.0", "norm2")
73
+ new_item = new_item.replace("out_layers.3", "conv2")
74
+
75
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
76
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
77
+
78
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
79
+
80
+ mapping.append({"old": old_item, "new": new_item})
81
+
82
+ return mapping
83
+
84
+
85
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
86
+ """
87
+ Updates paths inside resnets to the new naming scheme (local renaming)
88
+ """
89
+ mapping = []
90
+ for old_item in old_list:
91
+ new_item = old_item
92
+
93
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
94
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
95
+
96
+ mapping.append({"old": old_item, "new": new_item})
97
+
98
+ return mapping
99
+
100
+
101
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
102
+ """
103
+ Updates paths inside attentions to the new naming scheme (local renaming)
104
+ """
105
+ mapping = []
106
+ for old_item in old_list:
107
+ new_item = old_item
108
+
109
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
110
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
111
+
112
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
113
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
114
+
115
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
116
+
117
+ mapping.append({"old": old_item, "new": new_item})
118
+
119
+ return mapping
120
+
121
+
122
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
123
+ """
124
+ Updates paths inside attentions to the new naming scheme (local renaming)
125
+ """
126
+ mapping = []
127
+ for old_item in old_list:
128
+ new_item = old_item
129
+
130
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
131
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
132
+
133
+ new_item = new_item.replace("q.weight", "to_q.weight")
134
+ new_item = new_item.replace("q.bias", "to_q.bias")
135
+
136
+ new_item = new_item.replace("k.weight", "to_k.weight")
137
+ new_item = new_item.replace("k.bias", "to_k.bias")
138
+
139
+ new_item = new_item.replace("v.weight", "to_v.weight")
140
+ new_item = new_item.replace("v.bias", "to_v.bias")
141
+
142
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
143
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
144
+
145
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
146
+
147
+ mapping.append({"old": old_item, "new": new_item})
148
+ return mapping
149
+
150
+
151
+ def assign_to_checkpoint(
152
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
153
+ ):
154
+ """
155
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
156
+ attention layers, and takes into account additional replacements that may arise.
157
+
158
+ Assigns the weights to the new checkpoint.
159
+ """
160
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
161
+
162
+ # Splits the attention layers into three variables.
163
+ if attention_paths_to_split is not None:
164
+ for path, path_map in attention_paths_to_split.items():
165
+ old_tensor = old_checkpoint[path]
166
+ channels = old_tensor.shape[0] // 3
167
+
168
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
169
+
170
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
171
+
172
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
173
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
174
+
175
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
176
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
177
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
178
+
179
+ for path in paths:
180
+ new_path = path["new"]
181
+
182
+ # These have already been assigned
183
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
184
+ continue
185
+
186
+ # Global renaming happens here
187
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
188
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
189
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
190
+
191
+ if additional_replacements is not None:
192
+ for replacement in additional_replacements:
193
+ new_path = new_path.replace(replacement["old"], replacement["new"])
194
+
195
+ # proj_attn.weight has to be converted from conv 1D to linear
196
+ if "proj_attn.weight" in new_path:
197
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
198
+ elif 'to_out.0.weight' in new_path:
199
+ checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
200
+ elif any([qkv in new_path for qkv in ['to_q', 'to_k', 'to_v']]):
201
+ checkpoint[new_path] = old_checkpoint[path['old']].squeeze()
202
+ else:
203
+ checkpoint[new_path] = old_checkpoint[path["old"]]
204
+
205
+
206
+ def conv_attn_to_linear(checkpoint):
207
+ keys = list(checkpoint.keys())
208
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
209
+ for key in keys:
210
+ if ".".join(key.split(".")[-2:]) in attn_keys:
211
+ if checkpoint[key].ndim > 2:
212
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
213
+ elif "proj_attn.weight" in key:
214
+ if checkpoint[key].ndim > 2:
215
+ checkpoint[key] = checkpoint[key][:, :, 0]
216
+
217
+
218
+ def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
219
+ """
220
+ Creates a config for the diffusers based on the config of the LDM model.
221
+ """
222
+ if controlnet:
223
+ unet_params = original_config.model.params.control_stage_config.params
224
+ else:
225
+ unet_params = original_config.model.params.unet_config.params
226
+
227
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
228
+
229
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
230
+
231
+ down_block_types = []
232
+ resolution = 1
233
+ for i in range(len(block_out_channels)):
234
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
235
+ down_block_types.append(block_type)
236
+ if i != len(block_out_channels) - 1:
237
+ resolution *= 2
238
+
239
+ up_block_types = []
240
+ for i in range(len(block_out_channels)):
241
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
242
+ up_block_types.append(block_type)
243
+ resolution //= 2
244
+
245
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
246
+
247
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
248
+ use_linear_projection = (
249
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
250
+ )
251
+ if use_linear_projection:
252
+ # stable diffusion 2-base-512 and 2-768
253
+ if head_dim is None:
254
+ head_dim = [5, 10, 20, 20]
255
+
256
+ class_embed_type = None
257
+ projection_class_embeddings_input_dim = None
258
+
259
+ if "num_classes" in unet_params:
260
+ if unet_params.num_classes == "sequential":
261
+ class_embed_type = "projection"
262
+ assert "adm_in_channels" in unet_params
263
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
264
+ else:
265
+ raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
266
+
267
+ config = {
268
+ "sample_size": image_size // vae_scale_factor,
269
+ "in_channels": unet_params.in_channels,
270
+ "down_block_types": tuple(down_block_types),
271
+ "block_out_channels": tuple(block_out_channels),
272
+ "layers_per_block": unet_params.num_res_blocks,
273
+ "cross_attention_dim": unet_params.context_dim,
274
+ "attention_head_dim": head_dim,
275
+ "use_linear_projection": use_linear_projection,
276
+ "class_embed_type": class_embed_type,
277
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
278
+ }
279
+
280
+ if not controlnet:
281
+ config["out_channels"] = unet_params.out_channels
282
+ config["up_block_types"] = tuple(up_block_types)
283
+
284
+ return config
285
+
286
+
287
+ def create_vae_diffusers_config(original_config, image_size: int):
288
+ """
289
+ Creates a config for the diffusers based on the config of the LDM model.
290
+ """
291
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
292
+ _ = original_config.model.params.first_stage_config.params.embed_dim
293
+
294
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
295
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
296
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
297
+
298
+ config = {
299
+ "sample_size": image_size,
300
+ "in_channels": vae_params.in_channels,
301
+ "out_channels": vae_params.out_ch,
302
+ "down_block_types": tuple(down_block_types),
303
+ "up_block_types": tuple(up_block_types),
304
+ "block_out_channels": tuple(block_out_channels),
305
+ "latent_channels": vae_params.z_channels,
306
+ "layers_per_block": vae_params.num_res_blocks,
307
+ }
308
+ return config
309
+
310
+
311
+ def create_diffusers_schedular(original_config):
312
+ schedular = DDIMScheduler(
313
+ num_train_timesteps=original_config.model.params.timesteps,
314
+ beta_start=original_config.model.params.linear_start,
315
+ beta_end=original_config.model.params.linear_end,
316
+ beta_schedule="scaled_linear",
317
+ )
318
+ return schedular
319
+
320
+
321
+ def create_ldm_bert_config(original_config):
322
+ bert_params = original_config.model.parms.cond_stage_config.params
323
+ config = LDMBertConfig(
324
+ d_model=bert_params.n_embed,
325
+ encoder_layers=bert_params.n_layer,
326
+ encoder_ffn_dim=bert_params.n_embed * 4,
327
+ )
328
+ return config
329
+
330
+
331
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
332
+ """
333
+ Takes a state dict and a config, and returns a converted checkpoint.
334
+ """
335
+
336
+ # extract state_dict for UNet
337
+ unet_state_dict = {}
338
+ keys = list(checkpoint.keys())
339
+
340
+ if controlnet:
341
+ unet_key = "control_model."
342
+ else:
343
+ unet_key = "model.diffusion_model."
344
+
345
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
346
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
347
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
348
+ print(
349
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
350
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
351
+ )
352
+ for key in keys:
353
+ if key.startswith("model.diffusion_model"):
354
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
355
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
356
+ else:
357
+ if sum(k.startswith("model_ema") for k in keys) > 100:
358
+ print(
359
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
360
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
361
+ )
362
+
363
+ for key in keys:
364
+ if key.startswith(unet_key):
365
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
366
+
367
+ new_checkpoint = {}
368
+
369
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
370
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
371
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
372
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
373
+
374
+ if config["class_embed_type"] is None:
375
+ # No parameters to port
376
+ ...
377
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
378
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
379
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
380
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
381
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
382
+ else:
383
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
384
+
385
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
386
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
387
+
388
+ if not controlnet:
389
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
390
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
391
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
392
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
393
+
394
+ # Retrieves the keys for the input blocks only
395
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
396
+ input_blocks = {
397
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
398
+ for layer_id in range(num_input_blocks)
399
+ }
400
+
401
+ # Retrieves the keys for the middle blocks only
402
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
403
+ middle_blocks = {
404
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
405
+ for layer_id in range(num_middle_blocks)
406
+ }
407
+
408
+ # Retrieves the keys for the output blocks only
409
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
410
+ output_blocks = {
411
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
412
+ for layer_id in range(num_output_blocks)
413
+ }
414
+
415
+ for i in range(1, num_input_blocks):
416
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
417
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
418
+
419
+ resnets = [
420
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
421
+ ]
422
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
423
+
424
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
425
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
426
+ f"input_blocks.{i}.0.op.weight"
427
+ )
428
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
429
+ f"input_blocks.{i}.0.op.bias"
430
+ )
431
+
432
+ paths = renew_resnet_paths(resnets)
433
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
434
+ assign_to_checkpoint(
435
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
436
+ )
437
+
438
+ if len(attentions):
439
+ paths = renew_attention_paths(attentions)
440
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
441
+ assign_to_checkpoint(
442
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
443
+ )
444
+
445
+ resnet_0 = middle_blocks[0]
446
+ attentions = middle_blocks[1]
447
+ resnet_1 = middle_blocks[2]
448
+
449
+ resnet_0_paths = renew_resnet_paths(resnet_0)
450
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
451
+
452
+ resnet_1_paths = renew_resnet_paths(resnet_1)
453
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
454
+
455
+ attentions_paths = renew_attention_paths(attentions)
456
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
457
+ assign_to_checkpoint(
458
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
459
+ )
460
+
461
+ for i in range(num_output_blocks):
462
+ block_id = i // (config["layers_per_block"] + 1)
463
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
464
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
465
+ output_block_list = {}
466
+
467
+ for layer in output_block_layers:
468
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
469
+ if layer_id in output_block_list:
470
+ output_block_list[layer_id].append(layer_name)
471
+ else:
472
+ output_block_list[layer_id] = [layer_name]
473
+
474
+ if len(output_block_list) > 1:
475
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
476
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
477
+
478
+ resnet_0_paths = renew_resnet_paths(resnets)
479
+ paths = renew_resnet_paths(resnets)
480
+
481
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
482
+ assign_to_checkpoint(
483
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
484
+ )
485
+
486
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
487
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
488
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
489
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
490
+ f"output_blocks.{i}.{index}.conv.weight"
491
+ ]
492
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
493
+ f"output_blocks.{i}.{index}.conv.bias"
494
+ ]
495
+
496
+ # Clear attentions as they have been attributed above.
497
+ if len(attentions) == 2:
498
+ attentions = []
499
+
500
+ if len(attentions):
501
+ paths = renew_attention_paths(attentions)
502
+ meta_path = {
503
+ "old": f"output_blocks.{i}.1",
504
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
505
+ }
506
+ assign_to_checkpoint(
507
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
508
+ )
509
+ else:
510
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
511
+ for path in resnet_0_paths:
512
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
513
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
514
+
515
+ new_checkpoint[new_path] = unet_state_dict[old_path]
516
+
517
+ if controlnet:
518
+ # conditioning embedding
519
+
520
+ orig_index = 0
521
+
522
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
523
+ f"input_hint_block.{orig_index}.weight"
524
+ )
525
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
526
+ f"input_hint_block.{orig_index}.bias"
527
+ )
528
+
529
+ orig_index += 2
530
+
531
+ diffusers_index = 0
532
+
533
+ while diffusers_index < 6:
534
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
535
+ f"input_hint_block.{orig_index}.weight"
536
+ )
537
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
538
+ f"input_hint_block.{orig_index}.bias"
539
+ )
540
+ diffusers_index += 1
541
+ orig_index += 2
542
+
543
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
544
+ f"input_hint_block.{orig_index}.weight"
545
+ )
546
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
547
+ f"input_hint_block.{orig_index}.bias"
548
+ )
549
+
550
+ # down blocks
551
+ for i in range(num_input_blocks):
552
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
553
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
554
+
555
+ # mid block
556
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
557
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
558
+
559
+ return new_checkpoint
560
+
561
+
562
+ def convert_ldm_vae_checkpoint(checkpoint, config, only_decoder=False, only_encoder=False):
563
+ # extract state dict for VAE
564
+ vae_state_dict = {}
565
+ vae_key = "first_stage_model."
566
+ keys = list(checkpoint.keys())
567
+ for key in keys:
568
+ if key.startswith(vae_key):
569
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
570
+
571
+ new_checkpoint = {}
572
+
573
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
574
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
575
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
576
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
577
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
578
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
579
+
580
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
581
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
582
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
583
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
584
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
585
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
586
+
587
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
588
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
589
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
590
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
591
+
592
+ # Retrieves the keys for the encoder down blocks only
593
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
594
+ down_blocks = {
595
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
596
+ }
597
+
598
+ # Retrieves the keys for the decoder up blocks only
599
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
600
+ up_blocks = {
601
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
602
+ }
603
+
604
+ for i in range(num_down_blocks):
605
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
606
+
607
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
608
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
609
+ f"encoder.down.{i}.downsample.conv.weight"
610
+ )
611
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
612
+ f"encoder.down.{i}.downsample.conv.bias"
613
+ )
614
+
615
+ paths = renew_vae_resnet_paths(resnets)
616
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
617
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
618
+
619
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
620
+ num_mid_res_blocks = 2
621
+ for i in range(1, num_mid_res_blocks + 1):
622
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
623
+
624
+ paths = renew_vae_resnet_paths(resnets)
625
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
626
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
627
+
628
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
629
+ paths = renew_vae_attention_paths(mid_attentions)
630
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
631
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
632
+ conv_attn_to_linear(new_checkpoint)
633
+
634
+ for i in range(num_up_blocks):
635
+ block_id = num_up_blocks - 1 - i
636
+ resnets = [
637
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
638
+ ]
639
+
640
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
641
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
642
+ f"decoder.up.{block_id}.upsample.conv.weight"
643
+ ]
644
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
645
+ f"decoder.up.{block_id}.upsample.conv.bias"
646
+ ]
647
+
648
+ paths = renew_vae_resnet_paths(resnets)
649
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
650
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
651
+
652
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
653
+ num_mid_res_blocks = 2
654
+ for i in range(1, num_mid_res_blocks + 1):
655
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
656
+
657
+ paths = renew_vae_resnet_paths(resnets)
658
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
659
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
660
+
661
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
662
+ paths = renew_vae_attention_paths(mid_attentions)
663
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
664
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
665
+ conv_attn_to_linear(new_checkpoint)
666
+
667
+ if only_decoder:
668
+ new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('decoder') or k.startswith('post_quant')}
669
+ elif only_encoder:
670
+ new_checkpoint = {k: v for k, v in new_checkpoint.items() if k.startswith('encoder') or k.startswith('quant')}
671
+
672
+ return new_checkpoint
673
+
674
+
675
+ def convert_ldm_bert_checkpoint(checkpoint, config):
676
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
677
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
678
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
679
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
680
+
681
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
682
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
683
+
684
+ def _copy_linear(hf_linear, pt_linear):
685
+ hf_linear.weight = pt_linear.weight
686
+ hf_linear.bias = pt_linear.bias
687
+
688
+ def _copy_layer(hf_layer, pt_layer):
689
+ # copy layer norms
690
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
691
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
692
+
693
+ # copy attn
694
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
695
+
696
+ # copy MLP
697
+ pt_mlp = pt_layer[1][1]
698
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
699
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
700
+
701
+ def _copy_layers(hf_layers, pt_layers):
702
+ for i, hf_layer in enumerate(hf_layers):
703
+ if i != 0:
704
+ i += i
705
+ pt_layer = pt_layers[i : i + 2]
706
+ _copy_layer(hf_layer, pt_layer)
707
+
708
+ hf_model = LDMBertModel(config).eval()
709
+
710
+ # copy embeds
711
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
712
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
713
+
714
+ # copy layer norm
715
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
716
+
717
+ # copy hidden layers
718
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
719
+
720
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
721
+
722
+ return hf_model
723
+
724
+
725
+ def convert_ldm_clip_checkpoint(checkpoint):
726
+ keys = list(checkpoint.keys())
727
+
728
+ text_model_dict = {}
729
+ for key in keys:
730
+ if key.startswith("cond_stage_model.transformer"):
731
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
732
+
733
+ return text_model_dict
734
+
735
+
736
+ textenc_conversion_lst = [
737
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
738
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
739
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
740
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
741
+ ]
742
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
743
+
744
+ textenc_transformer_conversion_lst = [
745
+ # (stable-diffusion, HF Diffusers)
746
+ ("resblocks.", "text_model.encoder.layers."),
747
+ ("ln_1", "layer_norm1"),
748
+ ("ln_2", "layer_norm2"),
749
+ (".c_fc.", ".fc1."),
750
+ (".c_proj.", ".fc2."),
751
+ (".attn", ".self_attn"),
752
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
753
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
754
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
755
+ ]
756
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
757
+ textenc_pattern = re.compile("|".join(protected.keys()))
758
+
759
+
760
+ def convert_paint_by_example_checkpoint(checkpoint):
761
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
762
+ model = PaintByExampleImageEncoder(config)
763
+
764
+ keys = list(checkpoint.keys())
765
+
766
+ text_model_dict = {}
767
+
768
+ for key in keys:
769
+ if key.startswith("cond_stage_model.transformer"):
770
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
771
+
772
+ # load clip vision
773
+ model.model.load_state_dict(text_model_dict)
774
+
775
+ # load mapper
776
+ keys_mapper = {
777
+ k[len("cond_stage_model.mapper.res") :]: v
778
+ for k, v in checkpoint.items()
779
+ if k.startswith("cond_stage_model.mapper")
780
+ }
781
+
782
+ MAPPING = {
783
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
784
+ "attn.c_proj": ["attn1.to_out.0"],
785
+ "ln_1": ["norm1"],
786
+ "ln_2": ["norm3"],
787
+ "mlp.c_fc": ["ff.net.0.proj"],
788
+ "mlp.c_proj": ["ff.net.2"],
789
+ }
790
+
791
+ mapped_weights = {}
792
+ for key, value in keys_mapper.items():
793
+ prefix = key[: len("blocks.i")]
794
+ suffix = key.split(prefix)[-1].split(".")[-1]
795
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
796
+ mapped_names = MAPPING[name]
797
+
798
+ num_splits = len(mapped_names)
799
+ for i, mapped_name in enumerate(mapped_names):
800
+ new_name = ".".join([prefix, mapped_name, suffix])
801
+ shape = value.shape[0] // num_splits
802
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
803
+
804
+ model.mapper.load_state_dict(mapped_weights)
805
+
806
+ # load final layer norm
807
+ model.final_layer_norm.load_state_dict(
808
+ {
809
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
810
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
811
+ }
812
+ )
813
+
814
+ # load final proj
815
+ model.proj_out.load_state_dict(
816
+ {
817
+ "bias": checkpoint["proj_out.bias"],
818
+ "weight": checkpoint["proj_out.weight"],
819
+ }
820
+ )
821
+
822
+ # load uncond vector
823
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
824
+ return model
825
+
826
+
827
+ def convert_open_clip_checkpoint(checkpoint):
828
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
829
+
830
+ keys = list(checkpoint.keys())
831
+
832
+ text_model_dict = {}
833
+
834
+ if "cond_stage_model.model.text_projection" in checkpoint:
835
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
836
+ else:
837
+ d_model = 1024
838
+
839
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
840
+
841
+ for key in keys:
842
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
843
+ continue
844
+ if key in textenc_conversion_map:
845
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
846
+ if key.startswith("cond_stage_model.model.transformer."):
847
+ new_key = key[len("cond_stage_model.model.transformer.") :]
848
+ if new_key.endswith(".in_proj_weight"):
849
+ new_key = new_key[: -len(".in_proj_weight")]
850
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
851
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
852
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
853
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
854
+ elif new_key.endswith(".in_proj_bias"):
855
+ new_key = new_key[: -len(".in_proj_bias")]
856
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
857
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
858
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
859
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
860
+ else:
861
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
862
+
863
+ text_model_dict[new_key] = checkpoint[key]
864
+
865
+ text_model.load_state_dict(text_model_dict)
866
+
867
+ return text_model
868
+
869
+
870
+ def stable_unclip_image_encoder(original_config):
871
+ """
872
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
873
+
874
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
875
+ encoders.
876
+ """
877
+
878
+ image_embedder_config = original_config.model.params.embedder_config
879
+
880
+ sd_clip_image_embedder_class = image_embedder_config.target
881
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
882
+
883
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
884
+ clip_model_name = image_embedder_config.params.model
885
+
886
+ if clip_model_name == "ViT-L/14":
887
+ feature_extractor = CLIPImageProcessor()
888
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
889
+ else:
890
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
891
+
892
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
893
+ feature_extractor = CLIPImageProcessor()
894
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
895
+ else:
896
+ raise NotImplementedError(
897
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
898
+ )
899
+
900
+ return feature_extractor, image_encoder
901
+
902
+
903
+ def stable_unclip_image_noising_components(
904
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
905
+ ):
906
+ """
907
+ Returns the noising components for the img2img and txt2img unclip pipelines.
908
+
909
+ Converts the stability noise augmentor into
910
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
911
+ 2. a `DDPMScheduler` for holding the noise schedule
912
+
913
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
914
+ """
915
+ noise_aug_config = original_config.model.params.noise_aug_config
916
+ noise_aug_class = noise_aug_config.target
917
+ noise_aug_class = noise_aug_class.split(".")[-1]
918
+
919
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
920
+ noise_aug_config = noise_aug_config.params
921
+ embedding_dim = noise_aug_config.timestep_dim
922
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
923
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
924
+
925
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
926
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
927
+
928
+ if "clip_stats_path" in noise_aug_config:
929
+ if clip_stats_path is None:
930
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
931
+
932
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
933
+ clip_mean = clip_mean[None, :]
934
+ clip_std = clip_std[None, :]
935
+
936
+ clip_stats_state_dict = {
937
+ "mean": clip_mean,
938
+ "std": clip_std,
939
+ }
940
+
941
+ image_normalizer.load_state_dict(clip_stats_state_dict)
942
+ else:
943
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
944
+
945
+ return image_normalizer, image_noising_scheduler
946
+
947
+
948
+ def convert_controlnet_checkpoint(
949
+ checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
950
+ ):
951
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
952
+ ctrlnet_config["upcast_attention"] = upcast_attention
953
+
954
+ ctrlnet_config.pop("sample_size")
955
+
956
+ controlnet_model = ControlNetModel(**ctrlnet_config)
957
+
958
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
959
+ checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
960
+ )
961
+
962
+ controlnet_model.load_state_dict(converted_ctrl_checkpoint)
963
+
964
+ return controlnet_model
models/animatediff/utils/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Conversion script for the LoRA's safetensors checkpoints. """
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+ import pdb
25
+
26
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
27
+ # load base model
28
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
29
+
30
+ # load LoRA weight from .safetensors
31
+ # state_dict = load_file(checkpoint_path)
32
+
33
+ visited = []
34
+
35
+ # directly update weight in diffusers model
36
+ for key in state_dict:
37
+ # it is suggested to print out the key, it usually will be something like below
38
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
39
+
40
+ # as we have set the alpha beforehand, so just skip
41
+ if ".alpha" in key or key in visited:
42
+ continue
43
+
44
+ if "text" in key:
45
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
46
+ curr_layer = pipeline.text_encoder
47
+ else:
48
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
49
+ curr_layer = pipeline.unet
50
+
51
+ # find the target layer
52
+ temp_name = layer_infos.pop(0)
53
+ while len(layer_infos) > -1:
54
+ try:
55
+ curr_layer = curr_layer.__getattr__(temp_name)
56
+ if len(layer_infos) > 0:
57
+ temp_name = layer_infos.pop(0)
58
+ elif len(layer_infos) == 0:
59
+ break
60
+ except Exception:
61
+ if len(temp_name) > 0:
62
+ temp_name += "_" + layer_infos.pop(0)
63
+ else:
64
+ temp_name = layer_infos.pop(0)
65
+
66
+ pair_keys = []
67
+ if "lora_down" in key:
68
+ pair_keys.append(key.replace("lora_down", "lora_up"))
69
+ pair_keys.append(key)
70
+ else:
71
+ pair_keys.append(key)
72
+ pair_keys.append(key.replace("lora_up", "lora_down"))
73
+
74
+ # update weight
75
+ if len(state_dict[pair_keys[0]].shape) == 4:
76
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
77
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
78
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
79
+ else:
80
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
81
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
82
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
83
+
84
+ # update visited list
85
+ for item in pair_keys:
86
+ visited.append(item)
87
+
88
+ return pipeline
89
+
90
+
91
+ def convert_lora_model_level(state_dict, unet, text_encoder=None, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
92
+ """convert lora in model level instead of pipeline leval
93
+ """
94
+
95
+ visited = []
96
+
97
+ # directly update weight in diffusers model
98
+ for key in state_dict:
99
+ # it is suggested to print out the key, it usually will be something like below
100
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
101
+
102
+ # as we have set the alpha beforehand, so just skip
103
+ if ".alpha" in key or key in visited:
104
+ continue
105
+
106
+ if "text" in key:
107
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
108
+ assert text_encoder is not None, (
109
+ 'text_encoder must be passed since lora contains text encoder layers')
110
+ curr_layer = text_encoder
111
+ else:
112
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
113
+ curr_layer = unet
114
+
115
+ # find the target layer
116
+ temp_name = layer_infos.pop(0)
117
+ while len(layer_infos) > -1:
118
+ try:
119
+ curr_layer = curr_layer.__getattr__(temp_name)
120
+ if len(layer_infos) > 0:
121
+ temp_name = layer_infos.pop(0)
122
+ elif len(layer_infos) == 0:
123
+ break
124
+ except Exception:
125
+ if len(temp_name) > 0:
126
+ temp_name += "_" + layer_infos.pop(0)
127
+ else:
128
+ temp_name = layer_infos.pop(0)
129
+
130
+ pair_keys = []
131
+ if "lora_down" in key:
132
+ pair_keys.append(key.replace("lora_down", "lora_up"))
133
+ pair_keys.append(key)
134
+ else:
135
+ pair_keys.append(key)
136
+ pair_keys.append(key.replace("lora_up", "lora_down"))
137
+
138
+ # update weight
139
+ # NOTE: load lycon, meybe have bugs :(
140
+ if 'conv_in' in pair_keys[0]:
141
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
142
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
143
+ weight_up = weight_up.view(weight_up.size(0), -1)
144
+ weight_down = weight_down.view(weight_down.size(0), -1)
145
+ shape = [e for e in curr_layer.weight.data.shape]
146
+ shape[1] = 4
147
+ curr_layer.weight.data[:, :4, ...] += alpha * (weight_up @ weight_down).view(*shape)
148
+ elif 'conv' in pair_keys[0]:
149
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
150
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
151
+ weight_up = weight_up.view(weight_up.size(0), -1)
152
+ weight_down = weight_down.view(weight_down.size(0), -1)
153
+ shape = [e for e in curr_layer.weight.data.shape]
154
+ curr_layer.weight.data += alpha * (weight_up @ weight_down).view(*shape)
155
+ elif len(state_dict[pair_keys[0]].shape) == 4:
156
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
157
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
158
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
159
+ else:
160
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
161
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
162
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
163
+
164
+ # update visited list
165
+ for item in pair_keys:
166
+ visited.append(item)
167
+
168
+ return unet, text_encoder
169
+
170
+
171
+ if __name__ == "__main__":
172
+ parser = argparse.ArgumentParser()
173
+
174
+ parser.add_argument(
175
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
176
+ )
177
+ parser.add_argument(
178
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
179
+ )
180
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
181
+ parser.add_argument(
182
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
183
+ )
184
+ parser.add_argument(
185
+ "--lora_prefix_text_encoder",
186
+ default="lora_te",
187
+ type=str,
188
+ help="The prefix of text encoder weight in safetensors",
189
+ )
190
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
191
+ parser.add_argument(
192
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
193
+ )
194
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
195
+
196
+ args = parser.parse_args()
197
+
198
+ base_model_path = args.base_model_path
199
+ checkpoint_path = args.checkpoint_path
200
+ dump_path = args.dump_path
201
+ lora_prefix_unet = args.lora_prefix_unet
202
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
203
+ alpha = args.alpha
204
+
205
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
206
+
207
+ pipe = pipe.to(args.device)
208
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
models/animatediff/utils/util.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import imageio
3
+ import numpy as np
4
+ from typing import Union, Optional
5
+
6
+ import torch
7
+ import torchvision
8
+ import torch.distributed as dist
9
+
10
+ from tqdm import tqdm
11
+ from einops import rearrange
12
+ import cv2
13
+ import math
14
+ import moviepy.editor as mpy
15
+ from PIL import Image
16
+
17
+ # We recommend to use the following affinity score(motion magnitude)
18
+ # Also encourage to try to construct different score by yourself
19
+ RANGE_LIST = [
20
+ [1.0, 0.9, 0.85, 0.85, 0.85, 0.8], # 0 Small Motion
21
+ [1.0, 0.8, 0.8, 0.8, 0.79, 0.78, 0.75], # Moderate Motion
22
+ [1.0, 0.8, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, 0.6, 0.5, 0.5], # Large Motion
23
+ [1.0 , 0.9 , 0.85, 0.85, 0.85, 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.8 , 0.85, 0.85, 0.9 , 1.0 ], # Loop
24
+ [1.0 , 0.8 , 0.8 , 0.8 , 0.79, 0.78, 0.75, 0.75, 0.75, 0.75, 0.75, 0.78, 0.79, 0.8 , 0.8 , 1.0 ], # Loop
25
+ [1.0 , 0.8 , 0.7 , 0.7 , 0.7 , 0.7 , 0.6 , 0.5 , 0.5 , 0.6 , 0.7 , 0.7 , 0.7 , 0.7 , 0.8 , 1.0 ], # Loop
26
+ [0.5, 0.2], # Style Transfer Large Motion
27
+ [0.5, 0.4, 0.4, 0.4, 0.35, 0.35, 0.3, 0.25, 0.2], # Style Transfer Moderate Motion
28
+ [0.5, 0.4, 0.4, 0.4, 0.35, 0.3], # Style Transfer Candidate Small Motion
29
+ ]
30
+
31
+
32
+ def zero_rank_print(s):
33
+ if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
34
+
35
+ def save_videos_mp4(video: torch.Tensor, path: str, fps: int=8):
36
+ video = rearrange(video, "b c t h w -> t b c h w")
37
+ num_frames, batch_size, channels, height, width = video.shape
38
+ assert batch_size == 1,\
39
+ 'Only support batch size == 1'
40
+ video = video.squeeze(1)
41
+ video = rearrange(video, "t c h w -> t h w c")
42
+ def make_frame(t):
43
+ frame_tensor = video[int(t * fps)]
44
+ frame_np = (frame_tensor * 255).numpy().astype('uint8')
45
+ return frame_np
46
+ clip = mpy.VideoClip(make_frame, duration=num_frames / fps)
47
+ clip.write_videofile(path, fps=fps, codec='libx264')
48
+
49
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
50
+ videos = rearrange(videos, "b c t h w -> t b c h w")
51
+ outputs = []
52
+ for x in videos:
53
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
54
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
55
+ if rescale:
56
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
57
+ x = torch.clamp((x * 255), 0, 255).numpy().astype(np.uint8)
58
+ outputs.append(x)
59
+
60
+ os.makedirs(os.path.dirname(path), exist_ok=True)
61
+ imageio.mimsave(path, outputs, fps=fps)
62
+
63
+
64
+ # DDIM Inversion
65
+ @torch.no_grad()
66
+ def init_prompt(prompt, pipeline):
67
+ uncond_input = pipeline.tokenizer(
68
+ [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
69
+ return_tensors="pt"
70
+ )
71
+ uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
72
+ text_input = pipeline.tokenizer(
73
+ [prompt],
74
+ padding="max_length",
75
+ max_length=pipeline.tokenizer.model_max_length,
76
+ truncation=True,
77
+ return_tensors="pt",
78
+ )
79
+ text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
80
+ context = torch.cat([uncond_embeddings, text_embeddings])
81
+
82
+ return context
83
+
84
+
85
+ def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
86
+ sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
87
+ timestep, next_timestep = min(
88
+ timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
89
+ alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
90
+ alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
91
+ beta_prod_t = 1 - alpha_prod_t
92
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
93
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
94
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
95
+ return next_sample
96
+
97
+
98
+ def get_noise_pred_single(latents, t, context, unet):
99
+ noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
100
+ return noise_pred
101
+
102
+
103
+ @torch.no_grad()
104
+ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
105
+ context = init_prompt(prompt, pipeline)
106
+ uncond_embeddings, cond_embeddings = context.chunk(2)
107
+ all_latent = [latent]
108
+ latent = latent.clone().detach()
109
+ for i in tqdm(range(num_inv_steps)):
110
+ t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
111
+ noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
112
+ latent = next_step(noise_pred, t, latent, ddim_scheduler)
113
+ all_latent.append(latent)
114
+ return all_latent
115
+
116
+
117
+ @torch.no_grad()
118
+ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
119
+ ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
120
+ return ddim_latents
121
+
122
+ def prepare_mask_coef(video_length:int, cond_frame:int, sim_range:list=[0.2, 1.0]):
123
+
124
+ assert len(sim_range) == 2, \
125
+ 'sim_range should has the length of 2, including the min and max similarity'
126
+
127
+ assert video_length > 1, \
128
+ 'video_length should be greater than 1'
129
+
130
+ assert video_length > cond_frame,\
131
+ 'video_length should be greater than cond_frame'
132
+
133
+ diff = abs(sim_range[0] - sim_range[1]) / (video_length - 1)
134
+ coef = [1.0] * video_length
135
+ for f in range(video_length):
136
+ f_diff = diff * abs(cond_frame - f)
137
+ f_diff = 1 - f_diff
138
+ coef[f] *= f_diff
139
+
140
+ return coef
141
+
142
+ def prepare_mask_coef_by_statistics(video_length: int, cond_frame: int, sim_range: int):
143
+ assert video_length > 0, \
144
+ 'video_length should be greater than 0'
145
+
146
+ assert video_length > cond_frame,\
147
+ 'video_length should be greater than cond_frame'
148
+
149
+ range_list = RANGE_LIST
150
+
151
+ assert sim_range < len(range_list),\
152
+ f'sim_range type{sim_range} not implemented'
153
+
154
+ coef = range_list[sim_range]
155
+ coef = coef + ([coef[-1]] * (video_length - len(coef)))
156
+
157
+ order = [abs(i - cond_frame) for i in range(video_length)]
158
+ coef = [coef[order[i]] for i in range(video_length)]
159
+
160
+ return coef
161
+
162
+
163
+ def prepare_mask_coef_multi_cond(video_length:int, cond_frames:list, sim_range:list=[0.2, 1.0]):
164
+ assert len(sim_range) == 2, \
165
+ 'sim_range should has the length of 2, including the min and max similarity'
166
+
167
+ assert video_length > 1, \
168
+ 'video_length should be greater than 1'
169
+
170
+ assert isinstance(cond_frames, list), \
171
+ 'cond_frames should be a list'
172
+
173
+ assert video_length > max(cond_frames),\
174
+ 'video_length should be greater than cond_frame'
175
+
176
+ if max(sim_range) == min(sim_range):
177
+ cond_coefs = [sim_range[0]] * video_length
178
+ return cond_coefs
179
+
180
+ cond_coefs = []
181
+
182
+ for cond_frame in cond_frames:
183
+ cond_coef = prepare_mask_coef(video_length, cond_frame, sim_range)
184
+ cond_coefs.append(cond_coef)
185
+
186
+ mixed_coef = [0] * video_length
187
+ for conds in range(len(cond_frames)):
188
+
189
+ for f in range(video_length):
190
+ mixed_coef[f] = abs(cond_coefs[conds][f] - mixed_coef[f])
191
+
192
+ if conds > 0:
193
+ min_num = min(mixed_coef)
194
+ max_num = max(mixed_coef)
195
+
196
+ for f in range(video_length):
197
+ mixed_coef[f] = (mixed_coef[f] - min_num) / (max_num - min_num)
198
+
199
+ mixed_max = max(mixed_coef)
200
+ mixed_min = min(mixed_coef)
201
+ for f in range(video_length):
202
+ mixed_coef[f] = (max(sim_range) - min(sim_range)) * (mixed_coef[f] - mixed_min) / (mixed_max - mixed_min) + min(sim_range)
203
+
204
+ mixed_coef = [x if min(sim_range) <= x <= max(sim_range) else min(sim_range) if x < min(sim_range) else max(sim_range) for x in mixed_coef]
205
+
206
+ return mixed_coef
207
+
208
+ def prepare_masked_latent_cond(video_length: int, cond_frames: list):
209
+ for cond_frame in cond_frames:
210
+ assert cond_frame < video_length, \
211
+ 'cond_frame should be smaller than video_length'
212
+ assert cond_frame > -1, \
213
+ f'cond_frame should be in the range of [0, {video_length}]'
214
+
215
+ cond_frames.sort()
216
+ nearest = [cond_frames[0]] * video_length
217
+ for f in range(video_length):
218
+ for cond_frame in cond_frames:
219
+ if abs(nearest[f] - f) > abs(cond_frame - f):
220
+ nearest[f] = cond_frame
221
+
222
+ maked_latent_cond = nearest
223
+
224
+ return maked_latent_cond
225
+
226
+ def estimated_kernel_size(frame_width: int, frame_height: int) -> int:
227
+ """Estimate kernel size based on video resolution."""
228
+ # TODO: This equation is based on manual estimation from a few videos.
229
+ # Create a more comprehensive test suite to optimize against.
230
+ size: int = 4 + round(math.sqrt(frame_width * frame_height) / 192)
231
+ if size % 2 == 0:
232
+ size += 1
233
+ return size
234
+
235
+ def detect_edges(lum: np.ndarray) -> np.ndarray:
236
+ """Detect edges using the luma channel of a frame.
237
+
238
+ Arguments:
239
+ lum: 2D 8-bit image representing the luma channel of a frame.
240
+
241
+ Returns:
242
+ 2D 8-bit image of the same size as the input, where pixels with values of 255
243
+ represent edges, and all other pixels are 0.
244
+ """
245
+ # Initialize kernel.
246
+ kernel_size = estimated_kernel_size(lum.shape[1], lum.shape[0])
247
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
248
+
249
+ # Estimate levels for thresholding.
250
+ # TODO(0.6.3): Add config file entries for sigma, aperture/kernel size, etc.
251
+ sigma: float = 1.0 / 3.0
252
+ median = np.median(lum)
253
+ low = int(max(0, (1.0 - sigma) * median))
254
+ high = int(min(255, (1.0 + sigma) * median))
255
+
256
+ # Calculate edges using Canny algorithm, and reduce noise by dilating the edges.
257
+ # This increases edge overlap leading to improved robustness against noise and slow
258
+ # camera movement. Note that very large kernel sizes can negatively affect accuracy.
259
+ edges = cv2.Canny(lum, low, high)
260
+ return cv2.dilate(edges, kernel)
261
+
262
+ def prepare_mask_coef_by_score(video_shape: list, cond_frame_idx: list, sim_range: list = [0.2, 1.0],
263
+ statistic: list = [1, 100], coef_max: int = 0.98, score: Optional[torch.Tensor] = None):
264
+ '''
265
+ the shape of video_data is (b f c h w)
266
+ cond_frame_idx is a list, with length of batch_size
267
+ the shape of statistic is (f 2)
268
+ the shape of score is (b f)
269
+ the shape of coef is (b f)
270
+ '''
271
+ assert len(video_shape) == 2, \
272
+ f'the shape of video_shape should be (b f c h w), but now get {len(video_shape.shape)} channels'
273
+
274
+ batch_size, frame_num = video_shape[0], video_shape[1]
275
+
276
+ score = score.permute(0, 2, 1).squeeze(0)
277
+
278
+ # list -> b 1
279
+ cond_fram_mat = torch.tensor(cond_frame_idx).unsqueeze(-1)
280
+
281
+ statistic = torch.tensor(statistic)
282
+ # (f 2) -> (b f 2)
283
+ statistic = statistic.repeat(batch_size, 1, 1)
284
+
285
+ # shape of order (b f), shape of cond_mat (b f)
286
+ order = torch.arange(0, frame_num, 1)
287
+ order = order.repeat(batch_size, 1)
288
+ cond_mat = torch.ones((batch_size, frame_num)) * cond_fram_mat
289
+ order = abs(order - cond_mat)
290
+
291
+ statistic = statistic[:,order.to(torch.long)][0,:,:,:]
292
+
293
+ # score (b f) max_s (b f 1)
294
+ max_stats = torch.max(statistic, dim=2).values.to(dtype=score.dtype)
295
+ min_stats = torch.min(statistic, dim=2).values.to(dtype=score.dtype)
296
+
297
+ score[score > max_stats] = max_stats[score > max_stats] * 0.95
298
+ score[score < min_stats] = min_stats[score < min_stats]
299
+
300
+ eps = 1e-10
301
+ coef = 1 - abs((score / (max_stats + eps)) * (max(sim_range) - min(sim_range)))
302
+
303
+ indices = torch.arange(coef.shape[0]).unsqueeze(1)
304
+ coef[indices, cond_fram_mat] = 1.0
305
+
306
+ return coef
307
+
308
+ def preprocess_img(img_path, max_size:int=512):
309
+
310
+ ori_image = Image.open(img_path).convert('RGB')
311
+
312
+ width, height = ori_image.size
313
+
314
+ long_edge = max(width, height)
315
+ if long_edge > max_size:
316
+ scale_factor = max_size / long_edge
317
+ else:
318
+ scale_factor = 1
319
+ width = int(width * scale_factor)
320
+ height = int(height * scale_factor)
321
+ ori_image = ori_image.resize((width, height))
322
+
323
+ if (width % 8 != 0) or (height % 8 != 0):
324
+ in_width = (width // 8) * 8
325
+ in_height = (height // 8) * 8
326
+ else:
327
+ in_width = width
328
+ in_height = height
329
+ in_image = ori_image
330
+
331
+ in_image = ori_image.resize((in_width, in_height))
332
+ # in_image = ori_image.resize((512, 512))
333
+ in_image_np = np.array(in_image)
334
+ return in_image_np, in_height, in_width
models/draggan/dnnlib/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ from .util import EasyDict, make_cache_dir_path
models/draggan/dnnlib/util.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+
11
+ import ctypes
12
+ import fnmatch
13
+ import importlib
14
+ import inspect
15
+ import numpy as np
16
+ import os
17
+ import shutil
18
+ import sys
19
+ import types
20
+ import io
21
+ import pickle
22
+ import re
23
+ import requests
24
+ import html
25
+ import hashlib
26
+ import glob
27
+ import tempfile
28
+ import urllib
29
+ import urllib.request
30
+ import uuid
31
+
32
+ from distutils.util import strtobool
33
+ from typing import Any, List, Tuple, Union
34
+
35
+
36
+ # Util classes
37
+ # ------------------------------------------------------------------------------------------
38
+
39
+
40
+ class EasyDict(dict):
41
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
42
+
43
+ def __getattr__(self, name: str) -> Any:
44
+ try:
45
+ return self[name]
46
+ except KeyError:
47
+ raise AttributeError(name)
48
+
49
+ def __setattr__(self, name: str, value: Any) -> None:
50
+ self[name] = value
51
+
52
+ def __delattr__(self, name: str) -> None:
53
+ del self[name]
54
+
55
+
56
+ class Logger(object):
57
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
58
+
59
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
60
+ self.file = None
61
+
62
+ if file_name is not None:
63
+ self.file = open(file_name, file_mode)
64
+
65
+ self.should_flush = should_flush
66
+ self.stdout = sys.stdout
67
+ self.stderr = sys.stderr
68
+
69
+ sys.stdout = self
70
+ sys.stderr = self
71
+
72
+ def __enter__(self) -> "Logger":
73
+ return self
74
+
75
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
76
+ self.close()
77
+
78
+ def write(self, text: Union[str, bytes]) -> None:
79
+ """Write text to stdout (and a file) and optionally flush."""
80
+ if isinstance(text, bytes):
81
+ text = text.decode()
82
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
83
+ return
84
+
85
+ if self.file is not None:
86
+ self.file.write(text)
87
+
88
+ self.stdout.write(text)
89
+
90
+ if self.should_flush:
91
+ self.flush()
92
+
93
+ def flush(self) -> None:
94
+ """Flush written text to both stdout and a file, if open."""
95
+ if self.file is not None:
96
+ self.file.flush()
97
+
98
+ self.stdout.flush()
99
+
100
+ def close(self) -> None:
101
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
102
+ self.flush()
103
+
104
+ # if using multiple loggers, prevent closing in wrong order
105
+ if sys.stdout is self:
106
+ sys.stdout = self.stdout
107
+ if sys.stderr is self:
108
+ sys.stderr = self.stderr
109
+
110
+ if self.file is not None:
111
+ self.file.close()
112
+ self.file = None
113
+
114
+
115
+ # Cache directories
116
+ # ------------------------------------------------------------------------------------------
117
+
118
+ _dnnlib_cache_dir = None
119
+
120
+ def set_cache_dir(path: str) -> None:
121
+ global _dnnlib_cache_dir
122
+ _dnnlib_cache_dir = path
123
+
124
+ def make_cache_dir_path(*paths: str) -> str:
125
+ if _dnnlib_cache_dir is not None:
126
+ return os.path.join(_dnnlib_cache_dir, *paths)
127
+ if 'DNNLIB_CACHE_DIR' in os.environ:
128
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
129
+ if 'HOME' in os.environ:
130
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
131
+ if 'USERPROFILE' in os.environ:
132
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
133
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
134
+
135
+ # Small util functions
136
+ # ------------------------------------------------------------------------------------------
137
+
138
+
139
+ def format_time(seconds: Union[int, float]) -> str:
140
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
141
+ s = int(np.rint(seconds))
142
+
143
+ if s < 60:
144
+ return "{0}s".format(s)
145
+ elif s < 60 * 60:
146
+ return "{0}m {1:02}s".format(s // 60, s % 60)
147
+ elif s < 24 * 60 * 60:
148
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
149
+ else:
150
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
151
+
152
+
153
+ def format_time_brief(seconds: Union[int, float]) -> str:
154
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
155
+ s = int(np.rint(seconds))
156
+
157
+ if s < 60:
158
+ return "{0}s".format(s)
159
+ elif s < 60 * 60:
160
+ return "{0}m {1:02}s".format(s // 60, s % 60)
161
+ elif s < 24 * 60 * 60:
162
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
163
+ else:
164
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
165
+
166
+
167
+ def ask_yes_no(question: str) -> bool:
168
+ """Ask the user the question until the user inputs a valid answer."""
169
+ while True:
170
+ try:
171
+ print("{0} [y/n]".format(question))
172
+ return strtobool(input().lower())
173
+ except ValueError:
174
+ pass
175
+
176
+
177
+ def tuple_product(t: Tuple) -> Any:
178
+ """Calculate the product of the tuple elements."""
179
+ result = 1
180
+
181
+ for v in t:
182
+ result *= v
183
+
184
+ return result
185
+
186
+
187
+ _str_to_ctype = {
188
+ "uint8": ctypes.c_ubyte,
189
+ "uint16": ctypes.c_uint16,
190
+ "uint32": ctypes.c_uint32,
191
+ "uint64": ctypes.c_uint64,
192
+ "int8": ctypes.c_byte,
193
+ "int16": ctypes.c_int16,
194
+ "int32": ctypes.c_int32,
195
+ "int64": ctypes.c_int64,
196
+ "float32": ctypes.c_float,
197
+ "float64": ctypes.c_double
198
+ }
199
+
200
+
201
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
202
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
203
+ type_str = None
204
+
205
+ if isinstance(type_obj, str):
206
+ type_str = type_obj
207
+ elif hasattr(type_obj, "__name__"):
208
+ type_str = type_obj.__name__
209
+ elif hasattr(type_obj, "name"):
210
+ type_str = type_obj.name
211
+ else:
212
+ raise RuntimeError("Cannot infer type name from input")
213
+
214
+ assert type_str in _str_to_ctype.keys()
215
+
216
+ my_dtype = np.dtype(type_str)
217
+ my_ctype = _str_to_ctype[type_str]
218
+
219
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
220
+
221
+ return my_dtype, my_ctype
222
+
223
+
224
+ def is_pickleable(obj: Any) -> bool:
225
+ try:
226
+ with io.BytesIO() as stream:
227
+ pickle.dump(obj, stream)
228
+ return True
229
+ except:
230
+ return False
231
+
232
+
233
+ # Functionality to import modules/objects by name, and call functions by name
234
+ # ------------------------------------------------------------------------------------------
235
+
236
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
237
+ """Searches for the underlying module behind the name to some python object.
238
+ Returns the module and the object name (original name with module part removed)."""
239
+
240
+ # allow convenience shorthands, substitute them by full names
241
+ obj_name = re.sub("^np.", "numpy.", obj_name)
242
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
243
+
244
+ # list alternatives for (module_name, local_obj_name)
245
+ parts = obj_name.split(".")
246
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
247
+
248
+ # try each alternative in turn
249
+ for module_name, local_obj_name in name_pairs:
250
+ try:
251
+ module = importlib.import_module(module_name) # may raise ImportError
252
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
253
+ return module, local_obj_name
254
+ except:
255
+ pass
256
+
257
+ # maybe some of the modules themselves contain errors?
258
+ for module_name, _local_obj_name in name_pairs:
259
+ try:
260
+ importlib.import_module(module_name) # may raise ImportError
261
+ except ImportError:
262
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
263
+ raise
264
+
265
+ # maybe the requested attribute is missing?
266
+ for module_name, local_obj_name in name_pairs:
267
+ try:
268
+ module = importlib.import_module(module_name) # may raise ImportError
269
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
270
+ except ImportError:
271
+ pass
272
+
273
+ # we are out of luck, but we have no idea why
274
+ raise ImportError(obj_name)
275
+
276
+
277
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
278
+ """Traverses the object name and returns the last (rightmost) python object."""
279
+ if obj_name == '':
280
+ return module
281
+ obj = module
282
+ for part in obj_name.split("."):
283
+ obj = getattr(obj, part)
284
+ return obj
285
+
286
+
287
+ def get_obj_by_name(name: str) -> Any:
288
+ """Finds the python object with the given name."""
289
+ module, obj_name = get_module_from_obj_name(name)
290
+ return get_obj_from_module(module, obj_name)
291
+
292
+
293
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
294
+ """Finds the python object with the given name and calls it as a function."""
295
+ assert func_name is not None
296
+ func_obj = get_obj_by_name(func_name)
297
+ assert callable(func_obj)
298
+ return func_obj(*args, **kwargs)
299
+
300
+
301
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
302
+ """Finds the python class with the given name and constructs it with the given arguments."""
303
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
304
+
305
+
306
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
307
+ """Get the directory path of the module containing the given object name."""
308
+ module, _ = get_module_from_obj_name(obj_name)
309
+ return os.path.dirname(inspect.getfile(module))
310
+
311
+
312
+ def is_top_level_function(obj: Any) -> bool:
313
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
314
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
315
+
316
+
317
+ def get_top_level_function_name(obj: Any) -> str:
318
+ """Return the fully-qualified name of a top-level function."""
319
+ assert is_top_level_function(obj)
320
+ module = obj.__module__
321
+ if module == '__main__':
322
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
323
+ return module + "." + obj.__name__
324
+
325
+
326
+ # File system helpers
327
+ # ------------------------------------------------------------------------------------------
328
+
329
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
330
+ """List all files recursively in a given directory while ignoring given file and directory names.
331
+ Returns list of tuples containing both absolute and relative paths."""
332
+ assert os.path.isdir(dir_path)
333
+ base_name = os.path.basename(os.path.normpath(dir_path))
334
+
335
+ if ignores is None:
336
+ ignores = []
337
+
338
+ result = []
339
+
340
+ for root, dirs, files in os.walk(dir_path, topdown=True):
341
+ for ignore_ in ignores:
342
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
343
+
344
+ # dirs need to be edited in-place
345
+ for d in dirs_to_remove:
346
+ dirs.remove(d)
347
+
348
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
349
+
350
+ absolute_paths = [os.path.join(root, f) for f in files]
351
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
352
+
353
+ if add_base_to_relative:
354
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
355
+
356
+ assert len(absolute_paths) == len(relative_paths)
357
+ result += zip(absolute_paths, relative_paths)
358
+
359
+ return result
360
+
361
+
362
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
363
+ """Takes in a list of tuples of (src, dst) paths and copies files.
364
+ Will create all necessary directories."""
365
+ for file in files:
366
+ target_dir_name = os.path.dirname(file[1])
367
+
368
+ # will create all intermediate-level directories
369
+ if not os.path.exists(target_dir_name):
370
+ os.makedirs(target_dir_name)
371
+
372
+ shutil.copyfile(file[0], file[1])
373
+
374
+
375
+ # URL helpers
376
+ # ------------------------------------------------------------------------------------------
377
+
378
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
379
+ """Determine whether the given object is a valid URL string."""
380
+ if not isinstance(obj, str) or not "://" in obj:
381
+ return False
382
+ if allow_file_urls and obj.startswith('file://'):
383
+ return True
384
+ try:
385
+ res = requests.compat.urlparse(obj)
386
+ if not res.scheme or not res.netloc or not "." in res.netloc:
387
+ return False
388
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
389
+ if not res.scheme or not res.netloc or not "." in res.netloc:
390
+ return False
391
+ except:
392
+ return False
393
+ return True
394
+
395
+
396
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
397
+ """Download the given URL and return a binary-mode file object to access the data."""
398
+ assert num_attempts >= 1
399
+ assert not (return_filename and (not cache))
400
+
401
+ # Doesn't look like an URL scheme so interpret it as a local filename.
402
+ if not re.match('^[a-z]+://', url):
403
+ return url if return_filename else open(url, "rb")
404
+
405
+ # Handle file URLs. This code handles unusual file:// patterns that
406
+ # arise on Windows:
407
+ #
408
+ # file:///c:/foo.txt
409
+ #
410
+ # which would translate to a local '/c:/foo.txt' filename that's
411
+ # invalid. Drop the forward slash for such pathnames.
412
+ #
413
+ # If you touch this code path, you should test it on both Linux and
414
+ # Windows.
415
+ #
416
+ # Some internet resources suggest using urllib.request.url2pathname() but
417
+ # but that converts forward slashes to backslashes and this causes
418
+ # its own set of problems.
419
+ if url.startswith('file://'):
420
+ filename = urllib.parse.urlparse(url).path
421
+ if re.match(r'^/[a-zA-Z]:', filename):
422
+ filename = filename[1:]
423
+ return filename if return_filename else open(filename, "rb")
424
+
425
+ assert is_url(url)
426
+
427
+ # Lookup from cache.
428
+ if cache_dir is None:
429
+ cache_dir = make_cache_dir_path('downloads')
430
+
431
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
432
+ if cache:
433
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
434
+ if len(cache_files) == 1:
435
+ filename = cache_files[0]
436
+ return filename if return_filename else open(filename, "rb")
437
+
438
+ # Download.
439
+ url_name = None
440
+ url_data = None
441
+ with requests.Session() as session:
442
+ if verbose:
443
+ print("Downloading %s ..." % url, end="", flush=True)
444
+ for attempts_left in reversed(range(num_attempts)):
445
+ try:
446
+ with session.get(url) as res:
447
+ res.raise_for_status()
448
+ if len(res.content) == 0:
449
+ raise IOError("No data received")
450
+
451
+ if len(res.content) < 8192:
452
+ content_str = res.content.decode("utf-8")
453
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
454
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
455
+ if len(links) == 1:
456
+ url = requests.compat.urljoin(url, links[0])
457
+ raise IOError("Google Drive virus checker nag")
458
+ if "Google Drive - Quota exceeded" in content_str:
459
+ raise IOError("Google Drive download quota exceeded -- please try again later")
460
+
461
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
462
+ url_name = match[1] if match else url
463
+ url_data = res.content
464
+ if verbose:
465
+ print(" done")
466
+ break
467
+ except KeyboardInterrupt:
468
+ raise
469
+ except:
470
+ if not attempts_left:
471
+ if verbose:
472
+ print(" failed")
473
+ raise
474
+ if verbose:
475
+ print(".", end="", flush=True)
476
+
477
+ # Save to cache.
478
+ if cache:
479
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
480
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
481
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
482
+ os.makedirs(cache_dir, exist_ok=True)
483
+ with open(temp_file, "wb") as f:
484
+ f.write(url_data)
485
+ os.replace(temp_file, cache_file) # atomic
486
+ if return_filename:
487
+ return cache_file
488
+
489
+ # Return data as file object.
490
+ assert not return_filename
491
+ return io.BytesIO(url_data)
models/draggan/gan_inv/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
models/draggan/gan_inv/inversion.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from draggan.viz import renderer
4
+ import torch
5
+ from torch import optim
6
+ from torch.nn import functional as F
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ import dataclasses
11
+ import draggan.dnnlib as dnnlib
12
+ from .lpips import util
13
+
14
+
15
+ def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
16
+ lr_ramp = min(1, (1 - t) / rampdown)
17
+ lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
18
+ lr_ramp = lr_ramp * min(1, t / rampup)
19
+
20
+ return initial_lr * lr_ramp
21
+
22
+
23
+ def make_image(tensor):
24
+ return (
25
+ tensor.detach()
26
+ .clamp_(min=-1, max=1)
27
+ .add(1)
28
+ .div_(2)
29
+ .mul(255)
30
+ .type(torch.uint8)
31
+ .permute(0, 2, 3, 1)
32
+ .to("cpu")
33
+ .numpy()
34
+ )
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class InverseConfig:
39
+ lr_warmup = 0.05
40
+ lr_decay = 0.25
41
+ lr = 0.1
42
+ noise = 0.05
43
+ noise_decay = 0.75
44
+ # step = 1000
45
+ step = 1000
46
+ noise_regularize = 1e5
47
+ mse = 0.1
48
+
49
+
50
+
51
+ def inverse_image(
52
+ g_ema,
53
+ image,
54
+ percept,
55
+ image_size=256,
56
+ w_plus = False,
57
+ config=InverseConfig(),
58
+ device='cuda:0'
59
+ ):
60
+ args = config
61
+
62
+ n_mean_latent = 10000
63
+
64
+ resize = min(image_size, 256)
65
+
66
+ if torch.is_tensor(image)==False:
67
+ transform = transforms.Compose(
68
+ [
69
+ transforms.Resize(resize,),
70
+ transforms.CenterCrop(resize),
71
+ transforms.ToTensor(),
72
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
73
+ ]
74
+ )
75
+
76
+ img = transform(image)
77
+
78
+ else:
79
+ img = transforms.functional.resize(image,resize)
80
+ transform = transforms.Compose(
81
+ [
82
+ transforms.CenterCrop(resize),
83
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
84
+ ]
85
+ )
86
+ img = transform(img)
87
+ imgs = []
88
+ imgs.append(img)
89
+ imgs = torch.stack(imgs, 0).to(device)
90
+
91
+ with torch.no_grad():
92
+
93
+ #noise_sample = torch.randn(n_mean_latent, 512, device=device)
94
+ noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device)
95
+ #label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device)
96
+ w_samples = g_ema.mapping(noise_sample,None)
97
+ w_samples = w_samples[:, :1, :]
98
+ w_avg = w_samples.mean(0)
99
+ w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5
100
+
101
+
102
+
103
+
104
+ noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name}
105
+ for noise in noises.values():
106
+ noise = torch.randn_like(noise)
107
+ noise.requires_grad = True
108
+
109
+
110
+
111
+ w_opt = w_avg.detach().clone()
112
+ if w_plus:
113
+ w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1)
114
+ w_opt.requires_grad = True
115
+ #if args.w_plus:
116
+ #latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
117
+
118
+
119
+
120
+ optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr)
121
+
122
+ pbar = tqdm(range(args.step))
123
+ latent_path = []
124
+
125
+ for i in pbar:
126
+ t = i / args.step
127
+ lr = get_lr(t, args.lr)
128
+ optimizer.param_groups[0]["lr"] = lr
129
+ noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
130
+
131
+ w_noise = torch.randn_like(w_opt) * noise_strength
132
+ if w_plus:
133
+ ws = w_opt + w_noise
134
+ else:
135
+ ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1])
136
+
137
+ img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True)
138
+
139
+ #latent_n = latent_noise(latent_in, noise_strength.item())
140
+
141
+ #latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
142
+ #img_gen, F = g_ema.generate(latent, noise)
143
+
144
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
145
+
146
+ if img_gen.shape[2] > 256:
147
+ img_gen = F.interpolate(img_gen, size=(256, 256), mode='area')
148
+
149
+ p_loss = percept(img_gen,imgs)
150
+
151
+
152
+ # Noise regularization.
153
+ reg_loss = 0.0
154
+ for v in noises.values():
155
+ noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
156
+ while True:
157
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
158
+ reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
159
+ if noise.shape[2] <= 8:
160
+ break
161
+ noise = F.avg_pool2d(noise, kernel_size=2)
162
+ mse_loss = F.mse_loss(img_gen, imgs)
163
+
164
+ loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss
165
+
166
+ optimizer.zero_grad()
167
+ loss.backward()
168
+ optimizer.step()
169
+
170
+ # Normalize noise.
171
+ with torch.no_grad():
172
+ for buf in noises.values():
173
+ buf -= buf.mean()
174
+ buf *= buf.square().mean().rsqrt()
175
+
176
+ if (i + 1) % 100 == 0:
177
+ latent_path.append(w_opt.detach().clone())
178
+
179
+ pbar.set_description(
180
+ (
181
+ f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};"
182
+ f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
183
+ )
184
+ )
185
+
186
+ #latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
187
+ #img_gen, F = g_ema.generate(latent, noise)
188
+ if w_plus:
189
+ ws = latent_path[-1]
190
+ else:
191
+ ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1])
192
+
193
+ img_gen = g_ema.synthesis(ws, noise_mode='const')
194
+
195
+
196
+ result = {
197
+ "latent": latent_path[-1],
198
+ "sample": img_gen,
199
+ "real": imgs,
200
+ }
201
+
202
+ return result
203
+
204
+ def toogle_grad(model, flag=True):
205
+ for p in model.parameters():
206
+ p.requires_grad = flag
207
+
208
+
209
+ class PTI:
210
+ def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ):
211
+ self.g_ema = G
212
+ self.l2_lambda = l2_lambda
213
+ self.max_pti_step = max_pti_step
214
+ self.pti_lr = pti_lr
215
+ self.percept = percept
216
+ def cacl_loss(self,percept, generated_image,real_image):
217
+
218
+ mse_loss = F.mse_loss(generated_image, real_image)
219
+ p_loss = percept(generated_image, real_image).sum()
220
+ loss = p_loss +self.l2_lambda * mse_loss
221
+ return loss
222
+
223
+ def train(self,img,w_plus=False):
224
+ if not torch.cuda.is_available():
225
+ device = 'cpu'
226
+ else:
227
+ device = 'cuda'
228
+ if torch.is_tensor(img) == False:
229
+ transform = transforms.Compose(
230
+ [
231
+ transforms.Resize(self.g_ema.img_resolution, ),
232
+ transforms.CenterCrop(self.g_ema.img_resolution),
233
+ transforms.ToTensor(),
234
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
235
+ ]
236
+ )
237
+
238
+ real_img = transform(img).to(device).unsqueeze(0)
239
+
240
+ else:
241
+ img = transforms.functional.resize(img, self.g_ema.img_resolution)
242
+ transform = transforms.Compose(
243
+ [
244
+ transforms.CenterCrop(self.g_ema.img_resolution),
245
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
246
+ ]
247
+ )
248
+ real_img = transform(img).to(device).unsqueeze(0)
249
+ inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus,device=device)
250
+ w_pivot = inversed_result['latent']
251
+ if w_plus:
252
+ ws = w_pivot
253
+ else:
254
+ ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1])
255
+ toogle_grad(self.g_ema,True)
256
+ optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr)
257
+ print('start PTI')
258
+ pbar = tqdm(range(self.max_pti_step))
259
+ for i in pbar:
260
+ t = i / self.max_pti_step
261
+ lr = get_lr(t, self.pti_lr)
262
+ optimizer.param_groups[0]["lr"] = lr
263
+
264
+ generated_image = self.g_ema.synthesis(ws,noise_mode='const')
265
+ loss = self.cacl_loss(self.percept,generated_image,real_img)
266
+ pbar.set_description(
267
+ (
268
+ f"loss: {loss.item():.4f}"
269
+ )
270
+ )
271
+ optimizer.zero_grad()
272
+ loss.backward()
273
+ optimizer.step()
274
+ with torch.no_grad():
275
+ generated_image = self.g_ema.synthesis(ws, noise_mode='const')
276
+
277
+ return generated_image,ws
models/draggan/gan_inv/lpips/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
models/draggan/gan_inv/lpips/base_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.autograd import Variable
5
+ from pdb import set_trace as st
6
+ from IPython import embed
7
+
8
+ class BaseModel():
9
+ def __init__(self):
10
+ pass
11
+
12
+ def name(self):
13
+ return 'BaseModel'
14
+
15
+ def initialize(self, use_gpu=True, gpu_ids=[0]):
16
+ self.use_gpu = use_gpu
17
+ self.gpu_ids = gpu_ids
18
+
19
+ def forward(self):
20
+ pass
21
+
22
+ def get_image_paths(self):
23
+ pass
24
+
25
+ def optimize_parameters(self):
26
+ pass
27
+
28
+ def get_current_visuals(self):
29
+ return self.input
30
+
31
+ def get_current_errors(self):
32
+ return {}
33
+
34
+ def save(self, label):
35
+ pass
36
+
37
+ # helper saving function that can be used by subclasses
38
+ def save_network(self, network, path, network_label, epoch_label):
39
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40
+ save_path = os.path.join(path, save_filename)
41
+ torch.save(network.state_dict(), save_path)
42
+
43
+ # helper loading function that can be used by subclasses
44
+ def load_network(self, network, network_label, epoch_label):
45
+ save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46
+ save_path = os.path.join(self.save_dir, save_filename)
47
+ print('Loading network from %s'%save_path)
48
+ network.load_state_dict(torch.load(save_path))
49
+
50
+ def update_learning_rate():
51
+ pass
52
+
53
+ def get_image_paths(self):
54
+ return self.image_paths
55
+
56
+ def save_done(self, flag=False):
57
+ np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58
+ np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
models/draggan/gan_inv/lpips/dist_model.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ import os
9
+ from collections import OrderedDict
10
+ from torch.autograd import Variable
11
+ import itertools
12
+ from .base_model import BaseModel
13
+ from scipy.ndimage import zoom
14
+ import fractions
15
+ import functools
16
+ import skimage.transform
17
+ from tqdm import tqdm
18
+ import urllib
19
+
20
+ from IPython import embed
21
+
22
+ from . import networks_basic as networks
23
+ from . import util
24
+
25
+
26
+ class DownloadProgressBar(tqdm):
27
+ def update_to(self, b=1, bsize=1, tsize=None):
28
+ if tsize is not None:
29
+ self.total = tsize
30
+ self.update(b * bsize - self.n)
31
+
32
+
33
+ def get_path(base_path):
34
+ BASE_DIR = os.path.join('checkpoints')
35
+
36
+ save_path = os.path.join(BASE_DIR, base_path)
37
+ if not os.path.exists(save_path):
38
+ url = f"https://huggingface.co/aaronb/StyleGAN2/resolve/main/{base_path}"
39
+ print(f'{base_path} not found')
40
+ print('Try to download from huggingface: ', url)
41
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
42
+ download_url(url, save_path)
43
+ print('Downloaded to ', save_path)
44
+ return save_path
45
+
46
+
47
+ def download_url(url, output_path):
48
+ with DownloadProgressBar(unit='B', unit_scale=True,
49
+ miniters=1, desc=url.split('/')[-1]) as t:
50
+ urllib.request.urlretrieve(url, filename=output_path, reporthook=t.update_to)
51
+
52
+
53
+ class DistModel(BaseModel):
54
+ def name(self):
55
+ return self.model_name
56
+
57
+ def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
58
+ use_gpu=True, printNet=False, spatial=False,
59
+ is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
60
+ '''
61
+ INPUTS
62
+ model - ['net-lin'] for linearly calibrated network
63
+ ['net'] for off-the-shelf network
64
+ ['L2'] for L2 distance in Lab colorspace
65
+ ['SSIM'] for ssim in RGB colorspace
66
+ net - ['squeeze','alex','vgg']
67
+ model_path - if None, will look in weights/[NET_NAME].pth
68
+ colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
69
+ use_gpu - bool - whether or not to use a GPU
70
+ printNet - bool - whether or not to print network architecture out
71
+ spatial - bool - whether to output an array containing varying distances across spatial dimensions
72
+ spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
73
+ spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
74
+ spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
75
+ is_train - bool - [True] for training mode
76
+ lr - float - initial learning rate
77
+ beta1 - float - initial momentum term for adam
78
+ version - 0.1 for latest, 0.0 was original (with a bug)
79
+ gpu_ids - int array - [0] by default, gpus to use
80
+ '''
81
+ BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
82
+
83
+ self.model = model
84
+ self.net = net
85
+ self.is_train = is_train
86
+ self.spatial = spatial
87
+ self.gpu_ids = gpu_ids
88
+ self.model_name = '%s [%s]' % (model, net)
89
+
90
+ if(self.model == 'net-lin'): # pretrained net + linear layer
91
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
92
+ use_dropout=True, spatial=spatial, version=version, lpips=True)
93
+ kw = {}
94
+ if not use_gpu:
95
+ kw['map_location'] = 'cpu'
96
+ if(model_path is None):
97
+ model_path = get_path('weights/v%s/%s.pth' % (version, net))
98
+
99
+ if(not is_train):
100
+ print('Loading model from: %s' % model_path)
101
+ self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
102
+
103
+ elif(self.model == 'net'): # pretrained network
104
+ self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
105
+ elif(self.model in ['L2', 'l2']):
106
+ self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing
107
+ self.model_name = 'L2'
108
+ elif(self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']):
109
+ self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace)
110
+ self.model_name = 'SSIM'
111
+ else:
112
+ raise ValueError("Model [%s] not recognized." % self.model)
113
+
114
+ self.parameters = list(self.net.parameters())
115
+
116
+ if self.is_train: # training mode
117
+ # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
118
+ self.rankLoss = networks.BCERankingLoss()
119
+ self.parameters += list(self.rankLoss.net.parameters())
120
+ self.lr = lr
121
+ self.old_lr = lr
122
+ self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
123
+ else: # test mode
124
+ self.net.eval()
125
+
126
+ if(use_gpu):
127
+ self.net.to(gpu_ids[0])
128
+ self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
129
+ if(self.is_train):
130
+ self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
131
+
132
+ if(printNet):
133
+ print('---------- Networks initialized -------------')
134
+ networks.print_network(self.net)
135
+ print('-----------------------------------------------')
136
+
137
+ def forward(self, in0, in1, retPerLayer=False):
138
+ ''' Function computes the distance between image patches in0 and in1
139
+ INPUTS
140
+ in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
141
+ OUTPUT
142
+ computed distances between in0 and in1
143
+ '''
144
+
145
+ return self.net.forward(in0, in1, retPerLayer=retPerLayer)
146
+
147
+ # ***** TRAINING FUNCTIONS *****
148
+ def optimize_parameters(self):
149
+ self.forward_train()
150
+ self.optimizer_net.zero_grad()
151
+ self.backward_train()
152
+ self.optimizer_net.step()
153
+ self.clamp_weights()
154
+
155
+ def clamp_weights(self):
156
+ for module in self.net.modules():
157
+ if(hasattr(module, 'weight') and module.kernel_size == (1, 1)):
158
+ module.weight.data = torch.clamp(module.weight.data, min=0)
159
+
160
+ def set_input(self, data):
161
+ self.input_ref = data['ref']
162
+ self.input_p0 = data['p0']
163
+ self.input_p1 = data['p1']
164
+ self.input_judge = data['judge']
165
+
166
+ if(self.use_gpu):
167
+ self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
168
+ self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
169
+ self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
170
+ self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
171
+
172
+ self.var_ref = Variable(self.input_ref, requires_grad=True)
173
+ self.var_p0 = Variable(self.input_p0, requires_grad=True)
174
+ self.var_p1 = Variable(self.input_p1, requires_grad=True)
175
+
176
+ def forward_train(self): # run forward pass
177
+ # print(self.net.module.scaling_layer.shift)
178
+ # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
179
+
180
+ self.d0 = self.forward(self.var_ref, self.var_p0)
181
+ self.d1 = self.forward(self.var_ref, self.var_p1)
182
+ self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge)
183
+
184
+ self.var_judge = Variable(1. * self.input_judge).view(self.d0.size())
185
+
186
+ self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.)
187
+
188
+ return self.loss_total
189
+
190
+ def backward_train(self):
191
+ torch.mean(self.loss_total).backward()
192
+
193
+ def compute_accuracy(self, d0, d1, judge):
194
+ ''' d0, d1 are Variables, judge is a Tensor '''
195
+ d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten()
196
+ judge_per = judge.cpu().numpy().flatten()
197
+ return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per)
198
+
199
+ def get_current_errors(self):
200
+ retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
201
+ ('acc_r', self.acc_r)])
202
+
203
+ for key in retDict.keys():
204
+ retDict[key] = np.mean(retDict[key])
205
+
206
+ return retDict
207
+
208
+ def get_current_visuals(self):
209
+ zoom_factor = 256 / self.var_ref.data.size()[2]
210
+
211
+ ref_img = util.tensor2im(self.var_ref.data)
212
+ p0_img = util.tensor2im(self.var_p0.data)
213
+ p1_img = util.tensor2im(self.var_p1.data)
214
+
215
+ ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0)
216
+ p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0)
217
+ p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0)
218
+
219
+ return OrderedDict([('ref', ref_img_vis),
220
+ ('p0', p0_img_vis),
221
+ ('p1', p1_img_vis)])
222
+
223
+ def save(self, path, label):
224
+ if(self.use_gpu):
225
+ self.save_network(self.net.module, path, '', label)
226
+ else:
227
+ self.save_network(self.net, path, '', label)
228
+ self.save_network(self.rankLoss.net, path, 'rank', label)
229
+
230
+ def update_learning_rate(self, nepoch_decay):
231
+ lrd = self.lr / nepoch_decay
232
+ lr = self.old_lr - lrd
233
+
234
+ for param_group in self.optimizer_net.param_groups:
235
+ param_group['lr'] = lr
236
+
237
+ print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr))
238
+ self.old_lr = lr
239
+
240
+
241
+ def score_2afc_dataset(data_loader, func, name=''):
242
+ ''' Function computes Two Alternative Forced Choice (2AFC) score using
243
+ distance function 'func' in dataset 'data_loader'
244
+ INPUTS
245
+ data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
246
+ func - callable distance function - calling d=func(in0,in1) should take 2
247
+ pytorch tensors with shape Nx3xXxY, and return numpy array of length N
248
+ OUTPUTS
249
+ [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
250
+ [1] - dictionary with following elements
251
+ d0s,d1s - N arrays containing distances between reference patch to perturbed patches
252
+ gts - N array in [0,1], preferred patch selected by human evaluators
253
+ (closer to "0" for left patch p0, "1" for right patch p1,
254
+ "0.6" means 60pct people preferred right patch, 40pct preferred left)
255
+ scores - N array in [0,1], corresponding to what percentage function agreed with humans
256
+ CONSTS
257
+ N - number of test triplets in data_loader
258
+ '''
259
+
260
+ d0s = []
261
+ d1s = []
262
+ gts = []
263
+
264
+ for data in tqdm(data_loader.load_data(), desc=name):
265
+ d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist()
266
+ d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist()
267
+ gts += data['judge'].cpu().numpy().flatten().tolist()
268
+
269
+ d0s = np.array(d0s)
270
+ d1s = np.array(d1s)
271
+ gts = np.array(gts)
272
+ scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5
273
+
274
+ return(np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores))
275
+
276
+
277
+ def score_jnd_dataset(data_loader, func, name=''):
278
+ ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
279
+ INPUTS
280
+ data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
281
+ func - callable distance function - calling d=func(in0,in1) should take 2
282
+ pytorch tensors with shape Nx3xXxY, and return pytorch array of length N
283
+ OUTPUTS
284
+ [0] - JND score in [0,1], mAP score (area under precision-recall curve)
285
+ [1] - dictionary with following elements
286
+ ds - N array containing distances between two patches shown to human evaluator
287
+ sames - N array containing fraction of people who thought the two patches were identical
288
+ CONSTS
289
+ N - number of test triplets in data_loader
290
+ '''
291
+
292
+ ds = []
293
+ gts = []
294
+
295
+ for data in tqdm(data_loader.load_data(), desc=name):
296
+ ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist()
297
+ gts += data['same'].cpu().numpy().flatten().tolist()
298
+
299
+ sames = np.array(gts)
300
+ ds = np.array(ds)
301
+
302
+ sorted_inds = np.argsort(ds)
303
+ ds_sorted = ds[sorted_inds]
304
+ sames_sorted = sames[sorted_inds]
305
+
306
+ TPs = np.cumsum(sames_sorted)
307
+ FPs = np.cumsum(1 - sames_sorted)
308
+ FNs = np.sum(sames_sorted) - TPs
309
+
310
+ precs = TPs / (TPs + FPs)
311
+ recs = TPs / (TPs + FNs)
312
+ score = util.voc_ap(recs, precs)
313
+
314
+ return(score, dict(ds=ds, sames=sames))
models/draggan/gan_inv/lpips/networks_basic.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+
4
+ import sys
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.init as init
8
+ from torch.autograd import Variable
9
+ import numpy as np
10
+ from pdb import set_trace as st
11
+ from skimage import color
12
+ from IPython import embed
13
+ from . import pretrained_networks as pn
14
+
15
+ from . import util
16
+
17
+
18
+ def spatial_average(in_tens, keepdim=True):
19
+ return in_tens.mean([2,3],keepdim=keepdim)
20
+
21
+ def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
22
+ in_H = in_tens.shape[2]
23
+ scale_factor = 1.*out_H/in_H
24
+
25
+ return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
26
+
27
+ # Learned perceptual metric
28
+ class PNetLin(nn.Module):
29
+ def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
30
+ super(PNetLin, self).__init__()
31
+
32
+ self.pnet_type = pnet_type
33
+ self.pnet_tune = pnet_tune
34
+ self.pnet_rand = pnet_rand
35
+ self.spatial = spatial
36
+ self.lpips = lpips
37
+ self.version = version
38
+ self.scaling_layer = ScalingLayer()
39
+
40
+ if(self.pnet_type in ['vgg','vgg16']):
41
+ net_type = pn.vgg16
42
+ self.chns = [64,128,256,512,512]
43
+ elif(self.pnet_type=='alex'):
44
+ net_type = pn.alexnet
45
+ self.chns = [64,192,384,256,256]
46
+ elif(self.pnet_type=='squeeze'):
47
+ net_type = pn.squeezenet
48
+ self.chns = [64,128,256,384,384,512,512]
49
+ self.L = len(self.chns)
50
+
51
+ self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
52
+
53
+ if(lpips):
54
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
55
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
56
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
57
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
58
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
59
+ self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
60
+ if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
61
+ self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
62
+ self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
63
+ self.lins+=[self.lin5,self.lin6]
64
+
65
+ def forward(self, in0, in1, retPerLayer=False):
66
+ # v0.0 - original release had a bug, where input was not scaled
67
+ in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
68
+ outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
69
+ feats0, feats1, diffs = {}, {}, {}
70
+
71
+ for kk in range(self.L):
72
+ feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
73
+ diffs[kk] = (feats0[kk]-feats1[kk])**2
74
+
75
+ if(self.lpips):
76
+ if(self.spatial):
77
+ res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
78
+ else:
79
+ res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
80
+ else:
81
+ if(self.spatial):
82
+ res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
83
+ else:
84
+ res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
85
+
86
+ val = res[0]
87
+ for l in range(1,self.L):
88
+ val += res[l]
89
+
90
+ if(retPerLayer):
91
+ return (val, res)
92
+ else:
93
+ return val
94
+
95
+ class ScalingLayer(nn.Module):
96
+ def __init__(self):
97
+ super(ScalingLayer, self).__init__()
98
+ self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
99
+ self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
100
+
101
+ def forward(self, inp):
102
+ return (inp - self.shift) / self.scale
103
+
104
+
105
+ class NetLinLayer(nn.Module):
106
+ ''' A single linear layer which does a 1x1 conv '''
107
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
108
+ super(NetLinLayer, self).__init__()
109
+
110
+ layers = [nn.Dropout(),] if(use_dropout) else []
111
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
112
+ self.model = nn.Sequential(*layers)
113
+
114
+
115
+ class Dist2LogitLayer(nn.Module):
116
+ ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
117
+ def __init__(self, chn_mid=32, use_sigmoid=True):
118
+ super(Dist2LogitLayer, self).__init__()
119
+
120
+ layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
121
+ layers += [nn.LeakyReLU(0.2,True),]
122
+ layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
123
+ layers += [nn.LeakyReLU(0.2,True),]
124
+ layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
125
+ if(use_sigmoid):
126
+ layers += [nn.Sigmoid(),]
127
+ self.model = nn.Sequential(*layers)
128
+
129
+ def forward(self,d0,d1,eps=0.1):
130
+ return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
131
+
132
+ class BCERankingLoss(nn.Module):
133
+ def __init__(self, chn_mid=32):
134
+ super(BCERankingLoss, self).__init__()
135
+ self.net = Dist2LogitLayer(chn_mid=chn_mid)
136
+ # self.parameters = list(self.net.parameters())
137
+ self.loss = torch.nn.BCELoss()
138
+
139
+ def forward(self, d0, d1, judge):
140
+ per = (judge+1.)/2.
141
+ self.logit = self.net.forward(d0,d1)
142
+ return self.loss(self.logit, per)
143
+
144
+ # L2, DSSIM metrics
145
+ class FakeNet(nn.Module):
146
+ def __init__(self, use_gpu=True, colorspace='Lab'):
147
+ super(FakeNet, self).__init__()
148
+ self.use_gpu = use_gpu
149
+ self.colorspace=colorspace
150
+
151
+ class L2(FakeNet):
152
+
153
+ def forward(self, in0, in1, retPerLayer=None):
154
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
155
+
156
+ if(self.colorspace=='RGB'):
157
+ (N,C,X,Y) = in0.size()
158
+ value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
159
+ return value
160
+ elif(self.colorspace=='Lab'):
161
+ value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
162
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
163
+ ret_var = Variable( torch.Tensor((value,) ) )
164
+ if(self.use_gpu):
165
+ ret_var = ret_var.cuda()
166
+ return ret_var
167
+
168
+ class DSSIM(FakeNet):
169
+
170
+ def forward(self, in0, in1, retPerLayer=None):
171
+ assert(in0.size()[0]==1) # currently only supports batchSize 1
172
+
173
+ if(self.colorspace=='RGB'):
174
+ value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
175
+ elif(self.colorspace=='Lab'):
176
+ value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
177
+ util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
178
+ ret_var = Variable( torch.Tensor((value,) ) )
179
+ if(self.use_gpu):
180
+ ret_var = ret_var.cuda()
181
+ return ret_var
182
+
183
+ def print_network(net):
184
+ num_params = 0
185
+ for param in net.parameters():
186
+ num_params += param.numel()
187
+ print('Network',net)
188
+ print('Total number of parameters: %d' % num_params)
models/draggan/gan_inv/lpips/pretrained_networks.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+ import torch
3
+ from torchvision import models as tv
4
+ from IPython import embed
5
+
6
+ class squeezenet(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, pretrained=True):
8
+ super(squeezenet, self).__init__()
9
+ pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10
+ self.slice1 = torch.nn.Sequential()
11
+ self.slice2 = torch.nn.Sequential()
12
+ self.slice3 = torch.nn.Sequential()
13
+ self.slice4 = torch.nn.Sequential()
14
+ self.slice5 = torch.nn.Sequential()
15
+ self.slice6 = torch.nn.Sequential()
16
+ self.slice7 = torch.nn.Sequential()
17
+ self.N_slices = 7
18
+ for x in range(2):
19
+ self.slice1.add_module(str(x), pretrained_features[x])
20
+ for x in range(2,5):
21
+ self.slice2.add_module(str(x), pretrained_features[x])
22
+ for x in range(5, 8):
23
+ self.slice3.add_module(str(x), pretrained_features[x])
24
+ for x in range(8, 10):
25
+ self.slice4.add_module(str(x), pretrained_features[x])
26
+ for x in range(10, 11):
27
+ self.slice5.add_module(str(x), pretrained_features[x])
28
+ for x in range(11, 12):
29
+ self.slice6.add_module(str(x), pretrained_features[x])
30
+ for x in range(12, 13):
31
+ self.slice7.add_module(str(x), pretrained_features[x])
32
+ if not requires_grad:
33
+ for param in self.parameters():
34
+ param.requires_grad = False
35
+
36
+ def forward(self, X):
37
+ h = self.slice1(X)
38
+ h_relu1 = h
39
+ h = self.slice2(h)
40
+ h_relu2 = h
41
+ h = self.slice3(h)
42
+ h_relu3 = h
43
+ h = self.slice4(h)
44
+ h_relu4 = h
45
+ h = self.slice5(h)
46
+ h_relu5 = h
47
+ h = self.slice6(h)
48
+ h_relu6 = h
49
+ h = self.slice7(h)
50
+ h_relu7 = h
51
+ vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52
+ out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53
+
54
+ return out
55
+
56
+
57
+ class alexnet(torch.nn.Module):
58
+ def __init__(self, requires_grad=False, pretrained=True):
59
+ super(alexnet, self).__init__()
60
+ alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61
+ self.slice1 = torch.nn.Sequential()
62
+ self.slice2 = torch.nn.Sequential()
63
+ self.slice3 = torch.nn.Sequential()
64
+ self.slice4 = torch.nn.Sequential()
65
+ self.slice5 = torch.nn.Sequential()
66
+ self.N_slices = 5
67
+ for x in range(2):
68
+ self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69
+ for x in range(2, 5):
70
+ self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71
+ for x in range(5, 8):
72
+ self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73
+ for x in range(8, 10):
74
+ self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75
+ for x in range(10, 12):
76
+ self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77
+ if not requires_grad:
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, X):
82
+ h = self.slice1(X)
83
+ h_relu1 = h
84
+ h = self.slice2(h)
85
+ h_relu2 = h
86
+ h = self.slice3(h)
87
+ h_relu3 = h
88
+ h = self.slice4(h)
89
+ h_relu4 = h
90
+ h = self.slice5(h)
91
+ h_relu5 = h
92
+ alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93
+ out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94
+
95
+ return out
96
+
97
+ class vgg16(torch.nn.Module):
98
+ def __init__(self, requires_grad=False, pretrained=True):
99
+ super(vgg16, self).__init__()
100
+ vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101
+ self.slice1 = torch.nn.Sequential()
102
+ self.slice2 = torch.nn.Sequential()
103
+ self.slice3 = torch.nn.Sequential()
104
+ self.slice4 = torch.nn.Sequential()
105
+ self.slice5 = torch.nn.Sequential()
106
+ self.N_slices = 5
107
+ for x in range(4):
108
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
109
+ for x in range(4, 9):
110
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
111
+ for x in range(9, 16):
112
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
113
+ for x in range(16, 23):
114
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
115
+ for x in range(23, 30):
116
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
117
+ if not requires_grad:
118
+ for param in self.parameters():
119
+ param.requires_grad = False
120
+
121
+ def forward(self, X):
122
+ h = self.slice1(X)
123
+ h_relu1_2 = h
124
+ h = self.slice2(h)
125
+ h_relu2_2 = h
126
+ h = self.slice3(h)
127
+ h_relu3_3 = h
128
+ h = self.slice4(h)
129
+ h_relu4_3 = h
130
+ h = self.slice5(h)
131
+ h_relu5_3 = h
132
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134
+
135
+ return out
136
+
137
+
138
+
139
+ class resnet(torch.nn.Module):
140
+ def __init__(self, requires_grad=False, pretrained=True, num=18):
141
+ super(resnet, self).__init__()
142
+ if(num==18):
143
+ self.net = tv.resnet18(pretrained=pretrained)
144
+ elif(num==34):
145
+ self.net = tv.resnet34(pretrained=pretrained)
146
+ elif(num==50):
147
+ self.net = tv.resnet50(pretrained=pretrained)
148
+ elif(num==101):
149
+ self.net = tv.resnet101(pretrained=pretrained)
150
+ elif(num==152):
151
+ self.net = tv.resnet152(pretrained=pretrained)
152
+ self.N_slices = 5
153
+
154
+ self.conv1 = self.net.conv1
155
+ self.bn1 = self.net.bn1
156
+ self.relu = self.net.relu
157
+ self.maxpool = self.net.maxpool
158
+ self.layer1 = self.net.layer1
159
+ self.layer2 = self.net.layer2
160
+ self.layer3 = self.net.layer3
161
+ self.layer4 = self.net.layer4
162
+
163
+ def forward(self, X):
164
+ h = self.conv1(X)
165
+ h = self.bn1(h)
166
+ h = self.relu(h)
167
+ h_relu1 = h
168
+ h = self.maxpool(h)
169
+ h = self.layer1(h)
170
+ h_conv2 = h
171
+ h = self.layer2(h)
172
+ h_conv3 = h
173
+ h = self.layer3(h)
174
+ h_conv4 = h
175
+ h = self.layer4(h)
176
+ h_conv5 = h
177
+
178
+ outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179
+ out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180
+
181
+ return out
models/draggan/gan_inv/lpips/util.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import numpy as np
7
+ from skimage.metrics import structural_similarity
8
+ import torch
9
+
10
+
11
+ from . import dist_model
12
+
13
+ class PerceptualLoss(torch.nn.Module):
14
+ def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
15
+ # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16
+ super(PerceptualLoss, self).__init__()
17
+ print('Setting up Perceptual loss...')
18
+ self.use_gpu = use_gpu
19
+ self.spatial = spatial
20
+ self.gpu_ids = gpu_ids
21
+ self.model = dist_model.DistModel()
22
+ self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
23
+ print('...[%s] initialized'%self.model.name())
24
+ print('...Done')
25
+
26
+ def forward(self, pred, target, normalize=False):
27
+ """
28
+ Pred and target are Variables.
29
+ If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30
+ If normalize is False, assumes the images are already between [-1,+1]
31
+
32
+ Inputs pred and target are Nx3xHxW
33
+ Output pytorch Variable N long
34
+ """
35
+
36
+ if normalize:
37
+ target = 2 * target - 1
38
+ pred = 2 * pred - 1
39
+
40
+ return self.model.forward(target, pred)
41
+
42
+ def normalize_tensor(in_feat,eps=1e-10):
43
+ norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44
+ return in_feat/(norm_factor+eps)
45
+
46
+ def l2(p0, p1, range=255.):
47
+ return .5*np.mean((p0 / range - p1 / range)**2)
48
+
49
+ def psnr(p0, p1, peak=255.):
50
+ return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51
+
52
+ def dssim(p0, p1, range=255.):
53
+ return (1 - structural_similarity(p0, p1, data_range=range, multichannel=True)) / 2.
54
+
55
+ def rgb2lab(in_img,mean_cent=False):
56
+ from skimage import color
57
+ img_lab = color.rgb2lab(in_img)
58
+ if(mean_cent):
59
+ img_lab[:,:,0] = img_lab[:,:,0]-50
60
+ return img_lab
61
+
62
+ def tensor2np(tensor_obj):
63
+ # change dimension of a tensor object into a numpy array
64
+ return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
65
+
66
+ def np2tensor(np_obj):
67
+ # change dimenion of np array into tensor array
68
+ return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
69
+
70
+ def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
71
+ # image tensor to lab tensor
72
+ from skimage import color
73
+
74
+ img = tensor2im(image_tensor)
75
+ img_lab = color.rgb2lab(img)
76
+ if(mc_only):
77
+ img_lab[:,:,0] = img_lab[:,:,0]-50
78
+ if(to_norm and not mc_only):
79
+ img_lab[:,:,0] = img_lab[:,:,0]-50
80
+ img_lab = img_lab/100.
81
+
82
+ return np2tensor(img_lab)
83
+
84
+ def tensorlab2tensor(lab_tensor,return_inbnd=False):
85
+ from skimage import color
86
+ import warnings
87
+ warnings.filterwarnings("ignore")
88
+
89
+ lab = tensor2np(lab_tensor)*100.
90
+ lab[:,:,0] = lab[:,:,0]+50
91
+
92
+ rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
93
+ if(return_inbnd):
94
+ # convert back to lab, see if we match
95
+ lab_back = color.rgb2lab(rgb_back.astype('uint8'))
96
+ mask = 1.*np.isclose(lab_back,lab,atol=2.)
97
+ mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
98
+ return (im2tensor(rgb_back),mask)
99
+ else:
100
+ return im2tensor(rgb_back)
101
+
102
+ def rgb2lab(input):
103
+ from skimage import color
104
+ return color.rgb2lab(input / 255.)
105
+
106
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
107
+ image_numpy = image_tensor[0].cpu().float().numpy()
108
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
109
+ return image_numpy.astype(imtype)
110
+
111
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
112
+ return torch.Tensor((image / factor - cent)
113
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
114
+
115
+ def tensor2vec(vector_tensor):
116
+ return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
117
+
118
+ def voc_ap(rec, prec, use_07_metric=False):
119
+ """ ap = voc_ap(rec, prec, [use_07_metric])
120
+ Compute VOC AP given precision and recall.
121
+ If use_07_metric is true, uses the
122
+ VOC 07 11 point method (default:False).
123
+ """
124
+ if use_07_metric:
125
+ # 11 point metric
126
+ ap = 0.
127
+ for t in np.arange(0., 1.1, 0.1):
128
+ if np.sum(rec >= t) == 0:
129
+ p = 0
130
+ else:
131
+ p = np.max(prec[rec >= t])
132
+ ap = ap + p / 11.
133
+ else:
134
+ # correct AP calculation
135
+ # first append sentinel values at the end
136
+ mrec = np.concatenate(([0.], rec, [1.]))
137
+ mpre = np.concatenate(([0.], prec, [0.]))
138
+
139
+ # compute the precision envelope
140
+ for i in range(mpre.size - 1, 0, -1):
141
+ mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
142
+
143
+ # to calculate area under PR curve, look for points
144
+ # where X axis (recall) changes value
145
+ i = np.where(mrec[1:] != mrec[:-1])[0]
146
+
147
+ # and sum (\Delta recall) * prec
148
+ ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
149
+ return ap
150
+
151
+ def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
152
+ # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
153
+ image_numpy = image_tensor[0].cpu().float().numpy()
154
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
155
+ return image_numpy.astype(imtype)
156
+
157
+ def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
158
+ # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
159
+ return torch.Tensor((image / factor - cent)
160
+ [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
models/draggan/legacy.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Converting legacy network pickle into the new format."""
10
+
11
+ import click
12
+ import pickle
13
+ import re
14
+ import copy
15
+ import numpy as np
16
+ import torch
17
+ import sys, os
18
+ sys.path.append(os.path.dirname(__file__))
19
+ import dnnlib as dnnlib
20
+ from torch_utils import misc
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ def load_network_pkl(f, force_fp16=False):
25
+ data = _LegacyUnpickler(f).load()
26
+
27
+ # Legacy TensorFlow pickle => convert.
28
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
29
+ tf_G, tf_D, tf_Gs = data
30
+ G = convert_tf_generator(tf_G)
31
+ D = convert_tf_discriminator(tf_D)
32
+ G_ema = convert_tf_generator(tf_Gs)
33
+ data = dict(G=G, D=D, G_ema=G_ema)
34
+
35
+ # Add missing fields.
36
+ if 'training_set_kwargs' not in data:
37
+ data['training_set_kwargs'] = None
38
+ if 'augment_pipe' not in data:
39
+ data['augment_pipe'] = None
40
+
41
+ # Validate contents.
42
+ assert isinstance(data['G'], torch.nn.Module)
43
+ assert isinstance(data['D'], torch.nn.Module)
44
+ assert isinstance(data['G_ema'], torch.nn.Module)
45
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
46
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
47
+
48
+ # Force FP16.
49
+ if force_fp16:
50
+ for key in ['G', 'D', 'G_ema']:
51
+ old = data[key]
52
+ kwargs = copy.deepcopy(old.init_kwargs)
53
+ fp16_kwargs = kwargs.get('synthesis_kwargs', kwargs)
54
+ fp16_kwargs.num_fp16_res = 4
55
+ fp16_kwargs.conv_clamp = 256
56
+ if kwargs != old.init_kwargs:
57
+ new = type(old)(**kwargs).eval().requires_grad_(False)
58
+ misc.copy_params_and_buffers(old, new, require_all=True)
59
+ data[key] = new
60
+ return data
61
+
62
+ #----------------------------------------------------------------------------
63
+
64
+ class _TFNetworkStub(dnnlib.EasyDict):
65
+ pass
66
+
67
+ class _LegacyUnpickler(pickle.Unpickler):
68
+ def find_class(self, module, name):
69
+ if module == 'dnnlib.tflib.network' and name == 'Network':
70
+ return _TFNetworkStub
71
+ return super().find_class(module, name)
72
+
73
+ #----------------------------------------------------------------------------
74
+
75
+ def _collect_tf_params(tf_net):
76
+ # pylint: disable=protected-access
77
+ tf_params = dict()
78
+ def recurse(prefix, tf_net):
79
+ for name, value in tf_net.variables:
80
+ tf_params[prefix + name] = value
81
+ for name, comp in tf_net.components.items():
82
+ recurse(prefix + name + '/', comp)
83
+ recurse('', tf_net)
84
+ return tf_params
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def _populate_module_params(module, *patterns):
89
+ for name, tensor in misc.named_params_and_buffers(module):
90
+ found = False
91
+ value = None
92
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93
+ match = re.fullmatch(pattern, name)
94
+ if match:
95
+ found = True
96
+ if value_fn is not None:
97
+ value = value_fn(*match.groups())
98
+ break
99
+ try:
100
+ assert found
101
+ if value is not None:
102
+ tensor.copy_(torch.from_numpy(np.array(value)))
103
+ except:
104
+ print(name, list(tensor.shape))
105
+ raise
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def convert_tf_generator(tf_G):
110
+ if tf_G.version < 4:
111
+ raise ValueError('TensorFlow pickle version too low')
112
+
113
+ # Collect kwargs.
114
+ tf_kwargs = tf_G.static_kwargs
115
+ known_kwargs = set()
116
+ def kwarg(tf_name, default=None, none=None):
117
+ known_kwargs.add(tf_name)
118
+ val = tf_kwargs.get(tf_name, default)
119
+ return val if val is not None else none
120
+
121
+ # Convert kwargs.
122
+ from training import networks_stylegan2
123
+ network_class = networks_stylegan2.Generator
124
+ kwargs = dnnlib.EasyDict(
125
+ z_dim = kwarg('latent_size', 512),
126
+ c_dim = kwarg('label_size', 0),
127
+ w_dim = kwarg('dlatent_size', 512),
128
+ img_resolution = kwarg('resolution', 1024),
129
+ img_channels = kwarg('num_channels', 3),
130
+ channel_base = kwarg('fmap_base', 16384) * 2,
131
+ channel_max = kwarg('fmap_max', 512),
132
+ num_fp16_res = kwarg('num_fp16_res', 0),
133
+ conv_clamp = kwarg('conv_clamp', None),
134
+ architecture = kwarg('architecture', 'skip'),
135
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
136
+ use_noise = kwarg('use_noise', True),
137
+ activation = kwarg('nonlinearity', 'lrelu'),
138
+ mapping_kwargs = dnnlib.EasyDict(
139
+ num_layers = kwarg('mapping_layers', 8),
140
+ embed_features = kwarg('label_fmaps', None),
141
+ layer_features = kwarg('mapping_fmaps', None),
142
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
143
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
144
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
145
+ ),
146
+ )
147
+
148
+ # Check for unknown kwargs.
149
+ kwarg('truncation_psi')
150
+ kwarg('truncation_cutoff')
151
+ kwarg('style_mixing_prob')
152
+ kwarg('structure')
153
+ kwarg('conditioning')
154
+ kwarg('fused_modconv')
155
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
156
+ if len(unknown_kwargs) > 0:
157
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
158
+
159
+ # Collect params.
160
+ tf_params = _collect_tf_params(tf_G)
161
+ for name, value in list(tf_params.items()):
162
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
163
+ if match:
164
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
165
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
166
+ kwargs.synthesis.kwargs.architecture = 'orig'
167
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
168
+
169
+ # Convert params.
170
+ G = network_class(**kwargs).eval().requires_grad_(False)
171
+ # pylint: disable=unnecessary-lambda
172
+ # pylint: disable=f-string-without-interpolation
173
+ _populate_module_params(G,
174
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
175
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
176
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
177
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
178
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
179
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
180
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
181
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
182
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
183
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
184
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
185
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
186
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
187
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
188
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
189
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
190
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
191
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
192
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
193
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
194
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
195
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
196
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
197
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
198
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
199
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
200
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
201
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
202
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
203
+ r'.*\.resample_filter', None,
204
+ r'.*\.act_filter', None,
205
+ )
206
+ return G
207
+
208
+ #----------------------------------------------------------------------------
209
+
210
+ def convert_tf_discriminator(tf_D):
211
+ if tf_D.version < 4:
212
+ raise ValueError('TensorFlow pickle version too low')
213
+
214
+ # Collect kwargs.
215
+ tf_kwargs = tf_D.static_kwargs
216
+ known_kwargs = set()
217
+ def kwarg(tf_name, default=None):
218
+ known_kwargs.add(tf_name)
219
+ return tf_kwargs.get(tf_name, default)
220
+
221
+ # Convert kwargs.
222
+ kwargs = dnnlib.EasyDict(
223
+ c_dim = kwarg('label_size', 0),
224
+ img_resolution = kwarg('resolution', 1024),
225
+ img_channels = kwarg('num_channels', 3),
226
+ architecture = kwarg('architecture', 'resnet'),
227
+ channel_base = kwarg('fmap_base', 16384) * 2,
228
+ channel_max = kwarg('fmap_max', 512),
229
+ num_fp16_res = kwarg('num_fp16_res', 0),
230
+ conv_clamp = kwarg('conv_clamp', None),
231
+ cmap_dim = kwarg('mapping_fmaps', None),
232
+ block_kwargs = dnnlib.EasyDict(
233
+ activation = kwarg('nonlinearity', 'lrelu'),
234
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
235
+ freeze_layers = kwarg('freeze_layers', 0),
236
+ ),
237
+ mapping_kwargs = dnnlib.EasyDict(
238
+ num_layers = kwarg('mapping_layers', 0),
239
+ embed_features = kwarg('mapping_fmaps', None),
240
+ layer_features = kwarg('mapping_fmaps', None),
241
+ activation = kwarg('nonlinearity', 'lrelu'),
242
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
243
+ ),
244
+ epilogue_kwargs = dnnlib.EasyDict(
245
+ mbstd_group_size = kwarg('mbstd_group_size', None),
246
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
247
+ activation = kwarg('nonlinearity', 'lrelu'),
248
+ ),
249
+ )
250
+
251
+ # Check for unknown kwargs.
252
+ kwarg('structure')
253
+ kwarg('conditioning')
254
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
255
+ if len(unknown_kwargs) > 0:
256
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
257
+
258
+ # Collect params.
259
+ tf_params = _collect_tf_params(tf_D)
260
+ for name, value in list(tf_params.items()):
261
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
262
+ if match:
263
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
264
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
265
+ kwargs.architecture = 'orig'
266
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
267
+
268
+ # Convert params.
269
+ from training import networks_stylegan2
270
+ D = networks_stylegan2.Discriminator(**kwargs).eval().requires_grad_(False)
271
+ # pylint: disable=unnecessary-lambda
272
+ # pylint: disable=f-string-without-interpolation
273
+ _populate_module_params(D,
274
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
275
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
276
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
277
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
278
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
279
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
280
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
281
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
282
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
283
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
284
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
285
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
286
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
287
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
288
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
289
+ r'.*\.resample_filter', None,
290
+ )
291
+ return D
292
+
293
+ #----------------------------------------------------------------------------
294
+
295
+ @click.command()
296
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
297
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
298
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
299
+ def convert_network_pickle(source, dest, force_fp16):
300
+ """Convert legacy network pickle into the native PyTorch format.
301
+
302
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
303
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
304
+
305
+ Example:
306
+
307
+ \b
308
+ python legacy.py \\
309
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
310
+ --dest=stylegan2-cat-config-f.pkl
311
+ """
312
+ print(f'Loading "{source}"...')
313
+ with dnnlib.util.open_url(source) as f:
314
+ data = load_network_pkl(f, force_fp16=force_fp16)
315
+ print(f'Saving "{dest}"...')
316
+ with open(dest, 'wb') as f:
317
+ pickle.dump(data, f)
318
+ print('Done.')
319
+
320
+ #----------------------------------------------------------------------------
321
+
322
+ if __name__ == "__main__":
323
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
324
+
325
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
models/draggan/torch_utils/custom_ops.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import glob
10
+ import hashlib
11
+ import importlib
12
+ import os
13
+ import re
14
+ import shutil
15
+ import uuid
16
+
17
+ import torch
18
+ import torch.utils.cpp_extension
19
+ from torch.utils.file_baton import FileBaton
20
+
21
+ #----------------------------------------------------------------------------
22
+ # Global options.
23
+
24
+ verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
25
+
26
+ #----------------------------------------------------------------------------
27
+ # Internal helper funcs.
28
+
29
+ def _find_compiler_bindir():
30
+ patterns = [
31
+ 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
32
+ 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
33
+ 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
34
+ 'C:/Program Files*/Microsoft Visual Studio */vc/bin',
35
+ ]
36
+ for pattern in patterns:
37
+ matches = sorted(glob.glob(pattern))
38
+ if len(matches):
39
+ return matches[-1]
40
+ return None
41
+
42
+ #----------------------------------------------------------------------------
43
+
44
+ def _get_mangled_gpu_name():
45
+ name = torch.cuda.get_device_name().lower()
46
+ out = []
47
+ for c in name:
48
+ if re.match('[a-z0-9_-]+', c):
49
+ out.append(c)
50
+ else:
51
+ out.append('-')
52
+ return ''.join(out)
53
+
54
+ #----------------------------------------------------------------------------
55
+ # Main entry point for compiling and loading C++/CUDA plugins.
56
+
57
+ _cached_plugins = dict()
58
+
59
+ def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
60
+ assert verbosity in ['none', 'brief', 'full']
61
+ if headers is None:
62
+ headers = []
63
+ if source_dir is not None:
64
+ sources = [os.path.join(source_dir, fname) for fname in sources]
65
+ headers = [os.path.join(source_dir, fname) for fname in headers]
66
+
67
+ # Already cached?
68
+ if module_name in _cached_plugins:
69
+ return _cached_plugins[module_name]
70
+
71
+ # Print status.
72
+ if verbosity == 'full':
73
+ print(f'Setting up PyTorch plugin "{module_name}"...')
74
+ elif verbosity == 'brief':
75
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
76
+ verbose_build = (verbosity == 'full')
77
+
78
+ # Compile and load.
79
+ try: # pylint: disable=too-many-nested-blocks
80
+ # Make sure we can find the necessary compiler binaries.
81
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
82
+ compiler_bindir = _find_compiler_bindir()
83
+ if compiler_bindir is None:
84
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
85
+ os.environ['PATH'] += ';' + compiler_bindir
86
+
87
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
88
+ # break the build or unnecessarily restrict what's available to nvcc.
89
+ # Unset it to let nvcc decide based on what's available on the
90
+ # machine.
91
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
92
+
93
+ # Incremental build md5sum trickery. Copies all the input source files
94
+ # into a cached build directory under a combined md5 digest of the input
95
+ # source files. Copying is done only if the combined digest has changed.
96
+ # This keeps input file timestamps and filenames the same as in previous
97
+ # extension builds, allowing for fast incremental rebuilds.
98
+ #
99
+ # This optimization is done only in case all the source files reside in
100
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
101
+ # environment variable is set (we take this as a signal that the user
102
+ # actually cares about this.)
103
+ #
104
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
105
+ # around the *.cu dependency bug in ninja config.
106
+ #
107
+ all_source_files = sorted(sources + headers)
108
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
109
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
110
+
111
+ # Compute combined hash digest for all source files.
112
+ hash_md5 = hashlib.md5()
113
+ for src in all_source_files:
114
+ with open(src, 'rb') as f:
115
+ hash_md5.update(f.read())
116
+
117
+ # Select cached build directory name.
118
+ source_digest = hash_md5.hexdigest()
119
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
120
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
121
+
122
+ if not os.path.isdir(cached_build_dir):
123
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
124
+ os.makedirs(tmpdir)
125
+ for src in all_source_files:
126
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
127
+ try:
128
+ os.replace(tmpdir, cached_build_dir) # atomic
129
+ except OSError:
130
+ # source directory already exists, delete tmpdir and its contents.
131
+ shutil.rmtree(tmpdir)
132
+ if not os.path.isdir(cached_build_dir): raise
133
+
134
+ # Compile.
135
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
136
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
137
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
138
+ else:
139
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
140
+
141
+ # Load.
142
+ module = importlib.import_module(module_name)
143
+
144
+ except:
145
+ if verbosity == 'brief':
146
+ print('Failed!')
147
+ raise
148
+
149
+ # Print status and add to cache dict.
150
+ if verbosity == 'full':
151
+ print(f'Done setting up PyTorch plugin "{module_name}".')
152
+ elif verbosity == 'brief':
153
+ print('Done.')
154
+ _cached_plugins[module_name] = module
155
+ return module
156
+
157
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/misc.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import re
10
+ import contextlib
11
+ import numpy as np
12
+ import torch
13
+ import warnings
14
+ import draggan.dnnlib as dnnlib
15
+
16
+ #----------------------------------------------------------------------------
17
+ # Cached construction of constant tensors. Avoids CPU=>GPU copy when the
18
+ # same constant is used multiple times.
19
+
20
+ _constant_cache = dict()
21
+
22
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
23
+ value = np.asarray(value)
24
+ if shape is not None:
25
+ shape = tuple(shape)
26
+ if dtype is None:
27
+ dtype = torch.get_default_dtype()
28
+ if device is None:
29
+ device = torch.device('cpu')
30
+ if memory_format is None:
31
+ memory_format = torch.contiguous_format
32
+
33
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
34
+ tensor = _constant_cache.get(key, None)
35
+ if tensor is None:
36
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
37
+ if shape is not None:
38
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
39
+ tensor = tensor.contiguous(memory_format=memory_format)
40
+ _constant_cache[key] = tensor
41
+ return tensor
42
+
43
+ #----------------------------------------------------------------------------
44
+ # Replace NaN/Inf with specified numerical values.
45
+
46
+ try:
47
+ nan_to_num = torch.nan_to_num # 1.8.0a0
48
+ except AttributeError:
49
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
50
+ assert isinstance(input, torch.Tensor)
51
+ if posinf is None:
52
+ posinf = torch.finfo(input.dtype).max
53
+ if neginf is None:
54
+ neginf = torch.finfo(input.dtype).min
55
+ assert nan == 0
56
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
57
+
58
+ #----------------------------------------------------------------------------
59
+ # Symbolic assert.
60
+
61
+ try:
62
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
63
+ except AttributeError:
64
+ symbolic_assert = torch.Assert # 1.7.0
65
+
66
+ #----------------------------------------------------------------------------
67
+ # Context manager to temporarily suppress known warnings in torch.jit.trace().
68
+ # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
69
+
70
+ @contextlib.contextmanager
71
+ def suppress_tracer_warnings():
72
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
73
+ warnings.filters.insert(0, flt)
74
+ yield
75
+ warnings.filters.remove(flt)
76
+
77
+ #----------------------------------------------------------------------------
78
+ # Assert that the shape of a tensor matches the given list of integers.
79
+ # None indicates that the size of a dimension is allowed to vary.
80
+ # Performs symbolic assertion when used in torch.jit.trace().
81
+
82
+ def assert_shape(tensor, ref_shape):
83
+ if tensor.ndim != len(ref_shape):
84
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
85
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
86
+ if ref_size is None:
87
+ pass
88
+ elif isinstance(ref_size, torch.Tensor):
89
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
90
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
91
+ elif isinstance(size, torch.Tensor):
92
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
93
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
94
+ elif size != ref_size:
95
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
96
+
97
+ #----------------------------------------------------------------------------
98
+ # Function decorator that calls torch.autograd.profiler.record_function().
99
+
100
+ def profiled_function(fn):
101
+ def decorator(*args, **kwargs):
102
+ with torch.autograd.profiler.record_function(fn.__name__):
103
+ return fn(*args, **kwargs)
104
+ decorator.__name__ = fn.__name__
105
+ return decorator
106
+
107
+ #----------------------------------------------------------------------------
108
+ # Sampler for torch.utils.data.DataLoader that loops over the dataset
109
+ # indefinitely, shuffling items as it goes.
110
+
111
+ class InfiniteSampler(torch.utils.data.Sampler):
112
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
113
+ assert len(dataset) > 0
114
+ assert num_replicas > 0
115
+ assert 0 <= rank < num_replicas
116
+ assert 0 <= window_size <= 1
117
+ super().__init__(dataset)
118
+ self.dataset = dataset
119
+ self.rank = rank
120
+ self.num_replicas = num_replicas
121
+ self.shuffle = shuffle
122
+ self.seed = seed
123
+ self.window_size = window_size
124
+
125
+ def __iter__(self):
126
+ order = np.arange(len(self.dataset))
127
+ rnd = None
128
+ window = 0
129
+ if self.shuffle:
130
+ rnd = np.random.RandomState(self.seed)
131
+ rnd.shuffle(order)
132
+ window = int(np.rint(order.size * self.window_size))
133
+
134
+ idx = 0
135
+ while True:
136
+ i = idx % order.size
137
+ if idx % self.num_replicas == self.rank:
138
+ yield order[i]
139
+ if window >= 2:
140
+ j = (i - rnd.randint(window)) % order.size
141
+ order[i], order[j] = order[j], order[i]
142
+ idx += 1
143
+
144
+ #----------------------------------------------------------------------------
145
+ # Utilities for operating with torch.nn.Module parameters and buffers.
146
+
147
+ def params_and_buffers(module):
148
+ assert isinstance(module, torch.nn.Module)
149
+ return list(module.parameters()) + list(module.buffers())
150
+
151
+ def named_params_and_buffers(module):
152
+ assert isinstance(module, torch.nn.Module)
153
+ return list(module.named_parameters()) + list(module.named_buffers())
154
+
155
+ def copy_params_and_buffers(src_module, dst_module, require_all=False):
156
+ assert isinstance(src_module, torch.nn.Module)
157
+ assert isinstance(dst_module, torch.nn.Module)
158
+ src_tensors = dict(named_params_and_buffers(src_module))
159
+ for name, tensor in named_params_and_buffers(dst_module):
160
+ assert (name in src_tensors) or (not require_all)
161
+ if name in src_tensors:
162
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
163
+
164
+ #----------------------------------------------------------------------------
165
+ # Context manager for easily enabling/disabling DistributedDataParallel
166
+ # synchronization.
167
+
168
+ @contextlib.contextmanager
169
+ def ddp_sync(module, sync):
170
+ assert isinstance(module, torch.nn.Module)
171
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
172
+ yield
173
+ else:
174
+ with module.no_sync():
175
+ yield
176
+
177
+ #----------------------------------------------------------------------------
178
+ # Check DistributedDataParallel consistency across processes.
179
+
180
+ def check_ddp_consistency(module, ignore_regex=None):
181
+ assert isinstance(module, torch.nn.Module)
182
+ for name, tensor in named_params_and_buffers(module):
183
+ fullname = type(module).__name__ + '.' + name
184
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
185
+ continue
186
+ tensor = tensor.detach()
187
+ if tensor.is_floating_point():
188
+ tensor = nan_to_num(tensor)
189
+ other = tensor.clone()
190
+ torch.distributed.broadcast(tensor=other, src=0)
191
+ assert (tensor == other).all(), fullname
192
+
193
+ #----------------------------------------------------------------------------
194
+ # Print summary table of module hierarchy.
195
+
196
+ def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
197
+ assert isinstance(module, torch.nn.Module)
198
+ assert not isinstance(module, torch.jit.ScriptModule)
199
+ assert isinstance(inputs, (tuple, list))
200
+
201
+ # Register hooks.
202
+ entries = []
203
+ nesting = [0]
204
+ def pre_hook(_mod, _inputs):
205
+ nesting[0] += 1
206
+ def post_hook(mod, _inputs, outputs):
207
+ nesting[0] -= 1
208
+ if nesting[0] <= max_nesting:
209
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
210
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
211
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
212
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
213
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
214
+
215
+ # Run module.
216
+ outputs = module(*inputs)
217
+ for hook in hooks:
218
+ hook.remove()
219
+
220
+ # Identify unique outputs, parameters, and buffers.
221
+ tensors_seen = set()
222
+ for e in entries:
223
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
224
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
225
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
226
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
227
+
228
+ # Filter out redundant entries.
229
+ if skip_redundant:
230
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
231
+
232
+ # Construct table.
233
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
234
+ rows += [['---'] * len(rows[0])]
235
+ param_total = 0
236
+ buffer_total = 0
237
+ submodule_names = {mod: name for name, mod in module.named_modules()}
238
+ for e in entries:
239
+ name = '<top-level>' if e.mod is module else submodule_names[e.mod]
240
+ param_size = sum(t.numel() for t in e.unique_params)
241
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
242
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
243
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
244
+ rows += [[
245
+ name + (':0' if len(e.outputs) >= 2 else ''),
246
+ str(param_size) if param_size else '-',
247
+ str(buffer_size) if buffer_size else '-',
248
+ (output_shapes + ['-'])[0],
249
+ (output_dtypes + ['-'])[0],
250
+ ]]
251
+ for idx in range(1, len(e.outputs)):
252
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
253
+ param_total += param_size
254
+ buffer_total += buffer_size
255
+ rows += [['---'] * len(rows[0])]
256
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
257
+
258
+ # Print table.
259
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
260
+ print()
261
+ for row in rows:
262
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
263
+ print()
264
+ return outputs
265
+
266
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/ops/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ # empty
models/draggan/torch_utils/ops/bias_act.cpp ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "bias_act.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static bool has_same_layout(torch::Tensor x, torch::Tensor y)
17
+ {
18
+ if (x.dim() != y.dim())
19
+ return false;
20
+ for (int64_t i = 0; i < x.dim(); i++)
21
+ {
22
+ if (x.size(i) != y.size(i))
23
+ return false;
24
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
25
+ return false;
26
+ }
27
+ return true;
28
+ }
29
+
30
+ //------------------------------------------------------------------------
31
+
32
+ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
33
+ {
34
+ // Validate arguments.
35
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
36
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
37
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
38
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
39
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
40
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
41
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
42
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
43
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
44
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
45
+
46
+ // Validate layout.
47
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
48
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
49
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
50
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
51
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
52
+
53
+ // Create output tensor.
54
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
55
+ torch::Tensor y = torch::empty_like(x);
56
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
57
+
58
+ // Initialize CUDA kernel parameters.
59
+ bias_act_kernel_params p;
60
+ p.x = x.data_ptr();
61
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
62
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
63
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
64
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
65
+ p.y = y.data_ptr();
66
+ p.grad = grad;
67
+ p.act = act;
68
+ p.alpha = alpha;
69
+ p.gain = gain;
70
+ p.clamp = clamp;
71
+ p.sizeX = (int)x.numel();
72
+ p.sizeB = (int)b.numel();
73
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
74
+
75
+ // Choose CUDA kernel.
76
+ void* kernel;
77
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
78
+ {
79
+ kernel = choose_bias_act_kernel<scalar_t>(p);
80
+ });
81
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
82
+
83
+ // Launch CUDA kernel.
84
+ p.loopX = 4;
85
+ int blockSize = 4 * 32;
86
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
87
+ void* args[] = {&p};
88
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
89
+ return y;
90
+ }
91
+
92
+ //------------------------------------------------------------------------
93
+
94
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
95
+ {
96
+ m.def("bias_act", &bias_act);
97
+ }
98
+
99
+ //------------------------------------------------------------------------
models/draggan/torch_utils/ops/bias_act.cu ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "bias_act.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ //------------------------------------------------------------------------
21
+ // CUDA kernel.
22
+
23
+ template <class T, int A>
24
+ __global__ void bias_act_kernel(bias_act_kernel_params p)
25
+ {
26
+ typedef typename InternalType<T>::scalar_t scalar_t;
27
+ int G = p.grad;
28
+ scalar_t alpha = (scalar_t)p.alpha;
29
+ scalar_t gain = (scalar_t)p.gain;
30
+ scalar_t clamp = (scalar_t)p.clamp;
31
+ scalar_t one = (scalar_t)1;
32
+ scalar_t two = (scalar_t)2;
33
+ scalar_t expRange = (scalar_t)80;
34
+ scalar_t halfExpRange = (scalar_t)40;
35
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
36
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
37
+
38
+ // Loop over elements.
39
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
40
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
41
+ {
42
+ // Load.
43
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
44
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
45
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
46
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
47
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
48
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
49
+ scalar_t y = 0;
50
+
51
+ // Apply bias.
52
+ ((G == 0) ? x : xref) += b;
53
+
54
+ // linear
55
+ if (A == 1)
56
+ {
57
+ if (G == 0) y = x;
58
+ if (G == 1) y = x;
59
+ }
60
+
61
+ // relu
62
+ if (A == 2)
63
+ {
64
+ if (G == 0) y = (x > 0) ? x : 0;
65
+ if (G == 1) y = (yy > 0) ? x : 0;
66
+ }
67
+
68
+ // lrelu
69
+ if (A == 3)
70
+ {
71
+ if (G == 0) y = (x > 0) ? x : x * alpha;
72
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
73
+ }
74
+
75
+ // tanh
76
+ if (A == 4)
77
+ {
78
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
79
+ if (G == 1) y = x * (one - yy * yy);
80
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
81
+ }
82
+
83
+ // sigmoid
84
+ if (A == 5)
85
+ {
86
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
87
+ if (G == 1) y = x * yy * (one - yy);
88
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
89
+ }
90
+
91
+ // elu
92
+ if (A == 6)
93
+ {
94
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
95
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
96
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
97
+ }
98
+
99
+ // selu
100
+ if (A == 7)
101
+ {
102
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
103
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
104
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
105
+ }
106
+
107
+ // softplus
108
+ if (A == 8)
109
+ {
110
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
111
+ if (G == 1) y = x * (one - exp(-yy));
112
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
113
+ }
114
+
115
+ // swish
116
+ if (A == 9)
117
+ {
118
+ if (G == 0)
119
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
120
+ else
121
+ {
122
+ scalar_t c = exp(xref);
123
+ scalar_t d = c + one;
124
+ if (G == 1)
125
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
126
+ else
127
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
128
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
129
+ }
130
+ }
131
+
132
+ // Apply gain.
133
+ y *= gain * dy;
134
+
135
+ // Clamp.
136
+ if (clamp >= 0)
137
+ {
138
+ if (G == 0)
139
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
140
+ else
141
+ y = (yref > -clamp & yref < clamp) ? y : 0;
142
+ }
143
+
144
+ // Store.
145
+ ((T*)p.y)[xi] = (T)y;
146
+ }
147
+ }
148
+
149
+ //------------------------------------------------------------------------
150
+ // CUDA kernel selection.
151
+
152
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
153
+ {
154
+ if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
155
+ if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
156
+ if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
157
+ if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
158
+ if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
159
+ if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
160
+ if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
161
+ if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
162
+ if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
163
+ return NULL;
164
+ }
165
+
166
+ //------------------------------------------------------------------------
167
+ // Template specializations.
168
+
169
+ template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
170
+ template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
171
+ template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
172
+
173
+ //------------------------------------------------------------------------
models/draggan/torch_utils/ops/bias_act.h ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ //------------------------------------------------------------------------
10
+ // CUDA kernel parameters.
11
+
12
+ struct bias_act_kernel_params
13
+ {
14
+ const void* x; // [sizeX]
15
+ const void* b; // [sizeB] or NULL
16
+ const void* xref; // [sizeX] or NULL
17
+ const void* yref; // [sizeX] or NULL
18
+ const void* dy; // [sizeX] or NULL
19
+ void* y; // [sizeX]
20
+
21
+ int grad;
22
+ int act;
23
+ float alpha;
24
+ float gain;
25
+ float clamp;
26
+
27
+ int sizeX;
28
+ int sizeB;
29
+ int stepB;
30
+ int loopX;
31
+ };
32
+
33
+ //------------------------------------------------------------------------
34
+ // CUDA kernel selection.
35
+
36
+ template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
37
+
38
+ //------------------------------------------------------------------------
models/draggan/torch_utils/ops/bias_act.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient bias and activation."""
10
+
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+ import dnnlib
15
+
16
+ from .. import custom_ops
17
+ from .. import misc
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ activation_funcs = {
22
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
23
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
24
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
25
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
26
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
27
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
28
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
29
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
30
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
31
+ }
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ _plugin = None
36
+ _null_tensor = torch.empty([0])
37
+
38
+ def _init():
39
+ global _plugin
40
+ if _plugin is None:
41
+ _plugin = custom_ops.get_plugin(
42
+ module_name='bias_act_plugin',
43
+ sources=['bias_act.cpp', 'bias_act.cu'],
44
+ headers=['bias_act.h'],
45
+ source_dir=os.path.dirname(__file__),
46
+ extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
47
+ )
48
+ return True
49
+
50
+ #----------------------------------------------------------------------------
51
+
52
+ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
53
+ r"""Fused bias and activation function.
54
+
55
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
56
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
57
+ the fused op is considerably more efficient than performing the same calculation
58
+ using standard PyTorch ops. It supports first and second order gradients,
59
+ but not third order gradients.
60
+
61
+ Args:
62
+ x: Input activation tensor. Can be of any shape.
63
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
64
+ as `x`. The shape must be known, and it must match the dimension of `x`
65
+ corresponding to `dim`.
66
+ dim: The dimension in `x` corresponding to the elements of `b`.
67
+ The value of `dim` is ignored if `b` is not specified.
68
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
69
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
70
+ See `activation_funcs` for a full list. `None` is not allowed.
71
+ alpha: Shape parameter for the activation function, or `None` to use the default.
72
+ gain: Scaling factor for the output tensor, or `None` to use default.
73
+ See `activation_funcs` for the default scaling of each activation function.
74
+ If unsure, consider specifying 1.
75
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
76
+ the clamping (default).
77
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
78
+
79
+ Returns:
80
+ Tensor of the same shape and datatype as `x`.
81
+ """
82
+ assert isinstance(x, torch.Tensor)
83
+ assert impl in ['ref', 'cuda']
84
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
85
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
86
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
87
+
88
+ #----------------------------------------------------------------------------
89
+
90
+ @misc.profiled_function
91
+ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
92
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
93
+ """
94
+ assert isinstance(x, torch.Tensor)
95
+ assert clamp is None or clamp >= 0
96
+ spec = activation_funcs[act]
97
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
98
+ gain = float(gain if gain is not None else spec.def_gain)
99
+ clamp = float(clamp if clamp is not None else -1)
100
+
101
+ # Add bias.
102
+ if b is not None:
103
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
104
+ assert 0 <= dim < x.ndim
105
+ assert b.shape[0] == x.shape[dim]
106
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
107
+
108
+ # Evaluate activation function.
109
+ alpha = float(alpha)
110
+ x = spec.func(x, alpha=alpha)
111
+
112
+ # Scale by gain.
113
+ gain = float(gain)
114
+ if gain != 1:
115
+ x = x * gain
116
+
117
+ # Clamp.
118
+ if clamp >= 0:
119
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
120
+ return x
121
+
122
+ #----------------------------------------------------------------------------
123
+
124
+ _bias_act_cuda_cache = dict()
125
+
126
+ def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
127
+ """Fast CUDA implementation of `bias_act()` using custom ops.
128
+ """
129
+ # Parse arguments.
130
+ assert clamp is None or clamp >= 0
131
+ spec = activation_funcs[act]
132
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
133
+ gain = float(gain if gain is not None else spec.def_gain)
134
+ clamp = float(clamp if clamp is not None else -1)
135
+
136
+ # Lookup from cache.
137
+ key = (dim, act, alpha, gain, clamp)
138
+ if key in _bias_act_cuda_cache:
139
+ return _bias_act_cuda_cache[key]
140
+
141
+ # Forward op.
142
+ class BiasActCuda(torch.autograd.Function):
143
+ @staticmethod
144
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
145
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
146
+ x = x.contiguous(memory_format=ctx.memory_format)
147
+ b = b.contiguous() if b is not None else _null_tensor
148
+ y = x
149
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
150
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
151
+ ctx.save_for_backward(
152
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
153
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
154
+ y if 'y' in spec.ref else _null_tensor)
155
+ return y
156
+
157
+ @staticmethod
158
+ def backward(ctx, dy): # pylint: disable=arguments-differ
159
+ dy = dy.contiguous(memory_format=ctx.memory_format)
160
+ x, b, y = ctx.saved_tensors
161
+ dx = None
162
+ db = None
163
+
164
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
165
+ dx = dy
166
+ if act != 'linear' or gain != 1 or clamp >= 0:
167
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
168
+
169
+ if ctx.needs_input_grad[1]:
170
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
171
+
172
+ return dx, db
173
+
174
+ # Backward op.
175
+ class BiasActCudaGrad(torch.autograd.Function):
176
+ @staticmethod
177
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
178
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
179
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
180
+ ctx.save_for_backward(
181
+ dy if spec.has_2nd_grad else _null_tensor,
182
+ x, b, y)
183
+ return dx
184
+
185
+ @staticmethod
186
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
187
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
188
+ dy, x, b, y = ctx.saved_tensors
189
+ d_dy = None
190
+ d_x = None
191
+ d_b = None
192
+ d_y = None
193
+
194
+ if ctx.needs_input_grad[0]:
195
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
196
+
197
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
198
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
199
+
200
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
201
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
202
+
203
+ return d_dy, d_x, d_b, d_y
204
+
205
+ # Add to cache.
206
+ _bias_act_cuda_cache[key] = BiasActCuda
207
+ return BiasActCuda
208
+
209
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/ops/conv2d_gradfix.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.conv2d` that supports
10
+ arbitrarily high order gradients with zero performance penalty."""
11
+
12
+ import contextlib
13
+ import torch
14
+
15
+ # pylint: disable=redefined-builtin
16
+ # pylint: disable=arguments-differ
17
+ # pylint: disable=protected-access
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ enabled = False # Enable the custom op by setting this to true.
22
+ weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
23
+
24
+ @contextlib.contextmanager
25
+ def no_weight_gradients(disable=True):
26
+ global weight_gradients_disabled
27
+ old = weight_gradients_disabled
28
+ if disable:
29
+ weight_gradients_disabled = True
30
+ yield
31
+ weight_gradients_disabled = old
32
+
33
+ #----------------------------------------------------------------------------
34
+
35
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
36
+ if _should_use_custom_op(input):
37
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
38
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
39
+
40
+ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
41
+ if _should_use_custom_op(input):
42
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
43
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
44
+
45
+ #----------------------------------------------------------------------------
46
+
47
+ def _should_use_custom_op(input):
48
+ assert isinstance(input, torch.Tensor)
49
+ if (not enabled) or (not torch.backends.cudnn.enabled):
50
+ return False
51
+ if input.device.type != 'cuda':
52
+ return False
53
+ return True
54
+
55
+ def _tuple_of_ints(xs, ndim):
56
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
57
+ assert len(xs) == ndim
58
+ assert all(isinstance(x, int) for x in xs)
59
+ return xs
60
+
61
+ #----------------------------------------------------------------------------
62
+
63
+ _conv2d_gradfix_cache = dict()
64
+ _null_tensor = torch.empty([0])
65
+
66
+ def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
67
+ # Parse arguments.
68
+ ndim = 2
69
+ weight_shape = tuple(weight_shape)
70
+ stride = _tuple_of_ints(stride, ndim)
71
+ padding = _tuple_of_ints(padding, ndim)
72
+ output_padding = _tuple_of_ints(output_padding, ndim)
73
+ dilation = _tuple_of_ints(dilation, ndim)
74
+
75
+ # Lookup from cache.
76
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
77
+ if key in _conv2d_gradfix_cache:
78
+ return _conv2d_gradfix_cache[key]
79
+
80
+ # Validate arguments.
81
+ assert groups >= 1
82
+ assert len(weight_shape) == ndim + 2
83
+ assert all(stride[i] >= 1 for i in range(ndim))
84
+ assert all(padding[i] >= 0 for i in range(ndim))
85
+ assert all(dilation[i] >= 0 for i in range(ndim))
86
+ if not transpose:
87
+ assert all(output_padding[i] == 0 for i in range(ndim))
88
+ else: # transpose
89
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
90
+
91
+ # Helpers.
92
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
93
+ def calc_output_padding(input_shape, output_shape):
94
+ if transpose:
95
+ return [0, 0]
96
+ return [
97
+ input_shape[i + 2]
98
+ - (output_shape[i + 2] - 1) * stride[i]
99
+ - (1 - 2 * padding[i])
100
+ - dilation[i] * (weight_shape[i + 2] - 1)
101
+ for i in range(ndim)
102
+ ]
103
+
104
+ # Forward & backward.
105
+ class Conv2d(torch.autograd.Function):
106
+ @staticmethod
107
+ def forward(ctx, input, weight, bias):
108
+ assert weight.shape == weight_shape
109
+ ctx.save_for_backward(
110
+ input if weight.requires_grad else _null_tensor,
111
+ weight if input.requires_grad else _null_tensor,
112
+ )
113
+ ctx.input_shape = input.shape
114
+
115
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
116
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
117
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
118
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
119
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
120
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
121
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
122
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
123
+
124
+ # General case => cuDNN.
125
+ if transpose:
126
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
127
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
128
+
129
+ @staticmethod
130
+ def backward(ctx, grad_output):
131
+ input, weight = ctx.saved_tensors
132
+ input_shape = ctx.input_shape
133
+ grad_input = None
134
+ grad_weight = None
135
+ grad_bias = None
136
+
137
+ if ctx.needs_input_grad[0]:
138
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
139
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
140
+ grad_input = op.apply(grad_output, weight, None)
141
+ assert grad_input.shape == input_shape
142
+
143
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
144
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
145
+ assert grad_weight.shape == weight_shape
146
+
147
+ if ctx.needs_input_grad[2]:
148
+ grad_bias = grad_output.sum([0, 2, 3])
149
+
150
+ return grad_input, grad_weight, grad_bias
151
+
152
+ # Gradient with respect to the weights.
153
+ class Conv2dGradWeight(torch.autograd.Function):
154
+ @staticmethod
155
+ def forward(ctx, grad_output, input):
156
+ ctx.save_for_backward(
157
+ grad_output if input.requires_grad else _null_tensor,
158
+ input if grad_output.requires_grad else _null_tensor,
159
+ )
160
+ ctx.grad_output_shape = grad_output.shape
161
+ ctx.input_shape = input.shape
162
+
163
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
164
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
165
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
166
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
167
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
168
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
169
+
170
+ # General case => cuDNN.
171
+ name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'
172
+ flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]
173
+ return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)
174
+
175
+ @staticmethod
176
+ def backward(ctx, grad2_grad_weight):
177
+ grad_output, input = ctx.saved_tensors
178
+ grad_output_shape = ctx.grad_output_shape
179
+ input_shape = ctx.input_shape
180
+ grad2_grad_output = None
181
+ grad2_input = None
182
+
183
+ if ctx.needs_input_grad[0]:
184
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
185
+ assert grad2_grad_output.shape == grad_output_shape
186
+
187
+ if ctx.needs_input_grad[1]:
188
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
189
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
190
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
191
+ assert grad2_input.shape == input_shape
192
+
193
+ return grad2_grad_output, grad2_input
194
+
195
+ _conv2d_gradfix_cache[key] = Conv2d
196
+ return Conv2d
197
+
198
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/ops/conv2d_resample.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """2D convolution with optional up/downsampling."""
10
+
11
+ import torch
12
+
13
+ from .. import misc
14
+ from . import conv2d_gradfix
15
+ from . import upfirdn2d
16
+ from .upfirdn2d import _parse_padding
17
+ from .upfirdn2d import _get_filter_size
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ def _get_weight_shape(w):
22
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
23
+ shape = [int(sz) for sz in w.shape]
24
+ misc.assert_shape(w, shape)
25
+ return shape
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
30
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
31
+ """
32
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
33
+
34
+ # Flip weight if requested.
35
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
36
+ if not flip_weight and (kw > 1 or kh > 1):
37
+ w = w.flip([2, 3])
38
+
39
+ # Execute using conv2d_gradfix.
40
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
41
+ return op(x, w, stride=stride, padding=padding, groups=groups)
42
+
43
+ #----------------------------------------------------------------------------
44
+
45
+ @misc.profiled_function
46
+ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
47
+ r"""2D convolution with optional up/downsampling.
48
+
49
+ Padding is performed only once at the beginning, not between the operations.
50
+
51
+ Args:
52
+ x: Input tensor of shape
53
+ `[batch_size, in_channels, in_height, in_width]`.
54
+ w: Weight tensor of shape
55
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
56
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
57
+ calling upfirdn2d.setup_filter(). None = identity (default).
58
+ up: Integer upsampling factor (default: 1).
59
+ down: Integer downsampling factor (default: 1).
60
+ padding: Padding with respect to the upsampled image. Can be a single number
61
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
62
+ (default: 0).
63
+ groups: Split input channels into N groups (default: 1).
64
+ flip_weight: False = convolution, True = correlation (default: True).
65
+ flip_filter: False = convolution, True = correlation (default: False).
66
+
67
+ Returns:
68
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
69
+ """
70
+ # Validate arguments.
71
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
72
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
73
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
74
+ assert isinstance(up, int) and (up >= 1)
75
+ assert isinstance(down, int) and (down >= 1)
76
+ assert isinstance(groups, int) and (groups >= 1)
77
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
78
+ fw, fh = _get_filter_size(f)
79
+ px0, px1, py0, py1 = _parse_padding(padding)
80
+
81
+ # Adjust padding to account for up/downsampling.
82
+ if up > 1:
83
+ px0 += (fw + up - 1) // 2
84
+ px1 += (fw - up) // 2
85
+ py0 += (fh + up - 1) // 2
86
+ py1 += (fh - up) // 2
87
+ if down > 1:
88
+ px0 += (fw - down + 1) // 2
89
+ px1 += (fw - down) // 2
90
+ py0 += (fh - down + 1) // 2
91
+ py1 += (fh - down) // 2
92
+
93
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
94
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
95
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
96
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
97
+ return x
98
+
99
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
100
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
101
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
102
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
103
+ return x
104
+
105
+ # Fast path: downsampling only => use strided convolution.
106
+ if down > 1 and up == 1:
107
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
108
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
109
+ return x
110
+
111
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
112
+ if up > 1:
113
+ if groups == 1:
114
+ w = w.transpose(0, 1)
115
+ else:
116
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
117
+ w = w.transpose(1, 2)
118
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
119
+ px0 -= kw - 1
120
+ px1 -= kw - up
121
+ py0 -= kh - 1
122
+ py1 -= kh - up
123
+ pxt = max(min(-px0, -px1), 0)
124
+ pyt = max(min(-py0, -py1), 0)
125
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
126
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
127
+ if down > 1:
128
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
129
+ return x
130
+
131
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
132
+ if up == 1 and down == 1:
133
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
134
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
135
+
136
+ # Fallback: Generic reference implementation.
137
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
138
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
139
+ if down > 1:
140
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
141
+ return x
142
+
143
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/ops/filtered_lrelu.cpp ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "filtered_lrelu.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(
17
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
18
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
19
+ {
20
+ // Set CUDA device.
21
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
22
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
23
+
24
+ // Validate arguments.
25
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
26
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
27
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
28
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
29
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
30
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
31
+ TORCH_CHECK(x.numel() > 0, "x is empty");
32
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
33
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
34
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
35
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
36
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
37
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
38
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
39
+
40
+ // Figure out how much shared memory is available on the device.
41
+ int maxSharedBytes = 0;
42
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
43
+ int sharedKB = maxSharedBytes >> 10;
44
+
45
+ // Populate enough launch parameters to check if a CUDA kernel exists.
46
+ filtered_lrelu_kernel_params p;
47
+ p.up = up;
48
+ p.down = down;
49
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
50
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
51
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);
52
+ if (!test_spec.exec)
53
+ {
54
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
55
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
56
+ }
57
+
58
+ // Input/output element size.
59
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
60
+
61
+ // Input sizes.
62
+ int64_t xw = (int)x.size(3);
63
+ int64_t xh = (int)x.size(2);
64
+ int64_t fut_w = (int)fu.size(-1) - 1;
65
+ int64_t fut_h = (int)fu.size(0) - 1;
66
+ int64_t fdt_w = (int)fd.size(-1) - 1;
67
+ int64_t fdt_h = (int)fd.size(0) - 1;
68
+
69
+ // Logical size of upsampled buffer.
70
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
71
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
72
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
73
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
74
+
75
+ // Compute output size and allocate.
76
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
77
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
78
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
79
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
80
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
81
+
82
+ // Allocate sign tensor.
83
+ torch::Tensor so;
84
+ torch::Tensor s = si;
85
+ bool readSigns = !!s.numel();
86
+ int64_t sw_active = 0; // Active width of sign tensor.
87
+ if (writeSigns)
88
+ {
89
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
90
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
91
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
92
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
93
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
94
+ }
95
+ else if (readSigns)
96
+ sw_active = s.size(3) << 2;
97
+
98
+ // Validate sign tensor if in use.
99
+ if (readSigns || writeSigns)
100
+ {
101
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
102
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
103
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
104
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
105
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
106
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
107
+ }
108
+
109
+ // Populate rest of CUDA kernel parameters.
110
+ p.x = x.data_ptr();
111
+ p.y = y.data_ptr();
112
+ p.b = b.data_ptr();
113
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
114
+ p.fu = fu.data_ptr<float>();
115
+ p.fd = fd.data_ptr<float>();
116
+ p.pad0 = make_int2(px0, py0);
117
+ p.gain = gain;
118
+ p.slope = slope;
119
+ p.clamp = clamp;
120
+ p.flip = (flip_filters) ? 1 : 0;
121
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
122
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
123
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
124
+ p.sOfs = make_int2(sx, sy);
125
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
126
+
127
+ // x, y, b strides are in bytes.
128
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
129
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
130
+ p.bStride = sz * b.stride(0);
131
+
132
+ // fu, fd strides are in elements.
133
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
134
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
135
+
136
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
137
+ bool index64b = false;
138
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
139
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
140
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
141
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
142
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
143
+ if (s.numel() > INT_MAX) index64b = true;
144
+
145
+ // Choose CUDA kernel.
146
+ filtered_lrelu_kernel_spec spec = { 0 };
147
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
148
+ {
149
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
150
+ {
151
+ // Choose kernel based on index type, datatype and sign read/write modes.
152
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true, false>(p, sharedKB);
153
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);
154
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);
155
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true, false>(p, sharedKB);
156
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);
157
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);
158
+ }
159
+ });
160
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
161
+
162
+ // Launch CUDA kernel.
163
+ void* args[] = {&p};
164
+ int bx = spec.numWarps * 32;
165
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
166
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
167
+ int gz = p.yShape.z * p.yShape.w;
168
+
169
+ // Repeat multiple horizontal tiles in a CTA?
170
+ if (spec.xrep)
171
+ {
172
+ p.tilesXrep = spec.xrep;
173
+ p.tilesXdim = gx;
174
+
175
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
176
+ std::swap(gx, gy);
177
+ }
178
+ else
179
+ {
180
+ p.tilesXrep = 0;
181
+ p.tilesXdim = 0;
182
+ }
183
+
184
+ // Launch filter setup kernel.
185
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
186
+
187
+ // Copy kernels to constant memory.
188
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true, false>(at::cuda::getCurrentCUDAStream())));
189
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));
190
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));
191
+
192
+ // Set cache and shared memory configurations for main kernel.
193
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
194
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
195
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
196
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
197
+
198
+ // Launch main kernel.
199
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
200
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
201
+ {
202
+ p.blockZofs = zofs;
203
+ int subGz = std::min(maxSubGz, gz - zofs);
204
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
205
+ }
206
+
207
+ // Done.
208
+ return std::make_tuple(y, so, 0);
209
+ }
210
+
211
+ //------------------------------------------------------------------------
212
+
213
+ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
214
+ {
215
+ // Set CUDA device.
216
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
217
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
218
+
219
+ // Validate arguments.
220
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
221
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
222
+ TORCH_CHECK(x.numel() > 0, "x is empty");
223
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
224
+
225
+ // Output signs if we don't have sign input.
226
+ torch::Tensor so;
227
+ torch::Tensor s = si;
228
+ bool readSigns = !!s.numel();
229
+ if (writeSigns)
230
+ {
231
+ int64_t sw = x.size(3);
232
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
233
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
234
+ }
235
+
236
+ // Validate sign tensor if in use.
237
+ if (readSigns || writeSigns)
238
+ {
239
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
240
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
241
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
242
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
243
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
244
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
245
+ }
246
+
247
+ // Initialize CUDA kernel parameters.
248
+ filtered_lrelu_act_kernel_params p;
249
+ p.x = x.data_ptr();
250
+ p.s = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;
251
+ p.gain = gain;
252
+ p.slope = slope;
253
+ p.clamp = clamp;
254
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
255
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
256
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
257
+ p.sOfs = make_int2(sx, sy);
258
+
259
+ // Choose CUDA kernel.
260
+ void* func = 0;
261
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
262
+ {
263
+ if (writeSigns)
264
+ func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();
265
+ else if (readSigns)
266
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();
267
+ else
268
+ func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();
269
+ });
270
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
271
+
272
+ // Launch CUDA kernel.
273
+ void* args[] = {&p};
274
+ int bx = 128; // 4 warps per block.
275
+
276
+ // Logical size of launch = writeSigns ? p.s : p.x
277
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
278
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
279
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
280
+ gx = (gx - 1) / bx + 1;
281
+
282
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
283
+ const uint32_t gmax = 65535;
284
+ gy = std::min(gy, gmax);
285
+ gz = std::min(gz, gmax);
286
+
287
+ // Launch.
288
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
289
+ return so;
290
+ }
291
+
292
+ //------------------------------------------------------------------------
293
+
294
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
295
+ {
296
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
297
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
298
+ }
299
+
300
+ //------------------------------------------------------------------------
models/draggan/torch_utils/ops/filtered_lrelu.cu ADDED
@@ -0,0 +1,1284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "filtered_lrelu.h"
11
+ #include <cstdint>
12
+
13
+ //------------------------------------------------------------------------
14
+ // Helpers.
15
+
16
+ enum // Filter modes.
17
+ {
18
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
19
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
20
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
21
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
22
+ };
23
+
24
+ template <class T> struct InternalType;
25
+ template <> struct InternalType<double>
26
+ {
27
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
28
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
29
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
30
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
31
+ };
32
+ template <> struct InternalType<float>
33
+ {
34
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
35
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
36
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
37
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
38
+ };
39
+ template <> struct InternalType<c10::Half>
40
+ {
41
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
42
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
43
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
44
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
45
+ };
46
+
47
+ #define MIN(A, B) ((A) < (B) ? (A) : (B))
48
+ #define MAX(A, B) ((A) > (B) ? (A) : (B))
49
+ #define CEIL_DIV(A, B) (((B)==1) ? (A) : \
50
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
51
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
52
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
53
+
54
+ // This works only up to blocks of size 256 x 256 and for all N that are powers of two.
55
+ template <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
56
+ {
57
+ if ((N & (N-1)) && N <= 256)
58
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
59
+ else
60
+ y = i/N;
61
+
62
+ x = i - y*N;
63
+ }
64
+
65
+ // Type cast stride before reading it.
66
+ template <class T> __device__ __forceinline__ T get_stride(const int64_t& x)
67
+ {
68
+ return *reinterpret_cast<const T*>(&x);
69
+ }
70
+
71
+ //------------------------------------------------------------------------
72
+ // Filters, setup kernel, copying function.
73
+
74
+ #define MAX_FILTER_SIZE 32
75
+
76
+ // Combined up/down filter buffers so that transfer can be done with one copy.
77
+ __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
78
+ __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
79
+
80
+ // Accessors to combined buffers to index up/down filters individually.
81
+ #define c_fu (c_fbuf)
82
+ #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
83
+ #define g_fu (g_fbuf)
84
+ #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
85
+
86
+ // Set up filters into global memory buffer.
87
+ static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
88
+ {
89
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
90
+ {
91
+ int x, y;
92
+ fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);
93
+
94
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
95
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
96
+ if (p.fuShape.y > 0)
97
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
98
+ else
99
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
100
+
101
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
102
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
103
+ if (p.fdShape.y > 0)
104
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
105
+ else
106
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
107
+ }
108
+ }
109
+
110
+ // Host function to copy filters written by setup kernel into constant buffer for main kernel.
111
+ template <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)
112
+ {
113
+ void* src = 0;
114
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
115
+ if (err) return err;
116
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
117
+ }
118
+
119
+ //------------------------------------------------------------------------
120
+ // Coordinate spaces:
121
+ // - Relative to input tensor: inX, inY, tileInX, tileInY
122
+ // - Relative to input tile: relInX, relInY, tileInW, tileInH
123
+ // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
124
+ // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
125
+ // - Relative to output tensor: outX, outY, tileOutX, tileOutY
126
+ //
127
+ // Relationships between coordinate spaces:
128
+ // - inX = tileInX + relInX
129
+ // - inY = tileInY + relInY
130
+ // - relUpX = relInX * up + phaseInX
131
+ // - relUpY = relInY * up + phaseInY
132
+ // - relUpX = relOutX * down
133
+ // - relUpY = relOutY * down
134
+ // - outX = tileOutX + relOutX
135
+ // - outY = tileOutY + relOutY
136
+
137
+ extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
138
+
139
+ template <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>
140
+ static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
141
+ {
142
+ // Check that we don't try to support non-existing filter modes.
143
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
144
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
145
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
146
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
147
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
148
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
149
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
150
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
151
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
152
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
153
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
154
+
155
+ // Static definitions.
156
+ typedef typename InternalType<T>::scalar_t scalar_t;
157
+ typedef typename InternalType<T>::vec2_t vec2_t;
158
+ typedef typename InternalType<T>::vec4_t vec4_t;
159
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
160
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
161
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
162
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
163
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
164
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
165
+
166
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
167
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
168
+
169
+ // Sizes of logical buffers.
170
+ const int szIn = tileInH_up * tileInW;
171
+ const int szUpX = tileInH_up * tileUpW;
172
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
173
+ const int szDownX = tileUpH * tileOutW;
174
+
175
+ // Sizes for shared memory arrays.
176
+ const int s_buf0_size_base =
177
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
178
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
179
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
180
+ (filterMode == MODE_FUFD) ? szIn :
181
+ -1;
182
+ const int s_buf1_size_base =
183
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
184
+ (filterMode == MODE_FUSD) ? szUpXY :
185
+ (filterMode == MODE_SUFD) ? szUpX :
186
+ (filterMode == MODE_FUFD) ? szUpXY :
187
+ -1;
188
+
189
+ // Ensure U128 alignment.
190
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
191
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
192
+
193
+ // Check at compile time that we don't use too much shared memory.
194
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
195
+
196
+ // Declare shared memory arrays.
197
+ scalar_t* s_buf0;
198
+ scalar_t* s_buf1;
199
+ if (sharedKB <= 48)
200
+ {
201
+ // Allocate shared memory arrays here.
202
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
203
+ s_buf0 = s_buf0_st;
204
+ s_buf1 = s_buf0 + s_buf0_size;
205
+ }
206
+ else
207
+ {
208
+ // Use the dynamically allocated shared memory array.
209
+ s_buf0 = (scalar_t*)s_buf_raw;
210
+ s_buf1 = s_buf0 + s_buf0_size;
211
+ }
212
+
213
+ // Pointers to the buffers.
214
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
215
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
216
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
217
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
218
+ if (filterMode == MODE_SUSD)
219
+ {
220
+ s_tileIn = s_buf0;
221
+ s_tileUpX = s_buf1;
222
+ s_tileUpXY = s_buf0;
223
+ s_tileDownX = s_buf1;
224
+ }
225
+ else if (filterMode == MODE_FUSD)
226
+ {
227
+ s_tileIn = s_buf0;
228
+ s_tileUpXY = s_buf1;
229
+ s_tileDownX = s_buf0;
230
+ }
231
+ else if (filterMode == MODE_SUFD)
232
+ {
233
+ s_tileIn = s_buf0;
234
+ s_tileUpX = s_buf1;
235
+ s_tileUpXY = s_buf0;
236
+ }
237
+ else if (filterMode == MODE_FUFD)
238
+ {
239
+ s_tileIn = s_buf0;
240
+ s_tileUpXY = s_buf1;
241
+ }
242
+
243
+ // Allow large grids in z direction via per-launch offset.
244
+ int channelIdx = blockIdx.z + p.blockZofs;
245
+ int batchIdx = channelIdx / p.yShape.z;
246
+ channelIdx -= batchIdx * p.yShape.z;
247
+
248
+ // Offset to output feature map. In bytes.
249
+ index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);
250
+
251
+ // Sign shift amount.
252
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
253
+
254
+ // Inner tile loop.
255
+ #pragma unroll 1
256
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
257
+ {
258
+ // Locate output tile.
259
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
260
+ int tileOutX = tileX * tileOutW;
261
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
262
+
263
+ // Locate input tile.
264
+ int tmpX = tileOutX * down - p.pad0.x;
265
+ int tmpY = tileOutY * down - p.pad0.y;
266
+ int tileInX = CEIL_DIV(tmpX, up);
267
+ int tileInY = CEIL_DIV(tmpY, up);
268
+ const int phaseInX = tileInX * up - tmpX;
269
+ const int phaseInY = tileInY * up - tmpY;
270
+
271
+ // Extra sync if input and output buffers are the same and we are not on first tile.
272
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
273
+ __syncthreads();
274
+
275
+ // Load input tile & apply bias. Unrolled.
276
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));
277
+ index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);
278
+ int idx = threadIdx.x;
279
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
280
+ #pragma unroll
281
+ for (int loop = 0; loop < loopCountIN; loop++)
282
+ {
283
+ int relInX, relInY;
284
+ fast_div_mod<tileInW>(relInX, relInY, idx);
285
+ int inX = tileInX + relInX;
286
+ int inY = tileInY + relInY;
287
+ scalar_t v = 0;
288
+
289
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
290
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;
291
+
292
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
293
+ if (!skip)
294
+ s_tileIn[idx] = v;
295
+
296
+ idx += threadsPerBlock;
297
+ }
298
+
299
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
300
+ {
301
+ // Horizontal upsampling.
302
+ __syncthreads();
303
+ if (up == 4)
304
+ {
305
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
306
+ {
307
+ int relUpX0, relInY;
308
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
309
+ int relInX0 = relUpX0 / up;
310
+ int src0 = relInX0 + tileInW * relInY;
311
+ int dst = relInY * tileUpW + relUpX0;
312
+ vec4_t v = InternalType<T>::zero_vec4();
313
+ scalar_t a = s_tileIn[src0];
314
+ if (phaseInX == 0)
315
+ {
316
+ #pragma unroll
317
+ for (int step = 0; step < fuSize / up; step++)
318
+ {
319
+ v.x += a * (scalar_t)c_fu[step * up + 0];
320
+ a = s_tileIn[src0 + step + 1];
321
+ v.y += a * (scalar_t)c_fu[step * up + 3];
322
+ v.z += a * (scalar_t)c_fu[step * up + 2];
323
+ v.w += a * (scalar_t)c_fu[step * up + 1];
324
+ }
325
+ }
326
+ else if (phaseInX == 1)
327
+ {
328
+ #pragma unroll
329
+ for (int step = 0; step < fuSize / up; step++)
330
+ {
331
+ v.x += a * (scalar_t)c_fu[step * up + 1];
332
+ v.y += a * (scalar_t)c_fu[step * up + 0];
333
+ a = s_tileIn[src0 + step + 1];
334
+ v.z += a * (scalar_t)c_fu[step * up + 3];
335
+ v.w += a * (scalar_t)c_fu[step * up + 2];
336
+ }
337
+ }
338
+ else if (phaseInX == 2)
339
+ {
340
+ #pragma unroll
341
+ for (int step = 0; step < fuSize / up; step++)
342
+ {
343
+ v.x += a * (scalar_t)c_fu[step * up + 2];
344
+ v.y += a * (scalar_t)c_fu[step * up + 1];
345
+ v.z += a * (scalar_t)c_fu[step * up + 0];
346
+ a = s_tileIn[src0 + step + 1];
347
+ v.w += a * (scalar_t)c_fu[step * up + 3];
348
+ }
349
+ }
350
+ else // (phaseInX == 3)
351
+ {
352
+ #pragma unroll
353
+ for (int step = 0; step < fuSize / up; step++)
354
+ {
355
+ v.x += a * (scalar_t)c_fu[step * up + 3];
356
+ v.y += a * (scalar_t)c_fu[step * up + 2];
357
+ v.z += a * (scalar_t)c_fu[step * up + 1];
358
+ v.w += a * (scalar_t)c_fu[step * up + 0];
359
+ a = s_tileIn[src0 + step + 1];
360
+ }
361
+ }
362
+ s_tileUpX[dst+0] = v.x;
363
+ s_tileUpX[dst+1] = v.y;
364
+ s_tileUpX[dst+2] = v.z;
365
+ s_tileUpX[dst+3] = v.w;
366
+ }
367
+ }
368
+ else if (up == 2)
369
+ {
370
+ bool p0 = (phaseInX == 0);
371
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
372
+ {
373
+ int relUpX0, relInY;
374
+ fast_div_mod<tileUpW>(relUpX0, relInY, idx);
375
+ int relInX0 = relUpX0 / up;
376
+ int src0 = relInX0 + tileInW * relInY;
377
+ int dst = relInY * tileUpW + relUpX0;
378
+ vec2_t v = InternalType<T>::zero_vec2();
379
+ scalar_t a = s_tileIn[src0];
380
+ if (p0) // (phaseInX == 0)
381
+ {
382
+ #pragma unroll
383
+ for (int step = 0; step < fuSize / up; step++)
384
+ {
385
+ v.x += a * (scalar_t)c_fu[step * up + 0];
386
+ a = s_tileIn[src0 + step + 1];
387
+ v.y += a * (scalar_t)c_fu[step * up + 1];
388
+ }
389
+ }
390
+ else // (phaseInX == 1)
391
+ {
392
+ #pragma unroll
393
+ for (int step = 0; step < fuSize / up; step++)
394
+ {
395
+ v.x += a * (scalar_t)c_fu[step * up + 1];
396
+ v.y += a * (scalar_t)c_fu[step * up + 0];
397
+ a = s_tileIn[src0 + step + 1];
398
+ }
399
+ }
400
+ s_tileUpX[dst+0] = v.x;
401
+ s_tileUpX[dst+1] = v.y;
402
+ }
403
+ }
404
+
405
+ // Vertical upsampling & nonlinearity.
406
+
407
+ __syncthreads();
408
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
409
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
410
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
411
+ if (up == 4)
412
+ {
413
+ minY -= 3; // Adjust according to block height.
414
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
415
+ {
416
+ int relUpX, relInY0;
417
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
418
+ int relUpY0 = relInY0 * up;
419
+ int src0 = relInY0 * tileUpW + relUpX;
420
+ int dst = relUpY0 * tileUpW + relUpX;
421
+ vec4_t v = InternalType<T>::zero_vec4();
422
+
423
+ scalar_t a = s_tileUpX[src0];
424
+ if (phaseInY == 0)
425
+ {
426
+ #pragma unroll
427
+ for (int step = 0; step < fuSize / up; step++)
428
+ {
429
+ v.x += a * (scalar_t)c_fu[step * up + 0];
430
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
431
+ v.y += a * (scalar_t)c_fu[step * up + 3];
432
+ v.z += a * (scalar_t)c_fu[step * up + 2];
433
+ v.w += a * (scalar_t)c_fu[step * up + 1];
434
+ }
435
+ }
436
+ else if (phaseInY == 1)
437
+ {
438
+ #pragma unroll
439
+ for (int step = 0; step < fuSize / up; step++)
440
+ {
441
+ v.x += a * (scalar_t)c_fu[step * up + 1];
442
+ v.y += a * (scalar_t)c_fu[step * up + 0];
443
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
444
+ v.z += a * (scalar_t)c_fu[step * up + 3];
445
+ v.w += a * (scalar_t)c_fu[step * up + 2];
446
+ }
447
+ }
448
+ else if (phaseInY == 2)
449
+ {
450
+ #pragma unroll
451
+ for (int step = 0; step < fuSize / up; step++)
452
+ {
453
+ v.x += a * (scalar_t)c_fu[step * up + 2];
454
+ v.y += a * (scalar_t)c_fu[step * up + 1];
455
+ v.z += a * (scalar_t)c_fu[step * up + 0];
456
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
457
+ v.w += a * (scalar_t)c_fu[step * up + 3];
458
+ }
459
+ }
460
+ else // (phaseInY == 3)
461
+ {
462
+ #pragma unroll
463
+ for (int step = 0; step < fuSize / up; step++)
464
+ {
465
+ v.x += a * (scalar_t)c_fu[step * up + 3];
466
+ v.y += a * (scalar_t)c_fu[step * up + 2];
467
+ v.z += a * (scalar_t)c_fu[step * up + 1];
468
+ v.w += a * (scalar_t)c_fu[step * up + 0];
469
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
470
+ }
471
+ }
472
+
473
+ int x = tileOutX * down + relUpX;
474
+ int y = tileOutY * down + relUpY0;
475
+ int signX = x + p.sOfs.x;
476
+ int signY = y + p.sOfs.y;
477
+ int signZ = blockIdx.z + p.blockZofs;
478
+ int signXb = signX >> 2;
479
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
480
+ index_t si1 = si0 + p.sShape.x;
481
+ index_t si2 = si0 + p.sShape.x * 2;
482
+ index_t si3 = si0 + p.sShape.x * 3;
483
+
484
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
485
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
486
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
487
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
488
+
489
+ if (signWrite)
490
+ {
491
+ if (!enableWriteSkip)
492
+ {
493
+ // Determine and write signs.
494
+ int sx = __float_as_uint(v.x) >> 31 << 0;
495
+ int sy = __float_as_uint(v.y) >> 31 << 8;
496
+ int sz = __float_as_uint(v.z) >> 31 << 16;
497
+ int sw = __float_as_uint(v.w) >> 31 << 24;
498
+ if (sx) v.x *= p.slope;
499
+ if (sy) v.y *= p.slope;
500
+ if (sz) v.z *= p.slope;
501
+ if (sw) v.w *= p.slope;
502
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
503
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
504
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
505
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
506
+
507
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
508
+ {
509
+ // Combine signs.
510
+ uint32_t s = sx + sy + sw + sz;
511
+ s <<= (signX & 3) << 1;
512
+ s |= __shfl_xor_sync(groupMask, s, 1);
513
+ s |= __shfl_xor_sync(groupMask, s, 2);
514
+
515
+ // Write signs.
516
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
517
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
518
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
519
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
520
+ }
521
+ }
522
+ else
523
+ {
524
+ // Determine and write signs.
525
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
526
+ {
527
+ int sx = __float_as_uint(v.x) >> 31 << 0;
528
+ int sy = __float_as_uint(v.y) >> 31 << 8;
529
+ int sz = __float_as_uint(v.z) >> 31 << 16;
530
+ int sw = __float_as_uint(v.w) >> 31 << 24;
531
+ if (sx) v.x *= p.slope;
532
+ if (sy) v.y *= p.slope;
533
+ if (sz) v.z *= p.slope;
534
+ if (sw) v.w *= p.slope;
535
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
536
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
537
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }
538
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }
539
+
540
+ // Combine signs.
541
+ uint32_t s = sx + sy + sw + sz;
542
+ s <<= (signX & 3) << 1;
543
+ s |= __shfl_xor_sync(groupMask, s, 1);
544
+ s |= __shfl_xor_sync(groupMask, s, 2);
545
+
546
+ // Write signs.
547
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
548
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
549
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
550
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
551
+ }
552
+ else
553
+ {
554
+ // Just compute the values.
555
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
556
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
557
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
558
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
559
+ }
560
+ }
561
+ }
562
+ else if (signRead) // Read signs and apply.
563
+ {
564
+ if ((uint32_t)signXb < p.swLimit)
565
+ {
566
+ int ss = (signX & 3) << 1;
567
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
568
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
569
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
570
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
571
+ }
572
+ }
573
+ else // Forward pass with no sign write.
574
+ {
575
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
576
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
577
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
578
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
579
+ }
580
+
581
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
582
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
583
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
584
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
585
+ }
586
+ }
587
+ else if (up == 2)
588
+ {
589
+ minY -= 1; // Adjust according to block height.
590
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
591
+ {
592
+ int relUpX, relInY0;
593
+ fast_div_mod<tileUpW>(relUpX, relInY0, idx);
594
+ int relUpY0 = relInY0 * up;
595
+ int src0 = relInY0 * tileUpW + relUpX;
596
+ int dst = relUpY0 * tileUpW + relUpX;
597
+ vec2_t v = InternalType<T>::zero_vec2();
598
+
599
+ scalar_t a = s_tileUpX[src0];
600
+ if (phaseInY == 0)
601
+ {
602
+ #pragma unroll
603
+ for (int step = 0; step < fuSize / up; step++)
604
+ {
605
+ v.x += a * (scalar_t)c_fu[step * up + 0];
606
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
607
+ v.y += a * (scalar_t)c_fu[step * up + 1];
608
+ }
609
+ }
610
+ else // (phaseInY == 1)
611
+ {
612
+ #pragma unroll
613
+ for (int step = 0; step < fuSize / up; step++)
614
+ {
615
+ v.x += a * (scalar_t)c_fu[step * up + 1];
616
+ v.y += a * (scalar_t)c_fu[step * up + 0];
617
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
618
+ }
619
+ }
620
+
621
+ int x = tileOutX * down + relUpX;
622
+ int y = tileOutY * down + relUpY0;
623
+ int signX = x + p.sOfs.x;
624
+ int signY = y + p.sOfs.y;
625
+ int signZ = blockIdx.z + p.blockZofs;
626
+ int signXb = signX >> 2;
627
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
628
+ index_t si1 = si0 + p.sShape.x;
629
+
630
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
631
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
632
+
633
+ if (signWrite)
634
+ {
635
+ if (!enableWriteSkip)
636
+ {
637
+ // Determine and write signs.
638
+ int sx = __float_as_uint(v.x) >> 31 << 0;
639
+ int sy = __float_as_uint(v.y) >> 31 << 8;
640
+ if (sx) v.x *= p.slope;
641
+ if (sy) v.y *= p.slope;
642
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
643
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
644
+
645
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
646
+ {
647
+ // Combine signs.
648
+ int s = sx + sy;
649
+ s <<= signXo;
650
+ s |= __shfl_xor_sync(groupMask, s, 1);
651
+ s |= __shfl_xor_sync(groupMask, s, 2);
652
+
653
+ // Write signs.
654
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
655
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
656
+ }
657
+ }
658
+ else
659
+ {
660
+ // Determine and write signs.
661
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
662
+ {
663
+ int sx = __float_as_uint(v.x) >> 31 << 0;
664
+ int sy = __float_as_uint(v.y) >> 31 << 8;
665
+ if (sx) v.x *= p.slope;
666
+ if (sy) v.y *= p.slope;
667
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }
668
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }
669
+
670
+ // Combine signs.
671
+ int s = sx + sy;
672
+ s <<= signXo;
673
+ s |= __shfl_xor_sync(groupMask, s, 1);
674
+ s |= __shfl_xor_sync(groupMask, s, 2);
675
+
676
+ // Write signs.
677
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
678
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
679
+ }
680
+ else
681
+ {
682
+ // Just compute the values.
683
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
684
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
685
+ }
686
+ }
687
+ }
688
+ else if (signRead) // Read signs and apply.
689
+ {
690
+ if ((uint32_t)signXb < p.swLimit)
691
+ {
692
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
693
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
694
+ }
695
+ }
696
+ else // Forward pass with no sign write.
697
+ {
698
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
699
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
700
+ }
701
+
702
+ if (!downInline)
703
+ {
704
+ // Write into temporary buffer.
705
+ s_tileUpXY[dst] = v.x;
706
+ if (relUpY0 < tileUpH - 1)
707
+ s_tileUpXY[dst + tileUpW] = v.y;
708
+ }
709
+ else
710
+ {
711
+ // Write directly into output buffer.
712
+ if ((uint32_t)x < p.yShape.x)
713
+ {
714
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
715
+ index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;
716
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
717
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
718
+ }
719
+ }
720
+ }
721
+ }
722
+ }
723
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
724
+ {
725
+ // Full upsampling filter.
726
+
727
+ if (up == 2)
728
+ {
729
+ // 2 x 2-wide.
730
+ __syncthreads();
731
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
732
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
733
+ {
734
+ int relUpX0, relUpY0;
735
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
736
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
737
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
738
+ int src0 = relInX0 + tileInW * relInY0;
739
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
740
+
741
+ #define X_LOOP(TAPY, PX) \
742
+ for (int sx = 0; sx < fuSize / up; sx++) \
743
+ { \
744
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
745
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
746
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
747
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
748
+ }
749
+
750
+ vec4_t v = InternalType<T>::zero_vec4();
751
+ if (tap0y == 0 && phaseInX == 0)
752
+ #pragma unroll
753
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
754
+ #pragma unroll
755
+ X_LOOP(0, 0) }
756
+ if (tap0y == 0 && phaseInX == 1)
757
+ #pragma unroll
758
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
759
+ #pragma unroll
760
+ X_LOOP(0, 1) }
761
+ if (tap0y == 1 && phaseInX == 0)
762
+ #pragma unroll
763
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
764
+ #pragma unroll
765
+ X_LOOP(1, 0) }
766
+ if (tap0y == 1 && phaseInX == 1)
767
+ #pragma unroll
768
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
769
+ #pragma unroll
770
+ X_LOOP(1, 1) }
771
+
772
+ #undef X_LOOP
773
+
774
+ int x = tileOutX * down + relUpX0;
775
+ int y = tileOutY * down + relUpY0;
776
+ int signX = x + p.sOfs.x;
777
+ int signY = y + p.sOfs.y;
778
+ int signZ = blockIdx.z + p.blockZofs;
779
+ int signXb = signX >> 2;
780
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
781
+
782
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
783
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
784
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
785
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
786
+
787
+ if (signWrite)
788
+ {
789
+ if (!enableWriteSkip)
790
+ {
791
+ // Determine and write signs.
792
+ int sx = __float_as_uint(v.x) >> 31;
793
+ int sy = __float_as_uint(v.y) >> 31;
794
+ int sz = __float_as_uint(v.z) >> 31;
795
+ int sw = __float_as_uint(v.w) >> 31;
796
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
797
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
798
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
799
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
800
+
801
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
802
+ {
803
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
804
+ }
805
+ }
806
+ else
807
+ {
808
+ // Determine and write signs.
809
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
810
+ {
811
+ int sx = __float_as_uint(v.x) >> 31;
812
+ int sy = __float_as_uint(v.y) >> 31;
813
+ int sz = __float_as_uint(v.z) >> 31;
814
+ int sw = __float_as_uint(v.w) >> 31;
815
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }
816
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }
817
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }
818
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }
819
+
820
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
821
+ }
822
+ else
823
+ {
824
+ // Just compute the values.
825
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
826
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
827
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
828
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
829
+ }
830
+ }
831
+ }
832
+ else if (signRead) // Read sign and apply.
833
+ {
834
+ if ((uint32_t)signY < p.sShape.y)
835
+ {
836
+ int s = 0;
837
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
838
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
839
+ s >>= (signX & 3) << 1;
840
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
841
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
842
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
843
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
844
+ }
845
+ }
846
+ else // Forward pass with no sign write.
847
+ {
848
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);
849
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);
850
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);
851
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);
852
+ }
853
+
854
+ s_tileUpXY[idx + 0] = v.x;
855
+ s_tileUpXY[idx + 1] = v.y;
856
+ s_tileUpXY[idx + 2] = v.z;
857
+ s_tileUpXY[idx + 3] = v.w;
858
+ }
859
+ }
860
+ else if (up == 1)
861
+ {
862
+ __syncthreads();
863
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
864
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
865
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
866
+ {
867
+ int relUpX0, relUpY0;
868
+ fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);
869
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
870
+
871
+ int x = tileOutX * down + relUpX0;
872
+ int y = tileOutY * down + relUpY0;
873
+ int signX = x + p.sOfs.x;
874
+ int signY = y + p.sOfs.y;
875
+ int signZ = blockIdx.z + p.blockZofs;
876
+ int signXb = signX >> 2;
877
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
878
+ v *= (scalar_t)((float)up * (float)up * p.gain);
879
+
880
+ if (signWrite)
881
+ {
882
+ if (!enableWriteSkip)
883
+ {
884
+ // Determine and write sign.
885
+ uint32_t s = 0;
886
+ uint32_t signXbit = (1u << signXo);
887
+ if (v < 0.f)
888
+ {
889
+ s = signXbit;
890
+ v *= p.slope;
891
+ }
892
+ if (fabsf(v) > p.clamp)
893
+ {
894
+ s = signXbit * 2;
895
+ v = InternalType<T>::clamp(v, p.clamp);
896
+ }
897
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
898
+ {
899
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
900
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
901
+ p.s[si] = s; // Write.
902
+ }
903
+ }
904
+ else
905
+ {
906
+ // Determine and write sign.
907
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
908
+ {
909
+ uint32_t s = 0;
910
+ uint32_t signXbit = (1u << signXo);
911
+ if (v < 0.f)
912
+ {
913
+ s = signXbit;
914
+ v *= p.slope;
915
+ }
916
+ if (fabsf(v) > p.clamp)
917
+ {
918
+ s = signXbit * 2;
919
+ v = InternalType<T>::clamp(v, p.clamp);
920
+ }
921
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
922
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
923
+ p.s[si] = s; // Write.
924
+ }
925
+ else
926
+ {
927
+ // Just compute the value.
928
+ if (v < 0.f) v *= p.slope;
929
+ v = InternalType<T>::clamp(v, p.clamp);
930
+ }
931
+ }
932
+ }
933
+ else if (signRead)
934
+ {
935
+ // Read sign and apply if within sign tensor bounds.
936
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
937
+ {
938
+ int s = p.s[si];
939
+ s >>= signXo;
940
+ if (s & 1) v *= p.slope;
941
+ if (s & 2) v = 0.f;
942
+ }
943
+ }
944
+ else // Forward pass with no sign write.
945
+ {
946
+ if (v < 0.f) v *= p.slope;
947
+ v = InternalType<T>::clamp(v, p.clamp);
948
+ }
949
+
950
+ if (!downInline) // Write into temporary buffer.
951
+ s_tileUpXY[idx] = v;
952
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
953
+ *((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
954
+ }
955
+ }
956
+ }
957
+
958
+ // Downsampling.
959
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
960
+ {
961
+ // Horizontal downsampling.
962
+ __syncthreads();
963
+ if (down == 4 && tileOutW % 4 == 0)
964
+ {
965
+ // Calculate 4 pixels at a time.
966
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
967
+ {
968
+ int relOutX0, relUpY;
969
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
970
+ int relUpX0 = relOutX0 * down;
971
+ int src0 = relUpY * tileUpW + relUpX0;
972
+ vec4_t v = InternalType<T>::zero_vec4();
973
+ #pragma unroll
974
+ for (int step = 0; step < fdSize; step++)
975
+ {
976
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
977
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
978
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
979
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
980
+ }
981
+ s_tileDownX[idx+0] = v.x;
982
+ s_tileDownX[idx+1] = v.y;
983
+ s_tileDownX[idx+2] = v.z;
984
+ s_tileDownX[idx+3] = v.w;
985
+ }
986
+ }
987
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
988
+ {
989
+ // Calculate 2 pixels at a time.
990
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
991
+ {
992
+ int relOutX0, relUpY;
993
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
994
+ int relUpX0 = relOutX0 * down;
995
+ int src0 = relUpY * tileUpW + relUpX0;
996
+ vec2_t v = InternalType<T>::zero_vec2();
997
+ #pragma unroll
998
+ for (int step = 0; step < fdSize; step++)
999
+ {
1000
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
1001
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
1002
+ }
1003
+ s_tileDownX[idx+0] = v.x;
1004
+ s_tileDownX[idx+1] = v.y;
1005
+ }
1006
+ }
1007
+ else
1008
+ {
1009
+ // Calculate 1 pixel at a time.
1010
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
1011
+ {
1012
+ int relOutX0, relUpY;
1013
+ fast_div_mod<tileOutW>(relOutX0, relUpY, idx);
1014
+ int relUpX0 = relOutX0 * down;
1015
+ int src = relUpY * tileUpW + relUpX0;
1016
+ scalar_t v = 0.f;
1017
+ #pragma unroll
1018
+ for (int step = 0; step < fdSize; step++)
1019
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
1020
+ s_tileDownX[idx] = v;
1021
+ }
1022
+ }
1023
+
1024
+ // Vertical downsampling & store output tile.
1025
+ __syncthreads();
1026
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1027
+ {
1028
+ int relOutX, relOutY0;
1029
+ fast_div_mod<tileOutW>(relOutX, relOutY0, idx);
1030
+ int relUpY0 = relOutY0 * down;
1031
+ int src0 = relUpY0 * tileOutW + relOutX;
1032
+ scalar_t v = 0;
1033
+ #pragma unroll
1034
+ for (int step = 0; step < fdSize; step++)
1035
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
1036
+
1037
+ int outX = tileOutX + relOutX;
1038
+ int outY = tileOutY + relOutY0;
1039
+
1040
+ if (outX < p.yShape.x & outY < p.yShape.y)
1041
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1042
+ }
1043
+ }
1044
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
1045
+ {
1046
+ // Full downsampling filter.
1047
+ if (down == 2)
1048
+ {
1049
+ // 2-wide.
1050
+ __syncthreads();
1051
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
1052
+ {
1053
+ int relOutX0, relOutY0;
1054
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1055
+ int relUpX0 = relOutX0 * down;
1056
+ int relUpY0 = relOutY0 * down;
1057
+ int src0 = relUpY0 * tileUpW + relUpX0;
1058
+ vec2_t v = InternalType<T>::zero_vec2();
1059
+ #pragma unroll
1060
+ for (int sy = 0; sy < fdSize; sy++)
1061
+ #pragma unroll
1062
+ for (int sx = 0; sx < fdSize; sx++)
1063
+ {
1064
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1065
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
1066
+ }
1067
+
1068
+ int outX = tileOutX + relOutX0;
1069
+ int outY = tileOutY + relOutY0;
1070
+ if ((uint32_t)outY < p.yShape.y)
1071
+ {
1072
+ index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;
1073
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
1074
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;
1075
+ }
1076
+ }
1077
+ }
1078
+ else if (down == 1 && !downInline)
1079
+ {
1080
+ // Thread per pixel.
1081
+ __syncthreads();
1082
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
1083
+ {
1084
+ int relOutX0, relOutY0;
1085
+ fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);
1086
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
1087
+
1088
+ int outX = tileOutX + relOutX0;
1089
+ int outY = tileOutY + relOutY0;
1090
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
1091
+ *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;
1092
+ }
1093
+ }
1094
+ }
1095
+
1096
+ if (!enableXrep)
1097
+ break;
1098
+ }
1099
+ }
1100
+
1101
+ //------------------------------------------------------------------------
1102
+ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
1103
+ // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
1104
+
1105
+ template <class T, bool signWrite, bool signRead>
1106
+ static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
1107
+ {
1108
+ typedef typename InternalType<T>::scalar_t scalar_t;
1109
+
1110
+ // Indexing.
1111
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
1112
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
1113
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
1114
+
1115
+ // Loop to accommodate oversized tensors.
1116
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
1117
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
1118
+ {
1119
+ // Extract z and w (channel, minibatch index).
1120
+ int32_t w = q / p.xShape.z;
1121
+ int32_t z = q - w * p.xShape.z;
1122
+
1123
+ // Choose behavior based on sign read/write mode.
1124
+ if (signWrite)
1125
+ {
1126
+ // Process value if in p.x.
1127
+ uint32_t s = 0;
1128
+ if (x < p.xShape.x && y < p.xShape.y)
1129
+ {
1130
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1131
+ T* pv = ((T*)p.x) + ix;
1132
+ scalar_t v = (scalar_t)(*pv);
1133
+
1134
+ // Gain, LReLU, clamp.
1135
+ v *= p.gain;
1136
+ if (v < 0.f)
1137
+ {
1138
+ v *= p.slope;
1139
+ s = 1; // Sign.
1140
+ }
1141
+ if (fabsf(v) > p.clamp)
1142
+ {
1143
+ v = InternalType<T>::clamp(v, p.clamp);
1144
+ s = 2; // Clamp.
1145
+ }
1146
+
1147
+ *pv = (T)v; // Write value.
1148
+ }
1149
+
1150
+ // Coalesce into threads 0 and 16 of warp.
1151
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
1152
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
1153
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
1154
+ s |= __shfl_xor_sync(m, s, 2);
1155
+ s |= __shfl_xor_sync(m, s, 4);
1156
+ s |= __shfl_xor_sync(m, s, 8);
1157
+
1158
+ // Write signs if leader and in p.s.
1159
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
1160
+ {
1161
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
1162
+ ((uint32_t*)p.s)[is >> 4] = s;
1163
+ }
1164
+ }
1165
+ else if (signRead)
1166
+ {
1167
+ // Process value if in p.x.
1168
+ if (x < p.xShape.x) // y is always in.
1169
+ {
1170
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1171
+ T* pv = ((T*)p.x) + ix;
1172
+ scalar_t v = (scalar_t)(*pv);
1173
+ v *= p.gain;
1174
+
1175
+ // Apply sign buffer offset.
1176
+ uint32_t sx = x + p.sOfs.x;
1177
+ uint32_t sy = y + p.sOfs.y;
1178
+
1179
+ // Read and apply signs if we land inside valid region of sign buffer.
1180
+ if (sx < p.sShape.x && sy < p.sShape.y)
1181
+ {
1182
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
1183
+ unsigned char s = p.s[is];
1184
+ s >>= (sx & 3) << 1; // Shift into place.
1185
+ if (s & 1) // Sign?
1186
+ v *= p.slope;
1187
+ if (s & 2) // Clamp?
1188
+ v = 0.f;
1189
+ }
1190
+
1191
+ *pv = (T)v; // Write value.
1192
+ }
1193
+ }
1194
+ else
1195
+ {
1196
+ // Forward pass with no sign write. Process value if in p.x.
1197
+ if (x < p.xShape.x) // y is always in.
1198
+ {
1199
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
1200
+ T* pv = ((T*)p.x) + ix;
1201
+ scalar_t v = (scalar_t)(*pv);
1202
+ v *= p.gain;
1203
+ if (v < 0.f)
1204
+ v *= p.slope;
1205
+ if (fabsf(v) > p.clamp)
1206
+ v = InternalType<T>::clamp(v, p.clamp);
1207
+ *pv = (T)v; // Write value.
1208
+ }
1209
+ }
1210
+ }
1211
+ }
1212
+
1213
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)
1214
+ {
1215
+ return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;
1216
+ }
1217
+
1218
+ //------------------------------------------------------------------------
1219
+ // CUDA kernel selection.
1220
+
1221
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
1222
+ {
1223
+ filtered_lrelu_kernel_spec s = { 0 };
1224
+
1225
+ // Return the first matching kernel.
1226
+ #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
1227
+ if (sharedKB >= SH) \
1228
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
1229
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
1230
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
1231
+ { \
1232
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
1233
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
1234
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
1235
+ s.setup = (void*)setup_filters_kernel; \
1236
+ s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \
1237
+ s.tileOut = make_int2(TW, TH); \
1238
+ s.numWarps = W; \
1239
+ s.xrep = XR; \
1240
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
1241
+ return s; \
1242
+ }
1243
+
1244
+ // Launch parameters for various kernel specializations.
1245
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
1246
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
1247
+
1248
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
1249
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
1250
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
1251
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
1252
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
1253
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
1254
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
1255
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
1256
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
1257
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
1258
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
1259
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
1260
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
1261
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
1262
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
1263
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
1264
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
1265
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
1266
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
1267
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
1268
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
1269
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
1270
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
1271
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
1272
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
1273
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
1274
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
1275
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
1276
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
1277
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
1278
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
1279
+
1280
+ #undef CASE
1281
+ return s; // No kernel found.
1282
+ }
1283
+
1284
+ //------------------------------------------------------------------------
models/draggan/torch_utils/ops/filtered_lrelu.h ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct filtered_lrelu_kernel_params
15
+ {
16
+ // These parameters decide which kernel to use.
17
+ int up; // upsampling ratio (1, 2, 4)
18
+ int down; // downsampling ratio (1, 2, 4)
19
+ int2 fuShape; // [size, 1] | [size, size]
20
+ int2 fdShape; // [size, 1] | [size, size]
21
+
22
+ int _dummy; // Alignment.
23
+
24
+ // Rest of the parameters.
25
+ const void* x; // Input tensor.
26
+ void* y; // Output tensor.
27
+ const void* b; // Bias tensor.
28
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
29
+ const float* fu; // Upsampling filter.
30
+ const float* fd; // Downsampling filter.
31
+
32
+ int2 pad0; // Left/top padding.
33
+ float gain; // Additional gain factor.
34
+ float slope; // Leaky ReLU slope on negative side.
35
+ float clamp; // Clamp after nonlinearity.
36
+ int flip; // Filter kernel flip for gradient computation.
37
+
38
+ int tilesXdim; // Original number of horizontal output tiles.
39
+ int tilesXrep; // Number of horizontal tiles per CTA.
40
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
41
+
42
+ int4 xShape; // [width, height, channel, batch]
43
+ int4 yShape; // [width, height, channel, batch]
44
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
45
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
46
+ int swLimit; // Active width of sign tensor in bytes.
47
+
48
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
49
+ longlong4 yStride; //
50
+ int64_t bStride; //
51
+ longlong3 fuStride; //
52
+ longlong3 fdStride; //
53
+ };
54
+
55
+ struct filtered_lrelu_act_kernel_params
56
+ {
57
+ void* x; // Input/output, modified in-place.
58
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
59
+
60
+ float gain; // Additional gain factor.
61
+ float slope; // Leaky ReLU slope on negative side.
62
+ float clamp; // Clamp after nonlinearity.
63
+
64
+ int4 xShape; // [width, height, channel, batch]
65
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
66
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
67
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
68
+ };
69
+
70
+ //------------------------------------------------------------------------
71
+ // CUDA kernel specialization.
72
+
73
+ struct filtered_lrelu_kernel_spec
74
+ {
75
+ void* setup; // Function for filter kernel setup.
76
+ void* exec; // Function for main operation.
77
+ int2 tileOut; // Width/height of launch tile.
78
+ int numWarps; // Number of warps per thread block, determines launch block size.
79
+ int xrep; // For processing multiple horizontal tiles per thread block.
80
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
81
+ };
82
+
83
+ //------------------------------------------------------------------------
84
+ // CUDA kernel selection.
85
+
86
+ template <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
87
+ template <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);
88
+ template <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);
89
+
90
+ //------------------------------------------------------------------------
models/draggan/torch_utils/ops/filtered_lrelu.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import os
10
+ import numpy as np
11
+ import torch
12
+ import warnings
13
+
14
+ from .. import custom_ops
15
+ from .. import misc
16
+ from . import upfirdn2d
17
+ from . import bias_act
18
+
19
+ #----------------------------------------------------------------------------
20
+
21
+ _plugin = None
22
+
23
+ def _init():
24
+ global _plugin
25
+ if _plugin is None:
26
+ _plugin = custom_ops.get_plugin(
27
+ module_name='filtered_lrelu_plugin',
28
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
29
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
30
+ source_dir=os.path.dirname(__file__),
31
+ extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],
32
+ )
33
+ return True
34
+
35
+ def _get_filter_size(f):
36
+ if f is None:
37
+ return 1, 1
38
+ assert isinstance(f, torch.Tensor)
39
+ assert 1 <= f.ndim <= 2
40
+ return f.shape[-1], f.shape[0] # width, height
41
+
42
+ def _parse_padding(padding):
43
+ if isinstance(padding, int):
44
+ padding = [padding, padding]
45
+ assert isinstance(padding, (list, tuple))
46
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
47
+ padding = [int(x) for x in padding]
48
+ if len(padding) == 2:
49
+ px, py = padding
50
+ padding = [px, px, py, py]
51
+ px0, px1, py0, py1 = padding
52
+ return px0, px1, py0, py1
53
+
54
+ #----------------------------------------------------------------------------
55
+
56
+ def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
57
+ r"""Filtered leaky ReLU for a batch of 2D images.
58
+
59
+ Performs the following sequence of operations for each channel:
60
+
61
+ 1. Add channel-specific bias if provided (`b`).
62
+
63
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
64
+
65
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
66
+ Negative padding corresponds to cropping the image.
67
+
68
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
69
+ so that the footprint of all output pixels lies within the input image.
70
+
71
+ 5. Multiply each value by the provided gain factor (`gain`).
72
+
73
+ 6. Apply leaky ReLU activation function to each value.
74
+
75
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
76
+
77
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
78
+ it so that the footprint of all output pixels lies within the input image.
79
+
80
+ 9. Downsample the image by keeping every Nth pixel (`down`).
81
+
82
+ The fused op is considerably more efficient than performing the same calculation
83
+ using standard PyTorch ops. It supports gradients of arbitrary order.
84
+
85
+ Args:
86
+ x: Float32/float16/float64 input tensor of the shape
87
+ `[batch_size, num_channels, in_height, in_width]`.
88
+ fu: Float32 upsampling FIR filter of the shape
89
+ `[filter_height, filter_width]` (non-separable),
90
+ `[filter_taps]` (separable), or
91
+ `None` (identity).
92
+ fd: Float32 downsampling FIR filter of the shape
93
+ `[filter_height, filter_width]` (non-separable),
94
+ `[filter_taps]` (separable), or
95
+ `None` (identity).
96
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
97
+ as `x`. The length of vector must must match the channel dimension of `x`.
98
+ up: Integer upsampling factor (default: 1).
99
+ down: Integer downsampling factor. (default: 1).
100
+ padding: Padding with respect to the upsampled image. Can be a single number
101
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
102
+ (default: 0).
103
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
104
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
105
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
106
+ flip_filter: False = convolution, True = correlation (default: False).
107
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
108
+
109
+ Returns:
110
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
111
+ """
112
+ assert isinstance(x, torch.Tensor)
113
+ assert impl in ['ref', 'cuda']
114
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
115
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
116
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ @misc.profiled_function
121
+ def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
122
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
123
+ existing `upfirdn2n()` and `bias_act()` ops.
124
+ """
125
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
126
+ fu_w, fu_h = _get_filter_size(fu)
127
+ fd_w, fd_h = _get_filter_size(fd)
128
+ if b is not None:
129
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
130
+ misc.assert_shape(b, [x.shape[1]])
131
+ assert isinstance(up, int) and up >= 1
132
+ assert isinstance(down, int) and down >= 1
133
+ px0, px1, py0, py1 = _parse_padding(padding)
134
+ assert gain == float(gain) and gain > 0
135
+ assert slope == float(slope) and slope >= 0
136
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
137
+
138
+ # Calculate output size.
139
+ batch_size, channels, in_h, in_w = x.shape
140
+ in_dtype = x.dtype
141
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
142
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
143
+
144
+ # Compute using existing ops.
145
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
146
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
147
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
148
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
149
+
150
+ # Check output shape & dtype.
151
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
152
+ assert x.dtype == in_dtype
153
+ return x
154
+
155
+ #----------------------------------------------------------------------------
156
+
157
+ _filtered_lrelu_cuda_cache = dict()
158
+
159
+ def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
160
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
161
+ """
162
+ assert isinstance(up, int) and up >= 1
163
+ assert isinstance(down, int) and down >= 1
164
+ px0, px1, py0, py1 = _parse_padding(padding)
165
+ assert gain == float(gain) and gain > 0
166
+ gain = float(gain)
167
+ assert slope == float(slope) and slope >= 0
168
+ slope = float(slope)
169
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
170
+ clamp = float(clamp if clamp is not None else 'inf')
171
+
172
+ # Lookup from cache.
173
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
174
+ if key in _filtered_lrelu_cuda_cache:
175
+ return _filtered_lrelu_cuda_cache[key]
176
+
177
+ # Forward op.
178
+ class FilteredLReluCuda(torch.autograd.Function):
179
+ @staticmethod
180
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
181
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
182
+
183
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
184
+ if fu is None:
185
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
186
+ if fd is None:
187
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
188
+ assert 1 <= fu.ndim <= 2
189
+ assert 1 <= fd.ndim <= 2
190
+
191
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
192
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
193
+ fu = fu.square()[None]
194
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
195
+ fd = fd.square()[None]
196
+
197
+ # Missing sign input tensor.
198
+ if si is None:
199
+ si = torch.empty([0])
200
+
201
+ # Missing bias tensor.
202
+ if b is None:
203
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
204
+
205
+ # Construct internal sign tensor only if gradients are needed.
206
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
207
+
208
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
209
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
210
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
211
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
212
+
213
+ # Call C++/Cuda plugin if datatype is supported.
214
+ if x.dtype in [torch.float16, torch.float32]:
215
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
216
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
217
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
218
+ else:
219
+ return_code = -1
220
+
221
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
222
+ # only the bit-packed sign tensor is retained for gradient computation.
223
+ if return_code < 0:
224
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
225
+
226
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
227
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
228
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
229
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
230
+
231
+ # Prepare for gradient computation.
232
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
233
+ ctx.x_shape = x.shape
234
+ ctx.y_shape = y.shape
235
+ ctx.s_ofs = sx, sy
236
+ return y
237
+
238
+ @staticmethod
239
+ def backward(ctx, dy): # pylint: disable=arguments-differ
240
+ fu, fd, si = ctx.saved_tensors
241
+ _, _, xh, xw = ctx.x_shape
242
+ _, _, yh, yw = ctx.y_shape
243
+ sx, sy = ctx.s_ofs
244
+ dx = None # 0
245
+ dfu = None; assert not ctx.needs_input_grad[1]
246
+ dfd = None; assert not ctx.needs_input_grad[2]
247
+ db = None # 3
248
+ dsi = None; assert not ctx.needs_input_grad[4]
249
+ dsx = None; assert not ctx.needs_input_grad[5]
250
+ dsy = None; assert not ctx.needs_input_grad[6]
251
+
252
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
253
+ pp = [
254
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
255
+ xw * up - yw * down + px0 - (up - 1),
256
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
257
+ xh * up - yh * down + py0 - (up - 1),
258
+ ]
259
+ gg = gain * (up ** 2) / (down ** 2)
260
+ ff = (not flip_filter)
261
+ sx = sx - (fu.shape[-1] - 1) + px0
262
+ sy = sy - (fu.shape[0] - 1) + py0
263
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
264
+
265
+ if ctx.needs_input_grad[3]:
266
+ db = dx.sum([0, 2, 3])
267
+
268
+ return dx, dfu, dfd, db, dsi, dsx, dsy
269
+
270
+ # Add to cache.
271
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
272
+ return FilteredLReluCuda
273
+
274
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/ops/filtered_lrelu_ns.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for no signs mode (no gradients required).
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, false>(cudaStream_t stream);
models/draggan/torch_utils/ops/filtered_lrelu_rd.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign read mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, false, true>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, false, true>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<false, true>(cudaStream_t stream);
models/draggan/torch_utils/ops/filtered_lrelu_wr.cu ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include "filtered_lrelu.cu"
10
+
11
+ // Template/kernel specializations for sign write mode.
12
+
13
+ // Full op, 32-bit indexing.
14
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
15
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
16
+
17
+ // Full op, 64-bit indexing.
18
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
19
+ template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);
20
+
21
+ // Activation/signs only for generic variant. 64-bit indexing.
22
+ template void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);
23
+ template void* choose_filtered_lrelu_act_kernel<float, true, false>(void);
24
+ template void* choose_filtered_lrelu_act_kernel<double, true, false>(void);
25
+
26
+ // Copy filters to constant memory.
27
+ template cudaError_t copy_filters<true, false>(cudaStream_t stream);
models/draggan/torch_utils/ops/fma.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`."""
10
+
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+
15
+ def fma(a, b, c): # => a * b + c
16
+ return _FusedMultiplyAdd.apply(a, b, c)
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c
21
+ @staticmethod
22
+ def forward(ctx, a, b, c): # pylint: disable=arguments-differ
23
+ out = torch.addcmul(c, a, b)
24
+ ctx.save_for_backward(a, b)
25
+ ctx.c_shape = c.shape
26
+ return out
27
+
28
+ @staticmethod
29
+ def backward(ctx, dout): # pylint: disable=arguments-differ
30
+ a, b = ctx.saved_tensors
31
+ c_shape = ctx.c_shape
32
+ da = None
33
+ db = None
34
+ dc = None
35
+
36
+ if ctx.needs_input_grad[0]:
37
+ da = _unbroadcast(dout * b, a.shape)
38
+
39
+ if ctx.needs_input_grad[1]:
40
+ db = _unbroadcast(dout * a, b.shape)
41
+
42
+ if ctx.needs_input_grad[2]:
43
+ dc = _unbroadcast(dout, c_shape)
44
+
45
+ return da, db, dc
46
+
47
+ #----------------------------------------------------------------------------
48
+
49
+ def _unbroadcast(x, shape):
50
+ extra_dims = x.ndim - len(shape)
51
+ assert extra_dims >= 0
52
+ dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]
53
+ if len(dim):
54
+ x = x.sum(dim=dim, keepdim=True)
55
+ if extra_dims:
56
+ x = x.reshape(-1, *x.shape[extra_dims+1:])
57
+ assert x.shape == shape
58
+ return x
59
+
60
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/ops/grid_sample_gradfix.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom replacement for `torch.nn.functional.grid_sample` that
10
+ supports arbitrarily high order gradients between the input and output.
11
+ Only works on 2D images and assumes
12
+ `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
13
+
14
+ import torch
15
+
16
+ # pylint: disable=redefined-builtin
17
+ # pylint: disable=arguments-differ
18
+ # pylint: disable=protected-access
19
+
20
+ #----------------------------------------------------------------------------
21
+
22
+ enabled = False # Enable the custom op by setting this to true.
23
+
24
+ #----------------------------------------------------------------------------
25
+
26
+ def grid_sample(input, grid):
27
+ if _should_use_custom_op():
28
+ return _GridSample2dForward.apply(input, grid)
29
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
30
+
31
+ #----------------------------------------------------------------------------
32
+
33
+ def _should_use_custom_op():
34
+ return enabled
35
+
36
+ #----------------------------------------------------------------------------
37
+
38
+ class _GridSample2dForward(torch.autograd.Function):
39
+ @staticmethod
40
+ def forward(ctx, input, grid):
41
+ assert input.ndim == 4
42
+ assert grid.ndim == 4
43
+ output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
44
+ ctx.save_for_backward(input, grid)
45
+ return output
46
+
47
+ @staticmethod
48
+ def backward(ctx, grad_output):
49
+ input, grid = ctx.saved_tensors
50
+ grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
51
+ return grad_input, grad_grid
52
+
53
+ #----------------------------------------------------------------------------
54
+
55
+ class _GridSample2dBackward(torch.autograd.Function):
56
+ @staticmethod
57
+ def forward(ctx, grad_output, input, grid):
58
+ op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
59
+ grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
60
+ ctx.save_for_backward(grid)
61
+ return grad_input, grad_grid
62
+
63
+ @staticmethod
64
+ def backward(ctx, grad2_grad_input, grad2_grad_grid):
65
+ _ = grad2_grad_grid # unused
66
+ grid, = ctx.saved_tensors
67
+ grad2_grad_output = None
68
+ grad2_input = None
69
+ grad2_grid = None
70
+
71
+ if ctx.needs_input_grad[0]:
72
+ grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
73
+
74
+ assert not ctx.needs_input_grad[2]
75
+ return grad2_grad_output, grad2_input, grad2_grid
76
+
77
+ #----------------------------------------------------------------------------
models/draggan/torch_utils/ops/upfirdn2d.cpp ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <torch/extension.h>
10
+ #include <ATen/cuda/CUDAContext.h>
11
+ #include <c10/cuda/CUDAGuard.h>
12
+ #include "upfirdn2d.h"
13
+
14
+ //------------------------------------------------------------------------
15
+
16
+ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)
17
+ {
18
+ // Validate arguments.
19
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
20
+ TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x");
21
+ TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32");
22
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
23
+ TORCH_CHECK(f.numel() <= INT_MAX, "f is too large");
24
+ TORCH_CHECK(x.numel() > 0, "x has zero size");
25
+ TORCH_CHECK(f.numel() > 0, "f has zero size");
26
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
27
+ TORCH_CHECK(f.dim() == 2, "f must be rank 2");
28
+ TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large");
29
+ TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1");
30
+ TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1");
31
+ TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1");
32
+
33
+ // Create output tensor.
34
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
35
+ int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;
36
+ int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;
37
+ TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1");
38
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());
39
+ TORCH_CHECK(y.numel() <= INT_MAX, "output is too large");
40
+ TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large");
41
+
42
+ // Initialize CUDA kernel parameters.
43
+ upfirdn2d_kernel_params p;
44
+ p.x = x.data_ptr();
45
+ p.f = f.data_ptr<float>();
46
+ p.y = y.data_ptr();
47
+ p.up = make_int2(upx, upy);
48
+ p.down = make_int2(downx, downy);
49
+ p.pad0 = make_int2(padx0, pady0);
50
+ p.flip = (flip) ? 1 : 0;
51
+ p.gain = gain;
52
+ p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
53
+ p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));
54
+ p.filterSize = make_int2((int)f.size(1), (int)f.size(0));
55
+ p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0));
56
+ p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
57
+ p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));
58
+ p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;
59
+ p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1;
60
+
61
+ // Choose CUDA kernel.
62
+ upfirdn2d_kernel_spec spec;
63
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
64
+ {
65
+ spec = choose_upfirdn2d_kernel<scalar_t>(p);
66
+ });
67
+
68
+ // Set looping options.
69
+ p.loopMajor = (p.sizeMajor - 1) / 16384 + 1;
70
+ p.loopMinor = spec.loopMinor;
71
+ p.loopX = spec.loopX;
72
+ p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1;
73
+ p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1;
74
+
75
+ // Compute grid size.
76
+ dim3 blockSize, gridSize;
77
+ if (spec.tileOutW < 0) // large
78
+ {
79
+ blockSize = dim3(4, 32, 1);
80
+ gridSize = dim3(
81
+ ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,
82
+ (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,
83
+ p.launchMajor);
84
+ }
85
+ else // small
86
+ {
87
+ blockSize = dim3(256, 1, 1);
88
+ gridSize = dim3(
89
+ ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,
90
+ (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,
91
+ p.launchMajor);
92
+ }
93
+
94
+ // Launch CUDA kernel.
95
+ void* args[] = {&p};
96
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
97
+ return y;
98
+ }
99
+
100
+ //------------------------------------------------------------------------
101
+
102
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
103
+ {
104
+ m.def("upfirdn2d", &upfirdn2d);
105
+ }
106
+
107
+ //------------------------------------------------------------------------
models/draggan/torch_utils/ops/upfirdn2d.cu ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <c10/util/Half.h>
10
+ #include "upfirdn2d.h"
11
+
12
+ //------------------------------------------------------------------------
13
+ // Helpers.
14
+
15
+ template <class T> struct InternalType;
16
+ template <> struct InternalType<double> { typedef double scalar_t; };
17
+ template <> struct InternalType<float> { typedef float scalar_t; };
18
+ template <> struct InternalType<c10::Half> { typedef float scalar_t; };
19
+
20
+ static __device__ __forceinline__ int floor_div(int a, int b)
21
+ {
22
+ int t = 1 - a / b;
23
+ return (a + t * b) / b - t;
24
+ }
25
+
26
+ //------------------------------------------------------------------------
27
+ // Generic CUDA implementation for large filters.
28
+
29
+ template <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)
30
+ {
31
+ typedef typename InternalType<T>::scalar_t scalar_t;
32
+
33
+ // Calculate thread index.
34
+ int minorBase = blockIdx.x * blockDim.x + threadIdx.x;
35
+ int outY = minorBase / p.launchMinor;
36
+ minorBase -= outY * p.launchMinor;
37
+ int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;
38
+ int majorBase = blockIdx.z * p.loopMajor;
39
+ if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)
40
+ return;
41
+
42
+ // Setup Y receptive field.
43
+ int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;
44
+ int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);
45
+ int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;
46
+ int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;
47
+ if (p.flip)
48
+ filterY = p.filterSize.y - 1 - filterY;
49
+
50
+ // Loop over major, minor, and X.
51
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
52
+ for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)
53
+ {
54
+ int nc = major * p.sizeMinor + minor;
55
+ int n = nc / p.inSize.z;
56
+ int c = nc - n * p.inSize.z;
57
+ for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)
58
+ {
59
+ // Setup X receptive field.
60
+ int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;
61
+ int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);
62
+ int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;
63
+ int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;
64
+ if (p.flip)
65
+ filterX = p.filterSize.x - 1 - filterX;
66
+
67
+ // Initialize pointers.
68
+ const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
69
+ const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];
70
+ int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;
71
+ int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;
72
+
73
+ // Inner loop.
74
+ scalar_t v = 0;
75
+ for (int y = 0; y < h; y++)
76
+ {
77
+ for (int x = 0; x < w; x++)
78
+ {
79
+ v += (scalar_t)(*xp) * (scalar_t)(*fp);
80
+ xp += p.inStride.x;
81
+ fp += filterStepX;
82
+ }
83
+ xp += p.inStride.y - w * p.inStride.x;
84
+ fp += filterStepY - w * filterStepX;
85
+ }
86
+
87
+ // Store result.
88
+ v *= p.gain;
89
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
90
+ }
91
+ }
92
+ }
93
+
94
+ //------------------------------------------------------------------------
95
+ // Specialized CUDA implementation for small filters.
96
+
97
+ template <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>
98
+ static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)
99
+ {
100
+ typedef typename InternalType<T>::scalar_t scalar_t;
101
+ const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;
102
+ const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;
103
+ __shared__ volatile scalar_t sf[filterH][filterW];
104
+ __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];
105
+
106
+ // Calculate tile index.
107
+ int minorBase = blockIdx.x;
108
+ int tileOutY = minorBase / p.launchMinor;
109
+ minorBase -= tileOutY * p.launchMinor;
110
+ minorBase *= loopMinor;
111
+ tileOutY *= tileOutH;
112
+ int tileOutXBase = blockIdx.y * p.loopX * tileOutW;
113
+ int majorBase = blockIdx.z * p.loopMajor;
114
+ if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)
115
+ return;
116
+
117
+ // Load filter (flipped).
118
+ for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)
119
+ {
120
+ int fy = tapIdx / filterW;
121
+ int fx = tapIdx - fy * filterW;
122
+ scalar_t v = 0;
123
+ if (fx < p.filterSize.x & fy < p.filterSize.y)
124
+ {
125
+ int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;
126
+ int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;
127
+ v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];
128
+ }
129
+ sf[fy][fx] = v;
130
+ }
131
+
132
+ // Loop over major and X.
133
+ for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)
134
+ {
135
+ int baseNC = major * p.sizeMinor + minorBase;
136
+ int n = baseNC / p.inSize.z;
137
+ int baseC = baseNC - n * p.inSize.z;
138
+ for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)
139
+ {
140
+ // Load input pixels.
141
+ int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;
142
+ int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;
143
+ int tileInX = floor_div(tileMidX, upx);
144
+ int tileInY = floor_div(tileMidY, upy);
145
+ __syncthreads();
146
+ for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)
147
+ {
148
+ int relC = inIdx;
149
+ int relInX = relC / loopMinor;
150
+ int relInY = relInX / tileInW;
151
+ relC -= relInX * loopMinor;
152
+ relInX -= relInY * tileInW;
153
+ int c = baseC + relC;
154
+ int inX = tileInX + relInX;
155
+ int inY = tileInY + relInY;
156
+ scalar_t v = 0;
157
+ if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)
158
+ v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];
159
+ sx[relInY][relInX][relC] = v;
160
+ }
161
+
162
+ // Loop over output pixels.
163
+ __syncthreads();
164
+ for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)
165
+ {
166
+ int relC = outIdx;
167
+ int relOutX = relC / loopMinor;
168
+ int relOutY = relOutX / tileOutW;
169
+ relC -= relOutX * loopMinor;
170
+ relOutX -= relOutY * tileOutW;
171
+ int c = baseC + relC;
172
+ int outX = tileOutX + relOutX;
173
+ int outY = tileOutY + relOutY;
174
+
175
+ // Setup receptive field.
176
+ int midX = tileMidX + relOutX * downx;
177
+ int midY = tileMidY + relOutY * downy;
178
+ int inX = floor_div(midX, upx);
179
+ int inY = floor_div(midY, upy);
180
+ int relInX = inX - tileInX;
181
+ int relInY = inY - tileInY;
182
+ int filterX = (inX + 1) * upx - midX - 1; // flipped
183
+ int filterY = (inY + 1) * upy - midY - 1; // flipped
184
+
185
+ // Inner loop.
186
+ if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)
187
+ {
188
+ scalar_t v = 0;
189
+ #pragma unroll
190
+ for (int y = 0; y < filterH / upy; y++)
191
+ #pragma unroll
192
+ for (int x = 0; x < filterW / upx; x++)
193
+ v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];
194
+ v *= p.gain;
195
+ ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;
196
+ }
197
+ }
198
+ }
199
+ }
200
+ }
201
+
202
+ //------------------------------------------------------------------------
203
+ // CUDA kernel selection.
204
+
205
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)
206
+ {
207
+ int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;
208
+ upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous
209
+ if (s == 1) spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last
210
+
211
+ // No up/downsampling.
212
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
213
+ {
214
+ // contiguous
215
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
216
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
217
+ if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 64,16,1>, 64,16,1, 1};
218
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
219
+ if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 64,16,1>, 64,16,1, 1};
220
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
221
+ if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 64,16,1>, 64,16,1, 1};
222
+ if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
223
+ if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
224
+ if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
225
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
226
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
227
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
228
+ // channels_last
229
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
230
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
231
+ if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7, 16,16,8>, 16,16,8, 1};
232
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
233
+ if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5, 16,16,8>, 16,16,8, 1};
234
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
235
+ if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3, 16,16,8>, 16,16,8, 1};
236
+ if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
237
+ if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
238
+ if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
239
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
240
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
241
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
242
+ }
243
+
244
+ // 2x upsampling.
245
+ if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
246
+ {
247
+ // contiguous
248
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};
249
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};
250
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 64,16,1>, 64,16,1, 1};
251
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 64,16,1>, 64,16,1, 1};
252
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 64,16,1>, 64,16,1, 1};
253
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 64,16,1>, 64,16,1, 1};
254
+ // channels_last
255
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};
256
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};
257
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8, 16,16,8>, 16,16,8, 1};
258
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6, 16,16,8>, 16,16,8, 1};
259
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4, 16,16,8>, 16,16,8, 1};
260
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2, 16,16,8>, 16,16,8, 1};
261
+ }
262
+ if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
263
+ {
264
+ // contiguous
265
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};
266
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};
267
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,8,1>, 128,8,1, 1};
268
+ // channels_last
269
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};
270
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};
271
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1, 128,1,16>, 128,1,16, 1};
272
+ }
273
+ if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)
274
+ {
275
+ // contiguous
276
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};
277
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};
278
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 32,32,1>, 32,32,1, 1};
279
+ // channels_last
280
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};
281
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};
282
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8, 1,128,16>, 1,128,16, 1};
283
+ }
284
+
285
+ // 2x downsampling.
286
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)
287
+ {
288
+ // contiguous
289
+ if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};
290
+ if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};
291
+ if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 32,8,1>, 32,8,1, 1};
292
+ if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 32,8,1>, 32,8,1, 1};
293
+ if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 32,8,1>, 32,8,1, 1};
294
+ if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 32,8,1>, 32,8,1, 1};
295
+ // channels_last
296
+ if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};
297
+ if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};
298
+ if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8, 8,8,8>, 8,8,8, 1};
299
+ if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6, 8,8,8>, 8,8,8, 1};
300
+ if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4, 8,8,8>, 8,8,8, 1};
301
+ if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2, 8,8,8>, 8,8,8, 1};
302
+ }
303
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)
304
+ {
305
+ // contiguous
306
+ if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};
307
+ if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};
308
+ if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,8,1>, 64,8,1, 1};
309
+ // channels_last
310
+ if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};
311
+ if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};
312
+ if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1, 64,1,8>, 64,1,8, 1};
313
+ }
314
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)
315
+ {
316
+ // contiguous
317
+ if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};
318
+ if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};
319
+ if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 32,16,1>, 32,16,1, 1};
320
+ // channels_last
321
+ if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};
322
+ if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};
323
+ if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8, 1,64,8>, 1,64,8, 1};
324
+ }
325
+
326
+ // 4x upsampling.
327
+ if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
328
+ {
329
+ // contiguous
330
+ if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};
331
+ if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};
332
+ // channels_last
333
+ if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};
334
+ if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};
335
+ }
336
+ if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)
337
+ {
338
+ // contiguous
339
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};
340
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};
341
+ // channels_last
342
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};
343
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};
344
+ }
345
+ if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)
346
+ {
347
+ // contiguous
348
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};
349
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};
350
+ // channels_last
351
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};
352
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};
353
+ }
354
+
355
+ // 4x downsampling (inefficient).
356
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)
357
+ {
358
+ // contiguous
359
+ if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};
360
+ if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};
361
+ // channels_last
362
+ if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};
363
+ if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};
364
+ }
365
+ if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)
366
+ {
367
+ // contiguous
368
+ if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};
369
+ if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};
370
+ // channels_last
371
+ if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};
372
+ if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};
373
+ }
374
+ return spec;
375
+ }
376
+
377
+ //------------------------------------------------------------------------
378
+ // Template specializations.
379
+
380
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double> (const upfirdn2d_kernel_params& p);
381
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float> (const upfirdn2d_kernel_params& p);
382
+ template upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);
383
+
384
+ //------------------------------------------------------------------------
models/draggan/torch_utils/ops/upfirdn2d.h ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ //
3
+ // NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ // and proprietary rights in and to this software, related documentation
5
+ // and any modifications thereto. Any use, reproduction, disclosure or
6
+ // distribution of this software and related documentation without an express
7
+ // license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ #include <cuda_runtime.h>
10
+
11
+ //------------------------------------------------------------------------
12
+ // CUDA kernel parameters.
13
+
14
+ struct upfirdn2d_kernel_params
15
+ {
16
+ const void* x;
17
+ const float* f;
18
+ void* y;
19
+
20
+ int2 up;
21
+ int2 down;
22
+ int2 pad0;
23
+ int flip;
24
+ float gain;
25
+
26
+ int4 inSize; // [width, height, channel, batch]
27
+ int4 inStride;
28
+ int2 filterSize; // [width, height]
29
+ int2 filterStride;
30
+ int4 outSize; // [width, height, channel, batch]
31
+ int4 outStride;
32
+ int sizeMinor;
33
+ int sizeMajor;
34
+
35
+ int loopMinor;
36
+ int loopMajor;
37
+ int loopX;
38
+ int launchMinor;
39
+ int launchMajor;
40
+ };
41
+
42
+ //------------------------------------------------------------------------
43
+ // CUDA kernel specialization.
44
+
45
+ struct upfirdn2d_kernel_spec
46
+ {
47
+ void* kernel;
48
+ int tileOutW;
49
+ int tileOutH;
50
+ int loopMinor;
51
+ int loopX;
52
+ };
53
+
54
+ //------------------------------------------------------------------------
55
+ // CUDA kernel selection.
56
+
57
+ template <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);
58
+
59
+ //------------------------------------------------------------------------