Adapter commited on
Commit
fb6c2da
1 Parent(s): 41d366b

first upload

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +214 -12
  3. app.py +59 -0
  4. configs/stable-diffusion/app.yaml +87 -0
  5. configs/stable-diffusion/test_keypose.yaml +87 -0
  6. configs/stable-diffusion/test_mask.yaml +87 -0
  7. configs/stable-diffusion/test_mask_sketch.yaml +87 -0
  8. configs/stable-diffusion/test_sketch.yaml +87 -0
  9. configs/stable-diffusion/test_sketch_edit.yaml +87 -0
  10. configs/stable-diffusion/train_keypose.yaml +87 -0
  11. configs/stable-diffusion/train_mask.yaml +87 -0
  12. configs/stable-diffusion/train_sketch.yaml +87 -0
  13. dataset_coco.py +138 -0
  14. demo/demos.py +82 -0
  15. demo/model.py +390 -0
  16. dist_util.py +91 -0
  17. environment.yaml +31 -0
  18. examples/edit_cat/edge.png +0 -0
  19. examples/edit_cat/edge_2.png +0 -0
  20. examples/edit_cat/im.png +0 -0
  21. examples/edit_cat/mask.png +0 -0
  22. examples/keypose/iron.png +0 -0
  23. examples/seg/dinner.png +0 -0
  24. examples/seg/motor.png +0 -0
  25. examples/seg_sketch/edge.png +0 -0
  26. examples/seg_sketch/mask.png +0 -0
  27. examples/sketch/car.png +0 -0
  28. examples/sketch/girl.jpeg +0 -0
  29. examples/sketch/human.png +0 -0
  30. examples/sketch/scenery.jpg +0 -0
  31. examples/sketch/scenery2.jpg +0 -0
  32. experiments/README.md +0 -0
  33. gradio_keypose.py +254 -0
  34. gradio_sketch.py +147 -0
  35. ldm/data/__init__.py +0 -0
  36. ldm/data/base.py +23 -0
  37. ldm/data/imagenet.py +394 -0
  38. ldm/data/lsun.py +92 -0
  39. ldm/lr_scheduler.py +98 -0
  40. ldm/models/autoencoder.py +443 -0
  41. ldm/models/diffusion/__init__.py +0 -0
  42. ldm/models/diffusion/classifier.py +267 -0
  43. ldm/models/diffusion/ddim.py +241 -0
  44. ldm/models/diffusion/ddpm.py +1446 -0
  45. ldm/models/diffusion/dpm_solver/__init__.py +1 -0
  46. ldm/models/diffusion/dpm_solver/dpm_solver.py +1184 -0
  47. ldm/models/diffusion/dpm_solver/sampler.py +82 -0
  48. ldm/models/diffusion/plms.py +254 -0
  49. ldm/modules/attention.py +261 -0
  50. ldm/modules/diffusionmodules/__init__.py +0 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,214 @@
