diff --git a/.gitattributes b/.gitattributes
new file mode 100755
index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,34 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ckpt filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..04dc5a589e6655d965d09876f45ec6bd99b196bf
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,124 @@
+# ignored folders
+tmp/*
+
+*.DS_Store
+.idea
+
+# ignored files
+version.py
+
+# ignored files with suffix
+# *.html
+# *.png
+# *.jpeg
+# *.jpg
+# *.gif
+# *.pth
+# *.zip
+
+# template
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
diff --git a/LICENSE b/LICENSE
new file mode 100755
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..3c1e8d13b64dc98c1148fc48350cfd41f1db3b0c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,13 @@
+---
+license: openrail
+title: T2I-Adapter
+sdk: gradio
+sdk_version: 3.19.1
+emoji: 😻
+colorFrom: pink
+colorTo: blue
+pinned: false
+python_version: 3.8.16
+app_file: app.py
+duplicated_from: Adapter/T2I-Adapter
+---
diff --git a/app.py b/app.py
new file mode 100755
index 0000000000000000000000000000000000000000..ed9d7699550cd3096f20b4d54f18ea30781b4ce3
--- /dev/null
+++ b/app.py
@@ -0,0 +1,392 @@
+# demo inspired by https://huggingface.co/spaces/lambdalabs/image-mixer-demo
+import argparse
+import copy
+import os
+import shlex
+import subprocess
+from functools import partial
+from itertools import chain
+
+import cv2
+import gradio as gr
+import torch
+from basicsr.utils import tensor2img
+from huggingface_hub import hf_hub_url
+from pytorch_lightning import seed_everything
+from torch import autocast
+
+from ldm.inference_base import (DEFAULT_NEGATIVE_PROMPT, diffusion_inference, get_adapters, get_sd_models)
+from ldm.modules.extra_condition import api
+from ldm.modules.extra_condition.api import (ExtraCondition, get_adapter_feature, get_cond_model)
+
+torch.set_grad_enabled(False)
+
+supported_cond = ['style', 'color', 'sketch', 'openpose', 'depth', 'canny']
+
+# download the checkpoints
+urls = {
+ 'TencentARC/T2I-Adapter': [
+ 'models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_color_sd14v1.pth',
+ 'models/t2iadapter_openpose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth',
+ 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth',
+ 'third-party-models/body_pose_model.pth', "models/t2iadapter_style_sd14v1.pth",
+ "models/t2iadapter_canny_sd14v1.pth", 'third-party-models/table5_pidinet.pth'
+ ],
+ 'runwayml/stable-diffusion-v1-5': ['v1-5-pruned-emaonly.ckpt'],
+ 'andite/anything-v4.0': ['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
+}
+
+# download image samples
+torch.hub.download_url_to_file(
+ 'https://user-images.githubusercontent.com/52127135/223114920-cae3e723-3683-424a-bebc-0875479f2409.jpg',
+ 'cyber_style.jpg')
+torch.hub.download_url_to_file(
+ 'https://user-images.githubusercontent.com/52127135/223114946-6ccc127f-cb58-443e-8677-805f5dbaf6f1.png',
+ 'sword.png')
+torch.hub.download_url_to_file(
+ 'https://user-images.githubusercontent.com/52127135/223121793-20c2ac6a-5a4f-4ff8-88ea-6d007a7959dd.png',
+ 'white.png')
+torch.hub.download_url_to_file(
+ 'https://user-images.githubusercontent.com/52127135/223127404-4a3748cf-85a6-40f3-af31-a74e206db96e.jpeg',
+ 'scream_style.jpeg')
+torch.hub.download_url_to_file(
+ 'https://user-images.githubusercontent.com/52127135/223127433-8768913f-9872-4d24-b883-a19a3eb20623.jpg',
+ 'motorcycle.jpg')
+
+if os.path.exists('models') == False:
+ os.mkdir('models')
+for repo in urls:
+ files = urls[repo]
+ for file in files:
+ url = hf_hub_url(repo, file)
+ name_ckp = url.split('/')[-1]
+ save_path = os.path.join('models', name_ckp)
+ if os.path.exists(save_path) == False:
+ subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
+
+# config
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ '--sd_ckpt',
+ type=str,
+ default='models/v1-5-pruned-emaonly.ckpt',
+ help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported',
+)
+parser.add_argument(
+ '--vae_ckpt',
+ type=str,
+ default=None,
+ help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded',
+)
+global_opt = parser.parse_args()
+global_opt.config = 'configs/stable-diffusion/sd-v1-inference.yaml'
+for cond_name in supported_cond:
+ setattr(global_opt, f'{cond_name}_adapter_ckpt', f'models/t2iadapter_{cond_name}_sd14v1.pth')
+global_opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+global_opt.max_resolution = 512 * 512
+global_opt.sampler = 'ddim'
+global_opt.cond_weight = 1.0
+global_opt.C = 4
+global_opt.f = 8
+
+# stable-diffusion model
+sd_model, sampler = get_sd_models(global_opt)
+# adapters and models to processing condition inputs
+adapters = {}
+cond_models = {}
+torch.cuda.empty_cache()
+
+
+def run(*args):
+ with torch.inference_mode(), \
+ sd_model.ema_scope(), \
+ autocast('cuda'):
+
+ inps = []
+ for i in range(0, len(args) - 8, len(supported_cond)):
+ inps.append(args[i:i + len(supported_cond)])
+
+ opt = copy.deepcopy(global_opt)
+ opt.prompt, opt.neg_prompt, opt.scale, opt.n_samples, opt.seed, opt.steps, opt.resize_short_edge, opt.cond_tau \
+ = args[-8:]
+
+ conds = []
+ activated_conds = []
+
+ ims1 = []
+ ims2 = []
+ for idx, (b, im1, im2, cond_weight) in enumerate(zip(*inps)):
+ if idx > 1:
+ if im1 is not None or im2 is not None:
+ if im1 is not None:
+ h, w, _ = im1.shape
+ else:
+ h, w, _ = im2.shape
+ break
+ # resize all the images to the same size
+ for idx, (b, im1, im2, cond_weight) in enumerate(zip(*inps)):
+ if idx == 0:
+ ims1.append(im1)
+ ims2.append(im2)
+ continue
+ if im1 is not None:
+ im1 = cv2.resize(im1, (w, h), interpolation=cv2.INTER_CUBIC)
+ if im2 is not None:
+ im2 = cv2.resize(im2, (w, h), interpolation=cv2.INTER_CUBIC)
+ ims1.append(im1)
+ ims2.append(im2)
+
+ for idx, (b, _, _, cond_weight) in enumerate(zip(*inps)):
+ cond_name = supported_cond[idx]
+ if b == 'Nothing':
+ if cond_name in adapters:
+ adapters[cond_name]['model'] = adapters[cond_name]['model'].cpu()
+ else:
+ activated_conds.append(cond_name)
+ if cond_name in adapters:
+ adapters[cond_name]['model'] = adapters[cond_name]['model'].to(opt.device)
+ else:
+ adapters[cond_name] = get_adapters(opt, getattr(ExtraCondition, cond_name))
+ adapters[cond_name]['cond_weight'] = cond_weight
+
+ process_cond_module = getattr(api, f'get_cond_{cond_name}')
+
+ if b == 'Image':
+ if cond_name not in cond_models:
+ cond_models[cond_name] = get_cond_model(opt, getattr(ExtraCondition, cond_name))
+ conds.append(process_cond_module(opt, ims1[idx], 'image', cond_models[cond_name]))
+ else:
+ conds.append(process_cond_module(opt, ims2[idx], cond_name, None))
+
+ adapter_features, append_to_context = get_adapter_feature(
+ conds, [adapters[cond_name] for cond_name in activated_conds])
+
+ output_conds = []
+ for cond in conds:
+ output_conds.append(tensor2img(cond, rgb2bgr=False))
+
+ ims = []
+ seed_everything(opt.seed)
+ for _ in range(opt.n_samples):
+ result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)
+ ims.append(tensor2img(result, rgb2bgr=False))
+
+ # Clear GPU memory cache so less likely to OOM
+ torch.cuda.empty_cache()
+ return ims, output_conds
+
+
+def change_visible(im1, im2, val):
+ outputs = {}
+ if val == "Image":
+ outputs[im1] = gr.update(visible=True)
+ outputs[im2] = gr.update(visible=False)
+ elif val == "Nothing":
+ outputs[im1] = gr.update(visible=False)
+ outputs[im2] = gr.update(visible=False)
+ else:
+ outputs[im1] = gr.update(visible=False)
+ outputs[im2] = gr.update(visible=True)
+ return outputs
+
+
+DESCRIPTION = '# [Composable T2I-Adapter](https://github.com/TencentARC/T2I-Adapter)'
+
+DESCRIPTION += f'
Gradio demo for **T2I-Adapter**: [[GitHub]](https://github.com/TencentARC/T2I-Adapter), [[Paper]](https://arxiv.org/abs/2302.08453). If T2I-Adapter is helpful, please help to ⭐ the [Github Repo](https://github.com/TencentARC/T2I-Adapter) and recommend it to your friends 😊
'
+
+DESCRIPTION += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. ![Duplicate Space](https://bit.ly/3gLdBN6)
'
+
+with gr.Blocks(css='style.css') as demo:
+ gr.Markdown(DESCRIPTION)
+
+ btns = []
+ ims1 = []
+ ims2 = []
+ cond_weights = []
+
+ with gr.Row():
+ with gr.Column(scale=1.9):
+ with gr.Box():
+ gr.Markdown("Style & Color
")
+ with gr.Row():
+ for cond_name in supported_cond[:2]:
+ with gr.Box():
+ with gr.Column():
+ if cond_name == 'style':
+ btn1 = gr.Radio(
+ choices=["Image", "Nothing"],
+ label=f"Input type for {cond_name}",
+ interactive=True,
+ value="Nothing",
+ )
+ else:
+ btn1 = gr.Radio(
+ choices=["Image", cond_name, "Nothing"],
+ label=f"Input type for {cond_name}",
+ interactive=True,
+ value="Nothing",
+ )
+ im1 = gr.Image(
+ source='upload', label="Image", interactive=True, visible=False, type="numpy")
+ im2 = gr.Image(
+ source='upload', label=cond_name, interactive=True, visible=False, type="numpy")
+ cond_weight = gr.Slider(
+ label="Condition weight",
+ minimum=0,
+ maximum=5,
+ step=0.05,
+ value=1,
+ interactive=True)
+
+ fn = partial(change_visible, im1, im2)
+ btn1.change(fn=fn, inputs=[btn1], outputs=[im1, im2], queue=False)
+
+ btns.append(btn1)
+ ims1.append(im1)
+ ims2.append(im2)
+ cond_weights.append(cond_weight)
+ with gr.Column(scale=4):
+ with gr.Box():
+ gr.Markdown("Structure
")
+ with gr.Row():
+ for cond_name in supported_cond[2:6]:
+ with gr.Box():
+ with gr.Column():
+ if cond_name == 'openpose':
+ btn1 = gr.Radio(
+ choices=["Image", 'pose', "Nothing"],
+ label=f"Input type for {cond_name}",
+ interactive=True,
+ value="Nothing",
+ )
+ else:
+ btn1 = gr.Radio(
+ choices=["Image", cond_name, "Nothing"],
+ label=f"Input type for {cond_name}",
+ interactive=True,
+ value="Nothing",
+ )
+ im1 = gr.Image(
+ source='upload', label="Image", interactive=True, visible=False, type="numpy")
+ im2 = gr.Image(
+ source='upload', label=cond_name, interactive=True, visible=False, type="numpy")
+ cond_weight = gr.Slider(
+ label="Condition weight",
+ minimum=0,
+ maximum=5,
+ step=0.05,
+ value=1,
+ interactive=True)
+
+ fn = partial(change_visible, im1, im2)
+ btn1.change(fn=fn, inputs=[btn1], outputs=[im1, im2], queue=False)
+
+ btns.append(btn1)
+ ims1.append(im1)
+ ims2.append(im2)
+ cond_weights.append(cond_weight)
+
+ with gr.Column():
+ prompt = gr.Textbox(label="Prompt")
+
+ with gr.Accordion('Advanced options', open=False):
+ neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
+ scale = gr.Slider(
+ label="Guidance Scale (Classifier free guidance)", value=7.5, minimum=1, maximum=20, step=0.1)
+ n_samples = gr.Slider(label="Num samples", value=1, minimum=1, maximum=1, step=1)
+ seed = gr.Slider(label="Seed", value=42, minimum=0, maximum=10000, step=1, randomize=True)
+ steps = gr.Slider(label="Steps", value=50, minimum=10, maximum=100, step=1)
+ resize_short_edge = gr.Slider(label="Image resolution", value=512, minimum=320, maximum=1024, step=1)
+ cond_tau = gr.Slider(
+ label="timestamp parameter that determines until which step the adapter is applied",
+ value=1.0,
+ minimum=0.1,
+ maximum=1.0,
+ step=0.05)
+
+ with gr.Row():
+ submit = gr.Button("Generate")
+ output = gr.Gallery().style(grid=2, height='auto')
+ cond = gr.Gallery().style(grid=2, height='auto')
+
+ inps = list(chain(btns, ims1, ims2, cond_weights))
+
+ inps.extend([prompt, neg_prompt, scale, n_samples, seed, steps, resize_short_edge, cond_tau])
+ submit.click(fn=run, inputs=inps, outputs=[output, cond])
+
+ ex = gr.Examples([
+ [
+ "Image",
+ "Nothing",
+ "Image",
+ "Nothing",
+ "Nothing",
+ "Nothing",
+ "cyber_style.jpg",
+ "white.png",
+ "sword.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ "master sword",
+ "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
+ 7.5,
+ 1,
+ 2500,
+ 50,
+ 512,
+ 1,
+ ],
+ [
+ "Image",
+ "Nothing",
+ "Image",
+ "Nothing",
+ "Nothing",
+ "Nothing",
+ "scream_style.jpeg",
+ "white.png",
+ "motorcycle.jpg",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ "white.png",
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ "motorcycle",
+ "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality",
+ 7.5,
+ 1,
+ 2500,
+ 50,
+ 512,
+ 1,
+ ],
+ ],
+ fn=run,
+ inputs=inps,
+ outputs=[output, cond],
+ cache_examples=True)
+
+demo.queue().launch(debug=True, server_name='0.0.0.0')
diff --git a/configs/mm/faster_rcnn_r50_fpn_coco.py b/configs/mm/faster_rcnn_r50_fpn_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9ad9528b22163ae7ce1390375b69227fd6eafd9
--- /dev/null
+++ b/configs/mm/faster_rcnn_r50_fpn_coco.py
@@ -0,0 +1,182 @@
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[8, 11])
+total_epochs = 12
+
+model = dict(
+ type='FasterRCNN',
+ pretrained='torchvision://resnet50',
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ style='pytorch'),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ num_outs=5),
+ rpn_head=dict(
+ type='RPNHead',
+ in_channels=256,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ roi_head=dict(
+ type='StandardRoIHead',
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ bbox_head=dict(
+ type='Shared2FCBBoxHead',
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
+ # model training and testing settings
+ train_cfg=dict(
+ rpn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.7,
+ neg_iou_thr=0.3,
+ min_pos_iou=0.3,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ rpn_proposal=dict(
+ nms_pre=2000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.5,
+ min_pos_iou=0.5,
+ match_low_quality=False,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ rpn=dict(
+ nms_pre=1000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.5),
+ max_per_img=100)
+ # soft-nms is also supported for rcnn testing
+ # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
+ ))
+
+dataset_type = 'CocoDataset'
+data_root = 'data/coco'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=2,
+ train=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ pipeline=test_pipeline))
+evaluation = dict(interval=1, metric='bbox')
diff --git a/configs/mm/hrnet_w48_coco_256x192.py b/configs/mm/hrnet_w48_coco_256x192.py
new file mode 100644
index 0000000000000000000000000000000000000000..9755e6773cd3a8c0d2ac684c612d716cfd44b0ca
--- /dev/null
+++ b/configs/mm/hrnet_w48_coco_256x192.py
@@ -0,0 +1,169 @@
+# _base_ = [
+# '../../../../_base_/default_runtime.py',
+# '../../../../_base_/datasets/coco.py'
+# ]
+evaluation = dict(interval=10, metric='mAP', save_best='AP')
+
+optimizer = dict(
+ type='Adam',
+ lr=5e-4,
+)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[170, 200])
+total_epochs = 210
+channel_cfg = dict(
+ num_output_channels=17,
+ dataset_joints=17,
+ dataset_channel=[
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
+ ],
+ inference_channel=[
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
+ ])
+
+# model settings
+model = dict(
+ type='TopDown',
+ pretrained='https://download.openmmlab.com/mmpose/'
+ 'pretrain_models/hrnet_w48-8ef0771d.pth',
+ backbone=dict(
+ type='HRNet',
+ in_channels=3,
+ extra=dict(
+ stage1=dict(
+ num_modules=1,
+ num_branches=1,
+ block='BOTTLENECK',
+ num_blocks=(4, ),
+ num_channels=(64, )),
+ stage2=dict(
+ num_modules=1,
+ num_branches=2,
+ block='BASIC',
+ num_blocks=(4, 4),
+ num_channels=(48, 96)),
+ stage3=dict(
+ num_modules=4,
+ num_branches=3,
+ block='BASIC',
+ num_blocks=(4, 4, 4),
+ num_channels=(48, 96, 192)),
+ stage4=dict(
+ num_modules=3,
+ num_branches=4,
+ block='BASIC',
+ num_blocks=(4, 4, 4, 4),
+ num_channels=(48, 96, 192, 384))),
+ ),
+ keypoint_head=dict(
+ type='TopdownHeatmapSimpleHead',
+ in_channels=48,
+ out_channels=channel_cfg['num_output_channels'],
+ num_deconv_layers=0,
+ extra=dict(final_conv_kernel=1, ),
+ loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
+ train_cfg=dict(),
+ test_cfg=dict(
+ flip_test=True,
+ post_process='default',
+ shift_heatmap=True,
+ modulate_kernel=11))
+
+data_cfg = dict(
+ image_size=[192, 256],
+ heatmap_size=[48, 64],
+ num_output_channels=channel_cfg['num_output_channels'],
+ num_joints=channel_cfg['dataset_joints'],
+ dataset_channel=channel_cfg['dataset_channel'],
+ inference_channel=channel_cfg['inference_channel'],
+ soft_nms=False,
+ nms_thr=1.0,
+ oks_thr=0.9,
+ vis_thr=0.2,
+ use_gt_bbox=False,
+ det_bbox_thr=0.0,
+ bbox_file='data/coco/person_detection_results/'
+ 'COCO_val2017_detections_AP_H_56_person.json',
+)
+
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
+ dict(type='TopDownRandomShiftBboxCenter', shift_factor=0.16, prob=0.3),
+ dict(type='TopDownRandomFlip', flip_prob=0.5),
+ dict(
+ type='TopDownHalfBodyTransform',
+ num_joints_half_body=8,
+ prob_half_body=0.3),
+ dict(
+ type='TopDownGetRandomScaleRotation', rot_factor=40, scale_factor=0.5),
+ dict(type='TopDownAffine'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(type='TopDownGenerateTarget', sigma=2),
+ dict(
+ type='Collect',
+ keys=['img', 'target', 'target_weight'],
+ meta_keys=[
+ 'image_file', 'joints_3d', 'joints_3d_visible', 'center', 'scale',
+ 'rotation', 'bbox_score', 'flip_pairs'
+ ]),
+]
+
+val_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='TopDownGetBboxCenterScale', padding=1.25),
+ dict(type='TopDownAffine'),
+ dict(type='ToTensor'),
+ dict(
+ type='NormalizeTensor',
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]),
+ dict(
+ type='Collect',
+ keys=['img'],
+ meta_keys=[
+ 'image_file', 'center', 'scale', 'rotation', 'bbox_score',
+ 'flip_pairs'
+ ]),
+]
+
+test_pipeline = val_pipeline
+
+data_root = 'data/coco'
+data = dict(
+ samples_per_gpu=32,
+ workers_per_gpu=2,
+ val_dataloader=dict(samples_per_gpu=32),
+ test_dataloader=dict(samples_per_gpu=32),
+ train=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
+ data_cfg=data_cfg,
+ pipeline=train_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+ val=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ data_cfg=data_cfg,
+ pipeline=val_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+ test=dict(
+ type='TopDownCocoDataset',
+ ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ data_cfg=data_cfg,
+ pipeline=test_pipeline,
+ dataset_info={{_base_.dataset_info}}),
+)
diff --git a/configs/stable-diffusion/app.yaml b/configs/stable-diffusion/app.yaml
new file mode 100755
index 0000000000000000000000000000000000000000..19431de3476af4315d9747016068b619de8f05ce
--- /dev/null
+++ b/configs/stable-diffusion/app.yaml
@@ -0,0 +1,87 @@
+name: app
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ device: 'cuda'
+
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 1e4
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+dist_params:
+ backend: nccl
+ port: 29500
+training:
+ lr: !!float 1e-5
+ save_freq: 1e4
\ No newline at end of file
diff --git a/configs/stable-diffusion/sd-v1-inference.yaml b/configs/stable-diffusion/sd-v1-inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dba409bc86df919bbaa687e1c85fefd641b963de
--- /dev/null
+++ b/configs/stable-diffusion/sd-v1-inference.yaml
@@ -0,0 +1,65 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_fp16: True
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 512
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.WebUIFrozenCLIPEmebedder
+ params:
+ version: openai/clip-vit-large-patch14
+ layer: last
diff --git a/configs/stable-diffusion/sd-v1-train.yaml b/configs/stable-diffusion/sd-v1-train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3c22ae71c977c229d0bbf0d618a838196c601804
--- /dev/null
+++ b/configs/stable-diffusion/sd-v1-train.yaml
@@ -0,0 +1,86 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config: #__is_unconditional__
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ version: openai/clip-vit-large-patch14
+
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 1e4
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+dist_params:
+ backend: nccl
+ port: 29500
+training:
+ lr: !!float 1e-5
+ save_freq: 1e4
\ No newline at end of file
diff --git a/configs/stable-diffusion/train_keypose.yaml b/configs/stable-diffusion/train_keypose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cd25843a0d854ee3a36807ed69b666a66ada16ab
--- /dev/null
+++ b/configs/stable-diffusion/train_keypose.yaml
@@ -0,0 +1,87 @@
+name: train_keypose
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config: #__is_unconditional__
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ version: openai/clip-vit-large-patch14
+
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 1e4
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+dist_params:
+ backend: nccl
+ port: 29500
+training:
+ lr: !!float 1e-5
+ save_freq: 1e4
\ No newline at end of file
diff --git a/configs/stable-diffusion/train_mask.yaml b/configs/stable-diffusion/train_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7ab298114683416d3687bfc9a2a0b24b51fb1e62
--- /dev/null
+++ b/configs/stable-diffusion/train_mask.yaml
@@ -0,0 +1,87 @@
+name: train_mask
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config: #__is_unconditional__
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ version: openai/clip-vit-large-patch14
+
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 1e4
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+dist_params:
+ backend: nccl
+ port: 29500
+training:
+ lr: !!float 1e-5
+ save_freq: 1e4
\ No newline at end of file
diff --git a/configs/stable-diffusion/train_sketch.yaml b/configs/stable-diffusion/train_sketch.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..90d44870ec68d327b2cf85a6dfae280bd397a825
--- /dev/null
+++ b/configs/stable-diffusion/train_sketch.yaml
@@ -0,0 +1,87 @@
+name: train_sketch
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [ 10000 ]
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
+ f_start: [ 1.e-6 ]
+ f_max: [ 1. ]
+ f_min: [ 1. ]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config: #__is_unconditional__
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
+ params:
+ version: openai/clip-vit-large-patch14
+
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 1e4
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+dist_params:
+ backend: nccl
+ port: 29500
+training:
+ lr: !!float 1e-5
+ save_freq: 1e4
\ No newline at end of file
diff --git a/dist_util.py b/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..47441a48932a86d5556b1167ef327aa3b1ec8173
--- /dev/null
+++ b/dist_util.py
@@ -0,0 +1,91 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+def get_bare_model(net):
+ """Get bare model, especially under wrapping with
+ DistributedDataParallel or DataParallel.
+ """
+ if isinstance(net, (DataParallel, DistributedDataParallel)):
+ net = net.module
+ return net
diff --git a/docs/AdapterZoo.md b/docs/AdapterZoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..ffdf9a9c4588367796f463a575cccdddf65ab513
--- /dev/null
+++ b/docs/AdapterZoo.md
@@ -0,0 +1,16 @@
+# Adapter Zoo
+
+You can download the adapters from
+
+All the following adapters are trained with Stable Diffusion (SD) V1.4, and they can be directly used on custom models as long as they are fine-tuned from the same text-to-image models, such as Anything-4.0 or models on the .
+
+| Adapter Name | Adapter Description | Demos|Model Parameters| Model Storage | |
+| --- | --- |--- |--- |--- |---|
+| t2iadapter_color_sd14v1.pth | Spatial color palette → image | [Demos](examples.md#color-adapter-spatial-palette) |18 M | 75 MB | |
+| t2iadapter_style_sd14v1.pth | Image style → image | [Demos](examples.md#style-adapter)|| 154MB | Preliminary model. Style adapters with finer controls are on the way|
+| t2iadapter_openpose_sd14v1.pth | Openpose → image| [Demos](examples.md#openpose-adapter) |77 M| 309 MB | |
+| t2iadapter_canny_sd14v1.pth | Canny edges → image | [Demos](examples.md#canny-adapter-edge )|77 M | 309 MB ||
+| t2iadapter_sketch_sd14v1.pth | sketch → image ||77 M| 308 MB | |
+| t2iadapter_keypose_sd14v1.pth | keypose → image || 77 M| 309 MB | mmpose style |
+| t2iadapter_seg_sd14v1.pth | segmentation → image ||77 M| 309 MB ||
+| t2iadapter_depth_sd14v1.pth | depth maps → image ||77 M | 309 MB | Not the final model, still under training|
diff --git a/docs/FAQ.md b/docs/FAQ.md
new file mode 100644
index 0000000000000000000000000000000000000000..6b34bb16e54c63afaee471d54405afc0164b601f
--- /dev/null
+++ b/docs/FAQ.md
@@ -0,0 +1,5 @@
+# FAQ
+
+- **Q: The openpose adapter (t2iadapter_openpose_sd14v1) outputs gray-scale images.**
+
+ **A:** You can add `colorful` in the prompt to avoid this problem.
diff --git a/docs/examples.md b/docs/examples.md
new file mode 100644
index 0000000000000000000000000000000000000000..4e422ee622b7a6e2042776df3944b255368cdb49
--- /dev/null
+++ b/docs/examples.md
@@ -0,0 +1,41 @@
+# Demos
+
+## Style Adapter
+
+
+
+
+
+## Color Adapter (Spatial Palette)
+
+
+
+
+
+## Openpose Adapter
+
+
+
+
+
+## Canny Adapter (Edge)
+
+
+
+
+
+## Multi-adapters
+
+
+
+
+
+
+*T2I adapters naturally support using multiple adapters together.*
+
+
+The testing script usage for this example is similar to the command line given below, except that we replaced the pretrained SD model with Anything 4.5 and Kenshi
+
+>python test_composable_adapters.py --prompt "1gril, computer desk, best quality, extremely detailed" --neg_prompt "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" --depth_cond_path examples/depth/desk_depth.png --depth_cond_weight 1.0 --depth_ckpt models/t2iadapter_depth_sd14v1.pth --depth_type_in depth --pose_cond_path examples/keypose/person_keypose.png --pose_cond_weight 1.5 --ckpt models/anything-v4.0-pruned.ckpt --n_sample 4 --max_resolution 524288
+
+[Image source](https://twitter.com/toyxyz3/status/1628375164781211648)
diff --git a/experiments/README.md b/experiments/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/data/dataset_coco.py b/ldm/data/dataset_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b4aa4facb12be8534522c9240ca6e63ce4a68b5
--- /dev/null
+++ b/ldm/data/dataset_coco.py
@@ -0,0 +1,36 @@
+import json
+import cv2
+import os
+from basicsr.utils import img2tensor
+
+
+class dataset_coco_mask_color():
+ def __init__(self, path_json, root_path_im, root_path_mask, image_size):
+ super(dataset_coco_mask_color, self).__init__()
+ with open(path_json, 'r', encoding='utf-8') as fp:
+ data = json.load(fp)
+ data = data['annotations']
+ self.files = []
+ self.root_path_im = root_path_im
+ self.root_path_mask = root_path_mask
+ for file in data:
+ name = "%012d.png" % file['image_id']
+ self.files.append({'name': name, 'sentence': file['caption']})
+
+ def __getitem__(self, idx):
+ file = self.files[idx]
+ name = file['name']
+ # print(os.path.join(self.root_path_im, name))
+ im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png', '.jpg')))
+ im = cv2.resize(im, (512, 512))
+ im = img2tensor(im, bgr2rgb=True, float32=True) / 255.
+
+ mask = cv2.imread(os.path.join(self.root_path_mask, name)) # [:,:,0]
+ mask = cv2.resize(mask, (512, 512))
+ mask = img2tensor(mask, bgr2rgb=True, float32=True) / 255. # [0].unsqueeze(0)#/255.
+
+ sentence = file['sentence']
+ return {'im': im, 'mask': mask, 'sentence': sentence}
+
+ def __len__(self):
+ return len(self.files)
diff --git a/ldm/data/dataset_depth.py b/ldm/data/dataset_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3afe28da237c62795625574b89b60072da79cd2
--- /dev/null
+++ b/ldm/data/dataset_depth.py
@@ -0,0 +1,35 @@
+import json
+import cv2
+import os
+from basicsr.utils import img2tensor
+
+
+class DepthDataset():
+ def __init__(self, meta_file):
+ super(DepthDataset, self).__init__()
+
+ self.files = []
+ with open(meta_file, 'r') as f:
+ lines = f.readlines()
+ for line in lines:
+ img_path = line.strip()
+ depth_img_path = img_path.rsplit('.', 1)[0] + '.depth.png'
+ txt_path = img_path.rsplit('.', 1)[0] + '.txt'
+ self.files.append({'img_path': img_path, 'depth_img_path': depth_img_path, 'txt_path': txt_path})
+
+ def __getitem__(self, idx):
+ file = self.files[idx]
+
+ im = cv2.imread(file['img_path'])
+ im = img2tensor(im, bgr2rgb=True, float32=True) / 255.
+
+ depth = cv2.imread(file['depth_img_path']) # [:,:,0]
+ depth = img2tensor(depth, bgr2rgb=True, float32=True) / 255. # [0].unsqueeze(0)#/255.
+
+ with open(file['txt_path'], 'r') as fs:
+ sentence = fs.readline().strip()
+
+ return {'im': im, 'depth': depth, 'sentence': sentence}
+
+ def __len__(self):
+ return len(self.files)
diff --git a/ldm/data/dataset_laion.py b/ldm/data/dataset_laion.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b1807b1d87e27e09656daf6e7144bd5fba6adce
--- /dev/null
+++ b/ldm/data/dataset_laion.py
@@ -0,0 +1,130 @@
+# -*- coding: utf-8 -*-
+
+import numpy as np
+import os
+import pytorch_lightning as pl
+import torch
+import webdataset as wds
+from torchvision.transforms import transforms
+
+from ldm.util import instantiate_from_config
+
+
+def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True):
+ """Take a list of samples (as dictionary) and create a batch, preserving the keys.
+ If `tensors` is True, `ndarray` objects are combined into
+ tensor batches.
+ :param dict samples: list of samples
+ :param bool tensors: whether to turn lists of ndarrays into a single ndarray
+ :returns: single sample consisting of a batch
+ :rtype: dict
+ """
+ keys = set.intersection(*[set(sample.keys()) for sample in samples])
+ batched = {key: [] for key in keys}
+
+ for s in samples:
+ [batched[key].append(s[key]) for key in batched]
+
+ result = {}
+ for key in batched:
+ if isinstance(batched[key][0], (int, float)):
+ if combine_scalars:
+ result[key] = np.array(list(batched[key]))
+ elif isinstance(batched[key][0], torch.Tensor):
+ if combine_tensors:
+ result[key] = torch.stack(list(batched[key]))
+ elif isinstance(batched[key][0], np.ndarray):
+ if combine_tensors:
+ result[key] = np.array(list(batched[key]))
+ else:
+ result[key] = list(batched[key])
+ return result
+
+
+class WebDataModuleFromConfig(pl.LightningDataModule):
+
+ def __init__(self,
+ tar_base,
+ batch_size,
+ train=None,
+ validation=None,
+ test=None,
+ num_workers=4,
+ multinode=True,
+ min_size=None,
+ max_pwatermark=1.0,
+ **kwargs):
+ super().__init__()
+ print(f'Setting tar base to {tar_base}')
+ self.tar_base = tar_base
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.train = train
+ self.validation = validation
+ self.test = test
+ self.multinode = multinode
+ self.min_size = min_size # filter out very small images
+ self.max_pwatermark = max_pwatermark # filter out watermarked images
+
+ def make_loader(self, dataset_config):
+ image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms]
+ image_transforms = transforms.Compose(image_transforms)
+
+ process = instantiate_from_config(dataset_config['process'])
+
+ shuffle = dataset_config.get('shuffle', 0)
+ shardshuffle = shuffle > 0
+
+ nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only
+
+ tars = os.path.join(self.tar_base, dataset_config.shards)
+
+ dset = wds.WebDataset(
+ tars, nodesplitter=nodesplitter, shardshuffle=shardshuffle,
+ handler=wds.warn_and_continue).repeat().shuffle(shuffle)
+ print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.')
+
+ dset = (
+ dset.select(self.filter_keys).decode('pil',
+ handler=wds.warn_and_continue).select(self.filter_size).map_dict(
+ jpg=image_transforms, handler=wds.warn_and_continue).map(process))
+ dset = (dset.batched(self.batch_size, partial=False, collation_fn=dict_collation_fn))
+
+ loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=self.num_workers)
+
+ return loader
+
+ def filter_size(self, x):
+ if self.min_size is None:
+ return True
+ try:
+ return x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size and x[
+ 'json']['pwatermark'] <= self.max_pwatermark
+ except Exception:
+ return False
+
+ def filter_keys(self, x):
+ try:
+ return ("jpg" in x) and ("txt" in x)
+ except Exception:
+ return False
+
+ def train_dataloader(self):
+ return self.make_loader(self.train)
+
+ def val_dataloader(self):
+ return None
+
+ def test_dataloader(self):
+ return None
+
+
+if __name__ == '__main__':
+ from omegaconf import OmegaConf
+ config = OmegaConf.load("configs/stable-diffusion/train_canny_sd_v1.yaml")
+ datamod = WebDataModuleFromConfig(**config["data"]["params"])
+ dataloader = datamod.train_dataloader()
+
+ for batch in dataloader:
+ print(batch.keys())
+ print(batch['jpg'].shape)
diff --git a/ldm/data/dataset_wikiart.py b/ldm/data/dataset_wikiart.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7a2de87ccbba147580fed82e3c5e5a5ab38761e
--- /dev/null
+++ b/ldm/data/dataset_wikiart.py
@@ -0,0 +1,67 @@
+import json
+import os.path
+
+from PIL import Image
+from torch.utils.data import DataLoader
+
+from transformers import CLIPProcessor
+from torchvision.transforms import transforms
+
+import pytorch_lightning as pl
+
+
+class WikiArtDataset():
+ def __init__(self, meta_file):
+ super(WikiArtDataset, self).__init__()
+
+ self.files = []
+ with open(meta_file, 'r') as f:
+ js = json.load(f)
+ for img_path in js:
+ img_name = os.path.splitext(os.path.basename(img_path))[0]
+ caption = img_name.split('_')[-1]
+ caption = caption.split('-')
+ j = len(caption) - 1
+ while j >= 0:
+ if not caption[j].isdigit():
+ break
+ j -= 1
+ if j < 0:
+ continue
+ sentence = ' '.join(caption[:j + 1])
+ self.files.append({'img_path': os.path.join('datasets/wikiart', img_path), 'sentence': sentence})
+
+ version = 'openai/clip-vit-large-patch14'
+ self.processor = CLIPProcessor.from_pretrained(version)
+
+ self.jpg_transform = transforms.Compose([
+ transforms.Resize(512),
+ transforms.RandomCrop(512),
+ transforms.ToTensor(),
+ ])
+
+ def __getitem__(self, idx):
+ file = self.files[idx]
+
+ im = Image.open(file['img_path'])
+
+ im_tensor = self.jpg_transform(im)
+
+ clip_im = self.processor(images=im, return_tensors="pt")['pixel_values'][0]
+
+ return {'jpg': im_tensor, 'style': clip_im, 'txt': file['sentence']}
+
+ def __len__(self):
+ return len(self.files)
+
+
+class WikiArtDataModule(pl.LightningDataModule):
+ def __init__(self, meta_file, batch_size, num_workers):
+ super(WikiArtDataModule, self).__init__()
+ self.train_dataset = WikiArtDataset(meta_file)
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+
+ def train_dataloader(self):
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers,
+ pin_memory=True)
diff --git a/ldm/data/utils.py b/ldm/data/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ece8c92b4aca12d6c65908900460cc4beaf522e
--- /dev/null
+++ b/ldm/data/utils.py
@@ -0,0 +1,40 @@
+# -*- coding: utf-8 -*-
+
+import cv2
+import numpy as np
+from torchvision.transforms import transforms
+from torchvision.transforms.functional import to_tensor
+from transformers import CLIPProcessor
+
+from basicsr.utils import img2tensor
+
+
+class AddCannyFreezeThreshold(object):
+
+ def __init__(self, low_threshold=100, high_threshold=200):
+ self.low_threshold = low_threshold
+ self.high_threshold = high_threshold
+
+ def __call__(self, sample):
+ # sample['jpg'] is PIL image
+ x = sample['jpg']
+ img = cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR)
+ canny = cv2.Canny(img, self.low_threshold, self.high_threshold)[..., None]
+ sample['canny'] = img2tensor(canny, bgr2rgb=True, float32=True) / 255.
+ sample['jpg'] = to_tensor(x)
+ return sample
+
+
+class AddStyle(object):
+
+ def __init__(self, version):
+ self.processor = CLIPProcessor.from_pretrained(version)
+ self.pil_to_tensor = transforms.ToTensor()
+
+ def __call__(self, sample):
+ # sample['jpg'] is PIL image
+ x = sample['jpg']
+ style = self.processor(images=x, return_tensors="pt")['pixel_values'][0]
+ sample['style'] = style
+ sample['jpg'] = to_tensor(x)
+ return sample
diff --git a/ldm/inference_base.py b/ldm/inference_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7b62e852b4b52881e06ff66d478185b3a928396
--- /dev/null
+++ b/ldm/inference_base.py
@@ -0,0 +1,282 @@
+import argparse
+import torch
+from omegaconf import OmegaConf
+
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.modules.encoders.adapter import Adapter, StyleAdapter, Adapter_light
+from ldm.modules.extra_condition.api import ExtraCondition
+from ldm.util import fix_cond_shapes, load_model_from_config, read_state_dict
+
+DEFAULT_NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, ' \
+ 'fewer digits, cropped, worst quality, low quality'
+
+
+def get_base_argument_parser() -> argparse.ArgumentParser:
+ """get the base argument parser for inference scripts"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--outdir',
+ type=str,
+ help='dir to write results to',
+ default=None,
+ )
+
+ parser.add_argument(
+ '--prompt',
+ type=str,
+ nargs='?',
+ default=None,
+ help='positive prompt',
+ )
+
+ parser.add_argument(
+ '--neg_prompt',
+ type=str,
+ default=DEFAULT_NEGATIVE_PROMPT,
+ help='negative prompt',
+ )
+
+ parser.add_argument(
+ '--cond_path',
+ type=str,
+ default=None,
+ help='condition image path',
+ )
+
+ parser.add_argument(
+ '--cond_inp_type',
+ type=str,
+ default='image',
+ help='the type of the input condition image, take depth T2I as example, the input can be raw image, '
+ 'which depth will be calculated, or the input can be a directly a depth map image',
+ )
+
+ parser.add_argument(
+ '--sampler',
+ type=str,
+ default='ddim',
+ choices=['ddim', 'plms'],
+ help='sampling algorithm, currently, only ddim and plms are supported, more are on the way',
+ )
+
+ parser.add_argument(
+ '--steps',
+ type=int,
+ default=50,
+ help='number of sampling steps',
+ )
+
+ parser.add_argument(
+ '--sd_ckpt',
+ type=str,
+ default='models/sd-v1-4.ckpt',
+ help='path to checkpoint of stable diffusion model, both .ckpt and .safetensor are supported',
+ )
+
+ parser.add_argument(
+ '--vae_ckpt',
+ type=str,
+ default=None,
+ help='vae checkpoint, anime SD models usually have seperate vae ckpt that need to be loaded',
+ )
+
+ parser.add_argument(
+ '--adapter_ckpt',
+ type=str,
+ default=None,
+ help='path to checkpoint of adapter',
+ )
+
+ parser.add_argument(
+ '--config',
+ type=str,
+ default='configs/stable-diffusion/sd-v1-inference.yaml',
+ help='path to config which constructs SD model',
+ )
+
+ parser.add_argument(
+ '--max_resolution',
+ type=float,
+ default=512 * 512,
+ help='max image height * width, only for computer with limited vram',
+ )
+
+ parser.add_argument(
+ '--resize_short_edge',
+ type=int,
+ default=None,
+ help='resize short edge of the input image, if this arg is set, max_resolution will not be used',
+ )
+
+ parser.add_argument(
+ '--C',
+ type=int,
+ default=4,
+ help='latent channels',
+ )
+
+ parser.add_argument(
+ '--f',
+ type=int,
+ default=8,
+ help='downsampling factor',
+ )
+
+ parser.add_argument(
+ '--scale',
+ type=float,
+ default=7.5,
+ help='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))',
+ )
+
+ parser.add_argument(
+ '--cond_tau',
+ type=float,
+ default=1.0,
+ help='timestamp parameter that determines until which step the adapter is applied, '
+ 'similar as Prompt-to-Prompt tau')
+
+ parser.add_argument(
+ '--cond_weight',
+ type=float,
+ default=1.0,
+ help='the adapter features are multiplied by the cond_weight. The larger the cond_weight, the more aligned '
+ 'the generated image and condition will be, but the generated quality may be reduced',
+ )
+
+ parser.add_argument(
+ '--seed',
+ type=int,
+ default=42,
+ )
+
+ parser.add_argument(
+ '--n_samples',
+ type=int,
+ default=4,
+ help='# of samples to generate',
+ )
+
+ return parser
+
+
+def get_sd_models(opt):
+ """
+ build stable diffusion model, sampler
+ """
+ # SD
+ config = OmegaConf.load(f"{opt.config}")
+ model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt)
+ sd_model = model.to(opt.device)
+
+ # sampler
+ if opt.sampler == 'plms':
+ sampler = PLMSSampler(model)
+ elif opt.sampler == 'ddim':
+ sampler = DDIMSampler(model)
+ else:
+ raise NotImplementedError
+
+ return sd_model, sampler
+
+
+def get_t2i_adapter_models(opt):
+ config = OmegaConf.load(f"{opt.config}")
+ model = load_model_from_config(config, opt.sd_ckpt, opt.vae_ckpt)
+ adapter_ckpt_path = getattr(opt, f'{opt.which_cond}_adapter_ckpt', None)
+ if adapter_ckpt_path is None:
+ adapter_ckpt_path = getattr(opt, 'adapter_ckpt')
+ adapter_ckpt = read_state_dict(adapter_ckpt_path)
+ new_state_dict = {}
+ for k, v in adapter_ckpt.items():
+ if not k.startswith('adapter.'):
+ new_state_dict[f'adapter.{k}'] = v
+ else:
+ new_state_dict[k] = v
+ m, u = model.load_state_dict(new_state_dict, strict=False)
+ if len(u) > 0:
+ print(f"unexpected keys in loading adapter ckpt {adapter_ckpt_path}:")
+ print(u)
+
+ model = model.to(opt.device)
+
+ # sampler
+ if opt.sampler == 'plms':
+ sampler = PLMSSampler(model)
+ elif opt.sampler == 'ddim':
+ sampler = DDIMSampler(model)
+ else:
+ raise NotImplementedError
+
+ return model, sampler
+
+
+def get_cond_ch(cond_type: ExtraCondition):
+ if cond_type == ExtraCondition.sketch or cond_type == ExtraCondition.canny:
+ return 1
+ return 3
+
+
+def get_adapters(opt, cond_type: ExtraCondition):
+ adapter = {}
+ cond_weight = getattr(opt, f'{cond_type.name}_weight', None)
+ if cond_weight is None:
+ cond_weight = getattr(opt, 'cond_weight')
+ adapter['cond_weight'] = cond_weight
+
+ if cond_type == ExtraCondition.style:
+ adapter['model'] = StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8).to(opt.device)
+ elif cond_type == ExtraCondition.color:
+ adapter['model'] = Adapter_light(
+ cin=64 * get_cond_ch(cond_type),
+ channels=[320, 640, 1280, 1280],
+ nums_rb=4).to(opt.device)
+ else:
+ adapter['model'] = Adapter(
+ cin=64 * get_cond_ch(cond_type),
+ channels=[320, 640, 1280, 1280][:4],
+ nums_rb=2,
+ ksize=1,
+ sk=True,
+ use_conv=False).to(opt.device)
+ ckpt_path = getattr(opt, f'{cond_type.name}_adapter_ckpt', None)
+ if ckpt_path is None:
+ ckpt_path = getattr(opt, 'adapter_ckpt')
+ adapter['model'].load_state_dict(torch.load(ckpt_path))
+
+ return adapter
+
+
+def diffusion_inference(opt, model, sampler, adapter_features, append_to_context=None):
+ # get text embedding
+ c = model.get_learned_conditioning([opt.prompt])
+ if opt.scale != 1.0:
+ uc = model.get_learned_conditioning([opt.neg_prompt])
+ else:
+ uc = None
+ c, uc = fix_cond_shapes(model, c, uc)
+
+ if not hasattr(opt, 'H'):
+ opt.H = 512
+ opt.W = 512
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
+
+ samples_latents, _ = sampler.sample(
+ S=opt.steps,
+ conditioning=c,
+ batch_size=1,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=uc,
+ x_T=None,
+ features_adapter=adapter_features,
+ append_to_context=append_to_context,
+ cond_tau=opt.cond_tau,
+ )
+
+ x_samples = model.decode_first_stage(samples_latents)
+ x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
+
+ return x_samples
diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py
new file mode 100755
index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade
--- /dev/null
+++ b/ldm/lr_scheduler.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n, **kwargs):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n,**kwargs)
+
+
+class LambdaWarmUpCosineScheduler2:
+ """
+ supports repeated iterations, configurable via lists
+ note: use with a base_lr of 1.0.
+ """
+ def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
+ assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
+ self.lr_warm_up_steps = warm_up_steps
+ self.f_start = f_start
+ self.f_min = f_min
+ self.f_max = f_max
+ self.cycle_lengths = cycle_lengths
+ self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
+ self.last_f = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def find_in_interval(self, n):
+ interval = 0
+ for cl in self.cum_cycles[1:]:
+ if n <= cl:
+ return interval
+ interval += 1
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
+ t = min(t, 1.0)
+ f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
+ 1 + np.cos(t * np.pi))
+ self.last_f = f
+ return f
+
+ def __call__(self, n, **kwargs):
+ return self.schedule(n, **kwargs)
+
+
+class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
+
+ def schedule(self, n, **kwargs):
+ cycle = self.find_in_interval(n)
+ n = n - self.cum_cycles[cycle]
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
+ f"current cycle {cycle}")
+
+ if n < self.lr_warm_up_steps[cycle]:
+ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
+ self.last_f = f
+ return f
+ else:
+ f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
+ self.last_f = f
+ return f
+
diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py
new file mode 100755
index 0000000000000000000000000000000000000000..e3ff5fe3ed0f70de8b31f1af27e107b93fbb94ca
--- /dev/null
+++ b/ldm/models/autoencoder.py
@@ -0,0 +1,211 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+import torch.nn as nn
+from contextlib import contextmanager
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False
+ ):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.use_ema = ema_decay is not None
+ if self.use_ema:
+ self.ema_decay = ema_decay
+ assert 0. < ema_decay < 1.
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, postfix=""):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+ if self.learn_logvar:
+ print(f"{self.__class__.__name__}: Learning logvar")
+ ae_params_list.append(self.loss.logvar)
+ opt_ae = torch.optim.Adam(ae_params_list,
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
new file mode 100755
index 0000000000000000000000000000000000000000..9f19c803246a0125d9c67c31df49da351c7552f0
--- /dev/null
+++ b/ldm/models/diffusion/ddim.py
@@ -0,0 +1,292 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \
+ extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta, verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ features_adapter=None,
+ append_to_context=None,
+ cond_tau=0.4,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ features_adapter=features_adapter,
+ append_to_context=append_to_context,
+ cond_tau=cond_tau,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
+ append_to_context=None, cond_tau=0.4):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ features_adapter=None if index < int(
+ (1 - cond_tau) * total_steps) else features_adapter,
+ # TODO support style_cond_tau
+ append_to_context=None if index < int(
+ 0.5 * total_steps) else append_to_context,
+ )
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
+ append_to_context=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ if append_to_context is not None:
+ model_output = self.model.apply_model(x, t, torch.cat([c, append_to_context], dim=1),
+ features_adapter=features_adapter)
+ else:
+ model_output = self.model.apply_model(x, t, c, features_adapter=features_adapter)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ elif isinstance(c, list):
+ c_in = list()
+ assert isinstance(unconditional_conditioning, list)
+ for i in range(len(c)):
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+ else:
+ if append_to_context is not None:
+ pad_len = append_to_context.size(1)
+ new_unconditional_conditioning = torch.cat(
+ [unconditional_conditioning, unconditional_conditioning[:, -pad_len:, :]], dim=1)
+ new_c = torch.cat([c, append_to_context], dim=1)
+ c_in = torch.cat([new_unconditional_conditioning, new_c])
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in, features_adapter=features_adapter).chunk(2)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ return x_dec
diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py
new file mode 100755
index 0000000000000000000000000000000000000000..263840b499ec9df0be40a02a665e0245b32a2f29
--- /dev/null
+++ b/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1313 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ make_it_fit=False,
+ ucg_training=None,
+ reset_ema=False,
+ reset_num_ema_updates=False,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ self.make_it_fit = make_it_fit
+ if reset_ema: assert exists(ckpt_path)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ if reset_ema:
+ assert self.use_ema
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ self.ucg_training = ucg_training or dict()
+ if self.ucg_training:
+ self.ucg_prng = np.random.RandomState()
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @torch.no_grad()
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ if self.make_it_fit:
+ n_params = len([name for name, _ in
+ itertools.chain(self.named_parameters(),
+ self.named_buffers())])
+ for name, param in tqdm(
+ itertools.chain(self.named_parameters(),
+ self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape) == len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ if len(new_shape) == 1:
+ for i in range(new_param.shape[0]):
+ new_param[i] = old_param[i % old_shape[0]]
+ elif len(new_shape) >= 2:
+ for i in range(new_param.shape[0]):
+ for j in range(new_param.shape[1]):
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+ n_used_old = torch.ones(old_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_old[j % old_shape[1]] += 1
+ n_used_new = torch.zeros(new_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_new[j] = n_used_old[j % old_shape[1]]
+
+ n_used_new = n_used_new[None, :]
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+ new_param /= n_used_new
+
+ sd[name] = new_param
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys:\n {missing}")
+ if len(unexpected) > 0:
+ print(f"\nUnexpected Keys:\n {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ # if len(x.shape) == 3:
+ # x = x[..., None]
+ # x = rearrange(x, 'b h w c -> b c h w')
+ # x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ *args, **kwargs):
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__':
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ reset_ema = kwargs.pop("reset_ema", False)
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+ if reset_ema:
+ assert self.use_ema
+ print(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+ xc = batch[cond_key]
+ elif cond_key in ['class_label', 'cls']:
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ # import pudb; pudb.set_trace()
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c, **kwargs)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ if 't' not in kwargs:
+ t = torch.randint(0, self.num_timesteps, (x.shape[0], ), device=self.device).long()
+ else:
+ t = kwargs.pop('t')
+
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False, **kwargs):
+ if isinstance(cond, dict):
+ # hybrid case, cond is expected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ x_recon = self.model(x_noisy, t, **cond, **kwargs)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None, **kwargs):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond, **kwargs)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, **kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+ shape, cond, verbose=False, **kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True, **kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ if self.cond_stage_key in ["class_label", "cls"]:
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+ return self.get_learned_conditioning(xc)
+ else:
+ raise NotImplementedError("todo")
+ if isinstance(c, list): # in case the encoder gives us a list
+ for i in range(len(c)):
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+ else:
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+ return c
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', "cls"]:
+ try:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ except KeyError:
+ # probably no "human_label" in batch
+ pass
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ if self.model.conditioning_key == "crossattn-adm":
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with ema_scope("Plotting Inpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ mask = 1. - mask
+ with ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None, **kwargs):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t, **kwargs)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t, **kwargs)
+ elif self.conditioning_key == 'crossattn':
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, **kwargs)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, **kwargs)
+ elif self.conditioning_key == 'hybrid-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm, **kwargs)
+ elif self.conditioning_key == 'crossattn-adm':
+ assert c_adm is not None
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, y=c_adm, **kwargs)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc, **kwargs)
+ else:
+ raise NotImplementedError()
+
+ return out
diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100755
index 0000000000000000000000000000000000000000..23ebfebf167a6c16f3b57e09d491998c4adf68db
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1217 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+
+ t = self.inverse_lambda(lambda_t)
+
+ ===============================================================
+
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+
+ 1. For discrete-time DPMs:
+
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+
+
+ 2. For continuous-time DPMs:
+
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+
+ ===============================================================
+
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+
+ Example:
+
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+
+ We support four types of the diffusion model by setting `model_type`:
+
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+
+ ===============================================================
+
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3, ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3, ] * (K - 1) + [1]
+ else:
+ orders = [3, ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2, ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2, ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1, ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+ solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+ model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+ return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+ solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ solver_type=solver_type,
+ **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+
+ =====================================================
+
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+
+ =====================================================
+
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+ solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+ skip_type=skip_type,
+ t_T=t_T, t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order, ] * K
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+ N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100755
index 0000000000000000000000000000000000000000..fc2c96baf2bf5f8de3684c198bcd1b0df5b51149
--- /dev/null
+++ b/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,87 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+
+MODEL_TYPES = {
+ "eps": "noise",
+ "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+ def __init__(self, model, **kwargs):
+ super().__init__()
+ self.model = model
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=MODEL_TYPES[self.model.parameterization],
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
+
+ return x.to(device), None
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
new file mode 100755
index 0000000000000000000000000000000000000000..273ffbebaf952ffc25f6b92506b7c91b4af4c3bf
--- /dev/null
+++ b/ldm/models/diffusion/plms.py
@@ -0,0 +1,243 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta, verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ features_adapter=None,
+ cond_tau=0.4,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ # print('*'*20,x_T)
+ # exit(0)
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ features_adapter=features_adapter,
+ cond_tau=cond_tau
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None,
+ cond_tau=0.4):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None: # and index>=10:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next,
+ features_adapter=None if index < int(
+ (1 - cond_tau) * total_steps) else features_adapter)
+
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+ features_adapter=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c, features_adapter=features_adapter)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in, features_adapter=features_adapter).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py
new file mode 100755
index 0000000000000000000000000000000000000000..88a4d4727a4a337206ecd1dcf559ce90efa3401e
--- /dev/null
+++ b/ldm/modules/attention.py
@@ -0,0 +1,344 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+if os.environ.get("DISABLE_XFORMERS", "false").lower() == 'true':
+ XFORMERS_IS_AVAILBLE = False
+
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ # force cast to fp32 to avoid overflowing
+ if _ATTN_PRECISION =="fp32":
+ with torch.autocast(enabled=False, device_type = 'cuda'):
+ q, k = q.float(), k.float()
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ else:
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
\ No newline at end of file
diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py
new file mode 100755
index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b
--- /dev/null
+++ b/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,852 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+from ldm.modules.attention import MemoryEfficientCrossAttention
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+ print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+ out = self.proj_out(out)
+ return x+out
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None):
+ b, c, h, w = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ raise NotImplementedError()
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100755
index 0000000000000000000000000000000000000000..09972d58e1a65b88909dfe35c12c9126851da5cf
--- /dev/null
+++ b/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,798 @@
+from abc import abstractmethod
+import math
+import torch
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None, features_adapter=None, append_to_context=None, **kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+
+ if append_to_context is not None:
+ context = torch.cat([context, append_to_context], dim=1)
+
+ adapter_idx = 0
+ for id, module in enumerate(self.input_blocks):
+ h = module(h, emb, context)
+ if ((id+1)%3 == 0) and features_adapter is not None:
+ h = h + features_adapter[adapter_idx]
+ adapter_idx += 1
+ hs.append(h)
+ if features_adapter is not None:
+ assert len(features_adapter)==adapter_idx, 'Wrong features_adapter'
+
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..637363dfe34799e70cfdbcd11445212df9d9ca1f
--- /dev/null
+++ b/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,270 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
\ No newline at end of file
diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py
new file mode 100755
index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9
--- /dev/null
+++ b/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py
new file mode 100755
index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4
--- /dev/null
+++ b/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+ else torch.tensor(-1, dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/encoders/adapter.py b/ldm/modules/encoders/adapter.py
new file mode 100755
index 0000000000000000000000000000000000000000..0eef97edcaca1186835f32dc1b0c7bcb9c4bd3ec
--- /dev/null
+++ b/ldm/modules/encoders/adapter.py
@@ -0,0 +1,258 @@
+import torch
+import torch.nn as nn
+from collections import OrderedDict
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
+ super().__init__()
+ ps = ksize // 2
+ if in_c != out_c or sk == False:
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
+ else:
+ # print('n_in')
+ self.in_conv = None
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
+ self.act = nn.ReLU()
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
+ if sk == False:
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
+ else:
+ self.skep = None
+
+ self.down = down
+ if self.down == True:
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
+
+ def forward(self, x):
+ if self.down == True:
+ x = self.down_opt(x)
+ if self.in_conv is not None: # edit
+ x = self.in_conv(x)
+
+ h = self.block1(x)
+ h = self.act(h)
+ h = self.block2(h)
+ if self.skep is not None:
+ return h + self.skep(x)
+ else:
+ return h + x
+
+
+class Adapter(nn.Module):
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
+ super(Adapter, self).__init__()
+ self.unshuffle = nn.PixelUnshuffle(8)
+ self.channels = channels
+ self.nums_rb = nums_rb
+ self.body = []
+ for i in range(len(channels)):
+ for j in range(nums_rb):
+ if (i != 0) and (j == 0):
+ self.body.append(
+ ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
+ else:
+ self.body.append(
+ ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
+ self.body = nn.ModuleList(self.body)
+ self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
+
+ def forward(self, x):
+ # unshuffle
+ x = self.unshuffle(x)
+ # extract features
+ features = []
+ x = self.conv_in(x)
+ for i in range(len(self.channels)):
+ for j in range(self.nums_rb):
+ idx = i * self.nums_rb + j
+ x = self.body[idx](x)
+ features.append(x)
+
+ return features
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(
+ OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class StyleAdapter(nn.Module):
+
+ def __init__(self, width=1024, context_dim=768, num_head=8, n_layes=3, num_token=4):
+ super().__init__()
+
+ scale = width ** -0.5
+ self.transformer_layes = nn.Sequential(*[ResidualAttentionBlock(width, num_head) for _ in range(n_layes)])
+ self.num_token = num_token
+ self.style_embedding = nn.Parameter(torch.randn(1, num_token, width) * scale)
+ self.ln_post = LayerNorm(width)
+ self.ln_pre = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, context_dim))
+
+ def forward(self, x):
+ # x shape [N, HW+1, C]
+ style_embedding = self.style_embedding + torch.zeros(
+ (x.shape[0], self.num_token, self.style_embedding.shape[-1]), device=x.device)
+ x = torch.cat([x, style_embedding], dim=1)
+ x = self.ln_pre(x)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer_layes(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ x = self.ln_post(x[:, -self.num_token:, :])
+ x = x @ self.proj
+
+ return x
+
+
+class ResnetBlock_light(nn.Module):
+ def __init__(self, in_c):
+ super().__init__()
+ self.block1 = nn.Conv2d(in_c, in_c, 3, 1, 1)
+ self.act = nn.ReLU()
+ self.block2 = nn.Conv2d(in_c, in_c, 3, 1, 1)
+
+ def forward(self, x):
+ h = self.block1(x)
+ h = self.act(h)
+ h = self.block2(h)
+
+ return h + x
+
+
+class extractor(nn.Module):
+ def __init__(self, in_c, inter_c, out_c, nums_rb, down=False):
+ super().__init__()
+ self.in_conv = nn.Conv2d(in_c, inter_c, 1, 1, 0)
+ self.body = []
+ for _ in range(nums_rb):
+ self.body.append(ResnetBlock_light(inter_c))
+ self.body = nn.Sequential(*self.body)
+ self.out_conv = nn.Conv2d(inter_c, out_c, 1, 1, 0)
+ self.down = down
+ if self.down == True:
+ self.down_opt = Downsample(in_c, use_conv=False)
+
+ def forward(self, x):
+ if self.down == True:
+ x = self.down_opt(x)
+ x = self.in_conv(x)
+ x = self.body(x)
+ x = self.out_conv(x)
+
+ return x
+
+
+class Adapter_light(nn.Module):
+ def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
+ super(Adapter_light, self).__init__()
+ self.unshuffle = nn.PixelUnshuffle(8)
+ self.channels = channels
+ self.nums_rb = nums_rb
+ self.body = []
+ for i in range(len(channels)):
+ if i == 0:
+ self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
+ else:
+ self.body.append(extractor(in_c=channels[i-1], inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=True))
+ self.body = nn.ModuleList(self.body)
+
+ def forward(self, x):
+ # unshuffle
+ x = self.unshuffle(x)
+ # extract features
+ features = []
+ for i in range(len(self.channels)):
+ x = self.body[i](x)
+ features.append(x)
+
+ return features
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..d59229ac1c97980e811e3b808f3431311c4f3b7d
--- /dev/null
+++ b/ldm/modules/encoders/modules.py
@@ -0,0 +1,441 @@
+import torch
+import torch.nn as nn
+import math
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel, CLIPModel
+
+import open_clip
+import re
+from ldm.util import default, count_params
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+ def encode(self, x):
+ return x
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class'):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+
+ def forward(self, batch, key=None):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ c = self.embedding(c)
+ return c
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ #self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+ freeze=True, layer="last"): # clip-vit-base-patch32
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPModel.from_pretrained(version).text_model
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer != 'last')
+
+ if self.layer == 'penultimate':
+ z = outputs.hidden_states[-2]
+ z = self.transformer.final_layer_norm(z)
+ else:
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ #"pooled",
+ "last",
+ "penultimate"
+ ]
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+ clip_max_length=77, t5_max_length=77):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+# code from sd-webui
+re_attention = re.compile(r"""
+\\\(|
+\\\)|
+\\\[|
+\\]|
+\\\\|
+\\|
+\(|
+\[|
+:([+-]?[.\d]+)\)|
+\)|
+]|
+[^\\()\[\]:]+|
+:
+""", re.X)
+
+
+def parse_prompt_attention(text):
+ """
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
+ Accepted tokens are:
+ (abc) - increases attention to abc by a multiplier of 1.1
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
+ [abc] - decreases attention to abc by a multiplier of 1.1
+ \( - literal character '('
+ \[ - literal character '['
+ \) - literal character ')'
+ \] - literal character ']'
+ \\ - literal character '\'
+ anything else - just text
+
+ >>> parse_prompt_attention('normal text')
+ [['normal text', 1.0]]
+ >>> parse_prompt_attention('an (important) word')
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
+ >>> parse_prompt_attention('(unbalanced')
+ [['unbalanced', 1.1]]
+ >>> parse_prompt_attention('\(literal\]')
+ [['(literal]', 1.0]]
+ >>> parse_prompt_attention('(unnecessary)(parens)')
+ [['unnecessaryparens', 1.1]]
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
+ [['a ', 1.0],
+ ['house', 1.5730000000000004],
+ [' ', 1.1],
+ ['on', 1.0],
+ [' a ', 1.1],
+ ['hill', 0.55],
+ [', sun, ', 1.1],
+ ['sky', 1.4641000000000006],
+ ['.', 1.1]]
+ """
+
+ res = []
+ round_brackets = []
+ square_brackets = []
+
+ round_bracket_multiplier = 1.1
+ square_bracket_multiplier = 1 / 1.1
+
+ def multiply_range(start_position, multiplier):
+ for p in range(start_position, len(res)):
+ res[p][1] *= multiplier
+
+ for m in re_attention.finditer(text):
+ text = m.group(0)
+ weight = m.group(1)
+
+ if text.startswith('\\'):
+ res.append([text[1:], 1.0])
+ elif text == '(':
+ round_brackets.append(len(res))
+ elif text == '[':
+ square_brackets.append(len(res))
+ elif weight is not None and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), float(weight))
+ elif text == ')' and len(round_brackets) > 0:
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
+ elif text == ']' and len(square_brackets) > 0:
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
+ else:
+ res.append([text, 1.0])
+
+ for pos in round_brackets:
+ multiply_range(pos, round_bracket_multiplier)
+
+ for pos in square_brackets:
+ multiply_range(pos, square_bracket_multiplier)
+
+ if len(res) == 0:
+ res = [["", 1.0]]
+
+ # merge runs of identical weights
+ i = 0
+ while i + 1 < len(res):
+ if res[i][1] == res[i + 1][1]:
+ res[i][0] += res[i + 1][0]
+ res.pop(i + 1)
+ else:
+ i += 1
+
+ return res
+
+class WebUIFrozenCLIPEmebedder(AbstractEncoder):
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", freeze=True, layer="penultimate"):
+ super(WebUIFrozenCLIPEmebedder, self).__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPModel.from_pretrained(version).text_model
+ self.device = device
+ self.layer = layer
+ if freeze:
+ self.freeze()
+
+ self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ','][0]
+ self.comma_padding_backtrack = 20
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def tokenize(self, texts):
+ tokenized = self.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"]
+ return tokenized
+
+ def encode_with_transformers(self, tokens):
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer!='last')
+
+ if self.layer == 'penultimate':
+ z = outputs.hidden_states[-2]
+ z = self.transformer.final_layer_norm(z)
+ else:
+ z = outputs.last_hidden_state
+
+ return z
+
+ def tokenize_line(self, line):
+ parsed = parse_prompt_attention(line)
+ # print(parsed)
+
+ tokenized = self.tokenize([text for text, _ in parsed])
+
+ remade_tokens = []
+ multipliers = []
+ last_comma = -1
+
+ for tokens, (text, weight) in zip(tokenized, parsed):
+ i = 0
+ while i < len(tokens):
+ token = tokens[i]
+
+ if token == self.comma_token:
+ last_comma = len(remade_tokens)
+ elif self.comma_padding_backtrack != 0 and max(len(remade_tokens),
+ 1) % 75 == 0 and last_comma != -1 and len(
+ remade_tokens) - last_comma <= self.comma_padding_backtrack:
+ last_comma += 1
+ reloc_tokens = remade_tokens[last_comma:]
+ reloc_mults = multipliers[last_comma:]
+
+ remade_tokens = remade_tokens[:last_comma]
+ length = len(remade_tokens)
+
+ rem = int(math.ceil(length / 75)) * 75 - length
+ remade_tokens += [self.tokenizer.eos_token_id] * rem + reloc_tokens
+ multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
+
+ remade_tokens.append(token)
+ multipliers.append(weight)
+ i += 1
+
+ token_count = len(remade_tokens)
+ prompt_target_length = math.ceil(max(token_count, 1) / 75) * 75
+ tokens_to_add = prompt_target_length - len(remade_tokens)
+
+ remade_tokens = remade_tokens + [self.tokenizer.eos_token_id] * tokens_to_add
+ multipliers = multipliers + [1.0] * tokens_to_add
+
+ return remade_tokens, multipliers, token_count
+
+ def process_text(self, texts):
+ remade_batch_tokens = []
+ token_count = 0
+
+ cache = {}
+ batch_multipliers = []
+ for line in texts:
+ if line in cache:
+ remade_tokens, multipliers = cache[line]
+ else:
+ remade_tokens, multipliers, current_token_count = self.tokenize_line(line)
+ token_count = max(current_token_count, token_count)
+
+ cache[line] = (remade_tokens, multipliers)
+
+ remade_batch_tokens.append(remade_tokens)
+ batch_multipliers.append(multipliers)
+
+ return batch_multipliers, remade_batch_tokens, token_count
+
+ def process_tokens(self, remade_batch_tokens, batch_multipliers):
+ remade_batch_tokens = [[self.tokenizer.bos_token_id] + x[:75] + [self.tokenizer.eos_token_id] for x in remade_batch_tokens]
+ batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
+
+ tokens = torch.asarray(remade_batch_tokens).to(self.device)
+
+ z = self.encode_with_transformers(tokens)
+
+ # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
+ batch_multipliers_of_same_length = [x + [1.0] * (75 - len(x)) for x in batch_multipliers]
+ batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(self.device)
+ original_mean = z.mean()
+ z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
+ new_mean = z.mean()
+ z *= original_mean / new_mean
+
+ return z
+
+ def forward(self, text):
+ batch_multipliers, remade_batch_tokens, token_count = self.process_text(text)
+
+ z = None
+ i = 0
+ while max(map(len, remade_batch_tokens)) != 0:
+ rem_tokens = [x[75:] for x in remade_batch_tokens]
+ rem_multipliers = [x[75:] for x in batch_multipliers]
+
+ tokens = []
+ multipliers = []
+ for j in range(len(remade_batch_tokens)):
+ if len(remade_batch_tokens[j]) > 0:
+ tokens.append(remade_batch_tokens[j][:75])
+ multipliers.append(batch_multipliers[j][:75])
+ else:
+ tokens.append([self.tokenizer.eos_token_id] * 75)
+ multipliers.append([1.0] * 75)
+
+ z1 = self.process_tokens(tokens, multipliers)
+ z = z1 if z is None else torch.cat((z, z1), axis=-2)
+
+ remade_batch_tokens = rem_tokens
+ batch_multipliers = rem_multipliers
+ i += 1
+
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+
+if __name__ == "__main__":
+ model = FrozenCLIPEmbedder()
+ count_params(model, verbose=True)
diff --git a/ldm/modules/extra_condition/__init__.py b/ldm/modules/extra_condition/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26
--- /dev/null
+++ b/ldm/modules/extra_condition/__init__.py
@@ -0,0 +1 @@
+# -*- coding: utf-8 -*-
diff --git a/ldm/modules/extra_condition/api.py b/ldm/modules/extra_condition/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6968ef9dd4a087c862f8e66b05108eb12f671f4
--- /dev/null
+++ b/ldm/modules/extra_condition/api.py
@@ -0,0 +1,269 @@
+from enum import Enum, unique
+
+import cv2
+import torch
+from basicsr.utils import img2tensor
+from ldm.util import resize_numpy_image
+from PIL import Image
+from torch import autocast
+
+
+@unique
+class ExtraCondition(Enum):
+ sketch = 0
+ keypose = 1
+ seg = 2
+ depth = 3
+ canny = 4
+ style = 5
+ color = 6
+ openpose = 7
+
+
+def get_cond_model(opt, cond_type: ExtraCondition):
+ if cond_type == ExtraCondition.sketch:
+ from ldm.modules.extra_condition.model_edge import pidinet
+ model = pidinet()
+ ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
+ model.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()}, strict=True)
+ model.to(opt.device)
+ return model
+ elif cond_type == ExtraCondition.seg:
+ raise NotImplementedError
+ elif cond_type == ExtraCondition.keypose:
+ import mmcv
+ from mmdet.apis import init_detector
+ from mmpose.apis import init_pose_model
+ det_config = 'configs/mm/faster_rcnn_r50_fpn_coco.py'
+ det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
+ pose_config = 'configs/mm/hrnet_w48_coco_256x192.py'
+ pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
+ det_config_mmcv = mmcv.Config.fromfile(det_config)
+ det_model = init_detector(det_config_mmcv, det_checkpoint, device=opt.device)
+ pose_config_mmcv = mmcv.Config.fromfile(pose_config)
+ pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=opt.device)
+ return {'pose_model': pose_model, 'det_model': det_model}
+ elif cond_type == ExtraCondition.depth:
+ from ldm.modules.extra_condition.midas.api import MiDaSInference
+ model = MiDaSInference(model_type='dpt_hybrid').to(opt.device)
+ return model
+ elif cond_type == ExtraCondition.canny:
+ return None
+ elif cond_type == ExtraCondition.style:
+ from transformers import CLIPProcessor, CLIPVisionModel
+ version = 'openai/clip-vit-large-patch14'
+ processor = CLIPProcessor.from_pretrained(version)
+ clip_vision_model = CLIPVisionModel.from_pretrained(version).to(opt.device)
+ return {'processor': processor, 'clip_vision_model': clip_vision_model}
+ elif cond_type == ExtraCondition.color:
+ return None
+ elif cond_type == ExtraCondition.openpose:
+ from ldm.modules.extra_condition.openpose.api import OpenposeInference
+ model = OpenposeInference().to(opt.device)
+ return model
+ else:
+ raise NotImplementedError
+
+
+def get_cond_sketch(opt, cond_image, cond_inp_type, cond_model=None):
+ if isinstance(cond_image, str):
+ edge = cv2.imread(cond_image)
+ else:
+ # for gradio input, pay attention, it's rgb numpy
+ edge = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
+ edge = resize_numpy_image(edge, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
+ opt.H, opt.W = edge.shape[:2]
+ if cond_inp_type == 'sketch':
+ edge = img2tensor(edge)[0].unsqueeze(0).unsqueeze(0) / 255.
+ edge = edge.to(opt.device)
+ elif cond_inp_type == 'image':
+ edge = img2tensor(edge).unsqueeze(0) / 255.
+ edge = cond_model(edge.to(opt.device))[-1]
+ else:
+ raise NotImplementedError
+
+ # edge = 1-edge # for white background
+ edge = edge > 0.5
+ edge = edge.float()
+
+ return edge
+
+
+def get_cond_seg(opt, cond_image, cond_inp_type='image', cond_model=None):
+ if isinstance(cond_image, str):
+ seg = cv2.imread(cond_image)
+ else:
+ seg = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
+ seg = resize_numpy_image(seg, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
+ opt.H, opt.W = seg.shape[:2]
+ if cond_inp_type == 'seg':
+ seg = img2tensor(seg).unsqueeze(0) / 255.
+ seg = seg.to(opt.device)
+ else:
+ raise NotImplementedError
+
+ return seg
+
+
+def get_cond_keypose(opt, cond_image, cond_inp_type='image', cond_model=None):
+ if isinstance(cond_image, str):
+ pose = cv2.imread(cond_image)
+ else:
+ pose = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
+ pose = resize_numpy_image(pose, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
+ opt.H, opt.W = pose.shape[:2]
+ if cond_inp_type == 'keypose':
+ pose = img2tensor(pose).unsqueeze(0) / 255.
+ pose = pose.to(opt.device)
+ elif cond_inp_type == 'image':
+ from ldm.modules.extra_condition.utils import imshow_keypoints
+ from mmdet.apis import inference_detector
+ from mmpose.apis import (inference_top_down_pose_model, process_mmdet_results)
+
+ # mmpose seems not compatible with autocast fp16
+ with autocast("cuda", dtype=torch.float32):
+ mmdet_results = inference_detector(cond_model['det_model'], pose)
+ # keep the person class bounding boxes.
+ person_results = process_mmdet_results(mmdet_results, 1)
+
+ # optional
+ return_heatmap = False
+ dataset = cond_model['pose_model'].cfg.data['test']['type']
+
+ # e.g. use ('backbone', ) to return backbone feature
+ output_layer_names = None
+ pose_results, returned_outputs = inference_top_down_pose_model(
+ cond_model['pose_model'],
+ pose,
+ person_results,
+ bbox_thr=0.2,
+ format='xyxy',
+ dataset=dataset,
+ dataset_info=None,
+ return_heatmap=return_heatmap,
+ outputs=output_layer_names)
+
+ # show the results
+ pose = imshow_keypoints(pose, pose_results, radius=2, thickness=2)
+ pose = img2tensor(pose).unsqueeze(0) / 255.
+ pose = pose.to(opt.device)
+ else:
+ raise NotImplementedError
+
+ return pose
+
+
+def get_cond_depth(opt, cond_image, cond_inp_type='image', cond_model=None):
+ if isinstance(cond_image, str):
+ depth = cv2.imread(cond_image)
+ else:
+ depth = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
+ depth = resize_numpy_image(depth, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
+ opt.H, opt.W = depth.shape[:2]
+ if cond_inp_type == 'depth':
+ depth = img2tensor(depth).unsqueeze(0) / 255.
+ depth = depth.to(opt.device)
+ elif cond_inp_type == 'image':
+ depth = img2tensor(depth).unsqueeze(0) / 127.5 - 1.0
+ depth = cond_model(depth.to(opt.device)).repeat(1, 3, 1, 1)
+ depth -= torch.min(depth)
+ depth /= torch.max(depth)
+ else:
+ raise NotImplementedError
+
+ return depth
+
+
+def get_cond_canny(opt, cond_image, cond_inp_type='image', cond_model=None):
+ if isinstance(cond_image, str):
+ canny = cv2.imread(cond_image)
+ else:
+ canny = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
+ canny = resize_numpy_image(canny, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
+ opt.H, opt.W = canny.shape[:2]
+ if cond_inp_type == 'canny':
+ canny = img2tensor(canny)[0:1].unsqueeze(0) / 255.
+ canny = canny.to(opt.device)
+ elif cond_inp_type == 'image':
+ canny = cv2.Canny(canny, 100, 200)[..., None]
+ canny = img2tensor(canny).unsqueeze(0) / 255.
+ canny = canny.to(opt.device)
+ else:
+ raise NotImplementedError
+
+ return canny
+
+
+def get_cond_style(opt, cond_image, cond_inp_type='image', cond_model=None):
+ assert cond_inp_type == 'image'
+ if isinstance(cond_image, str):
+ style = Image.open(cond_image)
+ else:
+ # numpy image to PIL image
+ style = Image.fromarray(cond_image)
+
+ style_for_clip = cond_model['processor'](images=style, return_tensors="pt")['pixel_values']
+ style_feat = cond_model['clip_vision_model'](style_for_clip.to(opt.device))['last_hidden_state']
+
+ return style_feat
+
+
+def get_cond_color(opt, cond_image, cond_inp_type='image', cond_model=None):
+ if isinstance(cond_image, str):
+ color = cv2.imread(cond_image)
+ else:
+ color = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
+ color = resize_numpy_image(color, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
+ opt.H, opt.W = color.shape[:2]
+ if cond_inp_type == 'image':
+ color = cv2.resize(color, (opt.W//64, opt.H//64), interpolation=cv2.INTER_CUBIC)
+ color = cv2.resize(color, (opt.W, opt.H), interpolation=cv2.INTER_NEAREST)
+ color = img2tensor(color).unsqueeze(0) / 255.
+ color = color.to(opt.device)
+ return color
+
+
+def get_cond_openpose(opt, cond_image, cond_inp_type='image', cond_model=None):
+ if isinstance(cond_image, str):
+ openpose_keypose = cv2.imread(cond_image)
+ else:
+ openpose_keypose = cv2.cvtColor(cond_image, cv2.COLOR_RGB2BGR)
+ openpose_keypose = resize_numpy_image(
+ openpose_keypose, max_resolution=opt.max_resolution, resize_short_edge=opt.resize_short_edge)
+ opt.H, opt.W = openpose_keypose.shape[:2]
+ if cond_inp_type == 'openpose':
+ openpose_keypose = img2tensor(openpose_keypose).unsqueeze(0) / 255.
+ openpose_keypose = openpose_keypose.to(opt.device)
+ elif cond_inp_type == 'image':
+ with autocast('cuda', dtype=torch.float32):
+ openpose_keypose = cond_model(openpose_keypose)
+ openpose_keypose = img2tensor(openpose_keypose).unsqueeze(0) / 255.
+ openpose_keypose = openpose_keypose.to(opt.device)
+
+ else:
+ raise NotImplementedError
+
+ return openpose_keypose
+
+
+def get_adapter_feature(inputs, adapters):
+ ret_feat_map = None
+ ret_feat_seq = None
+ if not isinstance(inputs, list):
+ inputs = [inputs]
+ adapters = [adapters]
+
+ for input, adapter in zip(inputs, adapters):
+ cur_feature = adapter['model'](input)
+ if isinstance(cur_feature, list):
+ if ret_feat_map is None:
+ ret_feat_map = list(map(lambda x: x * adapter['cond_weight'], cur_feature))
+ else:
+ ret_feat_map = list(map(lambda x, y: x + y * adapter['cond_weight'], ret_feat_map, cur_feature))
+ else:
+ if ret_feat_seq is None:
+ ret_feat_seq = cur_feature
+ else:
+ ret_feat_seq = torch.cat([ret_feat_seq, cur_feature], dim=1)
+
+ return ret_feat_map, ret_feat_seq
diff --git a/ldm/modules/extra_condition/midas/__init__.py b/ldm/modules/extra_condition/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/extra_condition/midas/api.py b/ldm/modules/extra_condition/midas/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a6e194545c40ec263e65a140678b53a5a2abd54
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/api.py
@@ -0,0 +1,175 @@
+# based on https://github.com/isl-org/MiDaS
+import os
+
+import cv2
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from ldm.modules.extra_condition.midas.midas.dpt_depth import DPTDepthModel
+from ldm.modules.extra_condition.midas.midas.midas_net import MidasNet
+from ldm.modules.extra_condition.midas.midas.midas_net_custom import MidasNet_small
+from ldm.modules.extra_condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+
+ISL_PATHS = {
+ "dpt_large": "models/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "models/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+remote_model_path = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt"
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return transform
+
+
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ if not os.path.exists(model_path):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir='models')
+
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+ MODEL_TYPES_TORCH_HUB = [
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ MODEL_TYPES_ISL = [
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+
+ def forward(self, x):
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
+ # NOTE: we expect that the correct transform has been called during dataloading.
+ with torch.no_grad():
+ prediction = self.model(x)
+ prediction = torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=x.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
+ return prediction
diff --git a/ldm/modules/extra_condition/midas/midas/__init__.py b/ldm/modules/extra_condition/midas/midas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/extra_condition/midas/midas/base_model.py b/ldm/modules/extra_condition/midas/midas/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/ldm/modules/extra_condition/midas/midas/blocks.py b/ldm/modules/extra_condition/midas/midas/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/ldm/modules/extra_condition/midas/midas/dpt_depth.py b/ldm/modules/extra_condition/midas/midas/dpt_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
+
diff --git a/ldm/modules/extra_condition/midas/midas/midas_net.py b/ldm/modules/extra_condition/midas/midas/midas_net.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/ldm/modules/extra_condition/midas/midas/midas_net_custom.py b/ldm/modules/extra_condition/midas/midas/midas_net_custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/ldm/modules/extra_condition/midas/midas/transforms.py b/ldm/modules/extra_condition/midas/midas/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/ldm/modules/extra_condition/midas/midas/vit.py b/ldm/modules/extra_condition/midas/midas/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/ldm/modules/extra_condition/midas/utils.py b/ldm/modules/extra_condition/midas/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59
--- /dev/null
+++ b/ldm/modules/extra_condition/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/ldm/modules/extra_condition/model_edge.py b/ldm/modules/extra_condition/model_edge.py
new file mode 100644
index 0000000000000000000000000000000000000000..5511f1d89e30160477f37792ecc345901fe893a9
--- /dev/null
+++ b/ldm/modules/extra_condition/model_edge.py
@@ -0,0 +1,653 @@
+"""
+Author: Zhuo Su, Wenzhe Liu
+Date: Feb 18, 2021
+"""
+
+import math
+
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from basicsr.utils import img2tensor
+
+nets = {
+ 'baseline': {
+ 'layer0': 'cv',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'c-v15': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'a-v15': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'r-v15': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cv',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cv',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cv',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cvvv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'avvv4': {
+ 'layer0': 'ad',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'rvvv4': {
+ 'layer0': 'rd',
+ 'layer1': 'cv',
+ 'layer2': 'cv',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'cv',
+ 'layer6': 'cv',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'cv',
+ 'layer10': 'cv',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'cv',
+ 'layer14': 'cv',
+ 'layer15': 'cv',
+ },
+ 'cccv4': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cv',
+ },
+ 'aaav4': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'cv',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'cv',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'cv',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'cv',
+ },
+ 'rrrv4': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+ 'c16': {
+ 'layer0': 'cd',
+ 'layer1': 'cd',
+ 'layer2': 'cd',
+ 'layer3': 'cd',
+ 'layer4': 'cd',
+ 'layer5': 'cd',
+ 'layer6': 'cd',
+ 'layer7': 'cd',
+ 'layer8': 'cd',
+ 'layer9': 'cd',
+ 'layer10': 'cd',
+ 'layer11': 'cd',
+ 'layer12': 'cd',
+ 'layer13': 'cd',
+ 'layer14': 'cd',
+ 'layer15': 'cd',
+ },
+ 'a16': {
+ 'layer0': 'ad',
+ 'layer1': 'ad',
+ 'layer2': 'ad',
+ 'layer3': 'ad',
+ 'layer4': 'ad',
+ 'layer5': 'ad',
+ 'layer6': 'ad',
+ 'layer7': 'ad',
+ 'layer8': 'ad',
+ 'layer9': 'ad',
+ 'layer10': 'ad',
+ 'layer11': 'ad',
+ 'layer12': 'ad',
+ 'layer13': 'ad',
+ 'layer14': 'ad',
+ 'layer15': 'ad',
+ },
+ 'r16': {
+ 'layer0': 'rd',
+ 'layer1': 'rd',
+ 'layer2': 'rd',
+ 'layer3': 'rd',
+ 'layer4': 'rd',
+ 'layer5': 'rd',
+ 'layer6': 'rd',
+ 'layer7': 'rd',
+ 'layer8': 'rd',
+ 'layer9': 'rd',
+ 'layer10': 'rd',
+ 'layer11': 'rd',
+ 'layer12': 'rd',
+ 'layer13': 'rd',
+ 'layer14': 'rd',
+ 'layer15': 'rd',
+ },
+ 'carv4': {
+ 'layer0': 'cd',
+ 'layer1': 'ad',
+ 'layer2': 'rd',
+ 'layer3': 'cv',
+ 'layer4': 'cd',
+ 'layer5': 'ad',
+ 'layer6': 'rd',
+ 'layer7': 'cv',
+ 'layer8': 'cd',
+ 'layer9': 'ad',
+ 'layer10': 'rd',
+ 'layer11': 'cv',
+ 'layer12': 'cd',
+ 'layer13': 'ad',
+ 'layer14': 'rd',
+ 'layer15': 'cv',
+ },
+ }
+
+def createConvFunc(op_type):
+ assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type)
+ if op_type == 'cv':
+ return F.conv2d
+
+ if op_type == 'cd':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3'
+ assert padding == dilation, 'padding for cd_conv set wrong'
+
+ weights_c = weights.sum(dim=[2, 3], keepdim=True)
+ yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups)
+ y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y - yc
+ return func
+ elif op_type == 'ad':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3'
+ assert padding == dilation, 'padding for ad_conv set wrong'
+
+ shape = weights.shape
+ weights = weights.view(shape[0], shape[1], -1)
+ weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise
+ y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y
+ return func
+ elif op_type == 'rd':
+ def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2'
+ assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3'
+ padding = 2 * dilation
+
+ shape = weights.shape
+ if weights.is_cuda:
+ buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0)
+ else:
+ buffer = torch.zeros(shape[0], shape[1], 5 * 5)
+ weights = weights.view(shape[0], shape[1], -1)
+ buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:]
+ buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:]
+ buffer[:, :, 12] = 0
+ buffer = buffer.view(shape[0], shape[1], 5, 5)
+ y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+ return y
+ return func
+ else:
+ print('impossible to be here unless you force that')
+ return None
+
+class Conv2d(nn.Module):
+ def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False):
+ super(Conv2d, self).__init__()
+ if in_channels % groups != 0:
+ raise ValueError('in_channels must be divisible by groups')
+ if out_channels % groups != 0:
+ raise ValueError('out_channels must be divisible by groups')
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+ self.pdc = pdc
+
+ def reset_parameters(self):
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input):
+
+ return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
+
+class CSAM(nn.Module):
+ """
+ Compact Spatial Attention Module
+ """
+ def __init__(self, channels):
+ super(CSAM, self).__init__()
+
+ mid_channels = 4
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False)
+ self.sigmoid = nn.Sigmoid()
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ y = self.relu1(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = self.sigmoid(y)
+
+ return x * y
+
+class CDCM(nn.Module):
+ """
+ Compact Dilation Convolution based Module
+ """
+ def __init__(self, in_channels, out_channels):
+ super(CDCM, self).__init__()
+
+ self.relu1 = nn.ReLU()
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
+ self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False)
+ self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False)
+ self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False)
+ self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False)
+ nn.init.constant_(self.conv1.bias, 0)
+
+ def forward(self, x):
+ x = self.relu1(x)
+ x = self.conv1(x)
+ x1 = self.conv2_1(x)
+ x2 = self.conv2_2(x)
+ x3 = self.conv2_3(x)
+ x4 = self.conv2_4(x)
+ return x1 + x2 + x3 + x4
+
+
+class MapReduce(nn.Module):
+ """
+ Reduce feature maps into a single edge map
+ """
+ def __init__(self, channels):
+ super(MapReduce, self).__init__()
+ self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0)
+ nn.init.constant_(self.conv.bias, 0)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class PDCBlock(nn.Module):
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock, self).__init__()
+ self.stride=stride
+
+ self.stride=stride
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
+ self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+class PDCBlock_converted(nn.Module):
+ """
+ CPDC, APDC can be converted to vanilla 3x3 convolution
+ RPDC can be converted to vanilla 5x5 convolution
+ """
+ def __init__(self, pdc, inplane, ouplane, stride=1):
+ super(PDCBlock_converted, self).__init__()
+ self.stride=stride
+
+ if self.stride > 1:
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0)
+ if pdc == 'rd':
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False)
+ else:
+ self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False)
+ self.relu2 = nn.ReLU()
+ self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False)
+
+ def forward(self, x):
+ if self.stride > 1:
+ x = self.pool(x)
+ y = self.conv1(x)
+ y = self.relu2(y)
+ y = self.conv2(y)
+ if self.stride > 1:
+ x = self.shortcut(x)
+ y = y + x
+ return y
+
+class PiDiNet(nn.Module):
+ def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False):
+ super(PiDiNet, self).__init__()
+ self.sa = sa
+ if dil is not None:
+ assert isinstance(dil, int), 'dil should be an int'
+ self.dil = dil
+
+ self.fuseplanes = []
+
+ self.inplane = inplane
+ if convert:
+ if pdcs[0] == 'rd':
+ init_kernel_size = 5
+ init_padding = 2
+ else:
+ init_kernel_size = 3
+ init_padding = 1
+ self.init_block = nn.Conv2d(3, self.inplane,
+ kernel_size=init_kernel_size, padding=init_padding, bias=False)
+ block_class = PDCBlock_converted
+ else:
+ self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1)
+ block_class = PDCBlock
+
+ self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane)
+ self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane)
+ self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2)
+ self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane)
+ self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane)
+ self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 2C
+
+ inplane = self.inplane
+ self.inplane = self.inplane * 2
+ self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2)
+ self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane)
+ self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane)
+ self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2)
+ self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane)
+ self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane)
+ self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane)
+ self.fuseplanes.append(self.inplane) # 4C
+
+ self.conv_reduces = nn.ModuleList()
+ if self.sa and self.dil is not None:
+ self.attentions = nn.ModuleList()
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.attentions.append(CSAM(self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ elif self.sa:
+ self.attentions = nn.ModuleList()
+ for i in range(4):
+ self.attentions.append(CSAM(self.fuseplanes[i]))
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+ elif self.dil is not None:
+ self.dilations = nn.ModuleList()
+ for i in range(4):
+ self.dilations.append(CDCM(self.fuseplanes[i], self.dil))
+ self.conv_reduces.append(MapReduce(self.dil))
+ else:
+ for i in range(4):
+ self.conv_reduces.append(MapReduce(self.fuseplanes[i]))
+
+ self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias
+ nn.init.constant_(self.classifier.weight, 0.25)
+ nn.init.constant_(self.classifier.bias, 0)
+
+ # print('initialization done')
+
+ def get_weights(self):
+ conv_weights = []
+ bn_weights = []
+ relu_weights = []
+ for pname, p in self.named_parameters():
+ if 'bn' in pname:
+ bn_weights.append(p)
+ elif 'relu' in pname:
+ relu_weights.append(p)
+ else:
+ conv_weights.append(p)
+
+ return conv_weights, bn_weights, relu_weights
+
+ def forward(self, x):
+ H, W = x.size()[2:]
+
+ x = self.init_block(x)
+
+ x1 = self.block1_1(x)
+ x1 = self.block1_2(x1)
+ x1 = self.block1_3(x1)
+
+ x2 = self.block2_1(x1)
+ x2 = self.block2_2(x2)
+ x2 = self.block2_3(x2)
+ x2 = self.block2_4(x2)
+
+ x3 = self.block3_1(x2)
+ x3 = self.block3_2(x3)
+ x3 = self.block3_3(x3)
+ x3 = self.block3_4(x3)
+
+ x4 = self.block4_1(x3)
+ x4 = self.block4_2(x4)
+ x4 = self.block4_3(x4)
+ x4 = self.block4_4(x4)
+
+ x_fuses = []
+ if self.sa and self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](self.dilations[i](xi)))
+ elif self.sa:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.attentions[i](xi))
+ elif self.dil is not None:
+ for i, xi in enumerate([x1, x2, x3, x4]):
+ x_fuses.append(self.dilations[i](xi))
+ else:
+ x_fuses = [x1, x2, x3, x4]
+
+ e1 = self.conv_reduces[0](x_fuses[0])
+ e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False)
+
+ e2 = self.conv_reduces[1](x_fuses[1])
+ e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False)
+
+ e3 = self.conv_reduces[2](x_fuses[2])
+ e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False)
+
+ e4 = self.conv_reduces[3](x_fuses[3])
+ e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False)
+
+ outputs = [e1, e2, e3, e4]
+
+ output = self.classifier(torch.cat(outputs, dim=1))
+ #if not self.training:
+ # return torch.sigmoid(output)
+
+ outputs.append(output)
+ outputs = [torch.sigmoid(r) for r in outputs]
+ return outputs
+
+def config_model(model):
+ model_options = list(nets.keys())
+ assert model in model_options, \
+ 'unrecognized model, please choose from %s' % str(model_options)
+
+ # print(str(nets[model]))
+
+ pdcs = []
+ for i in range(16):
+ layer_name = 'layer%d' % i
+ op = nets[model][layer_name]
+ pdcs.append(createConvFunc(op))
+
+ return pdcs
+
+def pidinet():
+ pdcs = config_model('carv4')
+ dil = 24 #if args.dil else None
+ return PiDiNet(60, pdcs, dil=dil, sa=True)
+
+
+if __name__ == '__main__':
+ model = pidinet()
+ ckp = torch.load('table5_pidinet.pth')['state_dict']
+ model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
+ im = cv2.imread('examples/test_my/cat_v4.png')
+ im = img2tensor(im).unsqueeze(0)/255.
+ res = model(im)[-1]
+ res = res>0.5
+ res = res.float()
+ res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8)
+ print(res.shape)
+ cv2.imwrite('edge.png', res)
diff --git a/ldm/modules/extra_condition/openpose/__init__.py b/ldm/modules/extra_condition/openpose/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/ldm/modules/extra_condition/openpose/api.py b/ldm/modules/extra_condition/openpose/api.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbe7a8c1c0f9c035cdff8660d33348c58a0579c5
--- /dev/null
+++ b/ldm/modules/extra_condition/openpose/api.py
@@ -0,0 +1,35 @@
+import numpy as np
+import os
+import torch.nn as nn
+
+os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
+
+import cv2
+import torch
+
+from . import util
+from .body import Body
+
+remote_model_path = "https://huggingface.co/TencentARC/T2I-Adapter/blob/main/third-party-models/body_pose_model.pth"
+
+
+class OpenposeInference(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ body_modelpath = os.path.join('models', "body_pose_model.pth")
+
+ if not os.path.exists(body_modelpath):
+ from basicsr.utils.download_util import load_file_from_url
+ load_file_from_url(remote_model_path, model_dir='models')
+
+ self.body_estimation = Body(body_modelpath)
+
+ def forward(self, x):
+ x = x[:, :, ::-1].copy()
+ with torch.no_grad():
+ candidate, subset = self.body_estimation(x)
+ canvas = np.zeros_like(x)
+ canvas = util.draw_bodypose(canvas, candidate, subset)
+ canvas = cv2.cvtColor(canvas, cv2.COLOR_RGB2BGR)
+ return canvas
diff --git a/ldm/modules/extra_condition/openpose/body.py b/ldm/modules/extra_condition/openpose/body.py
new file mode 100644
index 0000000000000000000000000000000000000000..ecfa8a0946ee9f653f7c00e928ae54b0109a9bdf
--- /dev/null
+++ b/ldm/modules/extra_condition/openpose/body.py
@@ -0,0 +1,211 @@
+import cv2
+import math
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import time
+import torch
+from scipy.ndimage.filters import gaussian_filter
+from torchvision import transforms
+
+from . import util
+from .model import bodypose_model
+
+
+class Body(object):
+
+ def __init__(self, model_path):
+ self.model = bodypose_model()
+ if torch.cuda.is_available():
+ self.model = self.model.cuda()
+ print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImg):
+ # scale_search = [0.5, 1.0, 1.5, 2.0]
+ scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre1 = 0.1
+ thre2 = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19))
+ paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data)
+ Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy()
+ Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy()
+
+ # extract outputs, resize, and remove padding
+ # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs
+ paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs
+ paf = cv2.resize(paf, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ paf = cv2.resize(paf, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ heatmap_avg += heatmap_avg + heatmap / len(multiplier)
+ paf_avg += +paf / len(multiplier)
+
+ all_peaks = []
+ peak_counter = 0
+
+ for part in range(18):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+
+ map_left = np.zeros(one_heatmap.shape)
+ map_left[1:, :] = one_heatmap[:-1, :]
+ map_right = np.zeros(one_heatmap.shape)
+ map_right[:-1, :] = one_heatmap[1:, :]
+ map_up = np.zeros(one_heatmap.shape)
+ map_up[:, 1:] = one_heatmap[:, :-1]
+ map_down = np.zeros(one_heatmap.shape)
+ map_down[:, :-1] = one_heatmap[:, 1:]
+
+ peaks_binary = np.logical_and.reduce((one_heatmap >= map_left, one_heatmap >= map_right,
+ one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1))
+ peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse
+ peaks_with_score = [x + (map_ori[x[1], x[0]], ) for x in peaks]
+ peak_id = range(peak_counter, peak_counter + len(peaks))
+ peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i], ) for i in range(len(peak_id))]
+
+ all_peaks.append(peaks_with_score_and_id)
+ peak_counter += len(peaks)
+
+ # find connection in the specified sequence, center 29 is in the position 15
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+ # the middle joints heatmap correpondence
+ mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \
+ [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \
+ [55, 56], [37, 38], [45, 46]]
+
+ connection_all = []
+ special_k = []
+ mid_num = 10
+
+ for k in range(len(mapIdx)):
+ score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]]
+ candA = all_peaks[limbSeq[k][0] - 1]
+ candB = all_peaks[limbSeq[k][1] - 1]
+ nA = len(candA)
+ nB = len(candB)
+ indexA, indexB = limbSeq[k]
+ if (nA != 0 and nB != 0):
+ connection_candidate = []
+ for i in range(nA):
+ for j in range(nB):
+ vec = np.subtract(candB[j][:2], candA[i][:2])
+ norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1])
+ norm = max(0.001, norm)
+ vec = np.divide(vec, norm)
+
+ startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \
+ np.linspace(candA[i][1], candB[j][1], num=mid_num)))
+
+ vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \
+ for I in range(len(startend))])
+ vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \
+ for I in range(len(startend))])
+
+ score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1])
+ score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min(
+ 0.5 * oriImg.shape[0] / norm - 1, 0)
+ criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts)
+ criterion2 = score_with_dist_prior > 0
+ if criterion1 and criterion2:
+ connection_candidate.append(
+ [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]])
+
+ connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True)
+ connection = np.zeros((0, 5))
+ for c in range(len(connection_candidate)):
+ i, j, s = connection_candidate[c][0:3]
+ if (i not in connection[:, 3] and j not in connection[:, 4]):
+ connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]])
+ if (len(connection) >= min(nA, nB)):
+ break
+
+ connection_all.append(connection)
+ else:
+ special_k.append(k)
+ connection_all.append([])
+
+ # last number in each row is the total parts number of that person
+ # the second last number in each row is the score of the overall configuration
+ subset = -1 * np.ones((0, 20))
+ candidate = np.array([item for sublist in all_peaks for item in sublist])
+
+ for k in range(len(mapIdx)):
+ if k not in special_k:
+ partAs = connection_all[k][:, 0]
+ partBs = connection_all[k][:, 1]
+ indexA, indexB = np.array(limbSeq[k]) - 1
+
+ for i in range(len(connection_all[k])): # = 1:size(temp,1)
+ found = 0
+ subset_idx = [-1, -1]
+ for j in range(len(subset)): # 1:size(subset,1):
+ if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]:
+ subset_idx[found] = j
+ found += 1
+
+ if found == 1:
+ j = subset_idx[0]
+ if subset[j][indexB] != partBs[i]:
+ subset[j][indexB] = partBs[i]
+ subset[j][-1] += 1
+ subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+ elif found == 2: # if found 2 and disjoint, merge them
+ j1, j2 = subset_idx
+ membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2]
+ if len(np.nonzero(membership == 2)[0]) == 0: # merge
+ subset[j1][:-2] += (subset[j2][:-2] + 1)
+ subset[j1][-2:] += subset[j2][-2:]
+ subset[j1][-2] += connection_all[k][i][2]
+ subset = np.delete(subset, j2, 0)
+ else: # as like found == 1
+ subset[j1][indexB] = partBs[i]
+ subset[j1][-1] += 1
+ subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2]
+
+ # if find no partA in the subset, create a new subset
+ elif not found and k < 17:
+ row = -1 * np.ones(20)
+ row[indexA] = partAs[i]
+ row[indexB] = partBs[i]
+ row[-1] = 2
+ row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2]
+ subset = np.vstack([subset, row])
+ # delete some rows of subset which has few parts occur
+ deleteIdx = []
+ for i in range(len(subset)):
+ if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4:
+ deleteIdx.append(i)
+ subset = np.delete(subset, deleteIdx, axis=0)
+
+ # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts
+ # candidate: x, y, score, id
+ return candidate, subset
diff --git a/ldm/modules/extra_condition/openpose/hand.py b/ldm/modules/extra_condition/openpose/hand.py
new file mode 100644
index 0000000000000000000000000000000000000000..1100239e21d561cf0da050ff506bcd86c3b5fa04
--- /dev/null
+++ b/ldm/modules/extra_condition/openpose/hand.py
@@ -0,0 +1,77 @@
+import cv2
+import json
+import math
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import time
+import torch
+from scipy.ndimage.filters import gaussian_filter
+from skimage.measure import label
+
+from . import util
+from .model import handpose_model
+
+
+class Hand(object):
+
+ def __init__(self, model_path):
+ self.model = handpose_model()
+ if torch.cuda.is_available():
+ self.model = self.model.cuda()
+ print('cuda')
+ model_dict = util.transfer(self.model, torch.load(model_path))
+ self.model.load_state_dict(model_dict)
+ self.model.eval()
+
+ def __call__(self, oriImg):
+ scale_search = [0.5, 1.0, 1.5, 2.0]
+ # scale_search = [0.5]
+ boxsize = 368
+ stride = 8
+ padValue = 128
+ thre = 0.05
+ multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search]
+ heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22))
+ # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38))
+
+ for m in range(len(multiplier)):
+ scale = multiplier[m]
+ imageToTest = cv2.resize(oriImg, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+ imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue)
+ im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5
+ im = np.ascontiguousarray(im)
+
+ data = torch.from_numpy(im).float()
+ if torch.cuda.is_available():
+ data = data.cuda()
+ # data = data.permute([2, 0, 1]).unsqueeze(0).float()
+ with torch.no_grad():
+ output = self.model(data).cpu().numpy()
+ # output = self.model(data).numpy()q
+
+ # extract outputs, resize, and remove padding
+ heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps
+ heatmap = cv2.resize(heatmap, (0, 0), fx=stride, fy=stride, interpolation=cv2.INTER_CUBIC)
+ heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :]
+ heatmap = cv2.resize(heatmap, (oriImg.shape[1], oriImg.shape[0]), interpolation=cv2.INTER_CUBIC)
+
+ heatmap_avg += heatmap / len(multiplier)
+
+ all_peaks = []
+ for part in range(21):
+ map_ori = heatmap_avg[:, :, part]
+ one_heatmap = gaussian_filter(map_ori, sigma=3)
+ binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8)
+ # 全部小于阈值
+ if np.sum(binary) == 0:
+ all_peaks.append([0, 0])
+ continue
+ label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim)
+ max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1
+ label_img[label_img != max_index] = 0
+ map_ori[label_img == 0] = 0
+
+ y, x = util.npmax(map_ori)
+ all_peaks.append([x, y])
+ return np.array(all_peaks)
diff --git a/ldm/modules/extra_condition/openpose/model.py b/ldm/modules/extra_condition/openpose/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f5d8eb6b7e4af7e2a4fc21fe500b29f02ff176d
--- /dev/null
+++ b/ldm/modules/extra_condition/openpose/model.py
@@ -0,0 +1,178 @@
+import torch
+import torch.nn as nn
+from collections import OrderedDict
+
+
+def make_layers(block, no_relu_layers):
+ layers = []
+ for layer_name, v in block.items():
+ if 'pool' in layer_name:
+ layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])
+ layers.append((layer_name, layer))
+ else:
+ conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4])
+ layers.append((layer_name, conv2d))
+ if layer_name not in no_relu_layers:
+ layers.append(('relu_' + layer_name, nn.ReLU(inplace=True)))
+
+ return nn.Sequential(OrderedDict(layers))
+
+
+class bodypose_model(nn.Module):
+
+ def __init__(self):
+ super(bodypose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\
+ 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\
+ 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\
+ 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1']
+ blocks = {}
+ block0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]), ('pool1_stage1', [2, 2,
+ 0]),
+ ('conv2_1', [64, 128, 3, 1, 1]), ('conv2_2', [128, 128, 3, 1, 1]),
+ ('pool2_stage1', [2, 2, 0]), ('conv3_1', [128, 256, 3, 1, 1]),
+ ('conv3_2', [256, 256, 3, 1, 1]), ('conv3_3', [256, 256, 3, 1, 1]),
+ ('conv3_4', [256, 256, 3, 1, 1]), ('pool3_stage1', [2, 2, 0]),
+ ('conv4_1', [256, 512, 3, 1, 1]), ('conv4_2', [512, 512, 3, 1, 1]),
+ ('conv4_3_CPM', [512, 256, 3, 1, 1]), ('conv4_4_CPM', [256, 128, 3, 1, 1])])
+
+ # Stage 1
+ block1_1 = OrderedDict([('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L1', [512, 38, 1, 1, 0])])
+
+ block1_2 = OrderedDict([('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]),
+ ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]),
+ ('conv5_5_CPM_L2', [512, 19, 1, 1, 0])])
+ blocks['block1_1'] = block1_1
+ blocks['block1_2'] = block1_2
+
+ self.model0 = make_layers(block0, no_relu_layers)
+
+ # Stages 2 - 6
+ for i in range(2, 7):
+ blocks['block%d_1' % i] = OrderedDict([('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0])])
+
+ blocks['block%d_2' % i] = OrderedDict([('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]),
+ ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0])])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_1 = blocks['block1_1']
+ self.model2_1 = blocks['block2_1']
+ self.model3_1 = blocks['block3_1']
+ self.model4_1 = blocks['block4_1']
+ self.model5_1 = blocks['block5_1']
+ self.model6_1 = blocks['block6_1']
+
+ self.model1_2 = blocks['block1_2']
+ self.model2_2 = blocks['block2_2']
+ self.model3_2 = blocks['block3_2']
+ self.model4_2 = blocks['block4_2']
+ self.model5_2 = blocks['block5_2']
+ self.model6_2 = blocks['block6_2']
+
+ def forward(self, x):
+
+ out1 = self.model0(x)
+
+ out1_1 = self.model1_1(out1)
+ out1_2 = self.model1_2(out1)
+ out2 = torch.cat([out1_1, out1_2, out1], 1)
+
+ out2_1 = self.model2_1(out2)
+ out2_2 = self.model2_2(out2)
+ out3 = torch.cat([out2_1, out2_2, out1], 1)
+
+ out3_1 = self.model3_1(out3)
+ out3_2 = self.model3_2(out3)
+ out4 = torch.cat([out3_1, out3_2, out1], 1)
+
+ out4_1 = self.model4_1(out4)
+ out4_2 = self.model4_2(out4)
+ out5 = torch.cat([out4_1, out4_2, out1], 1)
+
+ out5_1 = self.model5_1(out5)
+ out5_2 = self.model5_2(out5)
+ out6 = torch.cat([out5_1, out5_2, out1], 1)
+
+ out6_1 = self.model6_1(out6)
+ out6_2 = self.model6_2(out6)
+
+ return out6_1, out6_2
+
+
+class handpose_model(nn.Module):
+
+ def __init__(self):
+ super(handpose_model, self).__init__()
+
+ # these layers have no relu layer
+ no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\
+ 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6']
+ # stage 1
+ block1_0 = OrderedDict([('conv1_1', [3, 64, 3, 1, 1]), ('conv1_2', [64, 64, 3, 1, 1]),
+ ('pool1_stage1', [2, 2, 0]), ('conv2_1', [64, 128, 3, 1, 1]),
+ ('conv2_2', [128, 128, 3, 1, 1]), ('pool2_stage1', [2, 2, 0]),
+ ('conv3_1', [128, 256, 3, 1, 1]), ('conv3_2', [256, 256, 3, 1, 1]),
+ ('conv3_3', [256, 256, 3, 1, 1]), ('conv3_4', [256, 256, 3, 1, 1]),
+ ('pool3_stage1', [2, 2, 0]), ('conv4_1', [256, 512, 3, 1, 1]),
+ ('conv4_2', [512, 512, 3, 1, 1]), ('conv4_3', [512, 512, 3, 1, 1]),
+ ('conv4_4', [512, 512, 3, 1, 1]), ('conv5_1', [512, 512, 3, 1, 1]),
+ ('conv5_2', [512, 512, 3, 1, 1]), ('conv5_3_CPM', [512, 128, 3, 1, 1])])
+
+ block1_1 = OrderedDict([('conv6_1_CPM', [128, 512, 1, 1, 0]), ('conv6_2_CPM', [512, 22, 1, 1, 0])])
+
+ blocks = {}
+ blocks['block1_0'] = block1_0
+ blocks['block1_1'] = block1_1
+
+ # stage 2-6
+ for i in range(2, 7):
+ blocks['block%d' % i] = OrderedDict([('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]),
+ ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]),
+ ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]),
+ ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0])])
+
+ for k in blocks.keys():
+ blocks[k] = make_layers(blocks[k], no_relu_layers)
+
+ self.model1_0 = blocks['block1_0']
+ self.model1_1 = blocks['block1_1']
+ self.model2 = blocks['block2']
+ self.model3 = blocks['block3']
+ self.model4 = blocks['block4']
+ self.model5 = blocks['block5']
+ self.model6 = blocks['block6']
+
+ def forward(self, x):
+ out1_0 = self.model1_0(x)
+ out1_1 = self.model1_1(out1_0)
+ concat_stage2 = torch.cat([out1_1, out1_0], 1)
+ out_stage2 = self.model2(concat_stage2)
+ concat_stage3 = torch.cat([out_stage2, out1_0], 1)
+ out_stage3 = self.model3(concat_stage3)
+ concat_stage4 = torch.cat([out_stage3, out1_0], 1)
+ out_stage4 = self.model4(concat_stage4)
+ concat_stage5 = torch.cat([out_stage4, out1_0], 1)
+ out_stage5 = self.model5(concat_stage5)
+ concat_stage6 = torch.cat([out_stage5, out1_0], 1)
+ out_stage6 = self.model6(concat_stage6)
+ return out_stage6
diff --git a/ldm/modules/extra_condition/openpose/util.py b/ldm/modules/extra_condition/openpose/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..29724d52a3863cb307945b7170e16b32a59609ae
--- /dev/null
+++ b/ldm/modules/extra_condition/openpose/util.py
@@ -0,0 +1,203 @@
+import math
+
+import cv2
+import matplotlib
+import numpy as np
+
+
+def padRightDownCorner(img, stride, padValue):
+ h = img.shape[0]
+ w = img.shape[1]
+
+ pad = 4 * [None]
+ pad[0] = 0 # up
+ pad[1] = 0 # left
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
+
+ img_padded = img
+ pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
+ pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
+ pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
+ pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
+
+ return img_padded, pad
+
+
+# transfer caffe model to pytorch which will match the layer name
+def transfer(model, model_weights):
+ transfered_model_weights = {}
+ for weights_name in model.state_dict().keys():
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
+ return transfered_model_weights
+
+
+# draw the body keypoint and lims
+def draw_bodypose(canvas, candidate, subset):
+ stickwidth = 4
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
+ [1, 16], [16, 18], [3, 17], [6, 18]]
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+ for i in range(18):
+ for n in range(len(subset)):
+ index = int(subset[n][i])
+ if index == -1:
+ continue
+ x, y = candidate[index][0:2]
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
+ for i in range(17):
+ for n in range(len(subset)):
+ index = subset[n][np.array(limbSeq[i]) - 1]
+ if -1 in index:
+ continue
+ cur_canvas = canvas.copy()
+ Y = candidate[index.astype(int), 0]
+ X = candidate[index.astype(int), 1]
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
+ cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+ # plt.imsave("preview.jpg", canvas[:, :, [2, 1, 0]])
+ # plt.imshow(canvas[:, :, [2, 1, 0]])
+ return canvas
+
+
+# image drawed by opencv is not good.
+def draw_handpose(canvas, all_hand_peaks, show_number=False):
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
+
+ for peaks in all_hand_peaks:
+ for ie, e in enumerate(edges):
+ if np.sum(np.all(peaks[e], axis=1) == 0) == 0:
+ x1, y1 = peaks[e[0]]
+ x2, y2 = peaks[e[1]]
+ cv2.line(
+ canvas, (x1, y1), (x2, y2),
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255,
+ thickness=2)
+
+ for i, keyponit in enumerate(peaks):
+ x, y = keyponit
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
+ if show_number:
+ cv2.putText(canvas, str(i), (x, y), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 0, 0), lineType=cv2.LINE_AA)
+ return canvas
+
+
+# detect hand according to body pose keypoints
+# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
+def handDetect(candidate, subset, oriImg):
+ # right hand: wrist 4, elbow 3, shoulder 2
+ # left hand: wrist 7, elbow 6, shoulder 5
+ ratioWristElbow = 0.33
+ detect_result = []
+ image_height, image_width = oriImg.shape[0:2]
+ for person in subset.astype(int):
+ # if any of three not detected
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
+ if not (has_left or has_right):
+ continue
+ hands = []
+ #left hand
+ if has_left:
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
+ x1, y1 = candidate[left_shoulder_index][:2]
+ x2, y2 = candidate[left_elbow_index][:2]
+ x3, y3 = candidate[left_wrist_index][:2]
+ hands.append([x1, y1, x2, y2, x3, y3, True])
+ # right hand
+ if has_right:
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
+ x1, y1 = candidate[right_shoulder_index][:2]
+ x2, y2 = candidate[right_elbow_index][:2]
+ x3, y3 = candidate[right_wrist_index][:2]
+ hands.append([x1, y1, x2, y2, x3, y3, False])
+
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
+ x = x3 + ratioWristElbow * (x3 - x2)
+ y = y3 + ratioWristElbow * (y3 - y2)
+ distanceWristElbow = math.sqrt((x3 - x2)**2 + (y3 - y2)**2)
+ distanceElbowShoulder = math.sqrt((x2 - x1)**2 + (y2 - y1)**2)
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
+ # x-y refers to the center --> offset to topLeft point
+ # handRectangle.x -= handRectangle.width / 2.f;
+ # handRectangle.y -= handRectangle.height / 2.f;
+ x -= width / 2
+ y -= width / 2 # width = height
+ # overflow the image
+ if x < 0: x = 0
+ if y < 0: y = 0
+ width1 = width
+ width2 = width
+ if x + width > image_width: width1 = image_width - x
+ if y + width > image_height: width2 = image_height - y
+ width = min(width1, width2)
+ # the max hand box value is 20 pixels
+ if width >= 20:
+ detect_result.append([int(x), int(y), int(width), is_left])
+ '''
+ return value: [[x, y, w, True if left hand else False]].
+ width=height since the network require squared input.
+ x, y is the coordinate of top left
+ '''
+ return detect_result
+
+
+# get max index of 2d array
+def npmax(array):
+ arrayindex = array.argmax(1)
+ arrayvalue = array.max(1)
+ i = arrayvalue.argmax()
+ j = arrayindex[i]
+ return i, j
+
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def resize_image(input_image, resolution):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ k = float(resolution) / min(H, W)
+ H *= k
+ W *= k
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+ return img
diff --git a/ldm/modules/extra_condition/utils.py b/ldm/modules/extra_condition/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..af6bcb9e1116a431a39579f4bbdde3a9e868e0b4
--- /dev/null
+++ b/ldm/modules/extra_condition/utils.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+import cv2
+import numpy as np
+
+skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10],
+ [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
+
+pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
+ [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0],
+ [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
+
+pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
+ [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0],
+ [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
+ [51, 153, 255], [51, 153, 255], [51, 153, 255]]
+
+
+def imshow_keypoints(img,
+ pose_result,
+ kpt_score_thr=0.1,
+ radius=2,
+ thickness=2):
+ """Draw keypoints and links on an image.
+
+ Args:
+ img (ndarry): The image to draw poses on.
+ pose_result (list[kpts]): The poses to draw. Each element kpts is
+ a set of K keypoints as an Kx3 numpy.ndarray, where each
+ keypoint is represented as x, y, score.
+ kpt_score_thr (float, optional): Minimum score of keypoints
+ to be shown. Default: 0.3.
+ thickness (int): Thickness of lines.
+ """
+
+ img_h, img_w, _ = img.shape
+ img = np.zeros(img.shape)
+
+ for idx, kpts in enumerate(pose_result):
+ if idx > 1:
+ continue
+ kpts = kpts['keypoints']
+ # print(kpts)
+ kpts = np.array(kpts, copy=False)
+
+ # draw each point on image
+ assert len(pose_kpt_color) == len(kpts)
+
+ for kid, kpt in enumerate(kpts):
+ x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
+
+ if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
+ # skip the point that should not be drawn
+ continue
+
+ color = tuple(int(c) for c in pose_kpt_color[kid])
+ cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
+
+ # draw links
+
+ for sk_id, sk in enumerate(skeleton):
+ pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
+ pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
+
+ if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
+ or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
+ or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
+ # skip the link that should not be drawn
+ continue
+ color = tuple(int(c) for c in pose_link_color[sk_id])
+ cv2.line(img, pos1, pos2, color, thickness=thickness)
+
+ return img
diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc
--- /dev/null
+++ b/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py
new file mode 100755
index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100755
index 0000000000000000000000000000000000000000..808c7f882cb75e2ba2340d5b55881d11927351f0
--- /dev/null
+++ b/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ if up:
+ image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png
new file mode 100755
index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6
Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ
diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py
new file mode 100755
index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98
--- /dev/null
+++ b/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/ldm/util.py b/ldm/util.py
new file mode 100755
index 0000000000000000000000000000000000000000..dc9e3c48b1924fbc1ac3ecdf7a2192e1a46d9228
--- /dev/null
+++ b/ldm/util.py
@@ -0,0 +1,200 @@
+import importlib
+import math
+
+import cv2
+import torch
+import numpy as np
+
+import os
+from safetensors.torch import load_file
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('assets/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+checkpoint_dict_replacements = {
+ 'cond_stage_model.transformer.text_model.embeddings.': 'cond_stage_model.transformer.embeddings.',
+ 'cond_stage_model.transformer.text_model.encoder.': 'cond_stage_model.transformer.encoder.',
+ 'cond_stage_model.transformer.text_model.final_layer_norm.': 'cond_stage_model.transformer.final_layer_norm.',
+}
+
+
+def transform_checkpoint_dict_key(k):
+ for text, replacement in checkpoint_dict_replacements.items():
+ if k.startswith(text):
+ k = replacement + k[len(text):]
+
+ return k
+
+
+def get_state_dict_from_checkpoint(pl_sd):
+ pl_sd = pl_sd.pop("state_dict", pl_sd)
+ pl_sd.pop("state_dict", None)
+
+ sd = {}
+ for k, v in pl_sd.items():
+ new_key = transform_checkpoint_dict_key(k)
+
+ if new_key is not None:
+ sd[new_key] = v
+
+ pl_sd.clear()
+ pl_sd.update(sd)
+
+ return pl_sd
+
+
+def read_state_dict(checkpoint_file, print_global_state=False):
+ _, extension = os.path.splitext(checkpoint_file)
+ if extension.lower() == ".safetensors":
+ pl_sd = load_file(checkpoint_file, device='cpu')
+ else:
+ pl_sd = torch.load(checkpoint_file, map_location='cpu')
+
+ if print_global_state and "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+
+ sd = get_state_dict_from_checkpoint(pl_sd)
+ return sd
+
+
+def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
+ print(f"Loading model from {ckpt}")
+ sd = read_state_dict(ckpt)
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ if 'anything' in ckpt.lower() and vae_ckpt is None:
+ vae_ckpt = 'models/anything-v4.0.vae.pt'
+
+ if vae_ckpt is not None and vae_ckpt != 'None':
+ print(f"Loading vae model from {vae_ckpt}")
+ vae_sd = torch.load(vae_ckpt, map_location="cpu")
+ if "global_step" in vae_sd:
+ print(f"Global Step: {vae_sd['global_step']}")
+ sd = vae_sd["state_dict"]
+ m, u = model.first_stage_model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+
+def resize_numpy_image(image, max_resolution=512 * 512, resize_short_edge=None):
+ h, w = image.shape[:2]
+ if resize_short_edge is not None:
+ k = resize_short_edge / min(h, w)
+ else:
+ k = max_resolution / (h * w)
+ k = k**0.5
+ h = int(np.round(h * k / 64)) * 64
+ w = int(np.round(w * k / 64)) * 64
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
+ return image
+
+
+# make uc and prompt shapes match via padding for long prompts
+null_cond = None
+
+def fix_cond_shapes(model, prompt_condition, uc):
+ if uc is None:
+ return prompt_condition, uc
+ global null_cond
+ if null_cond is None:
+ null_cond = model.get_learned_conditioning([""])
+ while prompt_condition.shape[1] > uc.shape[1]:
+ uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1)
+ while prompt_condition.shape[1] < uc.shape[1]:
+ prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1)
+ return prompt_condition, uc
diff --git a/requirements.txt b/requirements.txt
new file mode 100755
index 0000000000000000000000000000000000000000..22a11291fd310d712100382ddbb040399b64d94e
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+transformers==4.19.2
+diffusers==0.11.1
+invisible_watermark==0.1.5
+basicsr==1.4.2
+einops==0.6.0
+omegaconf==2.3.0
+pytorch_lightning==1.5.9
+kornia==0.6.8
+gradio
+opencv-python
+pudb
+imageio
+imageio-ffmpeg
+k-diffusion
+webdataset
+open-clip-torch
+kornia
+safetensors
+timm
diff --git a/style.css b/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..c4739b4ea5fc35e774a049e3dacc443f7f0eac19
--- /dev/null
+++ b/style.css
@@ -0,0 +1,3 @@
+h1 {
+ text-align: center;
+}
diff --git a/test_adapter.py b/test_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa8f7ae0cd5817eac836b3ab66d51480aa7bede4
--- /dev/null
+++ b/test_adapter.py
@@ -0,0 +1,80 @@
+import os
+
+import cv2
+import torch
+from basicsr.utils import tensor2img
+from pytorch_lightning import seed_everything
+from torch import autocast
+
+from ldm.inference_base import (diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models)
+from ldm.modules.extra_condition import api
+from ldm.modules.extra_condition.api import (ExtraCondition, get_adapter_feature, get_cond_model)
+
+torch.set_grad_enabled(False)
+
+
+def main():
+ supported_cond = [e.name for e in ExtraCondition]
+ parser = get_base_argument_parser()
+ parser.add_argument(
+ '--which_cond',
+ type=str,
+ required=True,
+ choices=supported_cond,
+ help='which condition modality you want to test',
+ )
+ opt = parser.parse_args()
+ which_cond = opt.which_cond
+ if opt.outdir is None:
+ opt.outdir = f'outputs/test-{which_cond}'
+ os.makedirs(opt.outdir, exist_ok=True)
+ if opt.resize_short_edge is None:
+ print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
+ opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+ # support two test mode: single image test, and batch test (through a txt file)
+ if opt.prompt.endswith('.txt'):
+ assert opt.prompt.endswith('.txt')
+ image_paths = []
+ prompts = []
+ with open(opt.prompt, 'r') as f:
+ lines = f.readlines()
+ for line in lines:
+ line = line.strip()
+ image_paths.append(line.split('; ')[0])
+ prompts.append(line.split('; ')[1])
+ else:
+ image_paths = [opt.cond_path]
+ prompts = [opt.prompt]
+ print(image_paths)
+
+ # prepare models
+ sd_model, sampler = get_sd_models(opt)
+ adapter = get_adapters(opt, getattr(ExtraCondition, which_cond))
+ cond_model = None
+ if opt.cond_inp_type == 'image':
+ cond_model = get_cond_model(opt, getattr(ExtraCondition, which_cond))
+
+ process_cond_module = getattr(api, f'get_cond_{which_cond}')
+
+ # inference
+ with torch.inference_mode(), \
+ sd_model.ema_scope(), \
+ autocast('cuda'):
+ for test_idx, (cond_path, prompt) in enumerate(zip(image_paths, prompts)):
+ seed_everything(opt.seed)
+ for v_idx in range(opt.n_samples):
+ # seed_everything(opt.seed+v_idx+test_idx)
+ cond = process_cond_module(opt, cond_path, opt.cond_inp_type, cond_model)
+
+ base_count = len(os.listdir(opt.outdir)) // 2
+ cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_{which_cond}.png'), tensor2img(cond))
+
+ adapter_features, append_to_context = get_adapter_feature(cond, adapter)
+ opt.prompt = prompt
+ result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)
+ cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(result))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/test_composable_adapters.py b/test_composable_adapters.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e814e949c381d096581d6b46029649be982a22e
--- /dev/null
+++ b/test_composable_adapters.py
@@ -0,0 +1,101 @@
+import cv2
+import os
+import torch
+from pytorch_lightning import seed_everything
+from torch import autocast
+
+from basicsr.utils import tensor2img
+from ldm.inference_base import diffusion_inference, get_adapters, get_base_argument_parser, get_sd_models
+from ldm.modules.extra_condition import api
+from ldm.modules.extra_condition.api import ExtraCondition, get_adapter_feature, get_cond_model
+
+torch.set_grad_enabled(False)
+
+
+def main():
+ supported_cond = [e.name for e in ExtraCondition]
+ parser = get_base_argument_parser()
+ for cond_name in supported_cond:
+ parser.add_argument(
+ f'--{cond_name}_path',
+ type=str,
+ default=None,
+ help=f'condition image path for {cond_name}',
+ )
+ parser.add_argument(
+ f'--{cond_name}_inp_type',
+ type=str,
+ default='image',
+ help=f'the type of the input condition image, can be image or {cond_name}',
+ choices=['image', cond_name],
+ )
+ parser.add_argument(
+ f'--{cond_name}_adapter_ckpt',
+ type=str,
+ default=None,
+ help=f'path to checkpoint of the {cond_name} adapter, '
+ f'if {cond_name}_path is not None, this should not be None too',
+ )
+ parser.add_argument(
+ f'--{cond_name}_weight',
+ type=float,
+ default=1.0,
+ help=f'the {cond_name} adapter features are multiplied by the {cond_name}_weight and then summed up together',
+ )
+ opt = parser.parse_args()
+
+ # process argument
+ activated_conds = []
+ cond_paths = []
+ adapter_ckpts = []
+ for cond_name in supported_cond:
+ if getattr(opt, f'{cond_name}_path') is None:
+ continue
+ assert getattr(opt, f'{cond_name}_adapter_ckpt') is not None, f'you should specify the {cond_name}_adapter_ckpt'
+ activated_conds.append(cond_name)
+ cond_paths.append(getattr(opt, f'{cond_name}_path'))
+ adapter_ckpts.append(getattr(opt, f'{cond_name}_adapter_ckpt'))
+ assert len(activated_conds) != 0, 'you did not input any condition'
+
+ if opt.outdir is None:
+ opt.outdir = f'outputs/test-composable-adapters'
+ os.makedirs(opt.outdir, exist_ok=True)
+ if opt.resize_short_edge is None:
+ print(f"you don't specify the resize_shot_edge, so the maximum resolution is set to {opt.max_resolution}")
+ opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+
+ # prepare models
+ adapters = []
+ cond_models = []
+ cond_inp_types = []
+ process_cond_modules = []
+ for cond_name in activated_conds:
+ adapters.append(get_adapters(opt, getattr(ExtraCondition, cond_name)))
+ cond_inp_type = getattr(opt, f'{cond_name}_inp_type', 'image')
+ if cond_inp_type == 'image':
+ cond_models.append(get_cond_model(opt, getattr(ExtraCondition, cond_name)))
+ else:
+ cond_models.append(None)
+ cond_inp_types.append(cond_inp_type)
+ process_cond_modules.append(getattr(api, f'get_cond_{cond_name}'))
+ sd_model, sampler = get_sd_models(opt)
+
+ # inference
+ with torch.inference_mode(), \
+ sd_model.ema_scope(), \
+ autocast('cuda'):
+ seed_everything(opt.seed)
+ conds = []
+ for cond_idx, cond_name in enumerate(activated_conds):
+ conds.append(process_cond_modules[cond_idx](
+ opt, cond_paths[cond_idx], cond_inp_types[cond_idx], cond_models[cond_idx],
+ ))
+ adapter_features, append_to_context = get_adapter_feature(conds, adapters)
+ for v_idx in range(opt.n_samples):
+ result = diffusion_inference(opt, sd_model, sampler, adapter_features, append_to_context)
+ base_count = len(os.listdir(opt.outdir))
+ cv2.imwrite(os.path.join(opt.outdir, f'{base_count:05}_result.png'), tensor2img(result))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/train_depth.py b/train_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..af9a203bdd8b0904440bdcc55f2127c26aab7ebd
--- /dev/null
+++ b/train_depth.py
@@ -0,0 +1,281 @@
+import argparse
+import logging
+import os
+import os.path as osp
+import torch
+from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
+ scandir)
+from basicsr.utils.options import copy_opt_file, dict2str
+from omegaconf import OmegaConf
+
+from ldm.data.dataset_depth import DepthDataset
+from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
+from ldm.modules.encoders.adapter import Adapter
+from ldm.util import load_model_from_config
+
+
+@master_only
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+ os.makedirs(osp.join(path, 'models'))
+ os.makedirs(osp.join(path, 'training_states'))
+ os.makedirs(osp.join(path, 'visualization'))
+
+
+def load_resume_state(opt):
+ resume_state_path = None
+ if opt.auto_resume:
+ state_path = osp.join('experiments', opt.name, 'training_states')
+ if osp.isdir(state_path):
+ states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
+ if len(states) != 0:
+ states = [float(v.split('.state')[0]) for v in states]
+ resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
+ opt.resume_state_path = resume_state_path
+
+ if resume_state_path is None:
+ resume_state = None
+ else:
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
+ return resume_state
+
+
+def parsr_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--bsize",
+ type=int,
+ default=8,
+ )
+ parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10000,
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=8,
+ )
+ parser.add_argument(
+ "--plms",
+ action='store_true',
+ help="use plms sampling",
+ )
+ parser.add_argument(
+ "--auto_resume",
+ action='store_true',
+ help="use plms sampling",
+ )
+ parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="models/sd-v1-4.ckpt",
+ help="path to checkpoint of model",
+ )
+ parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/stable-diffusion/sd-v1-train.yaml",
+ help="path to config which constructs model",
+ )
+ parser.add_argument(
+ "--name",
+ type=str,
+ default="train_depth",
+ help="experiment name",
+ )
+ parser.add_argument(
+ "--print_fq",
+ type=int,
+ default=100,
+ help="path to config which constructs model",
+ )
+ parser.add_argument(
+ "--H",
+ type=int,
+ default=512,
+ help="image height, in pixel space",
+ )
+ parser.add_argument(
+ "--W",
+ type=int,
+ default=512,
+ help="image width, in pixel space",
+ )
+ parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+ )
+ parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor",
+ )
+ parser.add_argument(
+ "--sample_steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+ )
+ parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=1,
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
+ )
+ parser.add_argument(
+ "--scale",
+ type=float,
+ default=7.5,
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
+ )
+ parser.add_argument(
+ "--gpus",
+ default=[0, 1, 2, 3],
+ help="gpu idx",
+ )
+ parser.add_argument(
+ '--local_rank',
+ default=0,
+ type=int,
+ help='node rank for distributed training'
+ )
+ parser.add_argument(
+ '--launcher',
+ default='pytorch',
+ type=str,
+ help='node rank for distributed training'
+ )
+ opt = parser.parse_args()
+ return opt
+
+
+def main():
+ opt = parsr_args()
+ config = OmegaConf.load(f"{opt.config}")
+
+ # distributed setting
+ init_dist(opt.launcher)
+ torch.backends.cudnn.benchmark = True
+ device = 'cuda'
+ torch.cuda.set_device(opt.local_rank)
+
+ # dataset
+ train_dataset = DepthDataset('datasets/laion_depth_meta_v1.txt')
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=opt.bsize,
+ shuffle=(train_sampler is None),
+ num_workers=opt.num_workers,
+ pin_memory=True,
+ sampler=train_sampler)
+
+ # stable diffusion
+ model = load_model_from_config(config, f"{opt.ckpt}").to(device)
+
+ # depth encoder
+ model_ad = Adapter(cin=3 * 64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(
+ device)
+
+ # to gpus
+ model_ad = torch.nn.parallel.DistributedDataParallel(
+ model_ad,
+ device_ids=[opt.local_rank],
+ output_device=opt.local_rank)
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[opt.local_rank],
+ output_device=opt.local_rank)
+
+ # optimizer
+ params = list(model_ad.parameters())
+ optimizer = torch.optim.AdamW(params, lr=config['training']['lr'])
+
+ experiments_root = osp.join('experiments', opt.name)
+
+ # resume state
+ resume_state = load_resume_state(opt)
+ if resume_state is None:
+ mkdir_and_rename(experiments_root)
+ start_epoch = 0
+ current_iter = 0
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
+ # Otherwise the logger will not be properly initialized
+ log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(config))
+ else:
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
+ # Otherwise the logger will not be properly initialized
+ log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(config))
+ resume_optimizers = resume_state['optimizers']
+ optimizer.load_state_dict(resume_optimizers)
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
+ start_epoch = resume_state['epoch']
+ current_iter = resume_state['iter']
+
+ # copy the yml file to the experiment root
+ copy_opt_file(opt.config, experiments_root)
+
+ # training
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
+ for epoch in range(start_epoch, opt.epochs):
+ train_dataloader.sampler.set_epoch(epoch)
+ # train
+ for _, data in enumerate(train_dataloader):
+ current_iter += 1
+ with torch.no_grad():
+ c = model.module.get_learned_conditioning(data['sentence'])
+ z = model.module.encode_first_stage((data['im'] * 2 - 1.).to(device))
+ z = model.module.get_first_stage_encoding(z)
+
+ optimizer.zero_grad()
+ model.zero_grad()
+ features_adapter = model_ad(data['depth'].to(device))
+ l_pixel, loss_dict = model(z, c=c, features_adapter=features_adapter)
+ l_pixel.backward()
+ optimizer.step()
+
+ if (current_iter + 1) % opt.print_fq == 0:
+ logger.info(loss_dict)
+
+ # save checkpoint
+ rank, _ = get_dist_info()
+ if (rank == 0) and ((current_iter + 1) % config['training']['save_freq'] == 0):
+ save_filename = f'model_ad_{current_iter + 1}.pth'
+ save_path = os.path.join(experiments_root, 'models', save_filename)
+ save_dict = {}
+ state_dict = model_ad.state_dict()
+ for key, param in state_dict.items():
+ if key.startswith('module.'): # remove unnecessary 'module.'
+ key = key[7:]
+ save_dict[key] = param.cpu()
+ torch.save(save_dict, save_path)
+ # save state
+ state = {'epoch': epoch, 'iter': current_iter + 1, 'optimizers': optimizer.state_dict()}
+ save_filename = f'{current_iter + 1}.state'
+ save_path = os.path.join(experiments_root, 'training_states', save_filename)
+ torch.save(state, save_path)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/train_seg.py b/train_seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..82ed0724ef757a93e9f9fdd4ef3ada4a0203f906
--- /dev/null
+++ b/train_seg.py
@@ -0,0 +1,372 @@
+import cv2
+import torch
+import os
+from basicsr.utils import img2tensor, tensor2img, scandir, get_time_str, get_root_logger, get_env_info
+from ldm.data.dataset_coco import dataset_coco_mask_color
+import argparse
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.models.diffusion.dpm_solver import DPMSolverSampler
+from omegaconf import OmegaConf
+from ldm.util import instantiate_from_config
+from ldm.modules.encoders.adapter import Adapter
+from PIL import Image
+import numpy as np
+import torch.nn as nn
+import matplotlib.pyplot as plt
+import time
+import os.path as osp
+from basicsr.utils.options import copy_opt_file, dict2str
+import logging
+from dist_util import init_dist, master_only, get_bare_model, get_dist_info
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+@master_only
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+ os.makedirs(osp.join(experiments_root, 'models'))
+ os.makedirs(osp.join(experiments_root, 'training_states'))
+ os.makedirs(osp.join(experiments_root, 'visualization'))
+
+def load_resume_state(opt):
+ resume_state_path = None
+ if opt.auto_resume:
+ state_path = osp.join('experiments', opt.name, 'training_states')
+ if osp.isdir(state_path):
+ states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
+ if len(states) != 0:
+ states = [float(v.split('.state')[0]) for v in states]
+ resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
+ opt.resume_state_path = resume_state_path
+ # else:
+ # if opt['path'].get('resume_state'):
+ # resume_state_path = opt['path']['resume_state']
+
+ if resume_state_path is None:
+ resume_state = None
+ else:
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
+ # check_resume(opt, resume_state['iter'])
+ return resume_state
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--bsize",
+ type=int,
+ default=8,
+ help="the prompt to render"
+)
+parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10000,
+ help="the prompt to render"
+)
+parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=8,
+ help="the prompt to render"
+)
+parser.add_argument(
+ "--use_shuffle",
+ type=bool,
+ default=True,
+ help="the prompt to render"
+)
+parser.add_argument(
+ "--dpm_solver",
+ action='store_true',
+ help="use dpm_solver sampling",
+)
+parser.add_argument(
+ "--plms",
+ action='store_true',
+ help="use plms sampling",
+)
+parser.add_argument(
+ "--auto_resume",
+ action='store_true',
+ help="use plms sampling",
+)
+parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="ckp/sd-v1-4.ckpt",
+ help="path to checkpoint of model",
+)
+parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/stable-diffusion/train_mask.yaml",
+ help="path to config which constructs model",
+)
+parser.add_argument(
+ "--print_fq",
+ type=int,
+ default=100,
+ help="path to config which constructs model",
+)
+parser.add_argument(
+ "--H",
+ type=int,
+ default=512,
+ help="image height, in pixel space",
+)
+parser.add_argument(
+ "--W",
+ type=int,
+ default=512,
+ help="image width, in pixel space",
+)
+parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+)
+parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor",
+)
+parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+)
+parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=1,
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
+)
+parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=0.0,
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
+)
+parser.add_argument(
+ "--scale",
+ type=float,
+ default=7.5,
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
+)
+parser.add_argument(
+ "--gpus",
+ default=[0,1,2,3],
+ help="gpu idx",
+)
+parser.add_argument(
+ '--local_rank',
+ default=0,
+ type=int,
+ help='node rank for distributed training'
+)
+parser.add_argument(
+ '--launcher',
+ default='pytorch',
+ type=str,
+ help='node rank for distributed training'
+)
+opt = parser.parse_args()
+
+if __name__ == '__main__':
+ config = OmegaConf.load(f"{opt.config}")
+ opt.name = config['name']
+
+ # distributed setting
+ init_dist(opt.launcher)
+ torch.backends.cudnn.benchmark = True
+ device='cuda'
+ torch.cuda.set_device(opt.local_rank)
+
+ # dataset
+ path_json_train = 'coco_stuff/mask/annotations/captions_train2017.json'
+ path_json_val = 'coco_stuff/mask/annotations/captions_val2017.json'
+ train_dataset = dataset_coco_mask_color(path_json_train,
+ root_path_im='coco/train2017',
+ root_path_mask='coco_stuff/mask/train2017_color',
+ image_size=512
+ )
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+ val_dataset = dataset_coco_mask_color(path_json_val,
+ root_path_im='coco/val2017',
+ root_path_mask='coco_stuff/mask/val2017_color',
+ image_size=512
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=opt.bsize,
+ shuffle=(train_sampler is None),
+ num_workers=opt.num_workers,
+ pin_memory=True,
+ sampler=train_sampler)
+ val_dataloader = torch.utils.data.DataLoader(
+ val_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ pin_memory=False)
+
+ # stable diffusion
+ model = load_model_from_config(config, f"{opt.ckpt}").to(device)
+
+ # sketch encoder
+ model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
+
+
+ # to gpus
+ model_ad = torch.nn.parallel.DistributedDataParallel(
+ model_ad,
+ device_ids=[opt.local_rank],
+ output_device=opt.local_rank)
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[opt.local_rank],
+ output_device=opt.local_rank)
+ # device_ids=[torch.cuda.current_device()])
+
+ # optimizer
+ params = list(model_ad.parameters())
+ optimizer = torch.optim.AdamW(params, lr=config['training']['lr'])
+
+ experiments_root = osp.join('experiments', opt.name)
+
+ # resume state
+ resume_state = load_resume_state(opt)
+ if resume_state is None:
+ mkdir_and_rename(experiments_root)
+ start_epoch = 0
+ current_iter = 0
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
+ # Otherwise the logger will not be properly initialized
+ log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(config))
+ else:
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
+ # Otherwise the logger will not be properly initialized
+ log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(config))
+ resume_optimizers = resume_state['optimizers']
+ optimizer.load_state_dict(resume_optimizers)
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
+ start_epoch = resume_state['epoch']
+ current_iter = resume_state['iter']
+
+ # copy the yml file to the experiment root
+ copy_opt_file(opt.config, experiments_root)
+
+ # training
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
+ for epoch in range(start_epoch, opt.epochs):
+ train_dataloader.sampler.set_epoch(epoch)
+ # train
+ for _, data in enumerate(train_dataloader):
+ current_iter += 1
+ with torch.no_grad():
+ c = model.module.get_learned_conditioning(data['sentence'])
+ z = model.module.encode_first_stage((data['im']*2-1.).cuda(non_blocking=True))
+ z = model.module.get_first_stage_encoding(z)
+
+ mask = data['mask']
+ optimizer.zero_grad()
+ model.zero_grad()
+ features_adapter = model_ad(mask)
+ l_pixel, loss_dict = model(z, c=c, features_adapter = features_adapter)
+ l_pixel.backward()
+ optimizer.step()
+
+ if (current_iter+1)%opt.print_fq == 0:
+ logger.info(loss_dict)
+
+ # save checkpoint
+ rank, _ = get_dist_info()
+ if (rank==0) and ((current_iter+1)%config['training']['save_freq'] == 0):
+ save_filename = f'model_ad_{current_iter+1}.pth'
+ save_path = os.path.join(experiments_root, 'models', save_filename)
+ save_dict = {}
+ model_ad_bare = get_bare_model(model_ad)
+ state_dict = model_ad_bare.state_dict()
+ for key, param in state_dict.items():
+ if key.startswith('module.'): # remove unnecessary 'module.'
+ key = key[7:]
+ save_dict[key] = param.cpu()
+ torch.save(save_dict, save_path)
+ # save state
+ state = {'epoch': epoch, 'iter': current_iter+1, 'optimizers': optimizer.state_dict()}
+ save_filename = f'{current_iter+1}.state'
+ save_path = os.path.join(experiments_root, 'training_states', save_filename)
+ torch.save(state, save_path)
+
+ # val
+ rank, _ = get_dist_info()
+ if rank==0:
+ for data in val_dataloader:
+ with torch.no_grad():
+ if opt.dpm_solver:
+ sampler = DPMSolverSampler(model.module)
+ elif opt.plms:
+ sampler = PLMSSampler(model.module)
+ else:
+ sampler = DDIMSampler(model.module)
+ c = model.module.get_learned_conditioning(data['sentence'])
+ mask = data['mask']
+ im_mask = tensor2img(mask)
+ cv2.imwrite(os.path.join(experiments_root, 'visualization', 'mask_%04d.png'%epoch), im_mask)
+ features_adapter = model_ad(mask)
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
+ conditioning=c,
+ batch_size=opt.n_samples,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=model.module.get_learned_conditioning(opt.n_samples * [""]),
+ eta=opt.ddim_eta,
+ x_T=None,
+ features_adapter=features_adapter)
+ x_samples_ddim = model.module.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
+ for id_sample, x_sample in enumerate(x_samples_ddim):
+ x_sample = 255.*x_sample
+ img = x_sample.astype(np.uint8)
+ img = cv2.putText(img.copy(), data['sentence'][0], (10,30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
+ cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_e%04d_s%04d.png'%(epoch, id_sample)), img[:,:,::-1])
+ break
diff --git a/train_sketch.py b/train_sketch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9ab8d2f0e742f8a0395578b697bbad00415ccfa
--- /dev/null
+++ b/train_sketch.py
@@ -0,0 +1,399 @@
+import argparse
+import logging
+import os
+import os.path as osp
+import time
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
+ img2tensor, scandir, tensor2img)
+from basicsr.utils.options import copy_opt_file, dict2str
+from omegaconf import OmegaConf
+from PIL import Image
+
+from ldm.data.dataset_coco import dataset_coco_mask_color
+from dist_util import get_bare_model, get_dist_info, init_dist, master_only
+from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.models.diffusion.dpm_solver import DPMSolverSampler
+from ldm.models.diffusion.plms import PLMSSampler
+from ldm.modules.encoders.adapter import Adapter
+from ldm.util import instantiate_from_config
+from ldm.modules.structure_condition.model_edge import pidinet
+
+
+def load_model_from_config(config, ckpt, verbose=False):
+ print(f"Loading model from {ckpt}")
+ pl_sd = torch.load(ckpt, map_location="cpu")
+ if "global_step" in pl_sd:
+ print(f"Global Step: {pl_sd['global_step']}")
+ sd = pl_sd["state_dict"]
+ model = instantiate_from_config(config.model)
+ m, u = model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ model.cuda()
+ model.eval()
+ return model
+
+@master_only
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+ os.makedirs(osp.join(experiments_root, 'models'))
+ os.makedirs(osp.join(experiments_root, 'training_states'))
+ os.makedirs(osp.join(experiments_root, 'visualization'))
+
+def load_resume_state(opt):
+ resume_state_path = None
+ if opt.auto_resume:
+ state_path = osp.join('experiments', opt.name, 'training_states')
+ if osp.isdir(state_path):
+ states = list(scandir(state_path, suffix='state', recursive=False, full_path=False))
+ if len(states) != 0:
+ states = [float(v.split('.state')[0]) for v in states]
+ resume_state_path = osp.join(state_path, f'{max(states):.0f}.state')
+ opt.resume_state_path = resume_state_path
+ # else:
+ # if opt['path'].get('resume_state'):
+ # resume_state_path = opt['path']['resume_state']
+
+ if resume_state_path is None:
+ resume_state = None
+ else:
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id))
+ # check_resume(opt, resume_state['iter'])
+ return resume_state
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--bsize",
+ type=int,
+ default=8,
+ help="the prompt to render"
+)
+parser.add_argument(
+ "--epochs",
+ type=int,
+ default=10000,
+ help="the prompt to render"
+)
+parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=8,
+ help="the prompt to render"
+)
+parser.add_argument(
+ "--use_shuffle",
+ type=bool,
+ default=True,
+ help="the prompt to render"
+)
+parser.add_argument(
+ "--dpm_solver",
+ action='store_true',
+ help="use dpm_solver sampling",
+)
+parser.add_argument(
+ "--plms",
+ action='store_true',
+ help="use plms sampling",
+)
+parser.add_argument(
+ "--auto_resume",
+ action='store_true',
+ help="use plms sampling",
+)
+parser.add_argument(
+ "--ckpt",
+ type=str,
+ default="models/sd-v1-4.ckpt",
+ help="path to checkpoint of model",
+)
+parser.add_argument(
+ "--config",
+ type=str,
+ default="configs/stable-diffusion/train_sketch.yaml",
+ help="path to config which constructs model",
+)
+parser.add_argument(
+ "--print_fq",
+ type=int,
+ default=100,
+ help="path to config which constructs model",
+)
+parser.add_argument(
+ "--H",
+ type=int,
+ default=512,
+ help="image height, in pixel space",
+)
+parser.add_argument(
+ "--W",
+ type=int,
+ default=512,
+ help="image width, in pixel space",
+)
+parser.add_argument(
+ "--C",
+ type=int,
+ default=4,
+ help="latent channels",
+)
+parser.add_argument(
+ "--f",
+ type=int,
+ default=8,
+ help="downsampling factor",
+)
+parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="number of ddim sampling steps",
+)
+parser.add_argument(
+ "--n_samples",
+ type=int,
+ default=1,
+ help="how many samples to produce for each given prompt. A.k.a. batch size",
+)
+parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=0.0,
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
+)
+parser.add_argument(
+ "--scale",
+ type=float,
+ default=7.5,
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
+)
+parser.add_argument(
+ "--gpus",
+ default=[0,1,2,3],
+ help="gpu idx",
+)
+parser.add_argument(
+ '--local_rank',
+ default=0,
+ type=int,
+ help='node rank for distributed training'
+)
+parser.add_argument(
+ '--launcher',
+ default='pytorch',
+ type=str,
+ help='node rank for distributed training'
+)
+parser.add_argument(
+ '--l_cond',
+ default=4,
+ type=int,
+ help='number of scales'
+)
+opt = parser.parse_args()
+
+if __name__ == '__main__':
+ config = OmegaConf.load(f"{opt.config}")
+ opt.name = config['name']
+
+ # distributed setting
+ init_dist(opt.launcher)
+ torch.backends.cudnn.benchmark = True
+ device='cuda'
+ torch.cuda.set_device(opt.local_rank)
+
+ # dataset
+ path_json_train = 'coco_stuff/mask/annotations/captions_train2017.json'
+ path_json_val = 'coco_stuff/mask/annotations/captions_val2017.json'
+ train_dataset = dataset_coco_mask_color(path_json_train,
+ root_path_im='coco/train2017',
+ root_path_mask='coco_stuff/mask/train2017_color',
+ image_size=512
+ )
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
+ val_dataset = dataset_coco_mask_color(path_json_val,
+ root_path_im='coco/val2017',
+ root_path_mask='coco_stuff/mask/val2017_color',
+ image_size=512
+ )
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=opt.bsize,
+ shuffle=(train_sampler is None),
+ num_workers=opt.num_workers,
+ pin_memory=True,
+ sampler=train_sampler)
+ val_dataloader = torch.utils.data.DataLoader(
+ val_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ pin_memory=False)
+
+ # edge_generator
+ net_G = pidinet()
+ ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
+ net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
+ net_G.cuda()
+
+ # stable diffusion
+ model = load_model_from_config(config, f"{opt.ckpt}").to(device)
+
+ # sketch encoder
+ model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
+
+ # to gpus
+ model_ad = torch.nn.parallel.DistributedDataParallel(
+ model_ad,
+ device_ids=[opt.local_rank],
+ output_device=opt.local_rank)
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[opt.local_rank],
+ output_device=opt.local_rank)
+ # device_ids=[torch.cuda.current_device()])
+ net_G = torch.nn.parallel.DistributedDataParallel(
+ net_G,
+ device_ids=[opt.local_rank],
+ output_device=opt.local_rank)
+ # device_ids=[torch.cuda.current_device()])
+
+ # optimizer
+ params = list(model_ad.parameters())
+ optimizer = torch.optim.AdamW(params, lr=config['training']['lr'])
+
+ experiments_root = osp.join('experiments', opt.name)
+
+ # resume state
+ resume_state = load_resume_state(opt)
+ if resume_state is None:
+ mkdir_and_rename(experiments_root)
+ start_epoch = 0
+ current_iter = 0
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
+ # Otherwise the logger will not be properly initialized
+ log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(config))
+ else:
+ # WARNING: should not use get_root_logger in the above codes, including the called functions
+ # Otherwise the logger will not be properly initialized
+ log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(config))
+ resume_optimizers = resume_state['optimizers']
+ optimizer.load_state_dict(resume_optimizers)
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
+ start_epoch = resume_state['epoch']
+ current_iter = resume_state['iter']
+
+ # copy the yml file to the experiment root
+ copy_opt_file(opt.config, experiments_root)
+
+
+ # training
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}')
+ for epoch in range(start_epoch, opt.epochs):
+ train_dataloader.sampler.set_epoch(epoch)
+ # train
+ for _, data in enumerate(train_dataloader):
+ current_iter += 1
+ with torch.no_grad():
+ edge = net_G(data['im'].cuda(non_blocking=True))[-1]
+ edge = edge>0.5
+ edge = edge.float()
+ c = model.module.get_learned_conditioning(data['sentence'])
+ z = model.module.encode_first_stage((data['im']*2-1.).cuda(non_blocking=True))
+ z = model.module.get_first_stage_encoding(z)
+
+ optimizer.zero_grad()
+ model.zero_grad()
+ features_adapter = model_ad(edge)
+ l_pixel, loss_dict = model(z, c=c, features_adapter = features_adapter)
+ l_pixel.backward()
+ optimizer.step()
+
+ if (current_iter+1)%opt.print_fq == 0:
+ logger.info(loss_dict)
+
+ # save checkpoint
+ rank, _ = get_dist_info()
+ if (rank==0) and ((current_iter+1)%config['training']['save_freq'] == 0):
+ save_filename = f'model_ad_{current_iter+1}.pth'
+ save_path = os.path.join(experiments_root, 'models', save_filename)
+ save_dict = {}
+ model_ad_bare = get_bare_model(model_ad)
+ state_dict = model_ad_bare.state_dict()
+ for key, param in state_dict.items():
+ if key.startswith('module.'): # remove unnecessary 'module.'
+ key = key[7:]
+ save_dict[key] = param.cpu()
+ torch.save(save_dict, save_path)
+ # save state
+ state = {'epoch': epoch, 'iter': current_iter+1, 'optimizers': optimizer.state_dict()}
+ save_filename = f'{current_iter+1}.state'
+ save_path = os.path.join(experiments_root, 'training_states', save_filename)
+ torch.save(state, save_path)
+
+ # val
+ rank, _ = get_dist_info()
+ if rank==0:
+ for data in val_dataloader:
+ with torch.no_grad():
+ if opt.dpm_solver:
+ sampler = DPMSolverSampler(model.module)
+ elif opt.plms:
+ sampler = PLMSSampler(model.module)
+ else:
+ sampler = DDIMSampler(model.module)
+ print(data['im'].shape)
+ c = model.module.get_learned_conditioning(data['sentence'])
+ edge = net_G(data['im'].cuda(non_blocking=True))[-1]
+ edge = edge>0.5
+ edge = edge.float()
+ im_edge = tensor2img(edge)
+ cv2.imwrite(os.path.join(experiments_root, 'visualization', 'edge_%04d.png'%epoch), im_edge)
+ features_adapter = model_ad(edge)
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
+ conditioning=c,
+ batch_size=opt.n_samples,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=model.module.get_learned_conditioning(opt.n_samples * [""]),
+ eta=opt.ddim_eta,
+ x_T=None,
+ features_adapter=features_adapter)
+ x_samples_ddim = model.module.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
+ for id_sample, x_sample in enumerate(x_samples_ddim):
+ x_sample = 255.*x_sample
+ img = x_sample.astype(np.uint8)
+ img = cv2.putText(img.copy(), data['sentence'][0], (10,30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2)
+ cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_e%04d_s%04d.png'%(epoch, id_sample)), img[:,:,::-1])
+ break