1
- ---
2
- title: T2I Adapter
3
- emoji: 📉
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 3.19.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/logo2.png" height=65>
3
+ </p>
4
+
5
+ <div align="center">
6
+
7
+ ⏬[**Download Models**](#-download-models) **|** 💻[**How to Test**](#-how-to-test)
8
+
9
+ </div>
10
+
11
+ Official implementation of T2I-Adapter: Learning Adapters to Dig out More Controllable Ability for Text-to-Image Diffusion Models.
12
+
13
+ #### [Paper](https://arxiv.org/abs/2302.08453)
14
+
15
+ <p align="center">
16
+ <img src="assets/overview1.png" height=250>
17
+ </p>
18
+
19
+ We propose T2I-Adapter, a **simple and small (~70M parameters, ~300M storage space)** network that can provide extra guidance to pre-trained text-to-image models while **freezing** the original large text-to-image models.
20
+
21
+ T2I-Adapter aligns internal knowledge in T2I models with external control signals.
22
+ We can train various adapters according to different conditions, and achieve rich control and editing effects.
23
+
24
+ <p align="center">
25
+ <img src="assets/teaser.png" height=500>
26
+ </p>
27
+
28
+ ### ⏬ Download Models
29
+
30
+ Put the downloaded models in the `T2I-Adapter/models` folder.
31
+
32
+ 1. The **T2I-Adapters** can be download from <https://huggingface.co/TencentARC/T2I-Adapter>.
33
+ 2. The pretrained **Stable Diffusion v1.4** models can be download from <https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/tree/main>. You need to download the `sd-v1-4.ckpt
34
+ ` file.
35
+ 3. [Optional] If you want to use **Anything v4.0** models, you can download the pretrained models from <https://huggingface.co/andite/anything-v4.0/tree/main>. You need to download the `anything-v4.0-pruned.ckpt` file.
36
+ 4. The pretrained **clip-vit-large-patch14** folder can be download from <https://huggingface.co/openai/clip-vit-large-patch14/tree/main>. Remember to download the whole folder!
37
+ 5. The pretrained keypose detection models include FasterRCNN (human detection) from <https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth> and HRNet (pose detection) from <https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth>.
38
+
39
+ After downloading, the folder structure should be like this:
40
+
41
+ <p align="center">
42
+ <img src="assets/downloaded_models.png" height=100>
43
+ </p>
44
+
45
+ ### 🔧 Dependencies and Installation
46
+
47
+ - Python >= 3.6 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
48
+ - [PyTorch >= 1.4](https://pytorch.org/)
49
+ ```bash
50
+ pip install -r requirements.txt
51
+ ```
52
+ - If you want to use the full function of keypose-guided generation, you need to install MMPose. For details please refer to <https://github.com/open-mmlab/mmpose>.
53
+
54
+ ### 💻 How to Test
55
+
56
+ - The results are in the `experiments` folder.
57
+ - If you want to use the `Anything v4.0`, please add `--ckpt models/anything-v4.0-pruned.ckpt` in the following commands.
58
+
59
+ #### **For Simple Experience**
60
+
61
+ > python app.py
62
+
63
+ #### **Sketch Adapter**
64
+
65
+ - Sketch to Image Generation
66
+
67
+ > python test_sketch.py --plms --auto_resume --prompt "A car with flying wings" --path_cond examples/sketch/car.png --ckpt models/sd-v1-4.ckpt --type_in sketch
68
+
69
+ - Image to Image Generation
70
+
71
+ > python test_sketch.py --plms --auto_resume --prompt "A beautiful girl" --path_cond examples/anything_sketch/human.png --ckpt models/sd-v1-4.ckpt --type_in image
72
+
73
+ - Generation with **Anything** setting
74
+
75
+ > python test_sketch.py --plms --auto_resume --prompt "A beautiful girl" --path_cond examples/anything_sketch/human.png --ckpt models/anything-v4.0-pruned.ckpt --type_in image
76
+
77
+ ##### Gradio Demo
78
+ <p align="center">
79
+ <img src="assets/gradio_sketch.png">
80
+ </p>
81
+ You can use gradio to experience all these three functions at once. CPU is also supported by setting device to 'cpu'.
82
+
83
+ ```bash
84
+ python gradio_sketch.py
85
+ ```
86
+
87
+ #### **Keypose Adapter**
88
+
89
+ - Keypose to Image Generation
90
+
91
+ > python test_keypose.py --plms --auto_resume --prompt "A beautiful girl" --path_cond examples/keypose/iron.png --type_in pose
92
+
93
+ - Image to Image Generation
94
+
95
+ > python test_keypose.py --plms --auto_resume --prompt "A beautiful girl" --path_cond examples/sketch/human.png --type_in image
96
+
97
+ - Generation with **Anything** setting
98
+
99
+ > python test_keypose.py --plms --auto_resume --prompt "A beautiful girl" --path_cond examples/sketch/human.png --ckpt models/anything-v4.0-pruned.ckpt --type_in image
100
+
101
+ ##### Gradio Demo
102
+ <p align="center">
103
+ <img src="assets/gradio_keypose.png">
104
+ </p>
105
+ You can use gradio to experience all these three functions at once. CPU is also supported by setting device to 'cpu'.
106
+
107
+ ```bash
108
+ python gradio_keypose.py
109
+ ```
110
+
111
+ #### **Segmentation Adapter**
112
+
113
+ > python test_seg.py --plms --auto_resume --prompt "A black Honda motorcycle parked in front of a garage" --path_cond examples/seg/motor.png
114
+
115
+ #### **Two adapters: Segmentation and Sketch Adapters**
116
+
117
+ > python test_seg_sketch.py --plms --auto_resume --prompt "An all white kitchen with an electric stovetop" --path_cond examples/seg_sketch/mask.png --path_cond2 examples/seg_sketch/edge.png
118
+
119
+ #### **Local editing with adapters**
120
+
121
+ > python test_sketch_edit.py --plms --auto_resume --prompt "A white cat" --path_cond examples/edit_cat/edge_2.png --path_x0 examples/edit_cat/im.png --path_mask examples/edit_cat/mask.png
122
+
123
+ ## Stable Diffusion + T2I-Adapters (only ~70M parameters, ~300M storage space)
124
+
125
+ The following is the detailed structure of a **Stable Diffusion** model with the **T2I-Adapter**.
126
+ <p align="center">
127
+ <img src="assets/overview2.png" height=300>
128
+ </p>
129
+
130
+ <!-- ## Web Demo
131
+
132
+ * All the usage of three T2I-Adapters (i.e, sketch, keypose and segmentation) are integrated into [Huggingface Spaces]() 🤗 using [Gradio](). Have fun with the Web Demo. -->
133
+
134
+ ## 🚀 Interesting Applications
135
+
136
+ ### Stable Diffusion results guided with the sketch T2I-Adapter
137
+
138
+ The corresponding edge maps are predicted by PiDiNet. The sketch T2I-Adapter can well generalize to other similar sketch types, for example, sketches from the Internet and user scribbles.
139
+
140
+ <p align="center">
141
+ <img src="assets/sketch_base.png" height=800>
142
+ </p>
143
+
144
+ ### Stable Diffusion results guided with the keypose T2I-Adapter
145
+
146
+ The keypose results predicted by the [MMPose](https://github.com/open-mmlab/mmpose).
147
+ With the keypose guidance, the keypose T2I-Adapter can also help to generate animals with the same keypose, for example, pandas and tigers.
148
+
149
+ <p align="center">
150
+ <img src="assets/keypose_base.png" height=600>
151
+ </p>
152
+
153
+ ### T2I-Adapter with Anything-v4.0
154
+
155
+ Once the T2I-Adapter is trained, it can act as a **plug-and-play module** and can be seamlessly integrated into the finetuned diffusion models **without re-training**, for example, Anything-4.0.
156
+
157
+ #### ✨ Anything results with the plug-and-play sketch T2I-Adapter (no extra training)
158
+
159
+ <p align="center">
160
+ <img src="assets/sketch_anything.png" height=600>
161
+ </p>
162
+
163
+ #### Anything results with the plug-and-play keypose T2I-Adapter (no extra training)
164
+
165
+ <p align="center">
166
+ <img src="assets/keypose_anything.png" height=600>
167
+ </p>
168
+
169
+ ### Local editing with the sketch adapter
170
+
171
+ When combined with the inpaiting mode of Stable Diffusion, we can realize local editing with user specific guidance.
172
+
173
+ #### ✨ Change the head direction of the cat
174
+
175
+ <p align="center">
176
+ <img src="assets/local_editing_cat.png" height=300>
177
+ </p>
178
+
179
+ #### ✨ Add rabbit ears on the head of the Iron Man.
180
+
181
+ <p align="center">
182
+ <img src="assets/local_editing_ironman.png" height=400>
183
+ </p>
184
+
185
+ ### Combine different concepts with adapter
186
+
187
+ Adapter can be used to enhance the SD ability to combine different concepts.
188
+
189
+ #### ✨ A car with flying wings. / A doll in the shape of letter ‘A’.
190
+
191
+ <p align="center">
192
+ <img src="assets/enhance_SD2.png" height=600>
193
+ </p>
194
+
195
+ ### Sequential editing with the sketch adapter
196
+
197
+ We can realize the sequential editing with the adapter guidance.
198
+
199
+ <p align="center">
200
+ <img src="assets/sequential_edit.png">
201
+ </p>
202
+
203
+ ### Composable Guidance with multiple adapters
204
+
205
+ Stable Diffusion results guided with the segmentation and sketch adapters together.
206
+
207
+ <p align="center">
208
+ <img src="assets/multiple_adapters.png">
209
+ </p>
210
+
211
+
212
+ ![visitors](https://visitor-badge.glitch.me/badge?page_id=TencentARC/T2I-Adapter)
213
+
214
+ Logo materials: [adapter](https://www.flaticon.com/free-icon/adapter_4777242), [lightbulb](https://www.flaticon.com/free-icon/lightbulb_3176369)
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from demo.model import Model_all
2
+ import gradio as gr
3
+ from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw
4
+ import torch
5
+ import subprocess
6
+ import os
7
+ import shlex
8
+ from huggingface_hub import hf_hub_url
9
+
10
+ urls = {
11
+ 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth'],
12
+ 'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
13
+ 'andite/anything-v4.0':['anything-v4.0-pruned.ckpt'],
14
+ }
15
+ urls_mmpose = [
16
+ 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
17
+ 'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth',
18
+ ]
19
+ if os.path.exists('models') == False:
20
+ os.mkdir('models')
21
+ for repo in urls:
22
+ files = urls[repo]
23
+ for file in files:
24
+ url = hf_hub_url(repo, file)
25
+ name_ckp = url.split('/')[-1]
26
+ save_path = os.path.join('models',name_ckp)
27
+ if os.path.exists(save_path) == False:
28
+ subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
29
+
30
+ for url in urls_mmpose:
31
+ name_ckp = url.split('/')[-1]
32
+ save_path = os.path.join('models',name_ckp)
33
+ if os.path.exists(save_path) == False:
34
+ subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
35
+
36
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
37
+ model = Model_all(device)
38
+
39
+ DESCRIPTION = '''# T2I-Adapter (Sketch & Keypose)
40
+ [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
41
+
42
+ This gradio demo is for a simple experience of T2I-Adapter:
43
+ - Keypose/Sketch to Image Generation
44
+ - Image to Image Generation
45
+ - Support the base model of Stable Diffusion v1.4 and Anything 4.0
46
+ '''
47
+
48
+ with gr.Blocks(css='style.css') as demo:
49
+ gr.Markdown(DESCRIPTION)
50
+
51
+ with gr.Tabs():
52
+ with gr.TabItem('Keypose'):
53
+ create_demo_keypose(model.process_keypose)
54
+ with gr.TabItem('Sketch'):
55
+ create_demo_sketch(model.process_sketch)
56
+ with gr.TabItem('Draw'):
57
+ create_demo_draw(model.process_draw)
58
+
59
+ demo.queue(api_open=False).launch(server_name='0.0.0.0')
configs/stable-diffusion/app.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: app
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config:
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ device: 'cuda'
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/test_keypose.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: test_keypose
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: models/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/test_mask.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: test_mask
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: models/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/test_mask_sketch.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: test_mask_sketch
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: models/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/test_sketch.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: test_sketch
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: models/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/test_sketch_edit.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: test_sketch_edit
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: models/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/train_keypose.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: train_keypose
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: models/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/train_mask.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: train_mask
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: models/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
configs/stable-diffusion/train_sketch.yaml ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: train_sketch
2
+ model:
3
+ base_learning_rate: 1.0e-04
4
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
5
+ params:
6
+ linear_start: 0.00085
7
+ linear_end: 0.0120
8
+ num_timesteps_cond: 1
9
+ log_every_t: 200
10
+ timesteps: 1000
11
+ first_stage_key: "jpg"
12
+ cond_stage_key: "txt"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false # Note: different from the one we trained before
16
+ conditioning_key: crossattn
17
+ monitor: val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ use_ema: False
20
+
21
+ scheduler_config: # 10000 warmup steps
22
+ target: ldm.lr_scheduler.LambdaLinearScheduler
23
+ params:
24
+ warm_up_steps: [ 10000 ]
25
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
26
+ f_start: [ 1.e-6 ]
27
+ f_max: [ 1. ]
28
+ f_min: [ 1. ]
29
+
30
+ unet_config:
31
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
+ params:
33
+ image_size: 32 # unused
34
+ in_channels: 4
35
+ out_channels: 4
36
+ model_channels: 320
37
+ attention_resolutions: [ 4, 2, 1 ]
38
+ num_res_blocks: 2
39
+ channel_mult: [ 1, 2, 4, 4 ]
40
+ num_heads: 8
41
+ use_spatial_transformer: True
42
+ transformer_depth: 1
43
+ context_dim: 768
44
+ use_checkpoint: True
45
+ legacy: False
46
+
47
+ first_stage_config:
48
+ target: ldm.models.autoencoder.AutoencoderKL
49
+ params:
50
+ embed_dim: 4
51
+ monitor: val/rec_loss
52
+ ddconfig:
53
+ double_z: true
54
+ z_channels: 4
55
+ resolution: 256
56
+ in_channels: 3
57
+ out_ch: 3
58
+ ch: 128
59
+ ch_mult:
60
+ - 1
61
+ - 2
62
+ - 4
63
+ - 4
64
+ num_res_blocks: 2
65
+ attn_resolutions: []
66
+ dropout: 0.0
67
+ lossconfig:
68
+ target: torch.nn.Identity
69
+
70
+ cond_stage_config: #__is_unconditional__
71
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
72
+ params:
73
+ version: models/clip-vit-large-patch14
74
+
75
+ logger:
76
+ print_freq: 100
77
+ save_checkpoint_freq: !!float 1e4
78
+ use_tb_logger: true
79
+ wandb:
80
+ project: ~
81
+ resume_id: ~
82
+ dist_params:
83
+ backend: nccl
84
+ port: 29500
85
+ training:
86
+ lr: !!float 1e-5
87
+ save_freq: 1e4
dataset_coco.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import cv2
4
+ import torch
5
+ import os
6
+ from basicsr.utils import img2tensor, tensor2img
7
+ import random
8
+
9
+ class dataset_coco():
10
+ def __init__(self, path_json, root_path, image_size, mode='train'):
11
+ super(dataset_coco, self).__init__()
12
+ with open(path_json, 'r', encoding='utf-8') as fp:
13
+ data = json.load(fp)
14
+ data = data['images']
15
+ self.paths = []
16
+ self.root_path = root_path
17
+ for file in data:
18
+ input_path = file['filepath']
19
+ if mode == 'train':
20
+ if 'val' not in input_path:
21
+ self.paths.append(file)
22
+ else:
23
+ if 'val' in input_path:
24
+ self.paths.append(file)
25
+
26
+ def __getitem__(self, idx):
27
+ file = self.paths[idx]
28
+ input_path = file['filepath']
29
+ input_name = file['filename']
30
+ path = os.path.join(self.root_path, input_path, input_name)
31
+ im = cv2.imread(path)
32
+ im = cv2.resize(im, (512,512))
33
+ im = img2tensor(im, bgr2rgb=True, float32=True)/255.
34
+ sentences = file['sentences']
35
+ sentence = sentences[int(random.random()*len(sentences))]['raw'].strip('.')
36
+ return {'im':im, 'sentence':sentence}
37
+
38
+ def __len__(self):
39
+ return len(self.paths)
40
+
41
+
42
+ class dataset_coco_mask():
43
+ def __init__(self, path_json, root_path_im, root_path_mask, image_size):
44
+ super(dataset_coco_mask, self).__init__()
45
+ with open(path_json, 'r', encoding='utf-8') as fp:
46
+ data = json.load(fp)
47
+ data = data['annotations']
48
+ self.files = []
49
+ self.root_path_im = root_path_im
50
+ self.root_path_mask = root_path_mask
51
+ for file in data:
52
+ name = "%012d.png"%file['image_id']
53
+ self.files.append({'name':name, 'sentence':file['caption']})
54
+
55
+ def __getitem__(self, idx):
56
+ file = self.files[idx]
57
+ name = file['name']
58
+ # print(os.path.join(self.root_path_im, name))
59
+ im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
60
+ im = cv2.resize(im, (512,512))
61
+ im = img2tensor(im, bgr2rgb=True, float32=True)/255.
62
+
63
+ mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
64
+ mask = cv2.resize(mask, (512,512))
65
+ mask = img2tensor(mask, bgr2rgb=True, float32=True)[0].unsqueeze(0)#/255.
66
+
67
+ sentence = file['sentence']
68
+ return {'im':im, 'mask':mask, 'sentence':sentence}
69
+
70
+ def __len__(self):
71
+ return len(self.files)
72
+
73
+
74
+ class dataset_coco_mask_color():
75
+ def __init__(self, path_json, root_path_im, root_path_mask, image_size):
76
+ super(dataset_coco_mask_color, self).__init__()
77
+ with open(path_json, 'r', encoding='utf-8') as fp:
78
+ data = json.load(fp)
79
+ data = data['annotations']
80
+ self.files = []
81
+ self.root_path_im = root_path_im
82
+ self.root_path_mask = root_path_mask
83
+ for file in data:
84
+ name = "%012d.png"%file['image_id']
85
+ self.files.append({'name':name, 'sentence':file['caption']})
86
+
87
+ def __getitem__(self, idx):
88
+ file = self.files[idx]
89
+ name = file['name']
90
+ # print(os.path.join(self.root_path_im, name))
91
+ im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
92
+ im = cv2.resize(im, (512,512))
93
+ im = img2tensor(im, bgr2rgb=True, float32=True)/255.
94
+
95
+ mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
96
+ mask = cv2.resize(mask, (512,512))
97
+ mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.#[0].unsqueeze(0)#/255.
98
+
99
+ sentence = file['sentence']
100
+ return {'im':im, 'mask':mask, 'sentence':sentence}
101
+
102
+ def __len__(self):
103
+ return len(self.files)
104
+
105
+ class dataset_coco_mask_color_sig():
106
+ def __init__(self, path_json, root_path_im, root_path_mask, image_size):
107
+ super(dataset_coco_mask_color_sig, self).__init__()
108
+ with open(path_json, 'r', encoding='utf-8') as fp:
109
+ data = json.load(fp)
110
+ data = data['annotations']
111
+ self.files = []
112
+ self.root_path_im = root_path_im
113
+ self.root_path_mask = root_path_mask
114
+ reg = {}
115
+ for file in data:
116
+ name = "%012d.png"%file['image_id']
117
+ if name in reg:
118
+ continue
119
+ self.files.append({'name':name, 'sentence':file['caption']})
120
+ reg[name] = name
121
+
122
+ def __getitem__(self, idx):
123
+ file = self.files[idx]
124
+ name = file['name']
125
+ # print(os.path.join(self.root_path_im, name))
126
+ im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
127
+ im = cv2.resize(im, (512,512))
128
+ im = img2tensor(im, bgr2rgb=True, float32=True)/255.
129
+
130
+ mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
131
+ mask = cv2.resize(mask, (512,512))
132
+ mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.#[0].unsqueeze(0)#/255.
133
+
134
+ sentence = file['sentence']
135
+ return {'im':im, 'mask':mask, 'sentence':sentence, 'name': name}
136
+
137
+ def __len__(self):
138
+ return len(self.files)
demo/demos.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+
4
+ def create_map():
5
+ return np.zeros(shape=(512, 1024), dtype=np.uint8)+255
6
+
7
+
8
+ def create_demo_keypose(process):
9
+ with gr.Blocks() as demo:
10
+ with gr.Row():
11
+ gr.Markdown('T2I-Adapter (Keypose)')
12
+ with gr.Row():
13
+ with gr.Column():
14
+ input_img = gr.Image(source='upload', type="numpy")
15
+ prompt = gr.Textbox(label="Prompt")
16
+ neg_prompt = gr.Textbox(label="Negative Prompt",
17
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
18
+ pos_prompt = gr.Textbox(label="Positive Prompt",
19
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
20
+ with gr.Row():
21
+ type_in = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a keypose map)')
22
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
23
+ run_button = gr.Button(label="Run")
24
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the keypose to the result)", minimum=0, maximum=1, value=1, step=0.1)
25
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
26
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
27
+ with gr.Column():
28
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
29
+ ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
30
+ run_button.click(fn=process, inputs=ips, outputs=[result])
31
+ return demo
32
+
33
+ def create_demo_sketch(process):
34
+ with gr.Blocks() as demo:
35
+ with gr.Row():
36
+ gr.Markdown('T2I-Adapter (Sketch)')
37
+ with gr.Row():
38
+ with gr.Column():
39
+ input_img = gr.Image(source='upload', type="numpy")
40
+ prompt = gr.Textbox(label="Prompt")
41
+ neg_prompt = gr.Textbox(label="Negative Prompt",
42
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
43
+ pos_prompt = gr.Textbox(label="Positive Prompt",
44
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
45
+ with gr.Row():
46
+ type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a sketch)')
47
+ color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
48
+ run_button = gr.Button(label="Run")
49
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
50
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
51
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
52
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
53
+ with gr.Column():
54
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
55
+ ips = [input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
56
+ run_button.click(fn=process, inputs=ips, outputs=[result])
57
+ return demo
58
+
59
+ def create_demo_draw(process):
60
+ with gr.Blocks() as demo:
61
+ with gr.Row():
62
+ gr.Markdown('T2I-Adapter (Hand-free drawing)')
63
+ with gr.Row():
64
+ with gr.Column():
65
+ create_button = gr.Button(label="Start", value='Hand-free drawing')
66
+ input_img = gr.Image(source='upload', type="numpy",tool='sketch')
67
+ create_button.click(fn=create_map, outputs=[input_img])
68
+ prompt = gr.Textbox(label="Prompt")
69
+ neg_prompt = gr.Textbox(label="Negative Prompt",
70
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
71
+ pos_prompt = gr.Textbox(label="Positive Prompt",
72
+ value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
73
+ run_button = gr.Button(label="Run")
74
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
75
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
76
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
77
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
78
+ with gr.Column():
79
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
80
+ ips = [input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
81
+ run_button.click(fn=process, inputs=ips, outputs=[result])
82
+ return demo
demo/model.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from basicsr.utils import img2tensor, tensor2img
3
+ from pytorch_lightning import seed_everything
4
+ from ldm.models.diffusion.plms import PLMSSampler
5
+ from ldm.modules.encoders.adapter import Adapter
6
+ from ldm.util import instantiate_from_config
7
+ from model_edge import pidinet
8
+ import gradio as gr
9
+ from omegaconf import OmegaConf
10
+ import mmcv
11
+ from mmdet.apis import inference_detector, init_detector
12
+ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
13
+ import os
14
+ import cv2
15
+ import numpy as np
16
+
17
+
18
+ def imshow_keypoints(img,
19
+ pose_result,
20
+ skeleton=None,
21
+ kpt_score_thr=0.1,
22
+ pose_kpt_color=None,
23
+ pose_link_color=None,
24
+ radius=4,
25
+ thickness=1):
26
+ """Draw keypoints and links on an image.
27
+
28
+ Args:
29
+ img (ndarry): The image to draw poses on.
30
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
31
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
32
+ keypoint is represented as x, y, score.
33
+ kpt_score_thr (float, optional): Minimum score of keypoints
34
+ to be shown. Default: 0.3.
35
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
36
+ the keypoint will not be drawn.
37
+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
38
+ links will not be drawn.
39
+ thickness (int): Thickness of lines.
40
+ """
41
+
42
+ img_h, img_w, _ = img.shape
43
+ img = np.zeros(img.shape)
44
+
45
+ for idx, kpts in enumerate(pose_result):
46
+ if idx > 1:
47
+ continue
48
+ kpts = kpts['keypoints']
49
+ # print(kpts)
50
+ kpts = np.array(kpts, copy=False)
51
+
52
+ # draw each point on image
53
+ if pose_kpt_color is not None:
54
+ assert len(pose_kpt_color) == len(kpts)
55
+
56
+ for kid, kpt in enumerate(kpts):
57
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
58
+
59
+ if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
60
+ # skip the point that should not be drawn
61
+ continue
62
+
63
+ color = tuple(int(c) for c in pose_kpt_color[kid])
64
+ cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
65
+
66
+ # draw links
67
+ if skeleton is not None and pose_link_color is not None:
68
+ assert len(pose_link_color) == len(skeleton)
69
+
70
+ for sk_id, sk in enumerate(skeleton):
71
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
72
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
73
+
74
+ if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
75
+ or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
76
+ or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
77
+ # skip the link that should not be drawn
78
+ continue
79
+ color = tuple(int(c) for c in pose_link_color[sk_id])
80
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
81
+
82
+ return img
83
+
84
+ def load_model_from_config(config, ckpt, verbose=False):
85
+ print(f"Loading model from {ckpt}")
86
+ pl_sd = torch.load(ckpt, map_location="cpu")
87
+ if "global_step" in pl_sd:
88
+ print(f"Global Step: {pl_sd['global_step']}")
89
+ if "state_dict" in pl_sd:
90
+ sd = pl_sd["state_dict"]
91
+ else:
92
+ sd = pl_sd
93
+ model = instantiate_from_config(config.model)
94
+ _, _ = model.load_state_dict(sd, strict=False)
95
+
96
+ model.cuda()
97
+ model.eval()
98
+ return model
99
+
100
+ class Model_all:
101
+ def __init__(self, device='cpu'):
102
+ # common part
103
+ self.device = device
104
+ self.config = OmegaConf.load("configs/stable-diffusion/app.yaml")
105
+ self.config.model.params.cond_stage_config.params.device = device
106
+ self.base_model = load_model_from_config(self.config, "models/sd-v1-4.ckpt").to(device)
107
+ self.current_base_pose = 'sd-v1-4.ckpt'
108
+ self.current_base_sketch = 'sd-v1-4.ckpt'
109
+ self.sampler = PLMSSampler(self.base_model)
110
+
111
+ # sketch part
112
+ self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
113
+ self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
114
+ self.model_edge = pidinet()
115
+ ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
116
+ self.model_edge.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
117
+ self.model_edge.to(device)
118
+
119
+ # keypose part
120
+ self.model_pose = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
121
+ self.model_pose.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth", map_location=device))
122
+ ## mmpose
123
+ det_config = 'models/faster_rcnn_r50_fpn_coco.py'
124
+ det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
125
+ pose_config = 'models/hrnet_w48_coco_256x192.py'
126
+ pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
127
+ self.det_cat_id = 1
128
+ self.bbox_thr = 0.2
129
+ ## detector
130
+ det_config_mmcv = mmcv.Config.fromfile(det_config)
131
+ self.det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
132
+ pose_config_mmcv = mmcv.Config.fromfile(pose_config)
133
+ self.pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
134
+ ## color
135
+ self.skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10],
136
+ [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
137
+ self.pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
138
+ [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0],
139
+ [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
140
+ self.pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
141
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0],
142
+ [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
143
+ [51, 153, 255], [51, 153, 255], [51, 153, 255]]
144
+
145
+ @torch.no_grad()
146
+ def process_sketch(self, input_img, type_in, color_back, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
147
+ if self.current_base_sketch != base_model:
148
+ ckpt = os.path.join("models", base_model)
149
+ pl_sd = torch.load(ckpt, map_location="cpu")
150
+ if "state_dict" in pl_sd:
151
+ sd = pl_sd["state_dict"]
152
+ else:
153
+ sd = pl_sd
154
+ self.base_model.load_state_dict(sd, strict=False)
155
+ self.current_base_sketch = base_model
156
+ # del sd
157
+ # del pl_sd
158
+ con_strength = int((1-con_strength)*50)
159
+ if fix_sample == 'True':
160
+ seed_everything(42)
161
+ im = cv2.resize(input_img,(512,512))
162
+
163
+ if type_in == 'Sketch':
164
+ if color_back == 'White':
165
+ im = 255-im
166
+ im_edge = im.copy()
167
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
168
+ im = im>0.5
169
+ im = im.float()
170
+ elif type_in == 'Image':
171
+ im = img2tensor(im).unsqueeze(0)/255.
172
+ im = self.model_edge(im.to(self.device))[-1]
173
+ im = im>0.5
174
+ im = im.float()
175
+ im_edge = tensor2img(im)
176
+
177
+ # save gpu memory
178
+ self.base_model.model = self.base_model.model.cpu()
179
+ self.model_sketch = self.model_sketch.cuda()
180
+ self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
181
+ self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
182
+
183
+ # extract condition features
184
+ c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
185
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
186
+ features_adapter = self.model_sketch(im.to(self.device))
187
+ shape = [4, 64, 64]
188
+
189
+ # save gpu memory
190
+ self.model_sketch = self.model_sketch.cpu()
191
+ self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
192
+ self.base_model.model = self.base_model.model.cuda()
193
+
194
+ # sampling
195
+ samples_ddim, _ = self.sampler.sample(S=50,
196
+ conditioning=c,
197
+ batch_size=1,
198
+ shape=shape,
199
+ verbose=False,
200
+ unconditional_guidance_scale=scale,
201
+ unconditional_conditioning=nc,
202
+ eta=0.0,
203
+ x_T=None,
204
+ features_adapter1=features_adapter,
205
+ mode = 'sketch',
206
+ con_strength = con_strength)
207
+ # save gpu memory
208
+ self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
209
+
210
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
211
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
212
+ x_samples_ddim = x_samples_ddim.to('cpu')
213
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
214
+ x_samples_ddim = 255.*x_samples_ddim
215
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
216
+
217
+ return [im_edge, x_samples_ddim]
218
+
219
+ @torch.no_grad()
220
+ def process_draw(self, input_img, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
221
+ if self.current_base_sketch != base_model:
222
+ ckpt = os.path.join("models", base_model)
223
+ pl_sd = torch.load(ckpt, map_location="cpu")
224
+ if "state_dict" in pl_sd:
225
+ sd = pl_sd["state_dict"]
226
+ else:
227
+ sd = pl_sd
228
+ self.base_model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
229
+ self.current_base_sketch = base_model
230
+ con_strength = int((1-con_strength)*50)
231
+ if fix_sample == 'True':
232
+ seed_everything(42)
233
+ input_img = input_img['mask']
234
+ c = input_img[:, :, 0:3].astype(np.float32)
235
+ a = input_img[:, :, 3:4].astype(np.float32) / 255.0
236
+ im = c * a + 255.0 * (1.0 - a)
237
+ im = im.clip(0, 255).astype(np.uint8)
238
+ im = cv2.resize(im,(512,512))
239
+
240
+ # im = 255-im
241
+ im_edge = im.copy()
242
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
243
+ im = im>0.5
244
+ im = im.float()
245
+
246
+ # save gpu memory
247
+ self.base_model.model = self.base_model.model.cpu()
248
+ self.model_sketch = self.model_sketch.cuda()
249
+ self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
250
+ self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
251
+
252
+ # extract condition features
253
+ c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
254
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
255
+ features_adapter = self.model_sketch(im.to(self.device))
256
+ shape = [4, 64, 64]
257
+
258
+ # save gpu memory
259
+ self.model_sketch = self.model_sketch.cpu()
260
+ self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
261
+ self.base_model.model = self.base_model.model.cuda()
262
+
263
+ # sampling
264
+ samples_ddim, _ = self.sampler.sample(S=50,
265
+ conditioning=c,
266
+ batch_size=1,
267
+ shape=shape,
268
+ verbose=False,
269
+ unconditional_guidance_scale=scale,
270
+ unconditional_conditioning=nc,
271
+ eta=0.0,
272
+ x_T=None,
273
+ features_adapter1=features_adapter,
274
+ mode = 'sketch',
275
+ con_strength = con_strength)
276
+
277
+ # save gpu memory
278
+ self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
279
+
280
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
281
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
282
+ x_samples_ddim = x_samples_ddim.to('cpu')
283
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
284
+ x_samples_ddim = 255.*x_samples_ddim
285
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
286
+
287
+ return [im_edge, x_samples_ddim]
288
+
289
+ @torch.no_grad()
290
+ def process_keypose(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model):
291
+ if self.current_base_pose != base_model:
292
+ ckpt = os.path.join("models", base_model)
293
+ pl_sd = torch.load(ckpt, map_location="cpu")
294
+ if "state_dict" in pl_sd:
295
+ sd = pl_sd["state_dict"]
296
+ else:
297
+ sd = pl_sd
298
+ self.base_model.load_state_dict(sd, strict=False)
299
+ self.current_base_pose = base_model
300
+ con_strength = int((1-con_strength)*50)
301
+ if fix_sample == 'True':
302
+ seed_everything(42)
303
+ im = cv2.resize(input_img,(512,512))
304
+
305
+ if type_in == 'Keypose':
306
+ im_pose = im.copy()
307
+ im = img2tensor(im).unsqueeze(0)/255.
308
+ elif type_in == 'Image':
309
+ image = im.copy()
310
+ im = img2tensor(im).unsqueeze(0)/255.
311
+ mmdet_results = inference_detector(self.det_model, image)
312
+ # keep the person class bounding boxes.
313
+ person_results = process_mmdet_results(mmdet_results, self.det_cat_id)
314
+
315
+ # optional
316
+ return_heatmap = False
317
+ dataset = self.pose_model.cfg.data['test']['type']
318
+
319
+ # e.g. use ('backbone', ) to return backbone feature
320
+ output_layer_names = None
321
+ pose_results, _ = inference_top_down_pose_model(
322
+ self.pose_model,
323
+ image,
324
+ person_results,
325
+ bbox_thr=self.bbox_thr,
326
+ format='xyxy',
327
+ dataset=dataset,
328
+ dataset_info=None,
329
+ return_heatmap=return_heatmap,
330
+ outputs=output_layer_names)
331
+
332
+ # show the results
333
+ im_pose = imshow_keypoints(
334
+ image,
335
+ pose_results,
336
+ skeleton=self.skeleton,
337
+ pose_kpt_color=self.pose_kpt_color,
338
+ pose_link_color=self.pose_link_color,
339
+ radius=2,
340
+ thickness=2)
341
+ im_pose = cv2.resize(im_pose,(512,512))
342
+
343
+ # save gpu memory
344
+ self.base_model.model = self.base_model.model.cpu()
345
+ self.model_pose = self.model_pose.cuda()
346
+ self.base_model.first_stage_model = self.base_model.first_stage_model.cpu()
347
+ self.base_model.cond_stage_model = self.base_model.cond_stage_model.cuda()
348
+
349
+ # extract condition features
350
+ c = self.base_model.get_learned_conditioning([prompt+', '+pos_prompt])
351
+ nc = self.base_model.get_learned_conditioning([neg_prompt])
352
+ pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255.
353
+ pose = pose.unsqueeze(0)
354
+ features_adapter = self.model_pose(pose.to(self.device))
355
+
356
+ # save gpu memory
357
+ self.model_pose = self.model_pose.cpu()
358
+ self.base_model.cond_stage_model = self.base_model.cond_stage_model.cpu()
359
+ self.base_model.model = self.base_model.model.cuda()
360
+
361
+ shape = [4, 64, 64]
362
+
363
+ # sampling
364
+ samples_ddim, _ = self.sampler.sample(S=50,
365
+ conditioning=c,
366
+ batch_size=1,
367
+ shape=shape,
368
+ verbose=False,
369
+ unconditional_guidance_scale=scale,
370
+ unconditional_conditioning=nc,
371
+ eta=0.0,
372
+ x_T=None,
373
+ features_adapter1=features_adapter,
374
+ mode = 'sketch',
375
+ con_strength = con_strength)
376
+
377
+ # save gpu memory
378
+ self.base_model.first_stage_model = self.base_model.first_stage_model.cuda()
379
+
380
+ x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
381
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
382
+ x_samples_ddim = x_samples_ddim.to('cpu')
383
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
384
+ x_samples_ddim = 255.*x_samples_ddim
385
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
386
+
387
+ return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim]
388
+
389
+ if __name__ == '__main__':
390
+ model = Model_all('cpu')
dist_util.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.multiprocessing as mp
8
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
9
+
10
+
11
+ def init_dist(launcher, backend='nccl', **kwargs):
12
+ if mp.get_start_method(allow_none=True) is None:
13
+ mp.set_start_method('spawn')
14
+ if launcher == 'pytorch':
15
+ _init_dist_pytorch(backend, **kwargs)
16
+ elif launcher == 'slurm':
17
+ _init_dist_slurm(backend, **kwargs)
18
+ else:
19
+ raise ValueError(f'Invalid launcher type: {launcher}')
20
+
21
+
22
+ def _init_dist_pytorch(backend, **kwargs):
23
+ rank = int(os.environ['RANK'])
24
+ num_gpus = torch.cuda.device_count()
25
+ torch.cuda.set_device(rank % num_gpus)
26
+ dist.init_process_group(backend=backend, **kwargs)
27
+
28
+
29
+ def _init_dist_slurm(backend, port=None):
30
+ """Initialize slurm distributed training environment.
31
+
32
+ If argument ``port`` is not specified, then the master port will be system
33
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
34
+ environment variable, then a default port ``29500`` will be used.
35
+
36
+ Args:
37
+ backend (str): Backend of torch.distributed.
38
+ port (int, optional): Master port. Defaults to None.
39
+ """
40
+ proc_id = int(os.environ['SLURM_PROCID'])
41
+ ntasks = int(os.environ['SLURM_NTASKS'])
42
+ node_list = os.environ['SLURM_NODELIST']
43
+ num_gpus = torch.cuda.device_count()
44
+ torch.cuda.set_device(proc_id % num_gpus)
45
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
46
+ # specify master port
47
+ if port is not None:
48
+ os.environ['MASTER_PORT'] = str(port)
49
+ elif 'MASTER_PORT' in os.environ:
50
+ pass # use MASTER_PORT in the environment variable
51
+ else:
52
+ # 29500 is torch.distributed default port
53
+ os.environ['MASTER_PORT'] = '29500'
54
+ os.environ['MASTER_ADDR'] = addr
55
+ os.environ['WORLD_SIZE'] = str(ntasks)
56
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57
+ os.environ['RANK'] = str(proc_id)
58
+ dist.init_process_group(backend=backend)
59
+
60
+
61
+ def get_dist_info():
62
+ if dist.is_available():
63
+ initialized = dist.is_initialized()
64
+ else:
65
+ initialized = False
66
+ if initialized:
67
+ rank = dist.get_rank()
68
+ world_size = dist.get_world_size()
69
+ else:
70
+ rank = 0
71
+ world_size = 1
72
+ return rank, world_size
73
+
74
+
75
+ def master_only(func):
76
+
77
+ @functools.wraps(func)
78
+ def wrapper(*args, **kwargs):
79
+ rank, _ = get_dist_info()
80
+ if rank == 0:
81
+ return func(*args, **kwargs)
82
+
83
+ return wrapper
84
+
85
+ def get_bare_model(net):
86
+ """Get bare model, especially under wrapping with
87
+ DistributedDataParallel or DataParallel.
88
+ """
89
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
90
+ net = net.module
91
+ return net
environment.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ldm
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pytorch=1.11.0
10
+ - torchvision=0.12.0
11
+ - numpy=1.19.2
12
+ - pip:
13
+ - albumentations==0.4.3
14
+ - diffusers
15
+ - opencv-python==4.1.2.30
16
+ - pudb==2019.2
17
+ - invisible-watermark
18
+ - imageio==2.9.0
19
+ - imageio-ffmpeg==0.4.2
20
+ - pytorch-lightning==1.4.2
21
+ - omegaconf==2.1.1
22
+ - test-tube>=0.7.5
23
+ - streamlit>=0.73.1
24
+ - einops==0.3.0
25
+ - torch-fidelity==0.3.0
26
+ - transformers==4.19.2
27
+ - torchmetrics==0.6.0
28
+ - kornia==0.6
29
+ - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
31
+ - -e .
examples/edit_cat/edge.png ADDED
examples/edit_cat/edge_2.png ADDED
examples/edit_cat/im.png ADDED
examples/edit_cat/mask.png ADDED
examples/keypose/iron.png ADDED
examples/seg/dinner.png ADDED
examples/seg/motor.png ADDED
examples/seg_sketch/edge.png ADDED
examples/seg_sketch/mask.png ADDED
examples/sketch/car.png ADDED
examples/sketch/girl.jpeg ADDED
examples/sketch/human.png ADDED
examples/sketch/scenery.jpg ADDED
examples/sketch/scenery2.jpg ADDED
experiments/README.md ADDED
File without changes
gradio_keypose.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from basicsr.utils import img2tensor, tensor2img
8
+ from pytorch_lightning import seed_everything
9
+ from ldm.models.diffusion.plms import PLMSSampler
10
+ from ldm.modules.encoders.adapter import Adapter
11
+ from ldm.util import instantiate_from_config
12
+ from model_edge import pidinet
13
+ import gradio as gr
14
+ from omegaconf import OmegaConf
15
+ import mmcv
16
+ from mmdet.apis import inference_detector, init_detector
17
+ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
18
+
19
+ skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10],
20
+ [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
21
+
22
+ pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
23
+ [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0],
24
+ [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
25
+
26
+ pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
27
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0],
28
+ [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
29
+ [51, 153, 255], [51, 153, 255], [51, 153, 255]]
30
+
31
+ def imshow_keypoints(img,
32
+ pose_result,
33
+ skeleton=None,
34
+ kpt_score_thr=0.1,
35
+ pose_kpt_color=None,
36
+ pose_link_color=None,
37
+ radius=4,
38
+ thickness=1):
39
+ """Draw keypoints and links on an image.
40
+
41
+ Args:
42
+ img (ndarry): The image to draw poses on.
43
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
44
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
45
+ keypoint is represented as x, y, score.
46
+ kpt_score_thr (float, optional): Minimum score of keypoints
47
+ to be shown. Default: 0.3.
48
+ pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
49
+ the keypoint will not be drawn.
50
+ pose_link_color (np.array[Mx3]): Color of M links. If None, the
51
+ links will not be drawn.
52
+ thickness (int): Thickness of lines.
53
+ """
54
+
55
+ img_h, img_w, _ = img.shape
56
+ img = np.zeros(img.shape)
57
+
58
+ for idx, kpts in enumerate(pose_result):
59
+ if idx > 1:
60
+ continue
61
+ kpts = kpts['keypoints']
62
+ # print(kpts)
63
+ kpts = np.array(kpts, copy=False)
64
+
65
+ # draw each point on image
66
+ if pose_kpt_color is not None:
67
+ assert len(pose_kpt_color) == len(kpts)
68
+
69
+ for kid, kpt in enumerate(kpts):
70
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
71
+
72
+ if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
73
+ # skip the point that should not be drawn
74
+ continue
75
+
76
+ color = tuple(int(c) for c in pose_kpt_color[kid])
77
+ cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
78
+
79
+ # draw links
80
+ if skeleton is not None and pose_link_color is not None:
81
+ assert len(pose_link_color) == len(skeleton)
82
+
83
+ for sk_id, sk in enumerate(skeleton):
84
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
85
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
86
+
87
+ if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
88
+ or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
89
+ or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
90
+ # skip the link that should not be drawn
91
+ continue
92
+ color = tuple(int(c) for c in pose_link_color[sk_id])
93
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
94
+
95
+ return img
96
+
97
+ def load_model_from_config(config, ckpt, verbose=False):
98
+ print(f"Loading model from {ckpt}")
99
+ pl_sd = torch.load(ckpt, map_location="cpu")
100
+ if "global_step" in pl_sd:
101
+ print(f"Global Step: {pl_sd['global_step']}")
102
+ if "state_dict" in pl_sd:
103
+ sd = pl_sd["state_dict"]
104
+ else:
105
+ sd = pl_sd
106
+ model = instantiate_from_config(config.model)
107
+ m, u = model.load_state_dict(sd, strict=False)
108
+
109
+ model.cuda()
110
+ model.eval()
111
+ return model
112
+
113
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
114
+ config = OmegaConf.load("configs/stable-diffusion/test_keypose.yaml")
115
+ config.model.params.cond_stage_config.params.device = device
116
+ model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
117
+ current_base = 'sd-v1-4.ckpt'
118
+ model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
119
+ model_ad.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth"))
120
+ sampler = PLMSSampler(model)
121
+ ## mmpose
122
+ det_config = 'models/faster_rcnn_r50_fpn_coco.py'
123
+ det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
124
+ pose_config = 'models/hrnet_w48_coco_256x192.py'
125
+ pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
126
+ det_cat_id = 1
127
+ bbox_thr = 0.2
128
+ ## detector
129
+ det_config_mmcv = mmcv.Config.fromfile(det_config)
130
+ det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
131
+ pose_config_mmcv = mmcv.Config.fromfile(pose_config)
132
+ pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
133
+ W, H = 512, 512
134
+
135
+
136
+ def process(input_img, type_in, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
137
+ global current_base
138
+ if current_base != base_model:
139
+ ckpt = os.path.join("models", base_model)
140
+ pl_sd = torch.load(ckpt, map_location="cpu")
141
+ if "state_dict" in pl_sd:
142
+ sd = pl_sd["state_dict"]
143
+ else:
144
+ sd = pl_sd
145
+ model.load_state_dict(sd, strict=False)
146
+ current_base = base_model
147
+ con_strength = int((1-con_strength)*50)
148
+ if fix_sample == 'True':
149
+ seed_everything(42)
150
+ im = cv2.resize(input_img,(W,H))
151
+
152
+ if type_in == 'Keypose':
153
+ im_pose = im.copy()
154
+ im = img2tensor(im).unsqueeze(0)/255.
155
+ elif type_in == 'Image':
156
+ image = im.copy()
157
+ im = img2tensor(im).unsqueeze(0)/255.
158
+ mmdet_results = inference_detector(det_model, image)
159
+ # keep the person class bounding boxes.
160
+ person_results = process_mmdet_results(mmdet_results, det_cat_id)
161
+
162
+ # optional
163
+ return_heatmap = False
164
+ dataset = pose_model.cfg.data['test']['type']
165
+
166
+ # e.g. use ('backbone', ) to return backbone feature
167
+ output_layer_names = None
168
+ pose_results, returned_outputs = inference_top_down_pose_model(
169
+ pose_model,
170
+ image,
171
+ person_results,
172
+ bbox_thr=bbox_thr,
173
+ format='xyxy',
174
+ dataset=dataset,
175
+ dataset_info=None,
176
+ return_heatmap=return_heatmap,
177
+ outputs=output_layer_names)
178
+
179
+ # show the results
180
+ im_pose = imshow_keypoints(
181
+ image,
182
+ pose_results,
183
+ skeleton=skeleton,
184
+ pose_kpt_color=pose_kpt_color,
185
+ pose_link_color=pose_link_color,
186
+ radius=2,
187
+ thickness=2)
188
+ im_pose = cv2.resize(im_pose,(W,H))
189
+
190
+ with torch.no_grad():
191
+ c = model.get_learned_conditioning([prompt])
192
+ nc = model.get_learned_conditioning([neg_prompt])
193
+ # extract condition features
194
+ pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255.
195
+ pose = pose.unsqueeze(0)
196
+ features_adapter = model_ad(pose.to(device))
197
+
198
+ shape = [4, W//8, H//8]
199
+
200
+ # sampling
201
+ samples_ddim, _ = sampler.sample(S=50,
202
+ conditioning=c,
203
+ batch_size=1,
204
+ shape=shape,
205
+ verbose=False,
206
+ unconditional_guidance_scale=scale,
207
+ unconditional_conditioning=nc,
208
+ eta=0.0,
209
+ x_T=None,
210
+ features_adapter1=features_adapter,
211
+ mode = 'sketch',
212
+ con_strength = con_strength)
213
+
214
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
215
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
216
+ x_samples_ddim = x_samples_ddim.to('cpu')
217
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
218
+ x_samples_ddim = 255.*x_samples_ddim
219
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
220
+
221
+ return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim]
222
+
223
+ DESCRIPTION = '''# T2I-Adapter (Keypose)
224
+ [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
225
+
226
+ This gradio demo is for keypose-guided generation. The current functions include:
227
+ - Keypose to Image Generation
228
+ - Image to Image Generation
229
+ - Generation with **Anything** setting
230
+ '''
231
+ block = gr.Blocks().queue()
232
+ with block:
233
+ with gr.Row():
234
+ gr.Markdown(DESCRIPTION)
235
+ with gr.Row():
236
+ with gr.Column():
237
+ input_img = gr.Image(source='upload', type="numpy")
238
+ prompt = gr.Textbox(label="Prompt")
239
+ neg_prompt = gr.Textbox(label="Negative Prompt",
240
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
241
+ with gr.Row():
242
+ type_in = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a keypose map)')
243
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
244
+ run_button = gr.Button(label="Run")
245
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the keypose to the result)", minimum=0, maximum=1, value=1, step=0.1)
246
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
247
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
248
+ with gr.Column():
249
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
250
+ ips = [input_img, type_in, prompt, neg_prompt, fix_sample, scale, con_strength, base_model]
251
+ run_button.click(fn=process, inputs=ips, outputs=[result])
252
+
253
+ block.launch(server_name='0.0.0.0')
254
+
gradio_sketch.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from basicsr.utils import img2tensor, tensor2img
8
+ from pytorch_lightning import seed_everything
9
+ from ldm.models.diffusion.plms import PLMSSampler
10
+ from ldm.modules.encoders.adapter import Adapter
11
+ from ldm.util import instantiate_from_config
12
+ from model_edge import pidinet
13
+ import gradio as gr
14
+ from omegaconf import OmegaConf
15
+
16
+
17
+ def load_model_from_config(config, ckpt, verbose=False):
18
+ print(f"Loading model from {ckpt}")
19
+ pl_sd = torch.load(ckpt, map_location="cpu")
20
+ if "global_step" in pl_sd:
21
+ print(f"Global Step: {pl_sd['global_step']}")
22
+ if "state_dict" in pl_sd:
23
+ sd = pl_sd["state_dict"]
24
+ else:
25
+ sd = pl_sd
26
+ model = instantiate_from_config(config.model)
27
+ m, u = model.load_state_dict(sd, strict=False)
28
+ # if len(m) > 0 and verbose:
29
+ # print("missing keys:")
30
+ # print(m)
31
+ # if len(u) > 0 and verbose:
32
+ # print("unexpected keys:")
33
+ # print(u)
34
+
35
+ model.cuda()
36
+ model.eval()
37
+ return model
38
+
39
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
40
+ config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")
41
+ config.model.params.cond_stage_config.params.device = device
42
+ model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
43
+ current_base = 'sd-v1-4.ckpt'
44
+ model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
45
+ model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
46
+ net_G = pidinet()
47
+ ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
48
+ net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
49
+ net_G.to(device)
50
+ sampler = PLMSSampler(model)
51
+ save_memory=True
52
+ W, H = 512, 512
53
+
54
+
55
+ def process(input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
56
+ global current_base
57
+ if current_base != base_model:
58
+ ckpt = os.path.join("models", base_model)
59
+ pl_sd = torch.load(ckpt, map_location="cpu")
60
+ if "state_dict" in pl_sd:
61
+ sd = pl_sd["state_dict"]
62
+ else:
63
+ sd = pl_sd
64
+ model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
65
+ current_base = base_model
66
+ con_strength = int((1-con_strength)*50)
67
+ if fix_sample == 'True':
68
+ seed_everything(42)
69
+ im = cv2.resize(input_img,(W,H))
70
+
71
+ if type_in == 'Sketch':
72
+ if color_back == 'White':
73
+ im = 255-im
74
+ im_edge = im.copy()
75
+ im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
76
+ im = im>0.5
77
+ im = im.float()
78
+ elif type_in == 'Image':
79
+ im = img2tensor(im).unsqueeze(0)/255.
80
+ im = net_G(im.to(device))[-1]
81
+ im = im>0.5
82
+ im = im.float()
83
+ im_edge = tensor2img(im)
84
+
85
+ with torch.no_grad():
86
+ c = model.get_learned_conditioning([prompt])
87
+ nc = model.get_learned_conditioning([neg_prompt])
88
+ # extract condition features
89
+ features_adapter = model_ad(im.to(device))
90
+ shape = [4, W//8, H//8]
91
+
92
+ # sampling
93
+ samples_ddim, _ = sampler.sample(S=50,
94
+ conditioning=c,
95
+ batch_size=1,
96
+ shape=shape,
97
+ verbose=False,
98
+ unconditional_guidance_scale=scale,
99
+ unconditional_conditioning=nc,
100
+ eta=0.0,
101
+ x_T=None,
102
+ features_adapter1=features_adapter,
103
+ mode = 'sketch',
104
+ con_strength = con_strength)
105
+
106
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
107
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
108
+ x_samples_ddim = x_samples_ddim.to('cpu')
109
+ x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
110
+ x_samples_ddim = 255.*x_samples_ddim
111
+ x_samples_ddim = x_samples_ddim.astype(np.uint8)
112
+
113
+ return [im_edge, x_samples_ddim]
114
+
115
+ DESCRIPTION = '''# T2I-Adapter (Sketch)
116
+ [Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
117
+
118
+ This gradio demo is for sketch-guided generation. The current functions include:
119
+ - Sketch to Image Generation
120
+ - Image to Image Generation
121
+ - Generation with **Anything** setting
122
+ '''
123
+ block = gr.Blocks().queue()
124
+ with block:
125
+ with gr.Row():
126
+ gr.Markdown(DESCRIPTION)
127
+ with gr.Row():
128
+ with gr.Column():
129
+ input_img = gr.Image(source='upload', type="numpy")
130
+ prompt = gr.Textbox(label="Prompt")
131
+ neg_prompt = gr.Textbox(label="Negative Prompt",
132
+ value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
133
+ with gr.Row():
134
+ type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a sketch)')
135
+ color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
136
+ run_button = gr.Button(label="Run")
137
+ con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
138
+ scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
139
+ fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
140
+ base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
141
+ with gr.Column():
142
+ result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
143
+ ips = [input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model]
144
+ run_button.click(fn=process, inputs=ips, outputs=[result])
145
+
146
+ block.launch(server_name='0.0.0.0')
147
+
ldm/data/__init__.py ADDED
File without changes
ldm/data/base.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3
+
4
+
5
+ class Txt2ImgIterableBaseDataset(IterableDataset):
6
+ '''
7
+ Define an interface to make the IterableDatasets for text2img data chainable
8
+ '''
9
+ def __init__(self, num_records=0, valid_ids=None, size=256):
10
+ super().__init__()
11
+ self.num_records = num_records
12
+ self.valid_ids = valid_ids
13
+ self.sample_ids = valid_ids
14
+ self.size = size
15
+
16
+ print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17
+
18
+ def __len__(self):
19
+ return self.num_records
20
+
21
+ @abstractmethod
22
+ def __iter__(self):
23
+ pass
ldm/data/imagenet.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, yaml, pickle, shutil, tarfile, glob
2
+ import cv2
3
+ import albumentations
4
+ import PIL
5
+ import numpy as np
6
+ import torchvision.transforms.functional as TF
7
+ from omegaconf import OmegaConf
8
+ from functools import partial
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+ from torch.utils.data import Dataset, Subset
12
+
13
+ import taming.data.utils as tdu
14
+ from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
15
+ from taming.data.imagenet import ImagePaths
16
+
17
+ from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
18
+
19
+
20
+ def synset2idx(path_to_yaml="data/index_synset.yaml"):
21
+ with open(path_to_yaml) as f:
22
+ di2s = yaml.load(f)
23
+ return dict((v,k) for k,v in di2s.items())
24
+
25
+
26
+ class ImageNetBase(Dataset):
27
+ def __init__(self, config=None):
28
+ self.config = config or OmegaConf.create()
29
+ if not type(self.config)==dict:
30
+ self.config = OmegaConf.to_container(self.config)
31
+ self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
32
+ self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
33
+ self._prepare()
34
+ self._prepare_synset_to_human()
35
+ self._prepare_idx_to_synset()
36
+ self._prepare_human_to_integer_label()
37
+ self._load()
38
+
39
+ def __len__(self):
40
+ return len(self.data)
41
+
42
+ def __getitem__(self, i):
43
+ return self.data[i]
44
+
45
+ def _prepare(self):
46
+ raise NotImplementedError()
47
+
48
+ def _filter_relpaths(self, relpaths):
49
+ ignore = set([
50
+ "n06596364_9591.JPEG",
51
+ ])
52
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
53
+ if "sub_indices" in self.config:
54
+ indices = str_to_indices(self.config["sub_indices"])
55
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
56
+ self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
57
+ files = []
58
+ for rpath in relpaths:
59
+ syn = rpath.split("/")[0]
60
+ if syn in synsets:
61
+ files.append(rpath)
62
+ return files
63
+ else:
64
+ return relpaths
65
+
66
+ def _prepare_synset_to_human(self):
67
+ SIZE = 2655750
68
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
69
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
70
+ if (not os.path.exists(self.human_dict) or
71
+ not os.path.getsize(self.human_dict)==SIZE):
72
+ download(URL, self.human_dict)
73
+
74
+ def _prepare_idx_to_synset(self):
75
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
76
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
77
+ if (not os.path.exists(self.idx2syn)):
78
+ download(URL, self.idx2syn)
79
+
80
+ def _prepare_human_to_integer_label(self):
81
+ URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
82
+ self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
83
+ if (not os.path.exists(self.human2integer)):
84
+ download(URL, self.human2integer)
85
+ with open(self.human2integer, "r") as f:
86
+ lines = f.read().splitlines()
87
+ assert len(lines) == 1000
88
+ self.human2integer_dict = dict()
89
+ for line in lines:
90
+ value, key = line.split(":")
91
+ self.human2integer_dict[key] = int(value)
92
+
93
+ def _load(self):
94
+ with open(self.txt_filelist, "r") as f:
95
+ self.relpaths = f.read().splitlines()
96
+ l1 = len(self.relpaths)
97
+ self.relpaths = self._filter_relpaths(self.relpaths)
98
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
99
+
100
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
101
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
102
+
103
+ unique_synsets = np.unique(self.synsets)
104
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
105
+ if not self.keep_orig_class_label:
106
+ self.class_labels = [class_dict[s] for s in self.synsets]
107
+ else:
108
+ self.class_labels = [self.synset2idx[s] for s in self.synsets]
109
+
110
+ with open(self.human_dict, "r") as f:
111
+ human_dict = f.read().splitlines()
112
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
113
+
114
+ self.human_labels = [human_dict[s] for s in self.synsets]
115
+
116
+ labels = {
117
+ "relpath": np.array(self.relpaths),
118
+ "synsets": np.array(self.synsets),
119
+ "class_label": np.array(self.class_labels),
120
+ "human_label": np.array(self.human_labels),
121
+ }
122
+
123
+ if self.process_images:
124
+ self.size = retrieve(self.config, "size", default=256)
125
+ self.data = ImagePaths(self.abspaths,
126
+ labels=labels,
127
+ size=self.size,
128
+ random_crop=self.random_crop,
129
+ )
130
+ else:
131
+ self.data = self.abspaths
132
+
133
+
134
+ class ImageNetTrain(ImageNetBase):
135
+ NAME = "ILSVRC2012_train"
136
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
137
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
138
+ FILES = [
139
+ "ILSVRC2012_img_train.tar",
140
+ ]
141
+ SIZES = [
142
+ 147897477120,
143
+ ]
144
+
145
+ def __init__(self, process_images=True, data_root=None, **kwargs):
146
+ self.process_images = process_images
147
+ self.data_root = data_root
148
+ super().__init__(**kwargs)
149
+
150
+ def _prepare(self):
151
+ if self.data_root:
152
+ self.root = os.path.join(self.data_root, self.NAME)
153
+ else:
154
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
155
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
156
+
157
+ self.datadir = os.path.join(self.root, "data")
158
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
159
+ self.expected_length = 1281167
160
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
161
+ default=True)
162
+ if not tdu.is_prepared(self.root):
163
+ # prep
164
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
165
+
166
+ datadir = self.datadir
167
+ if not os.path.exists(datadir):
168
+ path = os.path.join(self.root, self.FILES[0])
169
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
170
+ import academictorrents as at
171
+ atpath = at.get(self.AT_HASH, datastore=self.root)
172
+ assert atpath == path
173
+
174
+ print("Extracting {} to {}".format(path, datadir))
175
+ os.makedirs(datadir, exist_ok=True)
176
+ with tarfile.open(path, "r:") as tar:
177
+ tar.extractall(path=datadir)
178
+
179
+ print("Extracting sub-tars.")
180
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
181
+ for subpath in tqdm(subpaths):
182
+ subdir = subpath[:-len(".tar")]
183
+ os.makedirs(subdir, exist_ok=True)
184
+ with tarfile.open(subpath, "r:") as tar:
185
+ tar.extractall(path=subdir)
186
+
187
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
188
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
189
+ filelist = sorted(filelist)
190
+ filelist = "\n".join(filelist)+"\n"
191
+ with open(self.txt_filelist, "w") as f:
192
+ f.write(filelist)
193
+
194
+ tdu.mark_prepared(self.root)
195
+
196
+
197
+ class ImageNetValidation(ImageNetBase):
198
+ NAME = "ILSVRC2012_validation"
199
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
200
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
201
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
202
+ FILES = [
203
+ "ILSVRC2012_img_val.tar",
204
+ "validation_synset.txt",
205
+ ]
206
+ SIZES = [
207
+ 6744924160,
208
+ 1950000,
209
+ ]
210
+
211
+ def __init__(self, process_images=True, data_root=None, **kwargs):
212
+ self.data_root = data_root
213
+ self.process_images = process_images
214
+ super().__init__(**kwargs)
215
+
216
+ def _prepare(self):
217
+ if self.data_root:
218
+ self.root = os.path.join(self.data_root, self.NAME)
219
+ else:
220
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
221
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
222
+ self.datadir = os.path.join(self.root, "data")
223
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
224
+ self.expected_length = 50000
225
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
226
+ default=False)
227
+ if not tdu.is_prepared(self.root):
228
+ # prep
229
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
230
+
231
+ datadir = self.datadir
232
+ if not os.path.exists(datadir):
233
+ path = os.path.join(self.root, self.FILES[0])
234
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
235
+ import academictorrents as at
236
+ atpath = at.get(self.AT_HASH, datastore=self.root)
237
+ assert atpath == path
238
+
239
+ print("Extracting {} to {}".format(path, datadir))
240
+ os.makedirs(datadir, exist_ok=True)
241
+ with tarfile.open(path, "r:") as tar:
242
+ tar.extractall(path=datadir)
243
+
244
+ vspath = os.path.join(self.root, self.FILES[1])
245
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
246
+ download(self.VS_URL, vspath)
247
+
248
+ with open(vspath, "r") as f:
249
+ synset_dict = f.read().splitlines()
250
+ synset_dict = dict(line.split() for line in synset_dict)
251
+
252
+ print("Reorganizing into synset folders")
253
+ synsets = np.unique(list(synset_dict.values()))
254
+ for s in synsets:
255
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
256
+ for k, v in synset_dict.items():
257
+ src = os.path.join(datadir, k)
258
+ dst = os.path.join(datadir, v)
259
+ shutil.move(src, dst)
260
+
261
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
262
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
263
+ filelist = sorted(filelist)
264
+ filelist = "\n".join(filelist)+"\n"
265
+ with open(self.txt_filelist, "w") as f:
266
+ f.write(filelist)
267
+
268
+ tdu.mark_prepared(self.root)
269
+
270
+
271
+
272
+ class ImageNetSR(Dataset):
273
+ def __init__(self, size=None,
274
+ degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
275
+ random_crop=True):
276
+ """
277
+ Imagenet Superresolution Dataloader
278
+ Performs following ops in order:
279
+ 1. crops a crop of size s from image either as random or center crop
280
+ 2. resizes crop to size with cv2.area_interpolation
281
+ 3. degrades resized crop with degradation_fn
282
+
283
+ :param size: resizing to size after cropping
284
+ :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
285
+ :param downscale_f: Low Resolution Downsample factor
286
+ :param min_crop_f: determines crop size s,
287
+ where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
288
+ :param max_crop_f: ""
289
+ :param data_root:
290
+ :param random_crop:
291
+ """
292
+ self.base = self.get_base()
293
+ assert size
294
+ assert (size / downscale_f).is_integer()
295
+ self.size = size
296
+ self.LR_size = int(size / downscale_f)
297
+ self.min_crop_f = min_crop_f
298
+ self.max_crop_f = max_crop_f
299
+ assert(max_crop_f <= 1.)
300
+ self.center_crop = not random_crop
301
+
302
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
303
+
304
+ self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
305
+
306
+ if degradation == "bsrgan":
307
+ self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
308
+
309
+ elif degradation == "bsrgan_light":
310
+ self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
311
+
312
+ else:
313
+ interpolation_fn = {
314
+ "cv_nearest": cv2.INTER_NEAREST,
315
+ "cv_bilinear": cv2.INTER_LINEAR,
316
+ "cv_bicubic": cv2.INTER_CUBIC,
317
+ "cv_area": cv2.INTER_AREA,
318
+ "cv_lanczos": cv2.INTER_LANCZOS4,
319
+ "pil_nearest": PIL.Image.NEAREST,
320
+ "pil_bilinear": PIL.Image.BILINEAR,
321
+ "pil_bicubic": PIL.Image.BICUBIC,
322
+ "pil_box": PIL.Image.BOX,
323
+ "pil_hamming": PIL.Image.HAMMING,
324
+ "pil_lanczos": PIL.Image.LANCZOS,
325
+ }[degradation]
326
+
327
+ self.pil_interpolation = degradation.startswith("pil_")
328
+
329
+ if self.pil_interpolation:
330
+ self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
331
+
332
+ else:
333
+ self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
334
+ interpolation=interpolation_fn)
335
+
336
+ def __len__(self):
337
+ return len(self.base)
338
+
339
+ def __getitem__(self, i):
340
+ example = self.base[i]
341
+ image = Image.open(example["file_path_"])
342
+
343
+ if not image.mode == "RGB":
344
+ image = image.convert("RGB")
345
+
346
+ image = np.array(image).astype(np.uint8)
347
+
348
+ min_side_len = min(image.shape[:2])
349
+ crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
350
+ crop_side_len = int(crop_side_len)
351
+
352
+ if self.center_crop:
353
+ self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
354
+
355
+ else:
356
+ self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
357
+
358
+ image = self.cropper(image=image)["image"]
359
+ image = self.image_rescaler(image=image)["image"]
360
+
361
+ if self.pil_interpolation:
362
+ image_pil = PIL.Image.fromarray(image)
363
+ LR_image = self.degradation_process(image_pil)
364
+ LR_image = np.array(LR_image).astype(np.uint8)
365
+
366
+ else:
367
+ LR_image = self.degradation_process(image=image)["image"]
368
+
369
+ example["image"] = (image/127.5 - 1.0).astype(np.float32)
370
+ example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
371
+
372
+ return example
373
+
374
+
375
+ class ImageNetSRTrain(ImageNetSR):
376
+ def __init__(self, **kwargs):
377
+ super().__init__(**kwargs)
378
+
379
+ def get_base(self):
380
+ with open("data/imagenet_train_hr_indices.p", "rb") as f:
381
+ indices = pickle.load(f)
382
+ dset = ImageNetTrain(process_images=False,)
383
+ return Subset(dset, indices)
384
+
385
+
386
+ class ImageNetSRValidation(ImageNetSR):
387
+ def __init__(self, **kwargs):
388
+ super().__init__(**kwargs)
389
+
390
+ def get_base(self):
391
+ with open("data/imagenet_val_hr_indices.p", "rb") as f:
392
+ indices = pickle.load(f)
393
+ dset = ImageNetValidation(process_images=False,)
394
+ return Subset(dset, indices)
ldm/data/lsun.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+
8
+
9
+ class LSUNBase(Dataset):
10
+ def __init__(self,
11
+ txt_file,
12
+ data_root,
13
+ size=None,
14
+ interpolation="bicubic",
15
+ flip_p=0.5
16
+ ):
17
+ self.data_paths = txt_file
18
+ self.data_root = data_root
19
+ with open(self.data_paths, "r") as f:
20
+ self.image_paths = f.read().splitlines()
21
+ self._length = len(self.image_paths)
22
+ self.labels = {
23
+ "relative_file_path_": [l for l in self.image_paths],
24
+ "file_path_": [os.path.join(self.data_root, l)
25
+ for l in self.image_paths],
26
+ }
27
+
28
+ self.size = size
29
+ self.interpolation = {"linear": PIL.Image.LINEAR,
30
+ "bilinear": PIL.Image.BILINEAR,
31
+ "bicubic": PIL.Image.BICUBIC,
32
+ "lanczos": PIL.Image.LANCZOS,
33
+ }[interpolation]
34
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35
+
36
+ def __len__(self):
37
+ return self._length
38
+
39
+ def __getitem__(self, i):
40
+ example = dict((k, self.labels[k][i]) for k in self.labels)
41
+ image = Image.open(example["file_path_"])
42
+ if not image.mode == "RGB":
43
+ image = image.convert("RGB")
44
+
45
+ # default to score-sde preprocessing
46
+ img = np.array(image).astype(np.uint8)
47
+ crop = min(img.shape[0], img.shape[1])
48
+ h, w, = img.shape[0], img.shape[1]
49
+ img = img[(h - crop) // 2:(h + crop) // 2,
50
+ (w - crop) // 2:(w + crop) // 2]
51
+
52
+ image = Image.fromarray(img)
53
+ if self.size is not None:
54
+ image = image.resize((self.size, self.size), resample=self.interpolation)
55
+
56
+ image = self.flip(image)
57
+ image = np.array(image).astype(np.uint8)
58
+ example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59
+ return example
60
+
61
+
62
+ class LSUNChurchesTrain(LSUNBase):
63
+ def __init__(self, **kwargs):
64
+ super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65
+
66
+
67
+ class LSUNChurchesValidation(LSUNBase):
68
+ def __init__(self, flip_p=0., **kwargs):
69
+ super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70
+ flip_p=flip_p, **kwargs)
71
+
72
+
73
+ class LSUNBedroomsTrain(LSUNBase):
74
+ def __init__(self, **kwargs):
75
+ super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76
+
77
+
78
+ class LSUNBedroomsValidation(LSUNBase):
79
+ def __init__(self, flip_p=0.0, **kwargs):
80
+ super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81
+ flip_p=flip_p, **kwargs)
82
+
83
+
84
+ class LSUNCatsTrain(LSUNBase):
85
+ def __init__(self, **kwargs):
86
+ super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87
+
88
+
89
+ class LSUNCatsValidation(LSUNBase):
90
+ def __init__(self, flip_p=0., **kwargs):
91
+ super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92
+ flip_p=flip_p, **kwargs)
ldm/lr_scheduler.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class LambdaWarmUpCosineScheduler:
5
+ """
6
+ note: use with a base_lr of 1.0
7
+ """
8
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9
+ self.lr_warm_up_steps = warm_up_steps
10
+ self.lr_start = lr_start
11
+ self.lr_min = lr_min
12
+ self.lr_max = lr_max
13
+ self.lr_max_decay_steps = max_decay_steps
14
+ self.last_lr = 0.
15
+ self.verbosity_interval = verbosity_interval
16
+
17
+ def schedule(self, n, **kwargs):
18
+ if self.verbosity_interval > 0:
19
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20
+ if n < self.lr_warm_up_steps:
21
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22
+ self.last_lr = lr
23
+ return lr
24
+ else:
25
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26
+ t = min(t, 1.0)
27
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28
+ 1 + np.cos(t * np.pi))
29
+ self.last_lr = lr
30
+ return lr
31
+
32
+ def __call__(self, n, **kwargs):
33
+ return self.schedule(n,**kwargs)
34
+
35
+
36
+ class LambdaWarmUpCosineScheduler2:
37
+ """
38
+ supports repeated iterations, configurable via lists
39
+ note: use with a base_lr of 1.0.
40
+ """
41
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43
+ self.lr_warm_up_steps = warm_up_steps
44
+ self.f_start = f_start
45
+ self.f_min = f_min
46
+ self.f_max = f_max
47
+ self.cycle_lengths = cycle_lengths
48
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49
+ self.last_f = 0.
50
+ self.verbosity_interval = verbosity_interval
51
+
52
+ def find_in_interval(self, n):
53
+ interval = 0
54
+ for cl in self.cum_cycles[1:]:
55
+ if n <= cl:
56
+ return interval
57
+ interval += 1
58
+
59
+ def schedule(self, n, **kwargs):
60
+ cycle = self.find_in_interval(n)
61
+ n = n - self.cum_cycles[cycle]
62
+ if self.verbosity_interval > 0:
63
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64
+ f"current cycle {cycle}")
65
+ if n < self.lr_warm_up_steps[cycle]:
66
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67
+ self.last_f = f
68
+ return f
69
+ else:
70
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71
+ t = min(t, 1.0)
72
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73
+ 1 + np.cos(t * np.pi))
74
+ self.last_f = f
75
+ return f
76
+
77
+ def __call__(self, n, **kwargs):
78
+ return self.schedule(n, **kwargs)
79
+
80
+
81
+ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82
+
83
+ def schedule(self, n, **kwargs):
84
+ cycle = self.find_in_interval(n)
85
+ n = n - self.cum_cycles[cycle]
86
+ if self.verbosity_interval > 0:
87
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88
+ f"current cycle {cycle}")
89
+
90
+ if n < self.lr_warm_up_steps[cycle]:
91
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92
+ self.last_f = f
93
+ return f
94
+ else:
95
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96
+ self.last_f = f
97
+ return f
98
+
ldm/models/autoencoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torch.nn.functional as F
4
+ from contextlib import contextmanager
5
+
6
+ from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
7
+
8
+ from ldm.modules.diffusionmodules.model import Encoder, Decoder
9
+ from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
10
+
11
+ from ldm.util import instantiate_from_config
12
+
13
+
14
+ class VQModel(pl.LightningModule):
15
+ def __init__(self,
16
+ ddconfig,
17
+ lossconfig,
18
+ n_embed,
19
+ embed_dim,
20
+ ckpt_path=None,
21
+ ignore_keys=[],
22
+ image_key="image",
23
+ colorize_nlabels=None,
24
+ monitor=None,
25
+ batch_resize_range=None,
26
+ scheduler_config=None,
27
+ lr_g_factor=1.0,
28
+ remap=None,
29
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
30
+ use_ema=False
31
+ ):
32
+ super().__init__()
33
+ self.embed_dim = embed_dim
34
+ self.n_embed = n_embed
35
+ self.image_key = image_key
36
+ self.encoder = Encoder(**ddconfig)
37
+ self.decoder = Decoder(**ddconfig)
38
+ self.loss = instantiate_from_config(lossconfig)
39
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
40
+ remap=remap,
41
+ sane_index_shape=sane_index_shape)
42
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
43
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
44
+ if colorize_nlabels is not None:
45
+ assert type(colorize_nlabels)==int
46
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
47
+ if monitor is not None:
48
+ self.monitor = monitor
49
+ self.batch_resize_range = batch_resize_range
50
+ if self.batch_resize_range is not None:
51
+ print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
52
+
53
+ self.use_ema = use_ema
54
+ if self.use_ema:
55
+ self.model_ema = LitEma(self)
56
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
57
+
58
+ if ckpt_path is not None:
59
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
60
+ self.scheduler_config = scheduler_config
61
+ self.lr_g_factor = lr_g_factor
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def init_from_ckpt(self, path, ignore_keys=list()):
79
+ sd = torch.load(path, map_location="cpu")["state_dict"]
80
+ keys = list(sd.keys())
81
+ for k in keys:
82
+ for ik in ignore_keys:
83
+ if k.startswith(ik):
84
+ print("Deleting key {} from state_dict.".format(k))
85
+ del sd[k]
86
+ missing, unexpected = self.load_state_dict(sd, strict=False)
87
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
88
+ if len(missing) > 0:
89
+ print(f"Missing Keys: {missing}")
90
+ print(f"Unexpected Keys: {unexpected}")
91
+
92
+ def on_train_batch_end(self, *args, **kwargs):
93
+ if self.use_ema:
94
+ self.model_ema(self)
95
+
96
+ def encode(self, x):
97
+ h = self.encoder(x)
98
+ h = self.quant_conv(h)
99
+ quant, emb_loss, info = self.quantize(h)
100
+ return quant, emb_loss, info
101
+
102
+ def encode_to_prequant(self, x):
103
+ h = self.encoder(x)
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def decode(self, quant):
108
+ quant = self.post_quant_conv(quant)
109
+ dec = self.decoder(quant)
110
+ return dec
111
+
112
+ def decode_code(self, code_b):
113
+ quant_b = self.quantize.embed_code(code_b)
114
+ dec = self.decode(quant_b)
115
+ return dec
116
+
117
+ def forward(self, input, return_pred_indices=False):
118
+ quant, diff, (_,_,ind) = self.encode(input)
119
+ dec = self.decode(quant)
120
+ if return_pred_indices:
121
+ return dec, diff, ind
122
+ return dec, diff
123
+
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
129
+ if self.batch_resize_range is not None:
130
+ lower_size = self.batch_resize_range[0]
131
+ upper_size = self.batch_resize_range[1]
132
+ if self.global_step <= 4:
133
+ # do the first few batches with max size to avoid later oom
134
+ new_resize = upper_size
135
+ else:
136
+ new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
137
+ if new_resize != x.shape[2]:
138
+ x = F.interpolate(x, size=new_resize, mode="bicubic")
139
+ x = x.detach()
140
+ return x
141
+
142
+ def training_step(self, batch, batch_idx, optimizer_idx):
143
+ # https://github.com/pytorch/pytorch/issues/37142
144
+ # try not to fool the heuristics
145
+ x = self.get_input(batch, self.image_key)
146
+ xrec, qloss, ind = self(x, return_pred_indices=True)
147
+
148
+ if optimizer_idx == 0:
149
+ # autoencode
150
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
151
+ last_layer=self.get_last_layer(), split="train",
152
+ predicted_indices=ind)
153
+
154
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
155
+ return aeloss
156
+
157
+ if optimizer_idx == 1:
158
+ # discriminator
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
160
+ last_layer=self.get_last_layer(), split="train")
161
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
162
+ return discloss
163
+
164
+ def validation_step(self, batch, batch_idx):
165
+ log_dict = self._validation_step(batch, batch_idx)
166
+ with self.ema_scope():
167
+ log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
168
+ return log_dict
169
+
170
+ def _validation_step(self, batch, batch_idx, suffix=""):
171
+ x = self.get_input(batch, self.image_key)
172
+ xrec, qloss, ind = self(x, return_pred_indices=True)
173
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
174
+ self.global_step,
175
+ last_layer=self.get_last_layer(),
176
+ split="val"+suffix,
177
+ predicted_indices=ind
178
+ )
179
+
180
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
181
+ self.global_step,
182
+ last_layer=self.get_last_layer(),
183
+ split="val"+suffix,
184
+ predicted_indices=ind
185
+ )
186
+ rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
187
+ self.log(f"val{suffix}/rec_loss", rec_loss,
188
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
189
+ self.log(f"val{suffix}/aeloss", aeloss,
190
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
191
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
192
+ del log_dict_ae[f"val{suffix}/rec_loss"]
193
+ self.log_dict(log_dict_ae)
194
+ self.log_dict(log_dict_disc)
195
+ return self.log_dict
196
+
197
+ def configure_optimizers(self):
198
+ lr_d = self.learning_rate
199
+ lr_g = self.lr_g_factor*self.learning_rate
200
+ print("lr_d", lr_d)
201
+ print("lr_g", lr_g)
202
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
203
+ list(self.decoder.parameters())+
204
+ list(self.quantize.parameters())+
205
+ list(self.quant_conv.parameters())+
206
+ list(self.post_quant_conv.parameters()),
207
+ lr=lr_g, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
209
+ lr=lr_d, betas=(0.5, 0.9))
210
+
211
+ if self.scheduler_config is not None:
212
+ scheduler = instantiate_from_config(self.scheduler_config)
213
+
214
+ print("Setting up LambdaLR scheduler...")
215
+ scheduler = [
216
+ {
217
+ 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
218
+ 'interval': 'step',
219
+ 'frequency': 1
220
+ },
221
+ {
222
+ 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
223
+ 'interval': 'step',
224
+ 'frequency': 1
225
+ },
226
+ ]
227
+ return [opt_ae, opt_disc], scheduler
228
+ return [opt_ae, opt_disc], []
229
+
230
+ def get_last_layer(self):
231
+ return self.decoder.conv_out.weight
232
+
233
+ def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
234
+ log = dict()
235
+ x = self.get_input(batch, self.image_key)
236
+ x = x.to(self.device)
237
+ if only_inputs:
238
+ log["inputs"] = x
239
+ return log
240
+ xrec, _ = self(x)
241
+ if x.shape[1] > 3:
242
+ # colorize with random projection
243
+ assert xrec.shape[1] > 3
244
+ x = self.to_rgb(x)
245
+ xrec = self.to_rgb(xrec)
246
+ log["inputs"] = x
247
+ log["reconstructions"] = xrec
248
+ if plot_ema:
249
+ with self.ema_scope():
250
+ xrec_ema, _ = self(x)
251
+ if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
252
+ log["reconstructions_ema"] = xrec_ema
253
+ return log
254
+
255
+ def to_rgb(self, x):
256
+ assert self.image_key == "segmentation"
257
+ if not hasattr(self, "colorize"):
258
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
259
+ x = F.conv2d(x, weight=self.colorize)
260
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
261
+ return x
262
+
263
+
264
+ class VQModelInterface(VQModel):
265
+ def __init__(self, embed_dim, *args, **kwargs):
266
+ super().__init__(embed_dim=embed_dim, *args, **kwargs)
267
+ self.embed_dim = embed_dim
268
+
269
+ def encode(self, x):
270
+ h = self.encoder(x)
271
+ h = self.quant_conv(h)
272
+ return h
273
+
274
+ def decode(self, h, force_not_quantize=False):
275
+ # also go through quantization layer
276
+ if not force_not_quantize:
277
+ quant, emb_loss, info = self.quantize(h)
278
+ else:
279
+ quant = h
280
+ quant = self.post_quant_conv(quant)
281
+ dec = self.decoder(quant)
282
+ return dec
283
+
284
+
285
+ class AutoencoderKL(pl.LightningModule):
286
+ def __init__(self,
287
+ ddconfig,
288
+ lossconfig,
289
+ embed_dim,
290
+ ckpt_path=None,
291
+ ignore_keys=[],
292
+ image_key="image",
293
+ colorize_nlabels=None,
294
+ monitor=None,
295
+ ):
296
+ super().__init__()
297
+ self.image_key = image_key
298
+ self.encoder = Encoder(**ddconfig)
299
+ self.decoder = Decoder(**ddconfig)
300
+ self.loss = instantiate_from_config(lossconfig)
301
+ assert ddconfig["double_z"]
302
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
303
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
304
+ self.embed_dim = embed_dim
305
+ if colorize_nlabels is not None:
306
+ assert type(colorize_nlabels)==int
307
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
308
+ if monitor is not None:
309
+ self.monitor = monitor
310
+ if ckpt_path is not None:
311
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
312
+
313
+ def init_from_ckpt(self, path, ignore_keys=list()):
314
+ sd = torch.load(path, map_location="cpu")["state_dict"]
315
+ keys = list(sd.keys())
316
+ for k in keys:
317
+ for ik in ignore_keys:
318
+ if k.startswith(ik):
319
+ print("Deleting key {} from state_dict.".format(k))
320
+ del sd[k]
321
+ self.load_state_dict(sd, strict=False)
322
+ print(f"Restored from {path}")
323
+
324
+ def encode(self, x):
325
+ h = self.encoder(x)
326
+ moments = self.quant_conv(h)
327
+ posterior = DiagonalGaussianDistribution(moments)
328
+ return posterior
329
+
330
+ def decode(self, z):
331
+ z = self.post_quant_conv(z)
332
+ dec = self.decoder(z)
333
+ return dec
334
+
335
+ def forward(self, input, sample_posterior=True):
336
+ posterior = self.encode(input)
337
+ if sample_posterior:
338
+ z = posterior.sample()
339
+ else:
340
+ z = posterior.mode()
341
+ dec = self.decode(z)
342
+ return dec, posterior
343
+
344
+ def get_input(self, batch, k):
345
+ x = batch[k]
346
+ if len(x.shape) == 3:
347
+ x = x[..., None]
348
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
349
+ return x
350
+
351
+ def training_step(self, batch, batch_idx, optimizer_idx):
352
+ inputs = self.get_input(batch, self.image_key)
353
+ reconstructions, posterior = self(inputs)
354
+
355
+ if optimizer_idx == 0:
356
+ # train encoder+decoder+logvar
357
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
358
+ last_layer=self.get_last_layer(), split="train")
359
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
360
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
361
+ return aeloss
362
+
363
+ if optimizer_idx == 1:
364
+ # train the discriminator
365
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
366
+ last_layer=self.get_last_layer(), split="train")
367
+
368
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
369
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
370
+ return discloss
371
+
372
+ def validation_step(self, batch, batch_idx):
373
+ inputs = self.get_input(batch, self.image_key)
374
+ reconstructions, posterior = self(inputs)
375
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
376
+ last_layer=self.get_last_layer(), split="val")
377
+
378
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
379
+ last_layer=self.get_last_layer(), split="val")
380
+
381
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
382
+ self.log_dict(log_dict_ae)
383
+ self.log_dict(log_dict_disc)
384
+ return self.log_dict
385
+
386
+ def configure_optimizers(self):
387
+ lr = self.learning_rate
388
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
389
+ list(self.decoder.parameters())+
390
+ list(self.quant_conv.parameters())+
391
+ list(self.post_quant_conv.parameters()),
392
+ lr=lr, betas=(0.5, 0.9))
393
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
394
+ lr=lr, betas=(0.5, 0.9))
395
+ return [opt_ae, opt_disc], []
396
+
397
+ def get_last_layer(self):
398
+ return self.decoder.conv_out.weight
399
+
400
+ @torch.no_grad()
401
+ def log_images(self, batch, only_inputs=False, **kwargs):
402
+ log = dict()
403
+ x = self.get_input(batch, self.image_key)
404
+ x = x.to(self.device)
405
+ if not only_inputs:
406
+ xrec, posterior = self(x)
407
+ if x.shape[1] > 3:
408
+ # colorize with random projection
409
+ assert xrec.shape[1] > 3
410
+ x = self.to_rgb(x)
411
+ xrec = self.to_rgb(xrec)
412
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
413
+ log["reconstructions"] = xrec
414
+ log["inputs"] = x
415
+ return log
416
+
417
+ def to_rgb(self, x):
418
+ assert self.image_key == "segmentation"
419
+ if not hasattr(self, "colorize"):
420
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
421
+ x = F.conv2d(x, weight=self.colorize)
422
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
423
+ return x
424
+
425
+
426
+ class IdentityFirstStage(torch.nn.Module):
427
+ def __init__(self, *args, vq_interface=False, **kwargs):
428
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
429
+ super().__init__()
430
+
431
+ def encode(self, x, *args, **kwargs):
432
+ return x
433
+
434
+ def decode(self, x, *args, **kwargs):
435
+ return x
436
+
437
+ def quantize(self, x, *args, **kwargs):
438
+ if self.vq_interface:
439
+ return x, None, [None, None, None]
440
+ return x
441
+
442
+ def forward(self, x, *args, **kwargs):
443
+ return x
ldm/models/diffusion/__init__.py ADDED
File without changes
ldm/models/diffusion/classifier.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ from omegaconf import OmegaConf
5
+ from torch.nn import functional as F
6
+ from torch.optim import AdamW
7
+ from torch.optim.lr_scheduler import LambdaLR
8
+ from copy import deepcopy
9
+ from einops import rearrange
10
+ from glob import glob
11
+ from natsort import natsorted
12
+
13
+ from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14
+ from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15
+
16
+ __models__ = {
17
+ 'class_label': EncoderUNetModel,
18
+ 'segmentation': UNetModel
19
+ }
20
+
21
+
22
+ def disabled_train(self, mode=True):
23
+ """Overwrite model.train with this function to make sure train/eval mode
24
+ does not change anymore."""
25
+ return self
26
+
27
+
28
+ class NoisyLatentImageClassifier(pl.LightningModule):
29
+
30
+ def __init__(self,
31
+ diffusion_path,
32
+ num_classes,
33
+ ckpt_path=None,
34
+ pool='attention',
35
+ label_key=None,
36
+ diffusion_ckpt_path=None,
37
+ scheduler_config=None,
38
+ weight_decay=1.e-2,
39
+ log_steps=10,
40
+ monitor='val/loss',
41
+ *args,
42
+ **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+ self.num_classes = num_classes
45
+ # get latest config of diffusion model
46
+ diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47
+ self.diffusion_config = OmegaConf.load(diffusion_config).model
48
+ self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49
+ self.load_diffusion()
50
+
51
+ self.monitor = monitor
52
+ self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53
+ self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54
+ self.log_steps = log_steps
55
+
56
+ self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57
+ else self.diffusion_model.cond_stage_key
58
+
59
+ assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params'
60
+
61
+ if self.label_key not in __models__:
62
+ raise NotImplementedError()
63
+
64
+ self.load_classifier(ckpt_path, pool)
65
+
66
+ self.scheduler_config = scheduler_config
67
+ self.use_scheduler = self.scheduler_config is not None
68
+ self.weight_decay = weight_decay
69
+
70
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71
+ sd = torch.load(path, map_location="cpu")
72
+ if "state_dict" in list(sd.keys()):
73
+ sd = sd["state_dict"]
74
+ keys = list(sd.keys())
75
+ for k in keys:
76
+ for ik in ignore_keys:
77
+ if k.startswith(ik):
78
+ print("Deleting key {} from state_dict.".format(k))
79
+ del sd[k]
80
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81
+ sd, strict=False)
82
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83
+ if len(missing) > 0:
84
+ print(f"Missing Keys: {missing}")
85
+ if len(unexpected) > 0:
86
+ print(f"Unexpected Keys: {unexpected}")
87
+
88
+ def load_diffusion(self):
89
+ model = instantiate_from_config(self.diffusion_config)
90
+ self.diffusion_model = model.eval()
91
+ self.diffusion_model.train = disabled_train
92
+ for param in self.diffusion_model.parameters():
93
+ param.requires_grad = False
94
+
95
+ def load_classifier(self, ckpt_path, pool):
96
+ model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97
+ model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98
+ model_config.out_channels = self.num_classes
99
+ if self.label_key == 'class_label':
100
+ model_config.pool = pool
101
+
102
+ self.model = __models__[self.label_key](**model_config)
103
+ if ckpt_path is not None:
104
+ print('#####################################################################')
105
+ print(f'load from ckpt "{ckpt_path}"')
106
+ print('#####################################################################')
107
+ self.init_from_ckpt(ckpt_path)
108
+
109
+ @torch.no_grad()
110
+ def get_x_noisy(self, x, t, noise=None):
111
+ noise = default(noise, lambda: torch.randn_like(x))
112
+ continuous_sqrt_alpha_cumprod = None
113
+ if self.diffusion_model.use_continuous_noise:
114
+ continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115
+ # todo: make sure t+1 is correct here
116
+
117
+ return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118
+ continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119
+
120
+ def forward(self, x_noisy, t, *args, **kwargs):
121
+ return self.model(x_noisy, t)
122
+
123
+ @torch.no_grad()
124
+ def get_input(self, batch, k):
125
+ x = batch[k]
126
+ if len(x.shape) == 3:
127
+ x = x[..., None]
128
+ x = rearrange(x, 'b h w c -> b c h w')
129
+ x = x.to(memory_format=torch.contiguous_format).float()
130
+ return x
131
+
132
+ @torch.no_grad()
133
+ def get_conditioning(self, batch, k=None):
134
+ if k is None:
135
+ k = self.label_key
136
+ assert k is not None, 'Needs to provide label key'
137
+
138
+ targets = batch[k].to(self.device)
139
+
140
+ if self.label_key == 'segmentation':
141
+ targets = rearrange(targets, 'b h w c -> b c h w')
142
+ for down in range(self.numd):
143
+ h, w = targets.shape[-2:]
144
+ targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145
+
146
+ # targets = rearrange(targets,'b c h w -> b h w c')
147
+
148
+ return targets
149
+
150
+ def compute_top_k(self, logits, labels, k, reduction="mean"):
151
+ _, top_ks = torch.topk(logits, k, dim=1)
152
+ if reduction == "mean":
153
+ return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154
+ elif reduction == "none":
155
+ return (top_ks == labels[:, None]).float().sum(dim=-1)
156
+
157
+ def on_train_epoch_start(self):
158
+ # save some memory
159
+ self.diffusion_model.model.to('cpu')
160
+
161
+ @torch.no_grad()
162
+ def write_logs(self, loss, logits, targets):
163
+ log_prefix = 'train' if self.training else 'val'
164
+ log = {}
165
+ log[f"{log_prefix}/loss"] = loss.mean()
166
+ log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167
+ logits, targets, k=1, reduction="mean"
168
+ )
169
+ log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170
+ logits, targets, k=5, reduction="mean"
171
+ )
172
+
173
+ self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174
+ self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175
+ self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176
+ lr = self.optimizers().param_groups[0]['lr']
177
+ self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178
+
179
+ def shared_step(self, batch, t=None):
180
+ x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181
+ targets = self.get_conditioning(batch)
182
+ if targets.dim() == 4:
183
+ targets = targets.argmax(dim=1)
184
+ if t is None:
185
+ t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186
+ else:
187
+ t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188
+ x_noisy = self.get_x_noisy(x, t)
189
+ logits = self(x_noisy, t)
190
+
191
+ loss = F.cross_entropy(logits, targets, reduction='none')
192
+
193
+ self.write_logs(loss.detach(), logits.detach(), targets.detach())
194
+
195
+ loss = loss.mean()
196
+ return loss, logits, x_noisy, targets
197
+
198
+ def training_step(self, batch, batch_idx):
199
+ loss, *_ = self.shared_step(batch)
200
+ return loss
201
+
202
+ def reset_noise_accs(self):
203
+ self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204
+ range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205
+
206
+ def on_validation_start(self):
207
+ self.reset_noise_accs()
208
+
209
+ @torch.no_grad()
210
+ def validation_step(self, batch, batch_idx):
211
+ loss, *_ = self.shared_step(batch)
212
+
213
+ for t in self.noisy_acc:
214
+ _, logits, _, targets = self.shared_step(batch, t)
215
+ self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216
+ self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217
+
218
+ return loss
219
+
220
+ def configure_optimizers(self):
221
+ optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222
+
223
+ if self.use_scheduler:
224
+ scheduler = instantiate_from_config(self.scheduler_config)
225
+
226
+ print("Setting up LambdaLR scheduler...")
227
+ scheduler = [
228
+ {
229
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230
+ 'interval': 'step',
231
+ 'frequency': 1
232
+ }]
233
+ return [optimizer], scheduler
234
+
235
+ return optimizer
236
+
237
+ @torch.no_grad()
238
+ def log_images(self, batch, N=8, *args, **kwargs):
239
+ log = dict()
240
+ x = self.get_input(batch, self.diffusion_model.first_stage_key)
241
+ log['inputs'] = x
242
+
243
+ y = self.get_conditioning(batch)
244
+
245
+ if self.label_key == 'class_label':
246
+ y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247
+ log['labels'] = y
248
+
249
+ if ismap(y):
250
+ log['labels'] = self.diffusion_model.to_rgb(y)
251
+
252
+ for step in range(self.log_steps):
253
+ current_time = step * self.log_time_interval
254
+
255
+ _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256
+
257
+ log[f'inputs@t{current_time}'] = x_noisy
258
+
259
+ pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260
+ pred = rearrange(pred, 'b h w c -> b c h w')
261
+
262
+ log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263
+
264
+ for key in log:
265
+ log[key] = log[key][:N]
266
+
267
+ return log
ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
9
+ extract_into_tensor
10
+
11
+
12
+ class DDIMSampler(object):
13
+ def __init__(self, model, schedule="linear", **kwargs):
14
+ super().__init__()
15
+ self.model = model
16
+ self.ddpm_num_timesteps = model.num_timesteps
17
+ self.schedule = schedule
18
+
19
+ def register_buffer(self, name, attr):
20
+ if type(attr) == torch.Tensor:
21
+ if attr.device != torch.device("cuda"):
22
+ attr = attr.to(torch.device("cuda"))
23
+ setattr(self, name, attr)
24
+
25
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
26
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
+ alphas_cumprod = self.model.alphas_cumprod
29
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31
+
32
+ self.register_buffer('betas', to_torch(self.model.betas))
33
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35
+
36
+ # calculations for diffusion q(x_t | x_{t-1}) and others
37
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42
+
43
+ # ddim sampling parameters
44
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
+ ddim_timesteps=self.ddim_timesteps,
46
+ eta=ddim_eta,verbose=verbose)
47
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
48
+ self.register_buffer('ddim_alphas', ddim_alphas)
49
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55
+
56
+ @torch.no_grad()
57
+ def sample(self,
58
+ S,
59
+ batch_size,
60
+ shape,
61
+ conditioning=None,
62
+ callback=None,
63
+ normals_sequence=None,
64
+ img_callback=None,
65
+ quantize_x0=False,
66
+ eta=0.,
67
+ mask=None,
68
+ x0=None,
69
+ temperature=1.,
70
+ noise_dropout=0.,
71
+ score_corrector=None,
72
+ corrector_kwargs=None,
73
+ verbose=True,
74
+ x_T=None,
75
+ log_every_t=100,
76
+ unconditional_guidance_scale=1.,
77
+ unconditional_conditioning=None,
78
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
79
+ **kwargs
80
+ ):
81
+ if conditioning is not None:
82
+ if isinstance(conditioning, dict):
83
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
84
+ if cbs != batch_size:
85
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
86
+ else:
87
+ if conditioning.shape[0] != batch_size:
88
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
89
+
90
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
91
+ # sampling
92
+ C, H, W = shape
93
+ size = (batch_size, C, H, W)
94
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
95
+
96
+ samples, intermediates = self.ddim_sampling(conditioning, size,
97
+ callback=callback,
98
+ img_callback=img_callback,
99
+ quantize_denoised=quantize_x0,
100
+ mask=mask, x0=x0,
101
+ ddim_use_original_steps=False,
102
+ noise_dropout=noise_dropout,
103
+ temperature=temperature,
104
+ score_corrector=score_corrector,
105
+ corrector_kwargs=corrector_kwargs,
106
+ x_T=x_T,
107
+ log_every_t=log_every_t,
108
+ unconditional_guidance_scale=unconditional_guidance_scale,
109
+ unconditional_conditioning=unconditional_conditioning,
110
+ )
111
+ return samples, intermediates
112
+
113
+ @torch.no_grad()
114
+ def ddim_sampling(self, cond, shape,
115
+ x_T=None, ddim_use_original_steps=False,
116
+ callback=None, timesteps=None, quantize_denoised=False,
117
+ mask=None, x0=None, img_callback=None, log_every_t=100,
118
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
119
+ unconditional_guidance_scale=1., unconditional_conditioning=None,):
120
+ device = self.model.betas.device
121
+ b = shape[0]
122
+ if x_T is None:
123
+ img = torch.randn(shape, device=device)
124
+ else:
125
+ img = x_T
126
+
127
+ if timesteps is None:
128
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
129
+ elif timesteps is not None and not ddim_use_original_steps:
130
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
131
+ timesteps = self.ddim_timesteps[:subset_end]
132
+
133
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
134
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
135
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
136
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
137
+
138
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
139
+
140
+ for i, step in enumerate(iterator):
141
+ index = total_steps - i - 1
142
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
143
+
144
+ if mask is not None:
145
+ assert x0 is not None
146
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
147
+ img = img_orig * mask + (1. - mask) * img
148
+
149
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
150
+ quantize_denoised=quantize_denoised, temperature=temperature,
151
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
152
+ corrector_kwargs=corrector_kwargs,
153
+ unconditional_guidance_scale=unconditional_guidance_scale,
154
+ unconditional_conditioning=unconditional_conditioning)
155
+ img, pred_x0 = outs
156
+ if callback: callback(i)
157
+ if img_callback: img_callback(pred_x0, i)
158
+
159
+ if index % log_every_t == 0 or index == total_steps - 1:
160
+ intermediates['x_inter'].append(img)
161
+ intermediates['pred_x0'].append(pred_x0)
162
+
163
+ return img, intermediates
164
+
165
+ @torch.no_grad()
166
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
167
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
168
+ unconditional_guidance_scale=1., unconditional_conditioning=None):
169
+ b, *_, device = *x.shape, x.device
170
+
171
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
172
+ e_t = self.model.apply_model(x, t, c)
173
+ else:
174
+ x_in = torch.cat([x] * 2)
175
+ t_in = torch.cat([t] * 2)
176
+ c_in = torch.cat([unconditional_conditioning, c])
177
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
178
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
179
+
180
+ if score_corrector is not None:
181
+ assert self.model.parameterization == "eps"
182
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
183
+
184
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
185
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
186
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
187
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
188
+ # select parameters corresponding to the currently considered timestep
189
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
190
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
191
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
192
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
193
+
194
+ # current prediction for x_0
195
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
196
+ if quantize_denoised:
197
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
198
+ # direction pointing to x_t
199
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
200
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
201
+ if noise_dropout > 0.:
202
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
203
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
204
+ return x_prev, pred_x0
205
+
206
+ @torch.no_grad()
207
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
208
+ # fast, but does not allow for exact reconstruction
209
+ # t serves as an index to gather the correct alphas
210
+ if use_original_steps:
211
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
212
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
213
+ else:
214
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
215
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
216
+
217
+ if noise is None:
218
+ noise = torch.randn_like(x0)
219
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
220
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
221
+
222
+ @torch.no_grad()
223
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
224
+ use_original_steps=False):
225
+
226
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
227
+ timesteps = timesteps[:t_start]
228
+
229
+ time_range = np.flip(timesteps)
230
+ total_steps = timesteps.shape[0]
231
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
232
+
233
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
234
+ x_dec = x_latent
235
+ for i, step in enumerate(iterator):
236
+ index = total_steps - i - 1
237
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
238
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
239
+ unconditional_guidance_scale=unconditional_guidance_scale,
240
+ unconditional_conditioning=unconditional_conditioning)
241
+ return x_dec
ldm/models/diffusion/ddpm.py ADDED
@@ -0,0 +1,1446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wild mixture of
3
+ https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
+ https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
+ https://github.com/CompVis/taming-transformers
6
+ -- merci
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import numpy as np
12
+ import pytorch_lightning as pl
13
+ from torch.optim.lr_scheduler import LambdaLR
14
+ from einops import rearrange, repeat
15
+ from contextlib import contextmanager
16
+ from functools import partial
17
+ from tqdm import tqdm
18
+ from torchvision.utils import make_grid
19
+ from pytorch_lightning.utilities.distributed import rank_zero_only
20
+
21
+ from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
22
+ from ldm.modules.ema import LitEma
23
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
24
+ from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL
25
+ from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
26
+ from ldm.models.diffusion.ddim import DDIMSampler
27
+
28
+
29
+ __conditioning_keys__ = {'concat': 'c_concat',
30
+ 'crossattn': 'c_crossattn',
31
+ 'adm': 'y'}
32
+
33
+
34
+ def disabled_train(self, mode=True):
35
+ """Overwrite model.train with this function to make sure train/eval mode
36
+ does not change anymore."""
37
+ return self
38
+
39
+
40
+ def uniform_on_device(r1, r2, shape, device):
41
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
42
+
43
+
44
+ class DDPM(pl.LightningModule):
45
+ # classic DDPM with Gaussian diffusion, in image space
46
+ def __init__(self,
47
+ unet_config,
48
+ timesteps=1000,
49
+ beta_schedule="linear",
50
+ loss_type="l2",
51
+ ckpt_path=None,
52
+ ignore_keys=[],
53
+ load_only_unet=False,
54
+ monitor="val/loss",
55
+ use_ema=True,
56
+ first_stage_key="image",
57
+ image_size=256,
58
+ channels=3,
59
+ log_every_t=100,
60
+ clip_denoised=True,
61
+ linear_start=1e-4,
62
+ linear_end=2e-2,
63
+ cosine_s=8e-3,
64
+ given_betas=None,
65
+ original_elbo_weight=0.,
66
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
67
+ l_simple_weight=1.,
68
+ conditioning_key=None,
69
+ parameterization="eps", # all assuming fixed variance schedules
70
+ scheduler_config=None,
71
+ use_positional_encodings=False,
72
+ learn_logvar=False,
73
+ logvar_init=0.,
74
+ ):
75
+ super().__init__()
76
+ assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"'
77
+ self.parameterization = parameterization
78
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
79
+ self.cond_stage_model = None
80
+ self.clip_denoised = clip_denoised
81
+ self.log_every_t = log_every_t
82
+ self.first_stage_key = first_stage_key
83
+ self.image_size = image_size # try conv?
84
+ self.channels = channels
85
+ self.use_positional_encodings = use_positional_encodings
86
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
87
+ count_params(self.model, verbose=True)
88
+ self.use_ema = use_ema
89
+ if self.use_ema:
90
+ self.model_ema = LitEma(self.model)
91
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
92
+
93
+ self.use_scheduler = scheduler_config is not None
94
+ if self.use_scheduler:
95
+ self.scheduler_config = scheduler_config
96
+
97
+ self.v_posterior = v_posterior
98
+ self.original_elbo_weight = original_elbo_weight
99
+ self.l_simple_weight = l_simple_weight
100
+
101
+ if monitor is not None:
102
+ self.monitor = monitor
103
+ if ckpt_path is not None:
104
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
105
+
106
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
107
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
108
+
109
+ self.loss_type = loss_type
110
+
111
+ self.learn_logvar = learn_logvar
112
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
113
+ if self.learn_logvar:
114
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
115
+
116
+
117
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
118
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
119
+ if exists(given_betas):
120
+ betas = given_betas
121
+ else:
122
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
123
+ cosine_s=cosine_s)
124
+ alphas = 1. - betas
125
+ alphas_cumprod = np.cumprod(alphas, axis=0)
126
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
127
+
128
+ timesteps, = betas.shape
129
+ self.num_timesteps = int(timesteps)
130
+ self.linear_start = linear_start
131
+ self.linear_end = linear_end
132
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
133
+
134
+ to_torch = partial(torch.tensor, dtype=torch.float32)
135
+
136
+ self.register_buffer('betas', to_torch(betas))
137
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
138
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
139
+
140
+ # calculations for diffusion q(x_t | x_{t-1}) and others
141
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
142
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
143
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
144
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
145
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
146
+
147
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
148
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
149
+ 1. - alphas_cumprod) + self.v_posterior * betas
150
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
151
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
152
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
153
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
154
+ self.register_buffer('posterior_mean_coef1', to_torch(
155
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
156
+ self.register_buffer('posterior_mean_coef2', to_torch(
157
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
158
+
159
+ if self.parameterization == "eps":
160
+ lvlb_weights = self.betas ** 2 / (
161
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
162
+ elif self.parameterization == "x0":
163
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
164
+ else:
165
+ raise NotImplementedError("mu not supported")
166
+ # TODO how to choose this term
167
+ lvlb_weights[0] = lvlb_weights[1]
168
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
169
+ assert not torch.isnan(self.lvlb_weights).all()
170
+
171
+ @contextmanager
172
+ def ema_scope(self, context=None):
173
+ if self.use_ema:
174
+ self.model_ema.store(self.model.parameters())
175
+ self.model_ema.copy_to(self.model)
176
+ if context is not None:
177
+ print(f"{context}: Switched to EMA weights")
178
+ try:
179
+ yield None
180
+ finally:
181
+ if self.use_ema:
182
+ self.model_ema.restore(self.model.parameters())
183
+ if context is not None:
184
+ print(f"{context}: Restored training weights")
185
+
186
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
187
+ sd = torch.load(path, map_location="cpu")
188
+ if "state_dict" in list(sd.keys()):
189
+ sd = sd["state_dict"]
190
+ keys = list(sd.keys())
191
+ for k in keys:
192
+ for ik in ignore_keys:
193
+ if k.startswith(ik):
194
+ print("Deleting key {} from state_dict.".format(k))
195
+ del sd[k]
196
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
197
+ sd, strict=False)
198
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
199
+ if len(missing) > 0:
200
+ print(f"Missing Keys: {missing}")
201
+ if len(unexpected) > 0:
202
+ print(f"Unexpected Keys: {unexpected}")
203
+
204
+ def q_mean_variance(self, x_start, t):
205
+ """
206
+ Get the distribution q(x_t | x_0).
207
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
208
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
209
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
210
+ """
211
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
212
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
213
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
214
+ return mean, variance, log_variance
215
+
216
+ def predict_start_from_noise(self, x_t, t, noise):
217
+ return (
218
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
219
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
220
+ )
221
+
222
+ def q_posterior(self, x_start, x_t, t):
223
+ posterior_mean = (
224
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
225
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
226
+ )
227
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
228
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
229
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
230
+
231
+ def p_mean_variance(self, x, t, clip_denoised: bool):
232
+ model_out = self.model(x, t)
233
+ if self.parameterization == "eps":
234
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
235
+ elif self.parameterization == "x0":
236
+ x_recon = model_out
237
+ if clip_denoised:
238
+ x_recon.clamp_(-1., 1.)
239
+
240
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
241
+ return model_mean, posterior_variance, posterior_log_variance
242
+
243
+ @torch.no_grad()
244
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
245
+ b, *_, device = *x.shape, x.device
246
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
247
+ noise = noise_like(x.shape, device, repeat_noise)
248
+ # no noise when t == 0
249
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
250
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
251
+
252
+ @torch.no_grad()
253
+ def p_sample_loop(self, shape, return_intermediates=False):
254
+ device = self.betas.device
255
+ b = shape[0]
256
+ img = torch.randn(shape, device=device)
257
+ intermediates = [img]
258
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
259
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
260
+ clip_denoised=self.clip_denoised)
261
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
262
+ intermediates.append(img)
263
+ if return_intermediates:
264
+ return img, intermediates
265
+ return img
266
+
267
+ @torch.no_grad()
268
+ def sample(self, batch_size=16, return_intermediates=False):
269
+ image_size = self.image_size
270
+ channels = self.channels
271
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
272
+ return_intermediates=return_intermediates)
273
+
274
+ def q_sample(self, x_start, t, noise=None):
275
+ noise = default(noise, lambda: torch.randn_like(x_start))
276
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
277
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
278
+
279
+ def get_loss(self, pred, target, mean=True):
280
+ if self.loss_type == 'l1':
281
+ loss = (target - pred).abs()
282
+ if mean:
283
+ loss = loss.mean()
284
+ elif self.loss_type == 'l2':
285
+ if mean:
286
+ loss = torch.nn.functional.mse_loss(target, pred)
287
+ else:
288
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
289
+ else:
290
+ raise NotImplementedError("unknown loss type '{loss_type}'")
291
+
292
+ return loss
293
+
294
+ def p_losses(self, x_start, t, noise=None):
295
+ noise = default(noise, lambda: torch.randn_like(x_start))
296
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
297
+ model_out = self.model(x_noisy, t)
298
+
299
+ loss_dict = {}
300
+ if self.parameterization == "eps":
301
+ target = noise
302
+ elif self.parameterization == "x0":
303
+ target = x_start
304
+ else:
305
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
306
+
307
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
308
+
309
+ log_prefix = 'train' if self.training else 'val'
310
+
311
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
312
+ loss_simple = loss.mean() * self.l_simple_weight
313
+
314
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
315
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
316
+
317
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
318
+
319
+ loss_dict.update({f'{log_prefix}/loss': loss})
320
+
321
+ return loss, loss_dict
322
+
323
+ def forward(self, x, *args, **kwargs):
324
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
325
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
326
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
327
+ return self.p_losses(x, t, *args, **kwargs)
328
+
329
+ def get_input(self, batch, k):
330
+ x = batch[k]
331
+ if len(x.shape) == 3:
332
+ x = x[..., None]
333
+ x = rearrange(x, 'b h w c -> b c h w')
334
+ x = x.to(memory_format=torch.contiguous_format).float()
335
+ return x
336
+
337
+ def shared_step(self, batch):
338
+ x = self.get_input(batch, self.first_stage_key)
339
+ loss, loss_dict = self(x)
340
+ return loss, loss_dict
341
+
342
+ def training_step(self, batch, batch_idx):
343
+ loss, loss_dict = self.shared_step(batch)
344
+
345
+ self.log_dict(loss_dict, prog_bar=True,
346
+ logger=True, on_step=True, on_epoch=True)
347
+
348
+ self.log("global_step", self.global_step,
349
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
350
+
351
+ if self.use_scheduler:
352
+ lr = self.optimizers().param_groups[0]['lr']
353
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
354
+
355
+ return loss
356
+
357
+ @torch.no_grad()
358
+ def validation_step(self, batch, batch_idx):
359
+ _, loss_dict_no_ema = self.shared_step(batch)
360
+ with self.ema_scope():
361
+ _, loss_dict_ema = self.shared_step(batch)
362
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
363
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
364
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
365
+
366
+ def on_train_batch_end(self, *args, **kwargs):
367
+ if self.use_ema:
368
+ self.model_ema(self.model)
369
+
370
+ def _get_rows_from_list(self, samples):
371
+ n_imgs_per_row = len(samples)
372
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
373
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
374
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
375
+ return denoise_grid
376
+
377
+ @torch.no_grad()
378
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
379
+ log = dict()
380
+ x = self.get_input(batch, self.first_stage_key)
381
+ N = min(x.shape[0], N)
382
+ n_row = min(x.shape[0], n_row)
383
+ x = x.to(self.device)[:N]
384
+ log["inputs"] = x
385
+
386
+ # get diffusion row
387
+ diffusion_row = list()
388
+ x_start = x[:n_row]
389
+
390
+ for t in range(self.num_timesteps):
391
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
392
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
393
+ t = t.to(self.device).long()
394
+ noise = torch.randn_like(x_start)
395
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
396
+ diffusion_row.append(x_noisy)
397
+
398
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
399
+
400
+ if sample:
401
+ # get denoise row
402
+ with self.ema_scope("Plotting"):
403
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
404
+
405
+ log["samples"] = samples
406
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
407
+
408
+ if return_keys:
409
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
410
+ return log
411
+ else:
412
+ return {key: log[key] for key in return_keys}
413
+ return log
414
+
415
+ def configure_optimizers(self):
416
+ lr = self.learning_rate
417
+ params = list(self.model.parameters())
418
+ if self.learn_logvar:
419
+ params = params + [self.logvar]
420
+ opt = torch.optim.AdamW(params, lr=lr)
421
+ return opt
422
+
423
+
424
+ class DiffusionWrapper(pl.LightningModule):
425
+ def __init__(self, diff_model_config, conditioning_key):
426
+ super().__init__()
427
+ self.diffusion_model = instantiate_from_config(diff_model_config)
428
+ self.conditioning_key = conditioning_key
429
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm']
430
+
431
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, features_adapter=None):
432
+ if self.conditioning_key is None:
433
+ out = self.diffusion_model(x, t, features_adapter=features_adapter)
434
+ elif self.conditioning_key == 'concat':
435
+ xc = torch.cat([x] + c_concat, dim=1)
436
+ out = self.diffusion_model(xc, t, features_adapter=features_adapter)
437
+ elif self.conditioning_key == 'crossattn':
438
+ cc = torch.cat(c_crossattn, 1)
439
+ out = self.diffusion_model(x, t, context=cc, features_adapter=features_adapter)
440
+ elif self.conditioning_key == 'hybrid':
441
+ xc = torch.cat([x] + c_concat, dim=1)
442
+ cc = torch.cat(c_crossattn, 1)
443
+ out = self.diffusion_model(xc, t, context=cc, features_adapter=features_adapter)
444
+ elif self.conditioning_key == 'adm':
445
+ cc = c_crossattn[0]
446
+ out = self.diffusion_model(x, t, y=cc, features_adapter=features_adapter)
447
+ else:
448
+ raise NotImplementedError()
449
+
450
+ return out
451
+
452
+
453
+ class LatentDiffusion(DDPM):
454
+ """main class"""
455
+ def __init__(self,
456
+ first_stage_config,
457
+ cond_stage_config,
458
+ unet_config,
459
+ num_timesteps_cond=None,
460
+ cond_stage_key="image",
461
+ cond_stage_trainable=False,
462
+ concat_mode=True,
463
+ cond_stage_forward=None,
464
+ conditioning_key=None,
465
+ scale_factor=1.0,
466
+ scale_by_std=False,
467
+ *args, **kwargs):
468
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
469
+ self.scale_by_std = scale_by_std
470
+ assert self.num_timesteps_cond <= kwargs['timesteps']
471
+ # for backwards compatibility after implementation of DiffusionWrapper
472
+ if conditioning_key is None:
473
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
474
+ if cond_stage_config == '__is_unconditional__':
475
+ conditioning_key = None
476
+ ckpt_path = kwargs.pop("ckpt_path", None)
477
+ ignore_keys = kwargs.pop("ignore_keys", [])
478
+ super().__init__(conditioning_key=conditioning_key, unet_config=unet_config, *args, **kwargs)
479
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
480
+ self.concat_mode = concat_mode
481
+ self.cond_stage_trainable = cond_stage_trainable
482
+ self.cond_stage_key = cond_stage_key
483
+ try:
484
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
485
+ except:
486
+ self.num_downs = 0
487
+ if not scale_by_std:
488
+ self.scale_factor = scale_factor
489
+ else:
490
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
491
+ self.instantiate_first_stage(first_stage_config)
492
+ self.instantiate_cond_stage(cond_stage_config)
493
+ self.cond_stage_forward = cond_stage_forward
494
+ self.clip_denoised = False
495
+ self.bbox_tokenizer = None
496
+
497
+ self.restarted_from_ckpt = False
498
+ if ckpt_path is not None:
499
+ self.init_from_ckpt(ckpt_path, ignore_keys)
500
+ self.restarted_from_ckpt = True
501
+
502
+ def make_cond_schedule(self, ):
503
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
504
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
505
+ self.cond_ids[:self.num_timesteps_cond] = ids
506
+
507
+ @rank_zero_only
508
+ @torch.no_grad()
509
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
510
+ # only for very first batch
511
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
512
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
513
+ # set rescale weight to 1./std of encodings
514
+ print("### USING STD-RESCALING ###")
515
+ x = super().get_input(batch, self.first_stage_key)
516
+ x = x.to(self.device)
517
+ encoder_posterior = self.encode_first_stage(x)
518
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
519
+ del self.scale_factor
520
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
521
+ print(f"setting self.scale_factor to {self.scale_factor}")
522
+ print("### USING STD-RESCALING ###")
523
+
524
+ def register_schedule(self,
525
+ given_betas=None, beta_schedule="linear", timesteps=1000,
526
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
527
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
528
+
529
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
530
+ if self.shorten_cond_schedule:
531
+ self.make_cond_schedule()
532
+
533
+ def instantiate_first_stage(self, config):
534
+ model = instantiate_from_config(config)
535
+ self.first_stage_model = model.eval()
536
+ self.first_stage_model.train = disabled_train
537
+ for param in self.first_stage_model.parameters():
538
+ param.requires_grad = False
539
+
540
+ def instantiate_cond_stage(self, config):
541
+ if not self.cond_stage_trainable:
542
+ if config == "__is_first_stage__":
543
+ print("Using first stage also as cond stage.")
544
+ self.cond_stage_model = self.first_stage_model
545
+ elif config == "__is_unconditional__":
546
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
547
+ self.cond_stage_model = None
548
+ # self.be_unconditional = True
549
+ else:
550
+ model = instantiate_from_config(config)
551
+ self.cond_stage_model = model.eval()
552
+ self.cond_stage_model.train = disabled_train
553
+ for param in self.cond_stage_model.parameters():
554
+ param.requires_grad = False
555
+ else:
556
+ assert config != '__is_first_stage__'
557
+ assert config != '__is_unconditional__'
558
+ model = instantiate_from_config(config)
559
+ self.cond_stage_model = model
560
+
561
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
562
+ denoise_row = []
563
+ for zd in tqdm(samples, desc=desc):
564
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
565
+ force_not_quantize=force_no_decoder_quantization))
566
+ n_imgs_per_row = len(denoise_row)
567
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
568
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
569
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
570
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
571
+ return denoise_grid
572
+
573
+ def get_first_stage_encoding(self, encoder_posterior):
574
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
575
+ z = encoder_posterior.sample()
576
+ elif isinstance(encoder_posterior, torch.Tensor):
577
+ z = encoder_posterior
578
+ else:
579
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
580
+ return self.scale_factor * z
581
+
582
+ def get_learned_conditioning(self, c):
583
+ if self.cond_stage_forward is None:
584
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
585
+ c = self.cond_stage_model.encode(c)
586
+ if isinstance(c, DiagonalGaussianDistribution):
587
+ c = c.mode()
588
+ else:
589
+ c = self.cond_stage_model(c)
590
+ else:
591
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
592
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
593
+ return c
594
+
595
+ def meshgrid(self, h, w):
596
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
597
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
598
+
599
+ arr = torch.cat([y, x], dim=-1)
600
+ return arr
601
+
602
+ def delta_border(self, h, w):
603
+ """
604
+ :param h: height
605
+ :param w: width
606
+ :return: normalized distance to image border,
607
+ wtith min distance = 0 at border and max dist = 0.5 at image center
608
+ """
609
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
610
+ arr = self.meshgrid(h, w) / lower_right_corner
611
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
612
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
613
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
614
+ return edge_dist
615
+
616
+ def get_weighting(self, h, w, Ly, Lx, device):
617
+ weighting = self.delta_border(h, w)
618
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
619
+ self.split_input_params["clip_max_weight"], )
620
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
621
+
622
+ if self.split_input_params["tie_braker"]:
623
+ L_weighting = self.delta_border(Ly, Lx)
624
+ L_weighting = torch.clip(L_weighting,
625
+ self.split_input_params["clip_min_tie_weight"],
626
+ self.split_input_params["clip_max_tie_weight"])
627
+
628
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
629
+ weighting = weighting * L_weighting
630
+ return weighting
631
+
632
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
633
+ """
634
+ :param x: img of size (bs, c, h, w)
635
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
636
+ """
637
+ bs, nc, h, w = x.shape
638
+
639
+ # number of crops in image
640
+ Ly = (h - kernel_size[0]) // stride[0] + 1
641
+ Lx = (w - kernel_size[1]) // stride[1] + 1
642
+
643
+ if uf == 1 and df == 1:
644
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
645
+ unfold = torch.nn.Unfold(**fold_params)
646
+
647
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
648
+
649
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
650
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
651
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
652
+
653
+ elif uf > 1 and df == 1:
654
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
655
+ unfold = torch.nn.Unfold(**fold_params)
656
+
657
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
658
+ dilation=1, padding=0,
659
+ stride=(stride[0] * uf, stride[1] * uf))
660
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
661
+
662
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
663
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
664
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
665
+
666
+ elif df > 1 and uf == 1:
667
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
668
+ unfold = torch.nn.Unfold(**fold_params)
669
+
670
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
671
+ dilation=1, padding=0,
672
+ stride=(stride[0] // df, stride[1] // df))
673
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
674
+
675
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
676
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
677
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
678
+
679
+ else:
680
+ raise NotImplementedError
681
+
682
+ return fold, unfold, normalization, weighting
683
+
684
+ @torch.no_grad()
685
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
686
+ cond_key=None, return_original_cond=False, bs=None):
687
+ x = super().get_input(batch, k)
688
+ if bs is not None:
689
+ x = x[:bs]
690
+ x = x.to(self.device)
691
+ encoder_posterior = self.encode_first_stage(x)
692
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
693
+
694
+ if self.model.conditioning_key is not None:
695
+ if cond_key is None:
696
+ cond_key = self.cond_stage_key
697
+ if cond_key != self.first_stage_key:
698
+ if cond_key in ['caption', 'coordinates_bbox']:
699
+ xc = batch[cond_key]
700
+ elif cond_key == 'class_label':
701
+ xc = batch
702
+ else:
703
+ xc = super().get_input(batch, cond_key).to(self.device)
704
+ else:
705
+ xc = x
706
+ if not self.cond_stage_trainable or force_c_encode:
707
+ if isinstance(xc, dict) or isinstance(xc, list):
708
+ # import pudb; pudb.set_trace()
709
+ c = self.get_learned_conditioning(xc)
710
+ else:
711
+ c = self.get_learned_conditioning(xc.to(self.device))
712
+ else:
713
+ c = xc
714
+ if bs is not None:
715
+ c = c[:bs]
716
+
717
+ if self.use_positional_encodings:
718
+ pos_x, pos_y = self.compute_latent_shifts(batch)
719
+ ckey = __conditioning_keys__[self.model.conditioning_key]
720
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
721
+
722
+ else:
723
+ c = None
724
+ xc = None
725
+ if self.use_positional_encodings:
726
+ pos_x, pos_y = self.compute_latent_shifts(batch)
727
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
728
+ out = [z, c]
729
+ if return_first_stage_outputs:
730
+ xrec = self.decode_first_stage(z)
731
+ out.extend([x, xrec])
732
+ if return_original_cond:
733
+ out.append(xc)
734
+ return out
735
+
736
+ @torch.no_grad()
737
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
738
+ if predict_cids:
739
+ if z.dim() == 4:
740
+ z = torch.argmax(z.exp(), dim=1).long()
741
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
742
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
743
+
744
+ z = 1. / self.scale_factor * z
745
+
746
+ if hasattr(self, "split_input_params"):
747
+ if self.split_input_params["patch_distributed_vq"]:
748
+ ks = self.split_input_params["ks"] # eg. (128, 128)
749
+ stride = self.split_input_params["stride"] # eg. (64, 64)
750
+ uf = self.split_input_params["vqf"]
751
+ bs, nc, h, w = z.shape
752
+ if ks[0] > h or ks[1] > w:
753
+ ks = (min(ks[0], h), min(ks[1], w))
754
+ print("reducing Kernel")
755
+
756
+ if stride[0] > h or stride[1] > w:
757
+ stride = (min(stride[0], h), min(stride[1], w))
758
+ print("reducing stride")
759
+
760
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
761
+
762
+ z = unfold(z) # (bn, nc * prod(**ks), L)
763
+ # 1. Reshape to img shape
764
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
765
+
766
+ # 2. apply model loop over last dim
767
+ if isinstance(self.first_stage_model, VQModelInterface):
768
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
769
+ force_not_quantize=predict_cids or force_not_quantize)
770
+ for i in range(z.shape[-1])]
771
+ else:
772
+
773
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
774
+ for i in range(z.shape[-1])]
775
+
776
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
777
+ o = o * weighting
778
+ # Reverse 1. reshape to img shape
779
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
780
+ # stitch crops together
781
+ decoded = fold(o)
782
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
783
+ return decoded
784
+ else:
785
+ if isinstance(self.first_stage_model, VQModelInterface):
786
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
787
+ else:
788
+ return self.first_stage_model.decode(z)
789
+
790
+ else:
791
+ if isinstance(self.first_stage_model, VQModelInterface):
792
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
793
+ else:
794
+ return self.first_stage_model.decode(z)
795
+
796
+ # same as above but without decorator
797
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
798
+ if predict_cids:
799
+ if z.dim() == 4:
800
+ z = torch.argmax(z.exp(), dim=1).long()
801
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
802
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
803
+
804
+ z = 1. / self.scale_factor * z
805
+
806
+ if hasattr(self, "split_input_params"):
807
+ if self.split_input_params["patch_distributed_vq"]:
808
+ ks = self.split_input_params["ks"] # eg. (128, 128)
809
+ stride = self.split_input_params["stride"] # eg. (64, 64)
810
+ uf = self.split_input_params["vqf"]
811
+ bs, nc, h, w = z.shape
812
+ if ks[0] > h or ks[1] > w:
813
+ ks = (min(ks[0], h), min(ks[1], w))
814
+ print("reducing Kernel")
815
+
816
+ if stride[0] > h or stride[1] > w:
817
+ stride = (min(stride[0], h), min(stride[1], w))
818
+ print("reducing stride")
819
+
820
+ fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf)
821
+
822
+ z = unfold(z) # (bn, nc * prod(**ks), L)
823
+ # 1. Reshape to img shape
824
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
825
+
826
+ # 2. apply model loop over last dim
827
+ if isinstance(self.first_stage_model, VQModelInterface):
828
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
829
+ force_not_quantize=predict_cids or force_not_quantize)
830
+ for i in range(z.shape[-1])]
831
+ else:
832
+
833
+ output_list = [self.first_stage_model.decode(z[:, :, :, :, i])
834
+ for i in range(z.shape[-1])]
835
+
836
+ o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L)
837
+ o = o * weighting
838
+ # Reverse 1. reshape to img shape
839
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
840
+ # stitch crops together
841
+ decoded = fold(o)
842
+ decoded = decoded / normalization # norm is shape (1, 1, h, w)
843
+ return decoded
844
+ else:
845
+ if isinstance(self.first_stage_model, VQModelInterface):
846
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
847
+ else:
848
+ return self.first_stage_model.decode(z)
849
+
850
+ else:
851
+ if isinstance(self.first_stage_model, VQModelInterface):
852
+ return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
853
+ else:
854
+ return self.first_stage_model.decode(z)
855
+
856
+ @torch.no_grad()
857
+ def encode_first_stage(self, x):
858
+ if hasattr(self, "split_input_params"):
859
+ if self.split_input_params["patch_distributed_vq"]:
860
+ ks = self.split_input_params["ks"] # eg. (128, 128)
861
+ stride = self.split_input_params["stride"] # eg. (64, 64)
862
+ df = self.split_input_params["vqf"]
863
+ self.split_input_params['original_image_size'] = x.shape[-2:]
864
+ bs, nc, h, w = x.shape
865
+ if ks[0] > h or ks[1] > w:
866
+ ks = (min(ks[0], h), min(ks[1], w))
867
+ print("reducing Kernel")
868
+
869
+ if stride[0] > h or stride[1] > w:
870
+ stride = (min(stride[0], h), min(stride[1], w))
871
+ print("reducing stride")
872
+
873
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df)
874
+ z = unfold(x) # (bn, nc * prod(**ks), L)
875
+ # Reshape to img shape
876
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
877
+
878
+ output_list = [self.first_stage_model.encode(z[:, :, :, :, i])
879
+ for i in range(z.shape[-1])]
880
+
881
+ o = torch.stack(output_list, axis=-1)
882
+ o = o * weighting
883
+
884
+ # Reverse reshape to img shape
885
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
886
+ # stitch crops together
887
+ decoded = fold(o)
888
+ decoded = decoded / normalization
889
+ return decoded
890
+
891
+ else:
892
+ return self.first_stage_model.encode(x)
893
+ else:
894
+ return self.first_stage_model.encode(x)
895
+
896
+ def shared_step(self, batch, **kwargs):
897
+ x, c = self.get_input(batch, self.first_stage_key)
898
+ loss = self(x, c)
899
+ return loss
900
+
901
+ def forward(self, x, c, features_adapter=None, *args, **kwargs):
902
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
903
+
904
+ return self.p_losses(x, c, t, features_adapter, *args, **kwargs)
905
+
906
+ def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
907
+ def rescale_bbox(bbox):
908
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
909
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
910
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
911
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
912
+ return x0, y0, w, h
913
+
914
+ return [rescale_bbox(b) for b in bboxes]
915
+
916
+ def apply_model(self, x_noisy, t, cond, features_adapter=None, return_ids=False):
917
+
918
+ if isinstance(cond, dict):
919
+ # hybrid case, cond is exptected to be a dict
920
+ pass
921
+ else:
922
+ if not isinstance(cond, list):
923
+ cond = [cond]
924
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
925
+ cond = {key: cond}
926
+
927
+ if hasattr(self, "split_input_params"):
928
+ assert len(cond) == 1 # todo can only deal with one conditioning atm
929
+ assert not return_ids
930
+ ks = self.split_input_params["ks"] # eg. (128, 128)
931
+ stride = self.split_input_params["stride"] # eg. (64, 64)
932
+
933
+ h, w = x_noisy.shape[-2:]
934
+
935
+ fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride)
936
+
937
+ z = unfold(x_noisy) # (bn, nc * prod(**ks), L)
938
+ # Reshape to img shape
939
+ z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
940
+ z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]
941
+
942
+ if self.cond_stage_key in ["image", "LR_image", "segmentation",
943
+ 'bbox_img'] and self.model.conditioning_key: # todo check for completeness
944
+ c_key = next(iter(cond.keys())) # get key
945
+ c = next(iter(cond.values())) # get value
946
+ assert (len(c) == 1) # todo extend to list with more than one elem
947
+ c = c[0] # get element
948
+
949
+ c = unfold(c)
950
+ c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L )
951
+
952
+ cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])]
953
+
954
+ elif self.cond_stage_key == 'coordinates_bbox':
955
+ assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size'
956
+
957
+ # assuming padding of unfold is always 0 and its dilation is always 1
958
+ n_patches_per_row = int((w - ks[0]) / stride[0] + 1)
959
+ full_img_h, full_img_w = self.split_input_params['original_image_size']
960
+ # as we are operating on latents, we need the factor from the original image size to the
961
+ # spatial latent size to properly rescale the crops for regenerating the bbox annotations
962
+ num_downs = self.first_stage_model.encoder.num_resolutions - 1
963
+ rescale_latent = 2 ** (num_downs)
964
+
965
+ # get top left postions of patches as conforming for the bbbox tokenizer, therefore we
966
+ # need to rescale the tl patch coordinates to be in between (0,1)
967
+ tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w,
968
+ rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h)
969
+ for patch_nr in range(z.shape[-1])]
970
+
971
+ # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w)
972
+ patch_limits = [(x_tl, y_tl,
973
+ rescale_latent * ks[0] / full_img_w,
974
+ rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates]
975
+ # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates]
976
+
977
+ # tokenize crop coordinates for the bounding boxes of the respective patches
978
+ patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device)
979
+ for bbox in patch_limits] # list of length l with tensors of shape (1, 2)
980
+ print(patch_limits_tknzd[0].shape)
981
+ # cut tknzd crop position from conditioning
982
+ assert isinstance(cond, dict), 'cond must be dict to be fed into model'
983
+ cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device)
984
+ print(cut_cond.shape)
985
+
986
+ adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd])
987
+ adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n')
988
+ print(adapted_cond.shape)
989
+ adapted_cond = self.get_learned_conditioning(adapted_cond)
990
+ print(adapted_cond.shape)
991
+ adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1])
992
+ print(adapted_cond.shape)
993
+
994
+ cond_list = [{'c_crossattn': [e]} for e in adapted_cond]
995
+
996
+ else:
997
+ cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient
998
+
999
+ # apply model by loop over crops
1000
+ if features_adapter is not None:
1001
+ output_list = [self.model(z_list[i], t, **cond_list[i], features_adapter=features_adapter) for i in range(z.shape[-1])]
1002
+ else:
1003
+ output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])]
1004
+ assert not isinstance(output_list[0],
1005
+ tuple) # todo cant deal with multiple model outputs check this never happens
1006
+
1007
+ o = torch.stack(output_list, axis=-1)
1008
+ o = o * weighting
1009
+ # Reverse reshape to img shape
1010
+ o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L)
1011
+ # stitch crops together
1012
+ x_recon = fold(o) / normalization
1013
+
1014
+ else:
1015
+ if features_adapter is not None:
1016
+ x_recon = self.model(x_noisy, t, **cond, features_adapter=features_adapter)
1017
+ else:
1018
+ x_recon = self.model(x_noisy, t, **cond)
1019
+
1020
+ if isinstance(x_recon, tuple) and not return_ids:
1021
+ return x_recon[0]
1022
+ else:
1023
+ return x_recon
1024
+
1025
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
1026
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
1027
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
1028
+
1029
+ def _prior_bpd(self, x_start):
1030
+ """
1031
+ Get the prior KL term for the variational lower-bound, measured in
1032
+ bits-per-dim.
1033
+ This term can't be optimized, as it only depends on the encoder.
1034
+ :param x_start: the [N x C x ...] tensor of inputs.
1035
+ :return: a batch of [N] KL values (in bits), one per batch element.
1036
+ """
1037
+ batch_size = x_start.shape[0]
1038
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1039
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1040
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
1041
+ return mean_flat(kl_prior) / np.log(2.0)
1042
+
1043
+ def p_losses(self, x_start, cond, t, features_adapter=None, noise=None):
1044
+ noise = default(noise, lambda: torch.randn_like(x_start))
1045
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
1046
+ model_output = self.apply_model(x_noisy, t, cond, features_adapter)
1047
+
1048
+ loss_dict = {}
1049
+ prefix = 'train' if self.training else 'val'
1050
+
1051
+ if self.parameterization == "x0":
1052
+ target = x_start
1053
+ elif self.parameterization == "eps":
1054
+ target = noise
1055
+ else:
1056
+ raise NotImplementedError()
1057
+
1058
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
1059
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
1060
+
1061
+ logvar_t = self.logvar[t].to(self.device)
1062
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
1063
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
1064
+ if self.learn_logvar:
1065
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
1066
+ loss_dict.update({'logvar': self.logvar.data.mean()})
1067
+
1068
+ loss = self.l_simple_weight * loss.mean()
1069
+
1070
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
1071
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
1072
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
1073
+ loss += (self.original_elbo_weight * loss_vlb)
1074
+ loss_dict.update({f'{prefix}/loss': loss})
1075
+
1076
+ return loss, loss_dict
1077
+
1078
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
1079
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
1080
+ t_in = t
1081
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
1082
+
1083
+ if score_corrector is not None:
1084
+ assert self.parameterization == "eps"
1085
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
1086
+
1087
+ if return_codebook_ids:
1088
+ model_out, logits = model_out
1089
+
1090
+ if self.parameterization == "eps":
1091
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
1092
+ elif self.parameterization == "x0":
1093
+ x_recon = model_out
1094
+ else:
1095
+ raise NotImplementedError()
1096
+
1097
+ if clip_denoised:
1098
+ x_recon.clamp_(-1., 1.)
1099
+ if quantize_denoised:
1100
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
1101
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
1102
+ if return_codebook_ids:
1103
+ return model_mean, posterior_variance, posterior_log_variance, logits
1104
+ elif return_x0:
1105
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
1106
+ else:
1107
+ return model_mean, posterior_variance, posterior_log_variance
1108
+
1109
+ @torch.no_grad()
1110
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
1111
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
1112
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
1113
+ b, *_, device = *x.shape, x.device
1114
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
1115
+ return_codebook_ids=return_codebook_ids,
1116
+ quantize_denoised=quantize_denoised,
1117
+ return_x0=return_x0,
1118
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1119
+ if return_codebook_ids:
1120
+ raise DeprecationWarning("Support dropped.")
1121
+ model_mean, _, model_log_variance, logits = outputs
1122
+ elif return_x0:
1123
+ model_mean, _, model_log_variance, x0 = outputs
1124
+ else:
1125
+ model_mean, _, model_log_variance = outputs
1126
+
1127
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
1128
+ if noise_dropout > 0.:
1129
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
1130
+ # no noise when t == 0
1131
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
1132
+
1133
+ if return_codebook_ids:
1134
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
1135
+ if return_x0:
1136
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
1137
+ else:
1138
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
1139
+
1140
+ @torch.no_grad()
1141
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
1142
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
1143
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
1144
+ log_every_t=None):
1145
+ if not log_every_t:
1146
+ log_every_t = self.log_every_t
1147
+ timesteps = self.num_timesteps
1148
+ if batch_size is not None:
1149
+ b = batch_size if batch_size is not None else shape[0]
1150
+ shape = [batch_size] + list(shape)
1151
+ else:
1152
+ b = batch_size = shape[0]
1153
+ if x_T is None:
1154
+ img = torch.randn(shape, device=self.device)
1155
+ else:
1156
+ img = x_T
1157
+ intermediates = []
1158
+ if cond is not None:
1159
+ if isinstance(cond, dict):
1160
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1161
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1162
+ else:
1163
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1164
+
1165
+ if start_T is not None:
1166
+ timesteps = min(timesteps, start_T)
1167
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1168
+ total=timesteps) if verbose else reversed(
1169
+ range(0, timesteps))
1170
+ if type(temperature) == float:
1171
+ temperature = [temperature] * timesteps
1172
+
1173
+ for i in iterator:
1174
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1175
+ if self.shorten_cond_schedule:
1176
+ assert self.model.conditioning_key != 'hybrid'
1177
+ tc = self.cond_ids[ts].to(cond.device)
1178
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1179
+
1180
+ img, x0_partial = self.p_sample(img, cond, ts,
1181
+ clip_denoised=self.clip_denoised,
1182
+ quantize_denoised=quantize_denoised, return_x0=True,
1183
+ temperature=temperature[i], noise_dropout=noise_dropout,
1184
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1185
+ if mask is not None:
1186
+ assert x0 is not None
1187
+ img_orig = self.q_sample(x0, ts)
1188
+ img = img_orig * mask + (1. - mask) * img
1189
+
1190
+ if i % log_every_t == 0 or i == timesteps - 1:
1191
+ intermediates.append(x0_partial)
1192
+ if callback: callback(i)
1193
+ if img_callback: img_callback(img, i)
1194
+ return img, intermediates
1195
+
1196
+ @torch.no_grad()
1197
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
1198
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1199
+ mask=None, x0=None, img_callback=None, start_T=None,
1200
+ log_every_t=None):
1201
+
1202
+ if not log_every_t:
1203
+ log_every_t = self.log_every_t
1204
+ device = self.betas.device
1205
+ b = shape[0]
1206
+ if x_T is None:
1207
+ img = torch.randn(shape, device=device)
1208
+ else:
1209
+ img = x_T
1210
+
1211
+ intermediates = [img]
1212
+ if timesteps is None:
1213
+ timesteps = self.num_timesteps
1214
+
1215
+ if start_T is not None:
1216
+ timesteps = min(timesteps, start_T)
1217
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1218
+ range(0, timesteps))
1219
+
1220
+ if mask is not None:
1221
+ assert x0 is not None
1222
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1223
+
1224
+ for i in iterator:
1225
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
1226
+ if self.shorten_cond_schedule:
1227
+ assert self.model.conditioning_key != 'hybrid'
1228
+ tc = self.cond_ids[ts].to(cond.device)
1229
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1230
+
1231
+ img = self.p_sample(img, cond, ts,
1232
+ clip_denoised=self.clip_denoised,
1233
+ quantize_denoised=quantize_denoised)
1234
+ if mask is not None:
1235
+ img_orig = self.q_sample(x0, ts)
1236
+ img = img_orig * mask + (1. - mask) * img
1237
+
1238
+ if i % log_every_t == 0 or i == timesteps - 1:
1239
+ intermediates.append(img)
1240
+ if callback: callback(i)
1241
+ if img_callback: img_callback(img, i)
1242
+
1243
+ if return_intermediates:
1244
+ return img, intermediates
1245
+ return img
1246
+
1247
+ @torch.no_grad()
1248
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1249
+ verbose=True, timesteps=None, quantize_denoised=False,
1250
+ mask=None, x0=None, shape=None,**kwargs):
1251
+ if shape is None:
1252
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
1253
+ if cond is not None:
1254
+ if isinstance(cond, dict):
1255
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1256
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1257
+ else:
1258
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1259
+ return self.p_sample_loop(cond,
1260
+ shape,
1261
+ return_intermediates=return_intermediates, x_T=x_T,
1262
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1263
+ mask=mask, x0=x0)
1264
+
1265
+ @torch.no_grad()
1266
+ def sample_log(self,cond,batch_size,ddim, ddim_steps,**kwargs):
1267
+
1268
+ if ddim:
1269
+ ddim_sampler = DDIMSampler(self)
1270
+ shape = (self.channels, self.image_size, self.image_size)
1271
+ samples, intermediates =ddim_sampler.sample(ddim_steps,batch_size,
1272
+ shape,cond,verbose=False,**kwargs)
1273
+
1274
+ else:
1275
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1276
+ return_intermediates=True,**kwargs)
1277
+
1278
+ return samples, intermediates
1279
+
1280
+
1281
+ @torch.no_grad()
1282
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1283
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1284
+ plot_diffusion_rows=True, **kwargs):
1285
+
1286
+ use_ddim = ddim_steps is not None
1287
+
1288
+ log = dict()
1289
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1290
+ return_first_stage_outputs=True,
1291
+ force_c_encode=True,
1292
+ return_original_cond=True,
1293
+ bs=N)
1294
+ N = min(x.shape[0], N)
1295
+ n_row = min(x.shape[0], n_row)
1296
+ log["inputs"] = x
1297
+ log["reconstruction"] = xrec
1298
+ if self.model.conditioning_key is not None:
1299
+ if hasattr(self.cond_stage_model, "decode"):
1300
+ xc = self.cond_stage_model.decode(c)
1301
+ log["conditioning"] = xc
1302
+ elif self.cond_stage_key in ["caption"]:
1303
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["caption"])
1304
+ log["conditioning"] = xc
1305
+ elif self.cond_stage_key == 'class_label':
1306
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
1307
+ log['conditioning'] = xc
1308
+ elif isimage(xc):
1309
+ log["conditioning"] = xc
1310
+ if ismap(xc):
1311
+ log["original_conditioning"] = self.to_rgb(xc)
1312
+
1313
+ if plot_diffusion_rows:
1314
+ # get diffusion row
1315
+ diffusion_row = list()
1316
+ z_start = z[:n_row]
1317
+ for t in range(self.num_timesteps):
1318
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1319
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1320
+ t = t.to(self.device).long()
1321
+ noise = torch.randn_like(z_start)
1322
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1323
+ diffusion_row.append(self.decode_first_stage(z_noisy))
1324
+
1325
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1326
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1327
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1328
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1329
+ log["diffusion_row"] = diffusion_grid
1330
+
1331
+ if sample:
1332
+ # get denoise row
1333
+ with self.ema_scope("Plotting"):
1334
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1335
+ ddim_steps=ddim_steps,eta=ddim_eta)
1336
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1337
+ x_samples = self.decode_first_stage(samples)
1338
+ log["samples"] = x_samples
1339
+ if plot_denoise_rows:
1340
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1341
+ log["denoise_row"] = denoise_grid
1342
+
1343
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1344
+ self.first_stage_model, IdentityFirstStage):
1345
+ # also display when quantizing x0 while sampling
1346
+ with self.ema_scope("Plotting Quantized Denoised"):
1347
+ samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim,
1348
+ ddim_steps=ddim_steps,eta=ddim_eta,
1349
+ quantize_denoised=True)
1350
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1351
+ # quantize_denoised=True)
1352
+ x_samples = self.decode_first_stage(samples.to(self.device))
1353
+ log["samples_x0_quantized"] = x_samples
1354
+
1355
+ if inpaint:
1356
+ # make a simple center square
1357
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
1358
+ mask = torch.ones(N, h, w).to(self.device)
1359
+ # zeros will be filled in
1360
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1361
+ mask = mask[:, None, ...]
1362
+ with self.ema_scope("Plotting Inpaint"):
1363
+
1364
+ samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta,
1365
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1366
+ x_samples = self.decode_first_stage(samples.to(self.device))
1367
+ log["samples_inpainting"] = x_samples
1368
+ log["mask"] = mask
1369
+
1370
+ # outpaint
1371
+ with self.ema_scope("Plotting Outpaint"):
1372
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta,
1373
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1374
+ x_samples = self.decode_first_stage(samples.to(self.device))
1375
+ log["samples_outpainting"] = x_samples
1376
+
1377
+ if plot_progressive_rows:
1378
+ with self.ema_scope("Plotting Progressives"):
1379
+ img, progressives = self.progressive_denoising(c,
1380
+ shape=(self.channels, self.image_size, self.image_size),
1381
+ batch_size=N)
1382
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1383
+ log["progressive_row"] = prog_row
1384
+
1385
+ if return_keys:
1386
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1387
+ return log
1388
+ else:
1389
+ return {key: log[key] for key in return_keys}
1390
+ return log
1391
+
1392
+ def configure_optimizers(self):
1393
+ lr = self.learning_rate
1394
+ params = list(self.model.parameters())
1395
+ if self.cond_stage_trainable:
1396
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1397
+ params = params + list(self.cond_stage_model.parameters())
1398
+ if self.learn_logvar:
1399
+ print('Diffusion model optimizing logvar')
1400
+ params.append(self.logvar)
1401
+ opt = torch.optim.AdamW(params, lr=lr)
1402
+ if self.use_scheduler:
1403
+ assert 'target' in self.scheduler_config
1404
+ scheduler = instantiate_from_config(self.scheduler_config)
1405
+
1406
+ print("Setting up LambdaLR scheduler...")
1407
+ scheduler = [
1408
+ {
1409
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1410
+ 'interval': 'step',
1411
+ 'frequency': 1
1412
+ }]
1413
+ return [opt], scheduler
1414
+ return opt
1415
+
1416
+ @torch.no_grad()
1417
+ def to_rgb(self, x):
1418
+ x = x.float()
1419
+ if not hasattr(self, "colorize"):
1420
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1421
+ x = nn.functional.conv2d(x, weight=self.colorize)
1422
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1423
+ return x
1424
+
1425
+ class Layout2ImgDiffusion(LatentDiffusion):
1426
+ # TODO: move all layout-specific hacks to this class
1427
+ def __init__(self, cond_stage_key, *args, **kwargs):
1428
+ assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"'
1429
+ super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs)
1430
+
1431
+ def log_images(self, batch, N=8, *args, **kwargs):
1432
+ logs = super().log_images(batch=batch, N=N, *args, **kwargs)
1433
+
1434
+ key = 'train' if self.training else 'validation'
1435
+ dset = self.trainer.datamodule.datasets[key]
1436
+ mapper = dset.conditional_builders[self.cond_stage_key]
1437
+
1438
+ bbox_imgs = []
1439
+ map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno))
1440
+ for tknzd_bbox in batch[self.cond_stage_key][:N]:
1441
+ bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256))
1442
+ bbox_imgs.append(bboximg)
1443
+
1444
+ cond_img = torch.stack(bbox_imgs, dim=0)
1445
+ logs['bbox_image'] = cond_img
1446
+ return logs
ldm/models/diffusion/dpm_solver/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ from .sampler import DPMSolverSampler
ldm/models/diffusion/dpm_solver/dpm_solver.py ADDED
@@ -0,0 +1,1184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+
5
+
6
+ class NoiseScheduleVP:
7
+ def __init__(
8
+ self,
9
+ schedule='discrete',
10
+ betas=None,
11
+ alphas_cumprod=None,
12
+ continuous_beta_0=0.1,
13
+ continuous_beta_1=20.,
14
+ ):
15
+ """Create a wrapper class for the forward SDE (VP type).
16
+
17
+ ***
18
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
+ ***
21
+
22
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
23
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
24
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
25
+
26
+ log_alpha_t = self.marginal_log_mean_coeff(t)
27
+ sigma_t = self.marginal_std(t)
28
+ lambda_t = self.marginal_lambda(t)
29
+
30
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
31
+
32
+ t = self.inverse_lambda(lambda_t)
33
+
34
+ ===============================================================
35
+
36
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
37
+
38
+ 1. For discrete-time DPMs:
39
+
40
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
41
+ t_i = (i + 1) / N
42
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
43
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
44
+
45
+ Args:
46
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
47
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
48
+
49
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
50
+
51
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
52
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
53
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
54
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
55
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
56
+ and
57
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
58
+
59
+
60
+ 2. For continuous-time DPMs:
61
+
62
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
63
+ schedule are the default settings in DDPM and improved-DDPM:
64
+
65
+ Args:
66
+ beta_min: A `float` number. The smallest beta for the linear schedule.
67
+ beta_max: A `float` number. The largest beta for the linear schedule.
68
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
69
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
70
+ T: A `float` number. The ending time of the forward process.
71
+
72
+ ===============================================================
73
+
74
+ Args:
75
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
76
+ 'linear' or 'cosine' for continuous-time DPMs.
77
+ Returns:
78
+ A wrapper object of the forward SDE (VP type).
79
+
80
+ ===============================================================
81
+
82
+ Example:
83
+
84
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
85
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
86
+
87
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
88
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
89
+
90
+ # For continuous-time DPMs (VPSDE), linear schedule:
91
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
92
+
93
+ """
94
+
95
+ if schedule not in ['discrete', 'linear', 'cosine']:
96
+ raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
97
+
98
+ self.schedule = schedule
99
+ if schedule == 'discrete':
100
+ if betas is not None:
101
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
102
+ else:
103
+ assert alphas_cumprod is not None
104
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
105
+ self.total_N = len(log_alphas)
106
+ self.T = 1.
107
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
108
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
109
+ else:
110
+ self.total_N = 1000
111
+ self.beta_0 = continuous_beta_0
112
+ self.beta_1 = continuous_beta_1
113
+ self.cosine_s = 0.008
114
+ self.cosine_beta_max = 999.
115
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
116
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
117
+ self.schedule = schedule
118
+ if schedule == 'cosine':
119
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
120
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
121
+ self.T = 0.9946
122
+ else:
123
+ self.T = 1.
124
+
125
+ def marginal_log_mean_coeff(self, t):
126
+ """
127
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
128
+ """
129
+ if self.schedule == 'discrete':
130
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
131
+ elif self.schedule == 'linear':
132
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
133
+ elif self.schedule == 'cosine':
134
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
135
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
136
+ return log_alpha_t
137
+
138
+ def marginal_alpha(self, t):
139
+ """
140
+ Compute alpha_t of a given continuous-time label t in [0, T].
141
+ """
142
+ return torch.exp(self.marginal_log_mean_coeff(t))
143
+
144
+ def marginal_std(self, t):
145
+ """
146
+ Compute sigma_t of a given continuous-time label t in [0, T].
147
+ """
148
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
149
+
150
+ def marginal_lambda(self, t):
151
+ """
152
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
153
+ """
154
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
155
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
156
+ return log_mean_coeff - log_std
157
+
158
+ def inverse_lambda(self, lamb):
159
+ """
160
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
161
+ """
162
+ if self.schedule == 'linear':
163
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
164
+ Delta = self.beta_0**2 + tmp
165
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
166
+ elif self.schedule == 'discrete':
167
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
168
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
169
+ return t.reshape((-1,))
170
+ else:
171
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
172
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
173
+ t = t_fn(log_alpha)
174
+ return t
175
+
176
+
177
+ def model_wrapper(
178
+ model,
179
+ noise_schedule,
180
+ model_type="noise",
181
+ model_kwargs={},
182
+ guidance_type="uncond",
183
+ condition=None,
184
+ unconditional_condition=None,
185
+ guidance_scale=1.,
186
+ classifier_fn=None,
187
+ classifier_kwargs={},
188
+ ):
189
+ """Create a wrapper function for the noise prediction model.
190
+
191
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
192
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
193
+
194
+ We support four types of the diffusion model by setting `model_type`:
195
+
196
+ 1. "noise": noise prediction model. (Trained by predicting noise).
197
+
198
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
199
+
200
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
201
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
202
+
203
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
204
+ arXiv preprint arXiv:2202.00512 (2022).
205
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
206
+ arXiv preprint arXiv:2210.02303 (2022).
207
+
208
+ 4. "score": marginal score function. (Trained by denoising score matching).
209
+ Note that the score function and the noise prediction model follows a simple relationship:
210
+ ```
211
+ noise(x_t, t) = -sigma_t * score(x_t, t)
212
+ ```
213
+
214
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
215
+ 1. "uncond": unconditional sampling by DPMs.
216
+ The input `model` has the following format:
217
+ ``
218
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
219
+ ``
220
+
221
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
222
+ The input `model` has the following format:
223
+ ``
224
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
225
+ ``
226
+
227
+ The input `classifier_fn` has the following format:
228
+ ``
229
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
230
+ ``
231
+
232
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
233
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
234
+
235
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
236
+ The input `model` has the following format:
237
+ ``
238
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
239
+ ``
240
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
241
+
242
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
243
+ arXiv preprint arXiv:2207.12598 (2022).
244
+
245
+
246
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
247
+ or continuous-time labels (i.e. epsilon to T).
248
+
249
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
250
+ ``
251
+ def model_fn(x, t_continuous) -> noise:
252
+ t_input = get_model_input_time(t_continuous)
253
+ return noise_pred(model, x, t_input, **model_kwargs)
254
+ ``
255
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
256
+
257
+ ===============================================================
258
+
259
+ Args:
260
+ model: A diffusion model with the corresponding format described above.
261
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
262
+ model_type: A `str`. The parameterization type of the diffusion model.
263
+ "noise" or "x_start" or "v" or "score".
264
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
265
+ guidance_type: A `str`. The type of the guidance for sampling.
266
+ "uncond" or "classifier" or "classifier-free".
267
+ condition: A pytorch tensor. The condition for the guided sampling.
268
+ Only used for "classifier" or "classifier-free" guidance type.
269
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
270
+ Only used for "classifier-free" guidance type.
271
+ guidance_scale: A `float`. The scale for the guided sampling.
272
+ classifier_fn: A classifier function. Only used for the classifier guidance.
273
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
274
+ Returns:
275
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
276
+ """
277
+
278
+ def get_model_input_time(t_continuous):
279
+ """
280
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
281
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
282
+ For continuous-time DPMs, we just use `t_continuous`.
283
+ """
284
+ if noise_schedule.schedule == 'discrete':
285
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
286
+ else:
287
+ return t_continuous
288
+
289
+ def noise_pred_fn(x, t_continuous, cond=None):
290
+ if t_continuous.reshape((-1,)).shape[0] == 1:
291
+ t_continuous = t_continuous.expand((x.shape[0]))
292
+ t_input = get_model_input_time(t_continuous)
293
+ if cond is None:
294
+ output = model(x, t_input, **model_kwargs)
295
+ else:
296
+ output = model(x, t_input, cond, **model_kwargs)
297
+ if model_type == "noise":
298
+ return output
299
+ elif model_type == "x_start":
300
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
301
+ dims = x.dim()
302
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
303
+ elif model_type == "v":
304
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
305
+ dims = x.dim()
306
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
307
+ elif model_type == "score":
308
+ sigma_t = noise_schedule.marginal_std(t_continuous)
309
+ dims = x.dim()
310
+ return -expand_dims(sigma_t, dims) * output
311
+
312
+ def cond_grad_fn(x, t_input):
313
+ """
314
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
315
+ """
316
+ with torch.enable_grad():
317
+ x_in = x.detach().requires_grad_(True)
318
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
319
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
320
+
321
+ def model_fn(x, t_continuous):
322
+ """
323
+ The noise predicition model function that is used for DPM-Solver.
324
+ """
325
+ if t_continuous.reshape((-1,)).shape[0] == 1:
326
+ t_continuous = t_continuous.expand((x.shape[0]))
327
+ if guidance_type == "uncond":
328
+ return noise_pred_fn(x, t_continuous)
329
+ elif guidance_type == "classifier":
330
+ assert classifier_fn is not None
331
+ t_input = get_model_input_time(t_continuous)
332
+ cond_grad = cond_grad_fn(x, t_input)
333
+ sigma_t = noise_schedule.marginal_std(t_continuous)
334
+ noise = noise_pred_fn(x, t_continuous)
335
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
336
+ elif guidance_type == "classifier-free":
337
+ if guidance_scale == 1. or unconditional_condition is None:
338
+ return noise_pred_fn(x, t_continuous, cond=condition)
339
+ else:
340
+ x_in = torch.cat([x] * 2)
341
+ t_in = torch.cat([t_continuous] * 2)
342
+ c_in = torch.cat([unconditional_condition, condition])
343
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
344
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
345
+
346
+ assert model_type in ["noise", "x_start", "v"]
347
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
348
+ return model_fn
349
+
350
+
351
+ class DPM_Solver:
352
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
353
+ """Construct a DPM-Solver.
354
+
355
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
356
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
357
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
358
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
359
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
360
+
361
+ Args:
362
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
363
+ ``
364
+ def model_fn(x, t_continuous):
365
+ return noise
366
+ ``
367
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
368
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
369
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
370
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
371
+
372
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
373
+ """
374
+ self.model = model_fn
375
+ self.noise_schedule = noise_schedule
376
+ self.predict_x0 = predict_x0
377
+ self.thresholding = thresholding
378
+ self.max_val = max_val
379
+
380
+ def noise_prediction_fn(self, x, t):
381
+ """
382
+ Return the noise prediction model.
383
+ """
384
+ return self.model(x, t)
385
+
386
+ def data_prediction_fn(self, x, t):
387
+ """
388
+ Return the data prediction model (with thresholding).
389
+ """
390
+ noise = self.noise_prediction_fn(x, t)
391
+ dims = x.dim()
392
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
393
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
394
+ if self.thresholding:
395
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
396
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
397
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
398
+ x0 = torch.clamp(x0, -s, s) / s
399
+ return x0
400
+
401
+ def model_fn(self, x, t):
402
+ """
403
+ Convert the model to the noise prediction model or the data prediction model.
404
+ """
405
+ if self.predict_x0:
406
+ return self.data_prediction_fn(x, t)
407
+ else:
408
+ return self.noise_prediction_fn(x, t)
409
+
410
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
411
+ """Compute the intermediate time steps for sampling.
412
+
413
+ Args:
414
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
415
+ - 'logSNR': uniform logSNR for the time steps.
416
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
417
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
418
+ t_T: A `float`. The starting time of the sampling (default is T).
419
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
420
+ N: A `int`. The total number of the spacing of the time steps.
421
+ device: A torch device.
422
+ Returns:
423
+ A pytorch tensor of the time steps, with the shape (N + 1,).
424
+ """
425
+ if skip_type == 'logSNR':
426
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
427
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
428
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
429
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
430
+ elif skip_type == 'time_uniform':
431
+ return torch.linspace(t_T, t_0, N + 1).to(device)
432
+ elif skip_type == 'time_quadratic':
433
+ t_order = 2
434
+ t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
435
+ return t
436
+ else:
437
+ raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
438
+
439
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
440
+ """
441
+ Get the order of each step for sampling by the singlestep DPM-Solver.
442
+
443
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
444
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
445
+ - If order == 1:
446
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
447
+ - If order == 2:
448
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
449
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
450
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
451
+ - If order == 3:
452
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
453
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
454
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
455
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
456
+
457
+ ============================================
458
+ Args:
459
+ order: A `int`. The max order for the solver (2 or 3).
460
+ steps: A `int`. The total number of function evaluations (NFE).
461
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
462
+ - 'logSNR': uniform logSNR for the time steps.
463
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
464
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
465
+ t_T: A `float`. The starting time of the sampling (default is T).
466
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
467
+ device: A torch device.
468
+ Returns:
469
+ orders: A list of the solver order of each step.
470
+ """
471
+ if order == 3:
472
+ K = steps // 3 + 1
473
+ if steps % 3 == 0:
474
+ orders = [3,] * (K - 2) + [2, 1]
475
+ elif steps % 3 == 1:
476
+ orders = [3,] * (K - 1) + [1]
477
+ else:
478
+ orders = [3,] * (K - 1) + [2]
479
+ elif order == 2:
480
+ if steps % 2 == 0:
481
+ K = steps // 2
482
+ orders = [2,] * K
483
+ else:
484
+ K = steps // 2 + 1
485
+ orders = [2,] * (K - 1) + [1]
486
+ elif order == 1:
487
+ K = 1
488
+ orders = [1,] * steps
489
+ else:
490
+ raise ValueError("'order' must be '1' or '2' or '3'.")
491
+ if skip_type == 'logSNR':
492
+ # To reproduce the results in DPM-Solver paper
493
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
494
+ else:
495
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)]
496
+ return timesteps_outer, orders
497
+
498
+ def denoise_to_zero_fn(self, x, s):
499
+ """
500
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
501
+ """
502
+ return self.data_prediction_fn(x, s)
503
+
504
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
505
+ """
506
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
507
+
508
+ Args:
509
+ x: A pytorch tensor. The initial value at time `s`.
510
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
511
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
512
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
513
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
514
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
515
+ Returns:
516
+ x_t: A pytorch tensor. The approximated solution at time `t`.
517
+ """
518
+ ns = self.noise_schedule
519
+ dims = x.dim()
520
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
521
+ h = lambda_t - lambda_s
522
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
523
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
524
+ alpha_t = torch.exp(log_alpha_t)
525
+
526
+ if self.predict_x0:
527
+ phi_1 = torch.expm1(-h)
528
+ if model_s is None:
529
+ model_s = self.model_fn(x, s)
530
+ x_t = (
531
+ expand_dims(sigma_t / sigma_s, dims) * x
532
+ - expand_dims(alpha_t * phi_1, dims) * model_s
533
+ )
534
+ if return_intermediate:
535
+ return x_t, {'model_s': model_s}
536
+ else:
537
+ return x_t
538
+ else:
539
+ phi_1 = torch.expm1(h)
540
+ if model_s is None:
541
+ model_s = self.model_fn(x, s)
542
+ x_t = (
543
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
544
+ - expand_dims(sigma_t * phi_1, dims) * model_s
545
+ )
546
+ if return_intermediate:
547
+ return x_t, {'model_s': model_s}
548
+ else:
549
+ return x_t
550
+
551
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpm_solver'):
552
+ """
553
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
554
+
555
+ Args:
556
+ x: A pytorch tensor. The initial value at time `s`.
557
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
558
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
559
+ r1: A `float`. The hyperparameter of the second-order solver.
560
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
561
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
562
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
563
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
564
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
565
+ Returns:
566
+ x_t: A pytorch tensor. The approximated solution at time `t`.
567
+ """
568
+ if solver_type not in ['dpm_solver', 'taylor']:
569
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
570
+ if r1 is None:
571
+ r1 = 0.5
572
+ ns = self.noise_schedule
573
+ dims = x.dim()
574
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
575
+ h = lambda_t - lambda_s
576
+ lambda_s1 = lambda_s + r1 * h
577
+ s1 = ns.inverse_lambda(lambda_s1)
578
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
579
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
580
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
581
+
582
+ if self.predict_x0:
583
+ phi_11 = torch.expm1(-r1 * h)
584
+ phi_1 = torch.expm1(-h)
585
+
586
+ if model_s is None:
587
+ model_s = self.model_fn(x, s)
588
+ x_s1 = (
589
+ expand_dims(sigma_s1 / sigma_s, dims) * x
590
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
591
+ )
592
+ model_s1 = self.model_fn(x_s1, s1)
593
+ if solver_type == 'dpm_solver':
594
+ x_t = (
595
+ expand_dims(sigma_t / sigma_s, dims) * x
596
+ - expand_dims(alpha_t * phi_1, dims) * model_s
597
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
598
+ )
599
+ elif solver_type == 'taylor':
600
+ x_t = (
601
+ expand_dims(sigma_t / sigma_s, dims) * x
602
+ - expand_dims(alpha_t * phi_1, dims) * model_s
603
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (model_s1 - model_s)
604
+ )
605
+ else:
606
+ phi_11 = torch.expm1(r1 * h)
607
+ phi_1 = torch.expm1(h)
608
+
609
+ if model_s is None:
610
+ model_s = self.model_fn(x, s)
611
+ x_s1 = (
612
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
613
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
614
+ )
615
+ model_s1 = self.model_fn(x_s1, s1)
616
+ if solver_type == 'dpm_solver':
617
+ x_t = (
618
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
619
+ - expand_dims(sigma_t * phi_1, dims) * model_s
620
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
621
+ )
622
+ elif solver_type == 'taylor':
623
+ x_t = (
624
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
625
+ - expand_dims(sigma_t * phi_1, dims) * model_s
626
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
627
+ )
628
+ if return_intermediate:
629
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
630
+ else:
631
+ return x_t
632
+
633
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpm_solver'):
634
+ """
635
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
636
+
637
+ Args:
638
+ x: A pytorch tensor. The initial value at time `s`.
639
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
640
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
641
+ r1: A `float`. The hyperparameter of the third-order solver.
642
+ r2: A `float`. The hyperparameter of the third-order solver.
643
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
644
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
645
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
646
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
647
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
648
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
649
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
650
+ Returns:
651
+ x_t: A pytorch tensor. The approximated solution at time `t`.
652
+ """
653
+ if solver_type not in ['dpm_solver', 'taylor']:
654
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
655
+ if r1 is None:
656
+ r1 = 1. / 3.
657
+ if r2 is None:
658
+ r2 = 2. / 3.
659
+ ns = self.noise_schedule
660
+ dims = x.dim()
661
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
662
+ h = lambda_t - lambda_s
663
+ lambda_s1 = lambda_s + r1 * h
664
+ lambda_s2 = lambda_s + r2 * h
665
+ s1 = ns.inverse_lambda(lambda_s1)
666
+ s2 = ns.inverse_lambda(lambda_s2)
667
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
668
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
669
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
670
+
671
+ if self.predict_x0:
672
+ phi_11 = torch.expm1(-r1 * h)
673
+ phi_12 = torch.expm1(-r2 * h)
674
+ phi_1 = torch.expm1(-h)
675
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
676
+ phi_2 = phi_1 / h + 1.
677
+ phi_3 = phi_2 / h - 0.5
678
+
679
+ if model_s is None:
680
+ model_s = self.model_fn(x, s)
681
+ if model_s1 is None:
682
+ x_s1 = (
683
+ expand_dims(sigma_s1 / sigma_s, dims) * x
684
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
685
+ )
686
+ model_s1 = self.model_fn(x_s1, s1)
687
+ x_s2 = (
688
+ expand_dims(sigma_s2 / sigma_s, dims) * x
689
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
690
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
691
+ )
692
+ model_s2 = self.model_fn(x_s2, s2)
693
+ if solver_type == 'dpm_solver':
694
+ x_t = (
695
+ expand_dims(sigma_t / sigma_s, dims) * x
696
+ - expand_dims(alpha_t * phi_1, dims) * model_s
697
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
698
+ )
699
+ elif solver_type == 'taylor':
700
+ D1_0 = (1. / r1) * (model_s1 - model_s)
701
+ D1_1 = (1. / r2) * (model_s2 - model_s)
702
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
703
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
704
+ x_t = (
705
+ expand_dims(sigma_t / sigma_s, dims) * x
706
+ - expand_dims(alpha_t * phi_1, dims) * model_s
707
+ + expand_dims(alpha_t * phi_2, dims) * D1
708
+ - expand_dims(alpha_t * phi_3, dims) * D2
709
+ )
710
+ else:
711
+ phi_11 = torch.expm1(r1 * h)
712
+ phi_12 = torch.expm1(r2 * h)
713
+ phi_1 = torch.expm1(h)
714
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
715
+ phi_2 = phi_1 / h - 1.
716
+ phi_3 = phi_2 / h - 0.5
717
+
718
+ if model_s is None:
719
+ model_s = self.model_fn(x, s)
720
+ if model_s1 is None:
721
+ x_s1 = (
722
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
723
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
724
+ )
725
+ model_s1 = self.model_fn(x_s1, s1)
726
+ x_s2 = (
727
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
728
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
729
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
730
+ )
731
+ model_s2 = self.model_fn(x_s2, s2)
732
+ if solver_type == 'dpm_solver':
733
+ x_t = (
734
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
735
+ - expand_dims(sigma_t * phi_1, dims) * model_s
736
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
737
+ )
738
+ elif solver_type == 'taylor':
739
+ D1_0 = (1. / r1) * (model_s1 - model_s)
740
+ D1_1 = (1. / r2) * (model_s2 - model_s)
741
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
742
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
743
+ x_t = (
744
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
745
+ - expand_dims(sigma_t * phi_1, dims) * model_s
746
+ - expand_dims(sigma_t * phi_2, dims) * D1
747
+ - expand_dims(sigma_t * phi_3, dims) * D2
748
+ )
749
+
750
+ if return_intermediate:
751
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
752
+ else:
753
+ return x_t
754
+
755
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
756
+ """
757
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
758
+
759
+ Args:
760
+ x: A pytorch tensor. The initial value at time `s`.
761
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
762
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
763
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
764
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
765
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
766
+ Returns:
767
+ x_t: A pytorch tensor. The approximated solution at time `t`.
768
+ """
769
+ if solver_type not in ['dpm_solver', 'taylor']:
770
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
771
+ ns = self.noise_schedule
772
+ dims = x.dim()
773
+ model_prev_1, model_prev_0 = model_prev_list
774
+ t_prev_1, t_prev_0 = t_prev_list
775
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
776
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
777
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
778
+ alpha_t = torch.exp(log_alpha_t)
779
+
780
+ h_0 = lambda_prev_0 - lambda_prev_1
781
+ h = lambda_t - lambda_prev_0
782
+ r0 = h_0 / h
783
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
784
+ if self.predict_x0:
785
+ if solver_type == 'dpm_solver':
786
+ x_t = (
787
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
788
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
789
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
790
+ )
791
+ elif solver_type == 'taylor':
792
+ x_t = (
793
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
794
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
795
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
796
+ )
797
+ else:
798
+ if solver_type == 'dpm_solver':
799
+ x_t = (
800
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
801
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
802
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
803
+ )
804
+ elif solver_type == 'taylor':
805
+ x_t = (
806
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
807
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
808
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
809
+ )
810
+ return x_t
811
+
812
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
813
+ """
814
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
815
+
816
+ Args:
817
+ x: A pytorch tensor. The initial value at time `s`.
818
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
819
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
820
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
821
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
822
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
823
+ Returns:
824
+ x_t: A pytorch tensor. The approximated solution at time `t`.
825
+ """
826
+ ns = self.noise_schedule
827
+ dims = x.dim()
828
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
829
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
830
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
831
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
832
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
833
+ alpha_t = torch.exp(log_alpha_t)
834
+
835
+ h_1 = lambda_prev_1 - lambda_prev_2
836
+ h_0 = lambda_prev_0 - lambda_prev_1
837
+ h = lambda_t - lambda_prev_0
838
+ r0, r1 = h_0 / h, h_1 / h
839
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
840
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
841
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
842
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
843
+ if self.predict_x0:
844
+ x_t = (
845
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
846
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
847
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
848
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h**2 - 0.5), dims) * D2
849
+ )
850
+ else:
851
+ x_t = (
852
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
853
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
854
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
855
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h**2 - 0.5), dims) * D2
856
+ )
857
+ return x_t
858
+
859
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, r2=None):
860
+ """
861
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
862
+
863
+ Args:
864
+ x: A pytorch tensor. The initial value at time `s`.
865
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
866
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
867
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
868
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
869
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
870
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
871
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
872
+ r2: A `float`. The hyperparameter of the third-order solver.
873
+ Returns:
874
+ x_t: A pytorch tensor. The approximated solution at time `t`.
875
+ """
876
+ if order == 1:
877
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
878
+ elif order == 2:
879
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
880
+ elif order == 3:
881
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
882
+ else:
883
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
884
+
885
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
886
+ """
887
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
888
+
889
+ Args:
890
+ x: A pytorch tensor. The initial value at time `s`.
891
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
892
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
893
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
894
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
895
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
896
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
897
+ Returns:
898
+ x_t: A pytorch tensor. The approximated solution at time `t`.
899
+ """
900
+ if order == 1:
901
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
902
+ elif order == 2:
903
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
904
+ elif order == 3:
905
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
906
+ else:
907
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
908
+
909
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpm_solver'):
910
+ """
911
+ The adaptive step size solver based on singlestep DPM-Solver.
912
+
913
+ Args:
914
+ x: A pytorch tensor. The initial value at time `t_T`.
915
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
916
+ t_T: A `float`. The starting time of the sampling (default is T).
917
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
918
+ h_init: A `float`. The initial step size (for logSNR).
919
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
920
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
921
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
922
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
923
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
924
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
925
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
926
+ Returns:
927
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
928
+
929
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
930
+ """
931
+ ns = self.noise_schedule
932
+ s = t_T * torch.ones((x.shape[0],)).to(x)
933
+ lambda_s = ns.marginal_lambda(s)
934
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
935
+ h = h_init * torch.ones_like(s).to(x)
936
+ x_prev = x
937
+ nfe = 0
938
+ if order == 2:
939
+ r1 = 0.5
940
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
941
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
942
+ elif order == 3:
943
+ r1, r2 = 1. / 3., 2. / 3.
944
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
945
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
946
+ else:
947
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
948
+ while torch.abs((s - t_0)).mean() > t_err:
949
+ t = ns.inverse_lambda(lambda_s + h)
950
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
951
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
952
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
953
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
954
+ E = norm_fn((x_higher - x_lower) / delta).max()
955
+ if torch.all(E <= 1.):
956
+ x = x_higher
957
+ s = t
958
+ x_prev = x_lower
959
+ lambda_s = ns.marginal_lambda(s)
960
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
961
+ nfe += order
962
+ print('adaptive solver nfe', nfe)
963
+ return x
964
+
965
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
966
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
967
+ atol=0.0078, rtol=0.05,
968
+ ):
969
+ """
970
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
971
+
972
+ =====================================================
973
+
974
+ We support the following algorithms for both noise prediction model and data prediction model:
975
+ - 'singlestep':
976
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
977
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
978
+ The total number of function evaluations (NFE) == `steps`.
979
+ Given a fixed NFE == `steps`, the sampling procedure is:
980
+ - If `order` == 1:
981
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
982
+ - If `order` == 2:
983
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
984
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
985
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
986
+ - If `order` == 3:
987
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
988
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
989
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
990
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
991
+ - 'multistep':
992
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
993
+ We initialize the first `order` values by lower order multistep solvers.
994
+ Given a fixed NFE == `steps`, the sampling procedure is:
995
+ Denote K = steps.
996
+ - If `order` == 1:
997
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
998
+ - If `order` == 2:
999
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1000
+ - If `order` == 3:
1001
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1002
+ - 'singlestep_fixed':
1003
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1004
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1005
+ - 'adaptive':
1006
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1007
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1008
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1009
+ (NFE) and the sample quality.
1010
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1011
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1012
+
1013
+ =====================================================
1014
+
1015
+ Some advices for choosing the algorithm:
1016
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1017
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
1018
+ e.g.
1019
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
1020
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1021
+ skip_type='time_uniform', method='singlestep')
1022
+ - For **guided sampling with large guidance scale** by DPMs:
1023
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
1024
+ e.g.
1025
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
1026
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1027
+ skip_type='time_uniform', method='multistep')
1028
+
1029
+ We support three types of `skip_type`:
1030
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1031
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1032
+ - 'time_quadratic': quadratic time for the time steps.
1033
+
1034
+ =====================================================
1035
+ Args:
1036
+ x: A pytorch tensor. The initial value at time `t_start`
1037
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1038
+ steps: A `int`. The total number of function evaluations (NFE).
1039
+ t_start: A `float`. The starting time of the sampling.
1040
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
1041
+ t_end: A `float`. The ending time of the sampling.
1042
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1043
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
1044
+ For discrete-time DPMs:
1045
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1046
+ For continuous-time DPMs:
1047
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1048
+ order: A `int`. The order of DPM-Solver.
1049
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1050
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1051
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1052
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1053
+
1054
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1055
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1056
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
1057
+ (such as CIFAR-10). However, we observed that such trick does not matter for
1058
+ high-resolutional images. As it needs an additional NFE, we do not recommend
1059
+ it for high-resolutional images.
1060
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1061
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
1062
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1063
+ (especially for steps <= 10). So we recommend to set it to be `True`.
1064
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1065
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1066
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1067
+ Returns:
1068
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
1069
+
1070
+ """
1071
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1072
+ t_T = self.noise_schedule.T if t_start is None else t_start
1073
+ device = x.device
1074
+ if method == 'adaptive':
1075
+ with torch.no_grad():
1076
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
1077
+ elif method == 'multistep':
1078
+ assert steps >= order
1079
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1080
+ assert timesteps.shape[0] - 1 == steps
1081
+ with torch.no_grad():
1082
+ vec_t = timesteps[0].expand((x.shape[0]))
1083
+ model_prev_list = [self.model_fn(x, vec_t)]
1084
+ t_prev_list = [vec_t]
1085
+ # Init the first `order` values by lower order multistep DPM-Solver.
1086
+ for init_order in range(1, order):
1087
+ vec_t = timesteps[init_order].expand(x.shape[0])
1088
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, solver_type=solver_type)
1089
+ model_prev_list.append(self.model_fn(x, vec_t))
1090
+ t_prev_list.append(vec_t)
1091
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
1092
+ for step in range(order, steps + 1):
1093
+ vec_t = timesteps[step].expand(x.shape[0])
1094
+ if lower_order_final and steps < 15:
1095
+ step_order = min(order, steps + 1 - step)
1096
+ else:
1097
+ step_order = order
1098
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type)
1099
+ for i in range(order - 1):
1100
+ t_prev_list[i] = t_prev_list[i + 1]
1101
+ model_prev_list[i] = model_prev_list[i + 1]
1102
+ t_prev_list[-1] = vec_t
1103
+ # We do not need to evaluate the final model value.
1104
+ if step < steps:
1105
+ model_prev_list[-1] = self.model_fn(x, vec_t)
1106
+ elif method in ['singlestep', 'singlestep_fixed']:
1107
+ if method == 'singlestep':
1108
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
1109
+ elif method == 'singlestep_fixed':
1110
+ K = steps // order
1111
+ orders = [order,] * K
1112
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1113
+ for i, order in enumerate(orders):
1114
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1115
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device)
1116
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1117
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1118
+ h = lambda_inner[-1] - lambda_inner[0]
1119
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1120
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1121
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1122
+ if denoise_to_zero:
1123
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1124
+ return x
1125
+
1126
+
1127
+
1128
+ #############################################################
1129
+ # other utility functions
1130
+ #############################################################
1131
+
1132
+ def interpolate_fn(x, xp, yp):
1133
+ """
1134
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
1135
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
1136
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1137
+
1138
+ Args:
1139
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1140
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1141
+ yp: PyTorch tensor with shape [C, K].
1142
+ Returns:
1143
+ The function values f(x), with shape [N, C].
1144
+ """
1145
+ N, K = x.shape[0], xp.shape[1]
1146
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1147
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1148
+ x_idx = torch.argmin(x_indices, dim=2)
1149
+ cand_start_idx = x_idx - 1
1150
+ start_idx = torch.where(
1151
+ torch.eq(x_idx, 0),
1152
+ torch.tensor(1, device=x.device),
1153
+ torch.where(
1154
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1155
+ ),
1156
+ )
1157
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1158
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1159
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1160
+ start_idx2 = torch.where(
1161
+ torch.eq(x_idx, 0),
1162
+ torch.tensor(0, device=x.device),
1163
+ torch.where(
1164
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1165
+ ),
1166
+ )
1167
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1168
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1169
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1170
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1171
+ return cand
1172
+
1173
+
1174
+ def expand_dims(v, dims):
1175
+ """
1176
+ Expand the tensor `v` to the dim `dims`.
1177
+
1178
+ Args:
1179
+ `v`: a PyTorch tensor with shape [N].
1180
+ `dim`: a `int`.
1181
+ Returns:
1182
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1183
+ """
1184
+ return v[(...,) + (None,)*(dims - 1)]
ldm/models/diffusion/dpm_solver/sampler.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+
5
+ from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6
+
7
+
8
+ class DPMSolverSampler(object):
9
+ def __init__(self, model, **kwargs):
10
+ super().__init__()
11
+ self.model = model
12
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
14
+
15
+ def register_buffer(self, name, attr):
16
+ if type(attr) == torch.Tensor:
17
+ if attr.device != torch.device("cuda"):
18
+ attr = attr.to(torch.device("cuda"))
19
+ setattr(self, name, attr)
20
+
21
+ @torch.no_grad()
22
+ def sample(self,
23
+ S,
24
+ batch_size,
25
+ shape,
26
+ conditioning=None,
27
+ callback=None,
28
+ normals_sequence=None,
29
+ img_callback=None,
30
+ quantize_x0=False,
31
+ eta=0.,
32
+ mask=None,
33
+ x0=None,
34
+ temperature=1.,
35
+ noise_dropout=0.,
36
+ score_corrector=None,
37
+ corrector_kwargs=None,
38
+ verbose=True,
39
+ x_T=None,
40
+ log_every_t=100,
41
+ unconditional_guidance_scale=1.,
42
+ unconditional_conditioning=None,
43
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
44
+ **kwargs
45
+ ):
46
+ if conditioning is not None:
47
+ if isinstance(conditioning, dict):
48
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
49
+ if cbs != batch_size:
50
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
51
+ else:
52
+ if conditioning.shape[0] != batch_size:
53
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
54
+
55
+ # sampling
56
+ C, H, W = shape
57
+ size = (batch_size, C, H, W)
58
+
59
+ # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
60
+
61
+ device = self.model.betas.device
62
+ if x_T is None:
63
+ img = torch.randn(size, device=device)
64
+ else:
65
+ img = x_T
66
+
67
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
68
+
69
+ model_fn = model_wrapper(
70
+ lambda x, t, c: self.model.apply_model(x, t, c),
71
+ ns,
72
+ model_type="noise",
73
+ guidance_type="classifier-free",
74
+ condition=conditioning,
75
+ unconditional_condition=unconditional_conditioning,
76
+ guidance_scale=unconditional_guidance_scale,
77
+ )
78
+
79
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
80
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
81
+
82
+ return x.to(device), None
ldm/models/diffusion/plms.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+ import copy
8
+ from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
9
+
10
+ class PLMSSampler(object):
11
+ def __init__(self, model, schedule="linear", **kwargs):
12
+ super().__init__()
13
+ self.model = model
14
+ self.ddpm_num_timesteps = model.num_timesteps
15
+ self.schedule = schedule
16
+
17
+ def register_buffer(self, name, attr):
18
+ if type(attr) == torch.Tensor:
19
+ if attr.device != torch.device("cuda"):
20
+ attr = attr.to(torch.device("cuda"))
21
+ setattr(self, name, attr)
22
+
23
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
+ if ddim_eta != 0:
25
+ raise ValueError('ddim_eta must be 0 for PLMS')
26
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
27
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
28
+ alphas_cumprod = self.model.alphas_cumprod
29
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
30
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
31
+
32
+ self.register_buffer('betas', to_torch(self.model.betas))
33
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
34
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
35
+
36
+ # calculations for diffusion q(x_t | x_{t-1}) and others
37
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
38
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
39
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
40
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
41
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
42
+
43
+ # ddim sampling parameters
44
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
45
+ ddim_timesteps=self.ddim_timesteps,
46
+ eta=ddim_eta,verbose=verbose)
47
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
48
+ self.register_buffer('ddim_alphas', ddim_alphas)
49
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
50
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
51
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
52
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
53
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
54
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
55
+
56
+ @torch.no_grad()
57
+ def sample(self,
58
+ S,
59
+ batch_size,
60
+ shape,
61
+ conditioning=None,
62
+ callback=None,
63
+ normals_sequence=None,
64
+ img_callback=None,
65
+ quantize_x0=False,
66
+ eta=0.,
67
+ mask=None,
68
+ x0=None,
69
+ temperature=1.,
70
+ noise_dropout=0.,
71
+ score_corrector=None,
72
+ corrector_kwargs=None,
73
+ verbose=True,
74
+ x_T=None,
75
+ log_every_t=100,
76
+ unconditional_guidance_scale=1.,
77
+ unconditional_conditioning=None,
78
+ features_adapter1=None,
79
+ features_adapter2=None,
80
+ mode = 'sketch',
81
+ con_strength=30,
82
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
83
+ **kwargs
84
+ ):
85
+ # print('*'*20,x_T)
86
+ # exit(0)
87
+ if conditioning is not None:
88
+ if isinstance(conditioning, dict):
89
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
90
+ if cbs != batch_size:
91
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
+ else:
93
+ if conditioning.shape[0] != batch_size:
94
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
95
+
96
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
97
+ C, H, W = shape
98
+ size = (batch_size, C, H, W)
99
+ print(f'Data shape for PLMS sampling is {size}')
100
+
101
+ samples, intermediates = self.plms_sampling(conditioning, size,
102
+ callback=callback,
103
+ img_callback=img_callback,
104
+ quantize_denoised=quantize_x0,
105
+ mask=mask, x0=x0,
106
+ ddim_use_original_steps=False,
107
+ noise_dropout=noise_dropout,
108
+ temperature=temperature,
109
+ score_corrector=score_corrector,
110
+ corrector_kwargs=corrector_kwargs,
111
+ x_T=x_T,
112
+ log_every_t=log_every_t,
113
+ unconditional_guidance_scale=unconditional_guidance_scale,
114
+ unconditional_conditioning=unconditional_conditioning,
115
+ features_adapter1=copy.deepcopy(features_adapter1),
116
+ features_adapter2=copy.deepcopy(features_adapter2),
117
+ mode = mode,
118
+ con_strength = con_strength
119
+ )
120
+ return samples, intermediates
121
+
122
+ @torch.no_grad()
123
+ def plms_sampling(self, cond, shape,
124
+ x_T=None, ddim_use_original_steps=False,
125
+ callback=None, timesteps=None, quantize_denoised=False,
126
+ mask=None, x0=None, img_callback=None, log_every_t=100,
127
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
+ unconditional_guidance_scale=1., unconditional_conditioning=None,features_adapter1=None, features_adapter2=None, mode='sketch', con_strength=30):
129
+ device = self.model.betas.device
130
+ b = shape[0]
131
+ if x_T is None:
132
+ img = torch.randn(shape, device=device)
133
+ else:
134
+ img = x_T
135
+ if timesteps is None:
136
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
137
+ elif timesteps is not None and not ddim_use_original_steps:
138
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
139
+ timesteps = self.ddim_timesteps[:subset_end]
140
+
141
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
142
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
143
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
144
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
145
+
146
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
147
+ old_eps = []
148
+
149
+ for i, step in enumerate(iterator):
150
+ index = total_steps - i - 1
151
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
152
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
153
+
154
+ if mask is not None :#and index>=10:
155
+ assert x0 is not None
156
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
+ img = img_orig * mask + (1. - mask) * img
158
+
159
+ if mode == 'sketch':
160
+ if index<con_strength:
161
+ features_adapter = None
162
+ else:
163
+ features_adapter = features_adapter1
164
+ elif mode == 'mul':
165
+ features_adapter = [a1i*0.5 + a2i for a1i, a2i in zip(features_adapter1, features_adapter2)]
166
+ else:
167
+ features_adapter = features_adapter1
168
+
169
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
170
+ quantize_denoised=quantize_denoised, temperature=temperature,
171
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
172
+ corrector_kwargs=corrector_kwargs,
173
+ unconditional_guidance_scale=unconditional_guidance_scale,
174
+ unconditional_conditioning=unconditional_conditioning,
175
+ old_eps=old_eps, t_next=ts_next, features_adapter=copy.deepcopy(features_adapter))
176
+
177
+ img, pred_x0, e_t = outs
178
+ old_eps.append(e_t)
179
+ if len(old_eps) >= 4:
180
+ old_eps.pop(0)
181
+ if callback: callback(i)
182
+ if img_callback: img_callback(pred_x0, i)
183
+
184
+ if index % log_every_t == 0 or index == total_steps - 1:
185
+ intermediates['x_inter'].append(img)
186
+ intermediates['pred_x0'].append(pred_x0)
187
+
188
+ return img, intermediates
189
+
190
+ @torch.no_grad()
191
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
192
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
193
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, features_adapter=None):
194
+ b, *_, device = *x.shape, x.device
195
+
196
+ def get_model_output(x, t):
197
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
198
+ e_t = self.model.apply_model(x, t, c, copy.deepcopy(features_adapter))
199
+ else:
200
+ x_in = torch.cat([x] * 2)
201
+ t_in = torch.cat([t] * 2)
202
+ c_in = torch.cat([unconditional_conditioning, c])
203
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, copy.deepcopy(features_adapter)).chunk(2)
204
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
205
+
206
+ if score_corrector is not None:
207
+ assert self.model.parameterization == "eps"
208
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
209
+
210
+ return e_t
211
+
212
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
213
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
214
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
215
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
216
+
217
+ def get_x_prev_and_pred_x0(e_t, index):
218
+ # select parameters corresponding to the currently considered timestep
219
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
220
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
221
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
222
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
223
+
224
+ # current prediction for x_0
225
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
226
+ if quantize_denoised:
227
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
228
+ # direction pointing to x_t
229
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
230
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
231
+ if noise_dropout > 0.:
232
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
233
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
234
+ return x_prev, pred_x0
235
+
236
+ e_t = get_model_output(x, t)
237
+ if len(old_eps) == 0:
238
+ # Pseudo Improved Euler (2nd order)
239
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
240
+ e_t_next = get_model_output(x_prev, t_next)
241
+ e_t_prime = (e_t + e_t_next) / 2
242
+ elif len(old_eps) == 1:
243
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
244
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
245
+ elif len(old_eps) == 2:
246
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
247
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
248
+ elif len(old_eps) >= 3:
249
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
250
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
251
+
252
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
253
+
254
+ return x_prev, pred_x0, e_t
ldm/modules/attention.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from ldm.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return{el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = nn.Sequential(
53
+ nn.Linear(dim, inner_dim),
54
+ nn.GELU()
55
+ ) if not glu else GEGLU(dim, inner_dim)
56
+
57
+ self.net = nn.Sequential(
58
+ project_in,
59
+ nn.Dropout(dropout),
60
+ nn.Linear(inner_dim, dim_out)
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.net(x)
65
+
66
+
67
+ def zero_module(module):
68
+ """
69
+ Zero out the parameters of a module and return it.
70
+ """
71
+ for p in module.parameters():
72
+ p.detach().zero_()
73
+ return module
74
+
75
+
76
+ def Normalize(in_channels):
77
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78
+
79
+
80
+ class LinearAttention(nn.Module):
81
+ def __init__(self, dim, heads=4, dim_head=32):
82
+ super().__init__()
83
+ self.heads = heads
84
+ hidden_dim = dim_head * heads
85
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87
+
88
+ def forward(self, x):
89
+ b, c, h, w = x.shape
90
+ qkv = self.to_qkv(x)
91
+ q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92
+ k = k.softmax(dim=-1)
93
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
94
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
95
+ out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96
+ return self.to_out(out)
97
+
98
+
99
+ class SpatialSelfAttention(nn.Module):
100
+ def __init__(self, in_channels):
101
+ super().__init__()
102
+ self.in_channels = in_channels
103
+
104
+ self.norm = Normalize(in_channels)
105
+ self.q = torch.nn.Conv2d(in_channels,
106
+ in_channels,
107
+ kernel_size=1,
108
+ stride=1,
109
+ padding=0)
110
+ self.k = torch.nn.Conv2d(in_channels,
111
+ in_channels,
112
+ kernel_size=1,
113
+ stride=1,
114
+ padding=0)
115
+ self.v = torch.nn.Conv2d(in_channels,
116
+ in_channels,
117
+ kernel_size=1,
118
+ stride=1,
119
+ padding=0)
120
+ self.proj_out = torch.nn.Conv2d(in_channels,
121
+ in_channels,
122
+ kernel_size=1,
123
+ stride=1,
124
+ padding=0)
125
+
126
+ def forward(self, x):
127
+ h_ = x
128
+ h_ = self.norm(h_)
129
+ q = self.q(h_)
130
+ k = self.k(h_)
131
+ v = self.v(h_)
132
+
133
+ # compute attention
134
+ b,c,h,w = q.shape
135
+ q = rearrange(q, 'b c h w -> b (h w) c')
136
+ k = rearrange(k, 'b c h w -> b c (h w)')
137
+ w_ = torch.einsum('bij,bjk->bik', q, k)
138
+
139
+ w_ = w_ * (int(c)**(-0.5))
140
+ w_ = torch.nn.functional.softmax(w_, dim=2)
141
+
142
+ # attend to values
143
+ v = rearrange(v, 'b c h w -> b c (h w)')
144
+ w_ = rearrange(w_, 'b i j -> b j i')
145
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
146
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147
+ h_ = self.proj_out(h_)
148
+
149
+ return x+h_
150
+
151
+
152
+ class CrossAttention(nn.Module):
153
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154
+ super().__init__()
155
+ inner_dim = dim_head * heads
156
+ context_dim = default(context_dim, query_dim)
157
+
158
+ self.scale = dim_head ** -0.5
159
+ self.heads = heads
160
+
161
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164
+
165
+ self.to_out = nn.Sequential(
166
+ nn.Linear(inner_dim, query_dim),
167
+ nn.Dropout(dropout)
168
+ )
169
+
170
+ def forward(self, x, context=None, mask=None):
171
+ h = self.heads
172
+
173
+ q = self.to_q(x)
174
+ context = default(context, x)
175
+ k = self.to_k(context)
176
+ v = self.to_v(context)
177
+
178
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179
+
180
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181
+
182
+ if exists(mask):
183
+ mask = rearrange(mask, 'b ... -> b (...)')
184
+ max_neg_value = -torch.finfo(sim.dtype).max
185
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
186
+ sim.masked_fill_(~mask, max_neg_value)
187
+
188
+ # attention, what we cannot get enough of
189
+ attn = sim.softmax(dim=-1)
190
+
191
+ out = einsum('b i j, b j d -> b i d', attn, v)
192
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193
+ return self.to_out(out)
194
+
195
+
196
+ class BasicTransformerBlock(nn.Module):
197
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198
+ super().__init__()
199
+ self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201
+ self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203
+ self.norm1 = nn.LayerNorm(dim)
204
+ self.norm2 = nn.LayerNorm(dim)
205
+ self.norm3 = nn.LayerNorm(dim)
206
+ self.checkpoint = checkpoint
207
+
208
+ def forward(self, x, context=None):
209
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210
+
211
+ def _forward(self, x, context=None):
212
+ x = self.attn1(self.norm1(x)) + x
213
+ x = self.attn2(self.norm2(x), context=context) + x
214
+ x = self.ff(self.norm3(x)) + x
215
+ return x
216
+
217
+
218
+ class SpatialTransformer(nn.Module):
219
+ """
220
+ Transformer block for image-like data.
221
+ First, project the input (aka embedding)
222
+ and reshape to b, t, d.
223
+ Then apply standard transformer action.
224
+ Finally, reshape to image
225
+ """
226
+ def __init__(self, in_channels, n_heads, d_head,
227
+ depth=1, dropout=0., context_dim=None):
228
+ super().__init__()
229
+ self.in_channels = in_channels
230
+ inner_dim = n_heads * d_head
231
+ self.norm = Normalize(in_channels)
232
+
233
+ self.proj_in = nn.Conv2d(in_channels,
234
+ inner_dim,
235
+ kernel_size=1,
236
+ stride=1,
237
+ padding=0)
238
+
239
+ self.transformer_blocks = nn.ModuleList(
240
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241
+ for d in range(depth)]
242
+ )
243
+
244
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
245
+ in_channels,
246
+ kernel_size=1,
247
+ stride=1,
248
+ padding=0))
249
+
250
+ def forward(self, x, context=None):
251
+ # note: if no context is given, cross-attention defaults to self-attention
252
+ b, c, h, w = x.shape
253
+ x_in = x
254
+ x = self.norm(x)
255
+ x = self.proj_in(x)
256
+ x = rearrange(x, 'b c h w -> b (h w) c')
257
+ for block in self.transformer_blocks:
258
+ x = block(x, context=context)
259
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260
+ x = self.proj_out(x)
261
+ return x + x_in
ldm/modules/diffusionmodules/__init__.py ADDED
File without changes