diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..bc13a7ccd68e17685804fbbd899540be662cbda0 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/face/0229.png filter=lfs diff=lfs merge=lfs -text +examples/general/bx2vqrcj.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9b203b72ef8125b40d1f6d9b4563c87fb72e5096 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +__pycache__ +*.ckpt +*.pth +/data +/exps +*.sh +!install_env.sh +/weights +/temp +/gradio_cached_examples +/flagged diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..725dea2389b919e8b7166543977d432a4d03a22b --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 XPixelGroup + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..b956d3e8f6852ddc1eb70776f91674c84d5632b8 --- /dev/null +++ b/app.py @@ -0,0 +1,281 @@ +from typing import List +import math +import os + +import numpy as np +import torch +import einops +import pytorch_lightning as pl +import gradio as gr +from PIL import Image +from omegaconf import OmegaConf +from openxlab.model import download +from tqdm import tqdm + +from model.spaced_sampler import SpacedSampler +from model.cldm import ControlLDM +from utils.image import auto_resize, pad +from utils.common import instantiate_from_config, load_state_dict +from utils.face_restoration_helper import FaceRestoreHelper + + +# download models to local directory +download(model_repo="linxinqi/DiffBIR", model_name="diffbir_general_full_v1") +download(model_repo="linxinqi/DiffBIR", model_name="diffbir_general_swinir_v1") +download(model_repo="linxinqi/DiffBIR", model_name="diffbir_face_full_v1") + +config = "cldm.yaml" +general_full_ckpt = "general_full_v1.ckpt" +general_swinir_ckpt = "general_swinir_v1.ckpt" +face_full_ckpt = "face_full_v1.ckpt" + +# create general model +general_model: ControlLDM = instantiate_from_config(OmegaConf.load(config)).cuda() +load_state_dict(general_model, torch.load(general_full_ckpt, map_location="cuda"), strict=True) +load_state_dict(general_model.preprocess_model, torch.load(general_swinir_ckpt, map_location="cuda"), strict=True) +general_model.freeze() + +# keep a reference of general model's preprocess model and parallel model +general_preprocess_model = general_model.preprocess_model +general_control_model = general_model.control_model + +# create face model +face_model: ControlLDM = instantiate_from_config(OmegaConf.load(config)) +load_state_dict(face_model, torch.load(face_full_ckpt, map_location="cpu"), strict=True) +face_model.freeze() + +# share the pretrained weights with general model +_tmp = face_model.first_stage_model +face_model.first_stage_model = general_model.first_stage_model +del _tmp + +_tmp = face_model.cond_stage_model +face_model.cond_stage_model = general_model.cond_stage_model +del _tmp + +_tmp = face_model.model +face_model.model = general_model.model +del _tmp + +face_model.cuda() + +def to_tensor(image, device, bgr2rgb=False): + if bgr2rgb: + image = image[:, :, ::-1] + image_tensor = torch.tensor(image[None] / 255.0, dtype=torch.float32, device=device).clamp_(0, 1) + image_tensor = einops.rearrange(image_tensor, "n h w c -> n c h w").contiguous() + return image_tensor + +def to_array(image): + image = image.clamp(0, 1) + image_array = (einops.rearrange(image, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8) + return image_array + +@torch.no_grad() +def process( + control_img: Image.Image, + use_face_model: bool, + num_samples: int, + sr_scale: int, + disable_preprocess_model: bool, + strength: float, + positive_prompt: str, + negative_prompt: str, + cfg_scale: float, + steps: int, + use_color_fix: bool, + seed: int, + tiled: bool, + tile_size: int, + tile_stride: int + # progress = gr.Progress(track_tqdm=True) +) -> List[np.ndarray]: + pl.seed_everything(seed) + + global general_model + global face_model + + model = general_model + sampler = SpacedSampler(model, var_type="fixed_small") + model.control_scales = [strength] * 13 + if use_face_model: + print("use face model") + sampler_face = SpacedSampler(face_model, var_type="fixed_small") + face_model.control_scales = [strength] * 13 + + # prepare condition + if sr_scale != 1: + control_img = control_img.resize( + tuple(math.ceil(x * sr_scale) for x in control_img.size), + Image.BICUBIC + ) + input_size = control_img.size + if not tiled: + control_img = auto_resize(control_img, 512) + else: + control_img = auto_resize(control_img, tile_size) + h, w = control_img.height, control_img.width + control_img = pad(np.array(control_img), scale=64) # HWC, RGB, [0, 255] + + if use_face_model: + # set up FaceRestoreHelper + face_size = 512 + face_helper = FaceRestoreHelper(device=model.device, upscale_factor=1, face_size=face_size, use_parse=True) + # read BGR numpy [0, 255] + face_helper.read_image(np.array(control_img)[:, :, ::-1]) + # detect faces in input lq control image + face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) + face_helper.align_warp_face() + + control = to_tensor(control_img, device=model.device) + if not disable_preprocess_model: + control = model.preprocess_model(control) + height, width = control.size(-2), control.size(-1) + + preds = [] + for _ in tqdm(range(num_samples)): + shape = (1, 4, height // 8, width // 8) + x_T = torch.randn(shape, device=model.device, dtype=torch.float32) + if not tiled: + samples = sampler.sample( + steps=steps, shape=shape, cond_img=control, + positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, + cfg_scale=cfg_scale, cond_fn=None, + color_fix_type="wavelet" if use_color_fix else "none" + ) + else: + samples = sampler.sample_with_mixdiff( + tile_size=int(tile_size), tile_stride=int(tile_stride), + steps=steps, shape=shape, cond_img=control, + positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T, + cfg_scale=cfg_scale, cond_fn=None, + color_fix_type="wavelet" if use_color_fix else "none" + ) + restored_bg = to_array(samples) + + if use_face_model and len(face_helper.cropped_faces) > 0: + shape_face = (1, 4, face_size // 8, face_size // 8) + x_T_face = torch.randn(shape_face, device=model.device, dtype=torch.float32) + # face detected + for cropped_face in face_helper.cropped_faces: + cropped_face = to_tensor(cropped_face, device=model.device, bgr2rgb=True) + if not disable_preprocess_model: + cropped_face = face_model.preprocess_model(cropped_face) + samples_face = sampler_face.sample( + steps=steps, shape=shape, cond_img=cropped_face, + positive_prompt=positive_prompt, negative_prompt=negative_prompt, x_T=x_T_face, + cfg_scale=1.0, cond_fn=None, + color_fix_type="wavelet" if use_color_fix else "none" + ) + restored_face = to_array(samples_face) + face_helper.add_restored_face(restored_face[0]) + face_helper.get_inverse_affine(None) + # paste each restored face to the input image + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=restored_bg[0] + ) + + # remove padding and resize to input size + restored_img = Image.fromarray(restored_img[:h, :w, :]).resize(input_size, Image.LANCZOS) + preds.append(np.array(restored_img)) + return preds + +MAX_SIZE = int(os.getenv("MAX_SIZE")) +CONCURRENCY_COUNT = int(os.getenv("CONCURRENCY_COUNT")) + +print(f"max size = {MAX_SIZE}, concurrency_count = {CONCURRENCY_COUNT}") + +MARKDOWN = \ +""" +## DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior + +[GitHub](https://github.com/XPixelGroup/DiffBIR) | [Paper](https://arxiv.org/abs/2308.15070) | [Project Page](https://0x3f3f3f3fun.github.io/projects/diffbir/) + +If DiffBIR is helpful for you, please help star the GitHub Repo. Thanks! + +## NOTE + +1. This app processes user-uploaded images in sequence, so it may take some time before your image begins to be processed. +2. This is a publicly-used app, so please don't upload large images (>= 1024) to avoid taking up too much time. +""" + +block = gr.Blocks().queue(concurrency_count=CONCURRENCY_COUNT, max_size=MAX_SIZE) +with block: + with gr.Row(): + gr.Markdown(MARKDOWN) + with gr.Row(): + with gr.Column(): + input_image = gr.Image(source="upload", type="pil") + run_button = gr.Button(label="Run") + with gr.Accordion("Options", open=True): + use_face_model = gr.Checkbox(label="Use Face Model", value=False) + tiled = gr.Checkbox(label="Tiled", value=False) + tile_size = gr.Slider(label="Tile Size", minimum=512, maximum=1024, value=512, step=256) + tile_stride = gr.Slider(label="Tile Stride", minimum=256, maximum=512, value=256, step=128) + num_samples = gr.Slider(label="Number Of Samples", minimum=1, maximum=12, value=1, step=1) + sr_scale = gr.Number(label="SR Scale", value=1) + positive_prompt = gr.Textbox(label="Positive Prompt", value="") + negative_prompt = gr.Textbox( + label="Negative Prompt", + value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality" + ) + cfg_scale = gr.Slider(label="Classifier Free Guidance Scale (Set to a value larger than 1 to enable it!)", minimum=0.1, maximum=30.0, value=1.0, step=0.1) + strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1.0, step=0.01) + steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1) + disable_preprocess_model = gr.Checkbox(label="Disable Preprocess Model", value=False) + use_color_fix = gr.Checkbox(label="Use Color Correction", value=True) + seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=231) + with gr.Column(): + result_gallery = gr.Gallery(label="Output", show_label=False, elem_id="gallery").style(height="auto", grid=2) + # gr.Markdown("## Image Examples") + gr.Examples( + examples=[ + ["examples/face/0229.png", True, 1, 1, False, 1.0, "", "", 1.0, 50, True, 231, False, 512, 256], + ["examples/face/hermione.jpg", True, 1, 2, False, 1.0, "", "", 1.0, 50, True, 231, False, 512, 256], + ["examples/general/14.jpg", False, 1, 4, False, 1.0, "", "", 1.0, 50, True, 231, False, 512, 256], + ["examples/general/49.jpg", False, 1, 4, False, 1.0, "", "", 1.0, 50, True, 231, False, 512, 256], + ["examples/general/53.jpeg", False, 1, 4, False, 1.0, "", "", 1.0, 50, True, 231, False, 512, 256], + # ["examples/general/bx2vqrcj.png", False, 1, 4, False, 1.0, "", "", 1.0, 50, True, 231, True, 512, 256], + ], + inputs=[ + input_image, + use_face_model, + num_samples, + sr_scale, + disable_preprocess_model, + strength, + positive_prompt, + negative_prompt, + cfg_scale, + steps, + use_color_fix, + seed, + tiled, + tile_size, + tile_stride + ], + outputs=[result_gallery], + fn=process, + cache_examples=True, + ) + + inputs = [ + input_image, + use_face_model, + num_samples, + sr_scale, + disable_preprocess_model, + strength, + positive_prompt, + negative_prompt, + cfg_scale, + steps, + use_color_fix, + seed, + tiled, + tile_size, + tile_stride + ] + run_button.click(fn=process, inputs=inputs, outputs=[result_gallery]) + +block.launch() diff --git a/cldm.yaml b/cldm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52e532fb43f067ec05044864e2ee3285c1779038 --- /dev/null +++ b/cldm.yaml @@ -0,0 +1,106 @@ +target: model.cldm.ControlLDM +params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + control_key: "hint" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn + monitor: val/loss_simple_ema + scale_factor: 0.18215 + use_ema: False + + sd_locked: True + only_mid_control: False + # Learning rate. + learning_rate: 1e-4 + + control_stage_config: + target: model.cldm.ControlNet + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + hint_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + unet_config: + target: model.cldm.ControlledUnetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + preprocess_config: + target: model.swinir.SwinIR + params: + img_size: 64 + patch_size: 1 + in_chans: 3 + embed_dim: 180 + depths: [6, 6, 6, 6, 6, 6, 6, 6] + num_heads: [6, 6, 6, 6, 6, 6, 6, 6] + window_size: 8 + mlp_ratio: 2 + sf: 8 + img_range: 1.0 + upsampler: "nearest+conv" + resi_connection: "1conv" + unshuffle: True + unshuffle_scale: 8 diff --git a/examples/face/0229.png b/examples/face/0229.png new file mode 100644 index 0000000000000000000000000000000000000000..b28dda08845e9e70c0e8dbf95a85a91bfa96abb0 --- /dev/null +++ b/examples/face/0229.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9555d24b342b2d6ac453102aaecf2a32be9d2865caba1560aa7596835055ac21 +size 153945 diff --git a/examples/face/hermione.jpg b/examples/face/hermione.jpg new file mode 100644 index 0000000000000000000000000000000000000000..daabcd32f79437653a922f3747b581e6bbc13d58 Binary files /dev/null and b/examples/face/hermione.jpg differ diff --git a/examples/general/14.jpg b/examples/general/14.jpg new file mode 100644 index 0000000000000000000000000000000000000000..be45784c2e7178d1c1a165231cc52d87d343b82f Binary files /dev/null and b/examples/general/14.jpg differ diff --git a/examples/general/49.jpg b/examples/general/49.jpg new file mode 100644 index 0000000000000000000000000000000000000000..de46b72825be0f49126adc1ba2ae35d5a0a2a6d6 Binary files /dev/null and b/examples/general/49.jpg differ diff --git a/examples/general/53.jpeg b/examples/general/53.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..da0ca2a86815e931ea76df2e21a10ed9237d03ba Binary files /dev/null and b/examples/general/53.jpeg differ diff --git a/examples/general/bx2vqrcj.png b/examples/general/bx2vqrcj.png new file mode 100644 index 0000000000000000000000000000000000000000..dcd49c3c82b6e818035bf5dbccd650b0cdb64253 --- /dev/null +++ b/examples/general/bx2vqrcj.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2537ce5bbb6f303d6074bbc93bd966d6b58515328b2060da8fa2ed2d77c1c1e9 +size 402003 diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/data/util.py b/ldm/data/util.py new file mode 100644 index 0000000000000000000000000000000000000000..5b60ceb2349e3bd7900ff325740e2022d2903b1c --- /dev/null +++ b/ldm/data/util.py @@ -0,0 +1,24 @@ +import torch + +from ldm.modules.midas.api import load_midas_transform + + +class AddMiDaS(object): + def __init__(self, model_type): + super().__init__() + self.transform = load_midas_transform(model_type) + + def pt2np(self, x): + x = ((x + 1.0) * .5).detach().cpu().numpy() + return x + + def np2pt(self, x): + x = torch.from_numpy(x) * 2 - 1. + return x + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = self.pt2np(sample['jpg']) + x = self.transform({"image": x})["image"] + sample['midas_in'] = x + return sample \ No newline at end of file diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d122549995ce2cd64092c81a58419ed4a15a02fd --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,219 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config +from ldm.modules.ema import LitEma + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..27ead0ea914c64c747b64e690662899fb3801144 --- /dev/null +++ b/ldm/models/diffusion/ddim.py @@ -0,0 +1,336 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..8d20a818355d115ff17cd9fe5b717a9976a7dbd6 --- /dev/null +++ b/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1811 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from model.mixins import ImageLoggerMixin + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule, ImageLoggerMixin): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + if reset_ema: assert exists(ckpt_path) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + if reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.register_buffer('logvar', logvar) + + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys:\n {missing}") + if len(unexpected) > 0: + print(f"\nUnexpected Keys:\n {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + force_null_conditioning=False, + *args, **kwargs): + self.force_null_conditioning = force_null_conditioning + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + reset_ema = kwargs.pop("reset_ema", False) + reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + if reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + # print(self.scale_factor) + # exit() + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, return_x=False): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None and not self.force_null_conditioning: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox', "txt"]: + xc = batch[cond_key] + elif cond_key in ['class_label', 'cls']: + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_x: + out.extend([x]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + # 2023-04-08 + def decode_first_stage_with_grad(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is expected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None, **kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + if self.cond_stage_key in ["class_label", "cls"]: + xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) + return self.get_learned_conditioning(xc) + else: + raise NotImplementedError("todo") + if isinstance(c, list): # in case the encoder gives us a list + for i in range(len(c)): + c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + else: + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + return c + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', "cls"]: + try: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + except KeyError: + # probably no "human_label" in batch + pass + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + if self.model.conditioning_key == "crossattn-adm": + uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + if not self.sequential_cross_attn: + cc = torch.cat(c_crossattn, 1) + else: + cc = c_crossattn + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + self.noise_level_key = noise_level_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + if self.noise_level_key is not None: + # get noise level from batch instead, e.g. when extracting a custom noise level for bsr + raise NotImplementedError('TODO') + + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + uc[k] = c[k] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys: tuple, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log + + +class LatentInpaintDiffusion(LatentFinetuneDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, **kwargs + ): + super().__init__(concat_keys, *args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) + log["masked_image"] = rearrange(args[0]["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + c_cat = list() + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], + keepdim=True) + cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ + torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, + low_scale_config=None, low_scale_key=None, *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + print("Initializing a low-scale model") + assert exists(low_scale_key) + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = list() + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, 'b h w c -> b c h w') + if exists(self.reshuffle_patch_size): + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', + p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if exists(self.low_scale_model) and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if exists(noise_level): + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} + else: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + return log diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7427f38c07530afbab79154ea8aaf88c4bf70a08 --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..095e5ba3ce0b1aa7f4b3f1e2e5d8fff7cfe6dc8c --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1154 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( + model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + ===================================================== + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + ===================================================== + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in tqdm(range(1, order), desc="DPM init order"): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), desc="DPM multistep"): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), + N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..7d137b8cf36718c1c58faa09f9dd919e5fb2977b --- /dev/null +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,87 @@ +"""SAMPLING ONLY.""" +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + + +MODEL_TYPES = { + "eps": "noise", + "v": "v" +} + + +class DPMSolverSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) + + return x.to(device), None \ No newline at end of file diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..7002a365d27168ced0a04e9a4d83e088f8284eae --- /dev/null +++ b/ldm/models/diffusion/plms.py @@ -0,0 +1,244 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7eff02be6d7c54d43ee6680636ac0698dd3b3f33 --- /dev/null +++ b/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..509cd873768f0dd75a75ab3fcdd652822b12b59f --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,341 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from ldm.modules.diffusionmodules.util import checkpoint + + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +# CrossAttn precision handling +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b089eebbe1676d8249005bb9def002ff5180715b --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,852 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from typing import Optional, Any + +from ldm.modules.attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..7df6b5abfe8eff07f0c8e8703ba8aee90d45984b --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,786 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/ldm/modules/diffusionmodules/upscaling.py b/ldm/modules/diffusionmodules/upscaling.py new file mode 100644 index 0000000000000000000000000000000000000000..03816662098ce1ffac79bd939b892e867ab91988 --- /dev/null +++ b/ldm/modules/diffusionmodules/upscaling.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial + +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ldm.util import default + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..637363dfe34799e70cfdbcd11445212df9d9ca1f --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,270 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..bded25019b9bcbcd0260f0b8185f8c7859ca58c4 --- /dev/null +++ b/ldm/modules/ema.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4edd5496b9e668ea72a5be39db9cca94b6a42f9b --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,213 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import open_clip +from ldm.util import default, count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + diff --git a/ldm/modules/midas/__init__.py b/ldm/modules/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/midas/api.py b/ldm/modules/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..b58ebbffd942a2fc22264f0ab47e400c26b9f41c --- /dev/null +++ b/ldm/modules/midas/api.py @@ -0,0 +1,170 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from ldm.modules.midas.midas.dpt_depth import DPTDepthModel +from ldm.modules.midas.midas.midas_net import MidasNet +from ldm.modules.midas.midas.midas_net_custom import MidasNet_small +from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet + + +ISL_PATHS = { + "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array + # NOTE: we expect that the correct transform has been called during dataloading. + with torch.no_grad(): + prediction = self.model(x) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=x.shape[2:], + mode="bicubic", + align_corners=False, + ) + assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) + return prediction + diff --git a/ldm/modules/midas/midas/__init__.py b/ldm/modules/midas/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/midas/midas/base_model.py b/ldm/modules/midas/midas/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/ldm/modules/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/ldm/modules/midas/midas/blocks.py b/ldm/modules/midas/midas/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/ldm/modules/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/ldm/modules/midas/midas/dpt_depth.py b/ldm/modules/midas/midas/dpt_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/ldm/modules/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/ldm/modules/midas/midas/midas_net.py b/ldm/modules/midas/midas/midas_net.py new file mode 100644 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/ldm/modules/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/ldm/modules/midas/midas/midas_net_custom.py b/ldm/modules/midas/midas/midas_net_custom.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/ldm/modules/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/ldm/modules/midas/midas/transforms.py b/ldm/modules/midas/midas/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/ldm/modules/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/ldm/modules/midas/midas/vit.py b/ldm/modules/midas/midas/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/ldm/modules/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/ldm/modules/midas/utils.py b/ldm/modules/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/ldm/modules/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..72d4f843fe87d2ae1c6f4c94bc80ad0520fcd213 --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,198 @@ +import importlib + +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + # font = ImageFont.truetype('font/DejaVuSans.ttf', size=size) + font = ImageFont.load_default() + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/model/callbacks.py b/model/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..0574fd0b5138c9d0b9c23c9f5dcc56febbe6e102 --- /dev/null +++ b/model/callbacks.py @@ -0,0 +1,74 @@ +from typing import Dict, Any +import os + +import numpy as np +import pytorch_lightning as pl +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.types import STEP_OUTPUT +import torch +import torchvision +from PIL import Image +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities.distributed import rank_zero_only + +from .mixins import ImageLoggerMixin + + +__all__ = [ + "ModelCheckpoint", + "ImageLogger" +] + +class ImageLogger(Callback): + """ + Log images during training or validating. + + TODO: Support validating. + """ + + def __init__( + self, + log_every_n_steps: int=2000, + max_images_each_step: int=4, + log_images_kwargs: Dict[str, Any]=None + ) -> "ImageLogger": + super().__init__() + self.log_every_n_steps = log_every_n_steps + self.max_images_each_step = max_images_each_step + self.log_images_kwargs = log_images_kwargs or dict() + + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + assert isinstance(pl_module, ImageLoggerMixin) + + @rank_zero_only + def on_train_batch_end( + self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: STEP_OUTPUT, + batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: + if pl_module.global_step % self.log_every_n_steps == 0: + is_train = pl_module.training + if is_train: + pl_module.freeze() + + with torch.no_grad(): + # returned images should be: nchw, rgb, [0, 1] + images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs) + + # save images + save_dir = os.path.join(pl_module.logger.save_dir, "image_log", "train") + os.makedirs(save_dir, exist_ok=True) + for image_key in images: + image = images[image_key].detach().cpu() + N = min(self.max_images_each_step, len(image)) + grid = torchvision.utils.make_grid(image[:N], nrow=4) + # chw -> hwc (hw if gray) + grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1).numpy() + grid = (grid * 255).clip(0, 255).astype(np.uint8) + filename = "{}_step-{:06}_e-{:06}_b-{:06}.png".format( + image_key, pl_module.global_step, pl_module.current_epoch, batch_idx + ) + path = os.path.join(save_dir, filename) + Image.fromarray(grid).save(path) + + if is_train: + pl_module.unfreeze() diff --git a/model/cldm.py b/model/cldm.py new file mode 100644 index 0000000000000000000000000000000000000000..5ad85aa5589b5e4eb73b7e6b99d1b0bd05e071d5 --- /dev/null +++ b/model/cldm.py @@ -0,0 +1,411 @@ +from typing import Mapping, Any +import copy +from collections import OrderedDict + +import einops +import torch +import torch as th +import torch.nn as nn + +from ldm.modules.diffusionmodules.util import ( + conv_nd, + linear, + zero_module, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.util import log_txt_as_img, exists, instantiate_from_config +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution +from utils.common import frozen_module +from .spaced_sampler import SpacedSampler + + +class ControlledUnetModel(UNetModel): + def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs): + hs = [] + with torch.no_grad(): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + + if control is not None: + h += control.pop() + + for i, module in enumerate(self.output_blocks): + if only_mid_control or control is None: + h = torch.cat([h, hs.pop()], dim=1) + else: + h = torch.cat([h, hs.pop() + control.pop()], dim=1) + h = module(h, emb, context) + + h = h.type(x.dtype) + return self.out(h) + + +class ControlNet(nn.Module): + def __init__( + self, + image_size, + in_channels, + model_channels, + hint_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.dims = dims + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels + hint_channels, model_channels, 3, padding=1) + ) + ] + ) + self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) + + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self.zero_convs.append(self.make_zero_conv(ch)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + self.zero_convs.append(self.make_zero_conv(ch)) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self.middle_block_out = self.make_zero_conv(ch) + self._feature_size += ch + + def make_zero_conv(self, channels): + return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) + + def forward(self, x, hint, timesteps, context, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + x = torch.cat((x, hint), dim=1) + outs = [] + + h = x.type(self.dtype) + for module, zero_conv in zip(self.input_blocks, self.zero_convs): + h = module(h, emb, context) + outs.append(zero_conv(h, emb, context)) + + h = self.middle_block(h, emb, context) + outs.append(self.middle_block_out(h, emb, context)) + + return outs + + +class ControlLDM(LatentDiffusion): + + def __init__( + self, + control_stage_config: Mapping[str, Any], + control_key: str, + sd_locked: bool, + only_mid_control: bool, + learning_rate: float, + preprocess_config, + *args, + **kwargs + ) -> "ControlLDM": + super().__init__(*args, **kwargs) + # instantiate control module + self.control_model: ControlNet = instantiate_from_config(control_stage_config) + self.control_key = control_key + self.sd_locked = sd_locked + self.only_mid_control = only_mid_control + self.learning_rate = learning_rate + self.control_scales = [1.0] * 13 + + # instantiate preprocess module (SwinIR) + self.preprocess_model = instantiate_from_config(preprocess_config) + frozen_module(self.preprocess_model) + + # instantiate condition encoder, since our condition encoder has the same + # structure with AE encoder, we just make a copy of AE encoder. please + # note that AE encoder's parameters has not been initialized here. + self.cond_encoder = nn.Sequential(OrderedDict([ + ("encoder", copy.deepcopy(self.first_stage_model.encoder)), # cond_encoder.encoder + ("quant_conv", copy.deepcopy(self.first_stage_model.quant_conv)) # cond_encoder.quant_conv + ])) + frozen_module(self.cond_encoder) + + def apply_condition_encoder(self, control): + c_latent_meanvar = self.cond_encoder(control * 2 - 1) + c_latent = DiagonalGaussianDistribution(c_latent_meanvar).mode() # only use mode + c_latent = c_latent * self.scale_factor + return c_latent + + @torch.no_grad() + def get_input(self, batch, k, bs=None, *args, **kwargs): + x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs) + control = batch[self.control_key] + if bs is not None: + control = control[:bs] + control = control.to(self.device) + control = einops.rearrange(control, 'b h w c -> b c h w') + control = control.to(memory_format=torch.contiguous_format).float() + lq = control + # apply preprocess model + control = self.preprocess_model(control) + # apply condition encoder + c_latent = self.apply_condition_encoder(control) + return x, dict(c_crossattn=[c], c_latent=[c_latent], lq=[lq], c_concat=[control]) + + def apply_model(self, x_noisy, t, cond, *args, **kwargs): + assert isinstance(cond, dict) + diffusion_model = self.model.diffusion_model + + cond_txt = torch.cat(cond['c_crossattn'], 1) + + if cond['c_latent'] is None: + eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control) + else: + control = self.control_model( + x=x_noisy, hint=torch.cat(cond['c_latent'], 1), + timesteps=t, context=cond_txt + ) + control = [c * scale for c, scale in zip(control, self.control_scales)] + eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control) + + return eps + + @torch.no_grad() + def get_unconditional_conditioning(self, N): + return self.get_learned_conditioning([""] * N) + + @torch.no_grad() + def log_images(self, batch, sample_steps=50): + log = dict() + z, c = self.get_input(batch, self.first_stage_key) + c_lq = c["lq"][0] + c_latent = c["c_latent"][0] + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + + log["hq"] = (self.decode_first_stage(z) + 1) / 2 + log["control"] = c_cat + log["decoded_control"] = (self.decode_first_stage(c_latent) + 1) / 2 + log["lq"] = c_lq + log["text"] = (log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16) + 1) / 2 + + samples = self.sample_log( + # TODO: remove c_concat from cond + cond={"c_concat": [c_cat], "c_crossattn": [c], "c_latent": [c_latent]}, + steps=sample_steps + ) + x_samples = self.decode_first_stage(samples) + log["samples"] = (x_samples + 1) / 2 + + return log + + @torch.no_grad() + def sample_log(self, cond, steps): + sampler = SpacedSampler(self) + b, c, h, w = cond["c_concat"][0].shape + shape = (b, self.channels, h // 8, w // 8) + samples = sampler.sample( + steps, shape, cond, unconditional_guidance_scale=1.0, + unconditional_conditioning=None + ) + return samples + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.control_model.parameters()) + if not self.sd_locked: + params += list(self.model.diffusion_model.output_blocks.parameters()) + params += list(self.model.diffusion_model.out.parameters()) + opt = torch.optim.AdamW(params, lr=lr) + return opt + + def validation_step(self, batch, batch_idx): + # TODO: + pass diff --git a/model/cond_fn.py b/model/cond_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..10c7e4fdee33d8dc918b614452a8136352eb974d --- /dev/null +++ b/model/cond_fn.py @@ -0,0 +1,58 @@ +from typing import overload +import torch +from torch.nn import functional as F + + +class Guidance: + + def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type): + self.scale = scale + self.type = type + self.t_start = t_start + self.t_stop = t_stop + self.target = None + self.space = space + self.repeat = repeat + self.loss_type = loss_type + + def load_target(self, target): + self.target = target + + def __call__(self, target_x0, pred_x0, t): + if self.t_stop < t and t < self.t_start: + # print("sampling with classifier guidance") + # avoid propagating gradient out of this scope + pred_x0 = pred_x0.detach().clone() + target_x0 = target_x0.detach().clone() + return self.scale * self._forward(target_x0, pred_x0) + else: + return None + + @overload + def _forward(self, target_x0, pred_x0): ... + + +class MSEGuidance(Guidance): + + def __init__(self, scale, type, t_start, t_stop, space, repeat, loss_type) -> None: + super().__init__( + scale, type, t_start, t_stop, space, repeat, loss_type + ) + + @torch.enable_grad() + def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor): + # inputs: [-1, 1], nchw, rgb + pred_x0.requires_grad_(True) + + if self.loss_type == "mse": + loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum() + elif self.loss_type == "downsample_mse": + # FIXME: scale_factor should be 1/4, not 4 + lr_pred_x0 = F.interpolate(pred_x0, scale_factor=4, mode="bicubic") + lr_target_x0 = F.interpolate(target_x0, scale_factor=4, mode="bicubic") + loss = (lr_pred_x0 - lr_target_x0).pow(2).mean((1, 2, 3)).sum() + else: + raise ValueError(self.loss_type) + + print(f"loss = {loss.item()}") + return -torch.autograd.grad(loss, pred_x0)[0] diff --git a/model/mixins.py b/model/mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..a86d3c2d334b98e5de410b8af3e6f06acd4cccc2 --- /dev/null +++ b/model/mixins.py @@ -0,0 +1,9 @@ +from typing import overload, Any, Dict +import torch + + +class ImageLoggerMixin: + + @overload + def log_images(self, batch: Any, **kwargs: Dict[str, Any]) -> Dict[str, torch.Tensor]: + ... diff --git a/model/spaced_sampler.py b/model/spaced_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..ef5490c7cb519991506373bb00639437fd3e4af5 --- /dev/null +++ b/model/spaced_sampler.py @@ -0,0 +1,545 @@ +from typing import Optional, Tuple, Dict + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_beta_schedule +from model.cond_fn import Guidance +from utils.image import ( + wavelet_reconstruction, adaptive_instance_normalization +) + +# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/gaussian_diffusion.py +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res.expand(broadcast_shape) + + +class SpacedSampler: + """ + Implementation for spaced sampling schedule proposed in IDDPM. This class is designed + for sampling ControlLDM. + + https://arxiv.org/pdf/2102.09672.pdf + """ + + def __init__( + self, + model: "ControlLDM", + schedule: str="linear", + var_type: str="fixed_small" + ) -> "SpacedSampler": + self.model = model + self.original_num_steps = model.num_timesteps + self.schedule = schedule + self.var_type = var_type + + def make_schedule(self, num_steps: int) -> None: + """ + Initialize sampling parameters according to `num_steps`. + + Args: + num_steps (int): Sampling steps. + + Returns: + None + """ + # NOTE: this schedule, which generates betas linearly in log space, is a little different + # from guided diffusion. + original_betas = make_beta_schedule( + self.schedule, self.original_num_steps, linear_start=self.model.linear_start, + linear_end=self.model.linear_end + ) + original_alphas = 1.0 - original_betas + original_alphas_cumprod = np.cumprod(original_alphas, axis=0) + + # calcualte betas for spaced sampling + # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py + used_timesteps = space_timesteps(self.original_num_steps, str(num_steps)) + print(f"timesteps used in spaced sampler: \n\t{sorted(list(used_timesteps))}") + + betas = [] + last_alpha_cumprod = 1.0 + for i, alpha_cumprod in enumerate(original_alphas_cumprod): + if i in used_timesteps: + # marginal distribution is the same as q(x_{S_t}|x_0) + betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + assert len(betas) == num_steps + betas = np.array(betas, dtype=np.float64) + self.betas = betas + + self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...] + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (num_steps, ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # log calculation clipped because the posterior variance is 0 at the + # beginning of the diffusion chain. + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) + ) + + def q_sample( + self, + x_start: torch.Tensor, + t: torch.Tensor, + noise: Optional[torch.Tensor]=None + ) -> torch.Tensor: + """ + Implement the marginal distribution q(x_t|x_0). + + Args: + x_start (torch.Tensor): Images (NCHW) sampled from data distribution. + t (torch.Tensor): Timestep (N) for diffusion process. `t` serves as an index + to get parameters for each timestep. + noise (torch.Tensor, optional): Specify the noise (NCHW) added to `x_start`. + + Returns: + x_t (torch.Tensor): The noisy images. + """ + if noise is None: + noise = torch.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance( + self, + x_start: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor + ) -> Tuple[torch.Tensor]: + """ + Implement the posterior distribution q(x_{t-1}|x_t, x_0). + + Args: + x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`. + x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`. + t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get + parameters for each timestep. + + Returns: + posterior_mean (torch.Tensor): Mean of the posterior distribution. + posterior_variance (torch.Tensor): Variance of the posterior distribution. + posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution. + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def _predict_xstart_from_eps( + self, + x_t: torch.Tensor, + t: torch.Tensor, + eps: torch.Tensor + ) -> torch.Tensor: + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def predict_noise( + self, + x: torch.Tensor, + t: torch.Tensor, + cond: Dict[str, torch.Tensor], + cfg_scale: float, + uncond: Optional[Dict[str, torch.Tensor]] + ) -> torch.Tensor: + if uncond is None or cfg_scale == 1.: + model_output = self.model.apply_model(x, t, cond) + else: + # apply classifier-free guidance + model_cond = self.model.apply_model(x, t, cond) + model_uncond = self.model.apply_model(x, t, uncond) + model_output = model_uncond + cfg_scale * (model_cond - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + return e_t + + def apply_cond_fn( + self, + x: torch.Tensor, + cond: Dict[str, torch.Tensor], + t: torch.Tensor, + index: torch.Tensor, + cond_fn: Guidance, + cfg_scale: float, + uncond: Optional[Dict[str, torch.Tensor]] + ) -> torch.Tensor: + device = x.device + t_now = int(t[0].item()) + 1 + # ----------------- predict noise and x0 ----------------- # + e_t = self.predict_noise( + x, t, cond, cfg_scale, uncond + ) + pred_x0: torch.Tensor = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_x0, x_t=x, t=index + ) + + # apply classifier guidance for multiple times + for _ in range(cond_fn.repeat): + # ----------------- compute gradient for x0 in latent space ----------------- # + target, pred = None, None + if cond_fn.space == "latent": + target = self.model.get_first_stage_encoding( + self.model.encode_first_stage(cond_fn.target.to(device)) + ) + pred = pred_x0 + elif cond_fn.space == "rgb": + # We need to backward gradient to x0 in latent space, so it's required + # to trace the computation graph while decoding the latent. + with torch.enable_grad(): + pred_x0.requires_grad_(True) + target = cond_fn.target.to(device) + pred = self.model.decode_first_stage_with_grad(pred_x0) + else: + raise NotImplementedError(cond_fn.space) + delta_pred = cond_fn(target, pred, t_now) + + # ----------------- apply classifier guidance ----------------- # + if delta_pred is not None: + if cond_fn.space == "rgb": + # compute gradient for pred_x0 + pred.backward(delta_pred) + delta_pred_x0 = pred_x0.grad + # update prex_x0 + pred_x0 += delta_pred_x0 + # our classifier guidance is equivalent to multiply delta_pred_x0 + # by a constant and then add it to model_mean, We set the constant + # to 0.5 + model_mean += 0.5 * delta_pred_x0 + pred_x0.grad.zero_() + else: + delta_pred_x0 = delta_pred + pred_x0 += delta_pred_x0 + model_mean += 0.5 * delta_pred_x0 + else: + # means stop guidance + break + + return model_mean.detach().clone(), pred_x0.detach().clone() + + @torch.no_grad() + def p_sample( + self, + x: torch.Tensor, + cond: Dict[str, torch.Tensor], + t: torch.Tensor, + index: torch.Tensor, + cfg_scale: float, + uncond: Optional[Dict[str, torch.Tensor]], + cond_fn: Optional[Guidance] + ) -> torch.Tensor: + # variance of posterior distribution q(x_{t-1}|x_t, x_0) + model_variance = { + "fixed_large": np.append(self.posterior_variance[1], self.betas[1:]), + "fixed_small": self.posterior_variance + }[self.var_type] + model_variance = _extract_into_tensor(model_variance, index, x.shape) + + # mean of posterior distribution q(x_{t-1}|x_t, x_0) + if cond_fn is not None: + # apply classifier guidance + model_mean, pred_x0 = self.apply_cond_fn( + x, cond, t, index, cond_fn, + cfg_scale, uncond + ) + else: + e_t = self.predict_noise( + x, t, cond, cfg_scale, uncond + ) + pred_x0 = self._predict_xstart_from_eps(x_t=x, t=index, eps=e_t) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_x0, x_t=x, t=index + ) + + # sample x_t from q(x_{t-1}|x_t, x_0) + noise = torch.randn_like(x) + nonzero_mask = ( + (index != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) + x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise + return x_prev + + @torch.no_grad() + def sample_with_mixdiff( + self, + tile_size: int, + tile_stride: int, + steps: int, + shape: Tuple[int], + cond_img: torch.Tensor, + positive_prompt: str, + negative_prompt: str, + x_T: Optional[torch.Tensor]=None, + cfg_scale: float=1., + cond_fn: Optional[Guidance]=None, + color_fix_type: str="none" + ) -> torch.Tensor: + def _sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]: + hi_list = list(range(0, h - tile_size + 1, tile_stride)) + if (h - tile_size) % tile_stride != 0: + hi_list.append(h - tile_size) + + wi_list = list(range(0, w - tile_size + 1, tile_stride)) + if (w - tile_size) % tile_stride != 0: + wi_list.append(w - tile_size) + + coords = [] + for hi in hi_list: + for wi in wi_list: + coords.append((hi, hi + tile_size, wi, wi + tile_size)) + return coords + + # make sampling parameters (e.g. sigmas) + self.make_schedule(num_steps=steps) + + device = next(self.model.parameters()).device + b, _, h, w = shape + if x_T is None: + img = torch.randn(shape, dtype=torch.float32, device=device) + else: + img = x_T + # create buffers for accumulating predicted noise of different diffusion process + noise_buffer = torch.zeros_like(img) + count = torch.zeros(shape, dtype=torch.long, device=device) + # timesteps iterator + time_range = np.flip(self.timesteps) # [1000, 950, 900, ...] + total_steps = len(self.timesteps) + iterator = tqdm(time_range, desc="Spaced Sampler", total=total_steps) + + # sampling loop + for i, step in enumerate(iterator): + ts = torch.full((b,), step, device=device, dtype=torch.long) + index = torch.full_like(ts, fill_value=total_steps - i - 1) + + # predict noise for each tile + tiles_iterator = tqdm(_sliding_windows(h, w, tile_size // 8, tile_stride // 8)) + for hi, hi_end, wi, wi_end in tiles_iterator: + tiles_iterator.set_description(f"Process tile with location ({hi} {hi_end}) ({wi} {wi_end})") + # noisy latent of this diffusion process (tile) at this step + tile_img = img[:, :, hi:hi_end, wi:wi_end] + # prepare condition for this tile + tile_cond_img = cond_img[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] + tile_cond = { + "c_latent": [self.model.apply_condition_encoder(tile_cond_img)], + "c_crossattn": [self.model.get_learned_conditioning([positive_prompt] * b)] + } + tile_uncond = { + "c_latent": [self.model.apply_condition_encoder(tile_cond_img)], + "c_crossattn": [self.model.get_learned_conditioning([negative_prompt] * b)] + } + # TODO: tile_cond_fn + + # predict noise for this tile + tile_noise = self.predict_noise(tile_img, ts, tile_cond, cfg_scale, tile_uncond) + + # accumulate mean and variance + noise_buffer[:, :, hi:hi_end, wi:wi_end] += tile_noise + count[:, :, hi:hi_end, wi:wi_end] += 1 + + if (count == 0).any().item(): + print(f"find count == 0!") + # average on noise + noise_buffer.div_(count) + # sample previous latent + pred_x0 = self._predict_xstart_from_eps(x_t=img, t=index, eps=noise_buffer) + mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_x0, x_t=img, t=index + ) + variance = { + "fixed_large": np.append(self.posterior_variance[1], self.betas[1:]), + "fixed_small": self.posterior_variance + }[self.var_type] + variance = _extract_into_tensor(variance, index, noise_buffer.shape) + + nonzero_mask = ( + (index != 0).float().view(-1, *([1] * (len(noise_buffer.shape) - 1))) + ) + img = mean + nonzero_mask * torch.sqrt(variance) * torch.randn_like(mean) + + noise_buffer.zero_() + count.zero_() + + # decode samples of each diffusion process + img_buffer = torch.zeros_like(cond_img) + count = torch.zeros_like(cond_img, dtype=torch.long) + for hi, hi_end, wi, wi_end in _sliding_windows(h, w, tile_size // 8, tile_stride // 8): + tile_img = img[:, :, hi:hi_end, wi:wi_end] + tile_img_pixel = (self.model.decode_first_stage(tile_img) + 1) / 2 + tile_cond_img = cond_img[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] + # apply color correction (borrowed from StableSR) + if color_fix_type == "adain": + tile_img_pixel = adaptive_instance_normalization(tile_img_pixel, tile_cond_img) + elif color_fix_type == "wavelet": + tile_img_pixel = wavelet_reconstruction(tile_img_pixel, tile_cond_img) + else: + assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}" + img_buffer[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] += tile_img_pixel + count[:, :, hi * 8:hi_end * 8, wi * 8: wi_end * 8] += 1 + img_buffer.div_(count) + + return img_buffer + + @torch.no_grad() + def sample( + self, + steps: int, + shape: Tuple[int], + cond_img: torch.Tensor, + positive_prompt: str, + negative_prompt: str, + x_T: Optional[torch.Tensor]=None, + cfg_scale: float=1., + cond_fn: Optional[Guidance]=None, + color_fix_type: str="none" + ) -> torch.Tensor: + self.make_schedule(num_steps=steps) + + device = next(self.model.parameters()).device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + time_range = np.flip(self.timesteps) # [1000, 950, 900, ...] + total_steps = len(self.timesteps) + iterator = tqdm(time_range, desc="Spaced Sampler", total=total_steps) + + cond = { + "c_latent": [self.model.apply_condition_encoder(cond_img)], + "c_crossattn": [self.model.get_learned_conditioning([positive_prompt] * b)] + } + uncond = { + "c_latent": [self.model.apply_condition_encoder(cond_img)], + "c_crossattn": [self.model.get_learned_conditioning([negative_prompt] * b)] + } + for i, step in enumerate(iterator): + ts = torch.full((b,), step, device=device, dtype=torch.long) + index = torch.full_like(ts, fill_value=total_steps - i - 1) + img = self.p_sample( + img, cond, ts, index=index, + cfg_scale=cfg_scale, uncond=uncond, + cond_fn=cond_fn + ) + + img_pixel = (self.model.decode_first_stage(img) + 1) / 2 + # apply color correction (borrowed from StableSR) + if color_fix_type == "adain": + img_pixel = adaptive_instance_normalization(img_pixel, cond_img) + elif color_fix_type == "wavelet": + img_pixel = wavelet_reconstruction(img_pixel, cond_img) + else: + assert color_fix_type == "none", f"unexpected color fix type: {color_fix_type}" + return img_pixel diff --git a/model/swinir.py b/model/swinir.py new file mode 100644 index 0000000000000000000000000000000000000000..08b13ee322c3da7477a0f2ce6447df45a4ef89af --- /dev/null +++ b/model/swinir.py @@ -0,0 +1,986 @@ +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +from typing import Any, Dict, Set + +import torch +import torch.nn as nn +from torch import optim +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +import pytorch_lightning as pl +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from einops import rearrange + +from utils.metrics import calculate_psnr_pt, LPIPS +from .mixins import ImageLoggerMixin + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(pl.LightningModule, ImageLoggerMixin): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + sf: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__( + self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=[6, 6, 6, 6], + num_heads=[6, 6, 6, 6], + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + sf=4, + img_range=1., + upsampler='', + resi_connection='1conv', + unshuffle=False, + unshuffle_scale=None, + hq_key: str="jpg", + lq_key: str="hint", + learning_rate: float=None, + weight_decay: float=None + ) -> "SwinIR": + super(SwinIR, self).__init__() + num_in_ch = in_chans * (unshuffle_scale**2) if unshuffle else in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = sf + self.upsampler = upsampler + self.window_size = window_size + self.unshuffle_scale = unshuffle_scale + self.unshuffle = unshuffle + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + if unshuffle: + assert unshuffle_scale is not None + self.conv_first = nn.Sequential( + nn.PixelUnshuffle(sf), + nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1), + ) + else: + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None + ) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None + ) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1) + ) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True) + ) + self.upsample = Upsample(sf, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep( + sf, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1]) + ) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True) + ) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + elif self.upscale == 8: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + self.hq_key = hq_key + self.lq_key = lq_key + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.lpips_metric = LPIPS(net="alex") + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + # TODO: What's this ? + @torch.jit.ignore + def no_weight_decay(self) -> Set[str]: + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self) -> Set[str]: + return {'relative_position_bias_table'} + + def check_image_size(self, x: torch.Tensor) -> torch.Tensor: + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + elif self.upscale == 8: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale] + + def flops(self) -> int: + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + def get_loss(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor: + """ + Compute loss between model predictions and labels. + + Args: + pred (torch.Tensor): Batch model predictions. + label (torch.Tensor): Batch labels. + + Returns: + loss (torch.Tensor): The loss tensor. + """ + return F.mse_loss(input=pred, target=label, reduction="sum") + + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT: + """ + Args: + batch (Dict[str, torch.Tensor]): A dict contains LQ and HQ (NHWC, RGB, + LQ range in [0, 1] and HQ range in [-1, 1]). + batch_idx (int): Index of this batch. + + Returns: + outputs (torch.Tensor): The loss tensor. + """ + hq, lq = batch[self.hq_key], batch[self.lq_key] + hq = rearrange(((hq + 1) / 2).clamp_(0, 1), "n h w c -> n c h w") + lq = rearrange(lq, "n h w c -> n c h w") + pred = self(lq) + loss = self.get_loss(pred, hq) + self.log("train_loss", loss, on_step=True) + return loss + + def on_validation_start(self) -> None: + self.lpips_metric.to(self.device) + + def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None: + hq, lq = batch[self.hq_key], batch[self.lq_key] + lq = rearrange(lq, "n h w c -> n c h w") + pred = self(lq) + hq = rearrange(((hq + 1) / 2).clamp_(0, 1), "n h w c -> n c h w") + + # requiremtns for lpips model inputs: + # https://github.com/richzhang/PerceptualSimilarity + lpips = self.lpips_metric(pred, hq, normalize=True).mean() + self.log("val_lpips", lpips) + + pnsr = calculate_psnr_pt(pred, hq, crop_border=0).mean() + self.log("val_pnsr", pnsr) + + loss = self.get_loss(pred, hq) + self.log("val_loss", loss) + + def configure_optimizers(self) -> optim.AdamW: + """ + Configure optimizer for this model. + + Returns: + optimizer (optim.AdamW): The optimizer for this model. + """ + optimizer = optim.AdamW( + [p for p in self.parameters() if p.requires_grad], lr=self.learning_rate, + weight_decay=self.weight_decay + ) + return optimizer + + @torch.no_grad() + def log_images(self, batch: Any) -> Dict[str, torch.Tensor]: + hq, lq = batch[self.hq_key], batch[self.lq_key] + hq = rearrange(((hq + 1) / 2).clamp_(0, 1), "n h w c -> n c h w") + lq = rearrange(lq, "n h w c -> n c h w") + pred = self(lq) + return dict(lq=lq, pred=pred, hq=hq) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..d1b37283835c0fc9c79f185bc2d9b384071c55b7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +--extra-index-url https://download.pytorch.org/whl/cu116 +torch==1.13.1+cu116 +torchvision==0.14.1+cu116 +torchaudio==0.13.1 +xformers==0.0.16 +pytorch_lightning==1.4.2 +einops +open-clip-torch +omegaconf +torchmetrics==0.6.0 +triton==2.0.0 +opencv-python-headless +scipy +matplotlib +lpips +gradio +chardet +transformers +facexlib +openxlab diff --git a/utils/common.py b/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..b9982ed81c859719dd0743faabdc3a8361f2bb74 --- /dev/null +++ b/utils/common.py @@ -0,0 +1,51 @@ +from typing import Mapping, Any +import importlib + +from torch import nn + + +def get_obj_from_str(string: str, reload: bool=False) -> object: + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config: Mapping[str, Any]) -> object: + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def disabled_train(self: nn.Module) -> nn.Module: + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def frozen_module(module: nn.Module) -> None: + module.eval() + module.train = disabled_train + for p in module.parameters(): + p.requires_grad = False + + +def load_state_dict(model: nn.Module, state_dict: Mapping[str, Any], strict: bool=False) -> None: + state_dict = state_dict.get("state_dict", state_dict) + + is_model_key_starts_with_module = list(model.state_dict().keys())[0].startswith("module.") + is_state_dict_key_starts_with_module = list(state_dict.keys())[0].startswith("module.") + + if ( + is_model_key_starts_with_module and + (not is_state_dict_key_starts_with_module) + ): + state_dict = {f"module.{key}": value for key, value in state_dict.items()} + if ( + (not is_model_key_starts_with_module) and + is_state_dict_key_starts_with_module + ): + state_dict = {key[len("module."):]: value for key, value in state_dict.items()} + + model.load_state_dict(state_dict, strict=strict) diff --git a/utils/degradation.py b/utils/degradation.py new file mode 100644 index 0000000000000000000000000000000000000000..aa75976523a36d0020ddbc0db9f0f0c3d753fde9 --- /dev/null +++ b/utils/degradation.py @@ -0,0 +1,765 @@ +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/data/degradations.py +import cv2 +import math +import numpy as np +import random +import torch +from scipy import special +from scipy.stats import multivariate_normal +from torchvision.transforms.functional_tensor import rgb_to_grayscale + +# -------------------------------------------------------------------- # +# --------------------------- blur kernels --------------------------- # +# -------------------------------------------------------------------- # + + +# --------------------------- util functions --------------------------- # +def sigma_matrix2(sig_x, sig_y, theta): + """Calculate the rotated sigma matrix (two dimensional matrix). + + Args: + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + + Returns: + ndarray: Rotated sigma matrix. + """ + d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) + u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) + return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + +def mesh_grid(kernel_size): + """Generate the mesh grid, centering at zero. + + Args: + kernel_size (int): + + Returns: + xy (ndarray): with the shape (kernel_size, kernel_size, 2) + xx (ndarray): with the shape (kernel_size, kernel_size) + yy (ndarray): with the shape (kernel_size, kernel_size) + """ + ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) + xx, yy = np.meshgrid(ax, ax) + xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, + 1))).reshape(kernel_size, kernel_size, 2) + return xy, xx, yy + + +def pdf2(sigma_matrix, grid): + """Calculate PDF of the bivariate Gaussian distribution. + + Args: + sigma_matrix (ndarray): with the shape (2, 2) + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + kernel (ndarrray): un-normalized kernel. + """ + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) + return kernel + + +def cdf2(d_matrix, grid): + """Calculate the CDF of the standard bivariate Gaussian distribution. + Used in skewed Gaussian distribution. + + Args: + d_matrix (ndarrasy): skew matrix. + grid (ndarray): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. + + Returns: + cdf (ndarray): skewed cdf. + """ + rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) + grid = np.dot(grid, d_matrix) + cdf = rv.cdf(grid) + return cdf + + +def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): + """Generate a bivariate isotropic or anisotropic Gaussian kernel. + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + isotropic (bool): + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + kernel = pdf2(sigma_matrix, grid) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a bivariate generalized Gaussian kernel. + + ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions`` + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + kernel = kernel / np.sum(kernel) + return kernel + + +def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): + """Generate a plateau-like anisotropic kernel. + + 1 / (1+x^(beta)) + + Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution + + In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. + + Args: + kernel_size (int): + sig_x (float): + sig_y (float): + theta (float): Radian measurement. + beta (float): shape parameter, beta = 1 is the normal distribution. + grid (ndarray, optional): generated by :func:`mesh_grid`, + with the shape (K, K, 2), K is the kernel size. Default: None + + Returns: + kernel (ndarray): normalized kernel. + """ + if grid is None: + grid, _, _ = mesh_grid(kernel_size) + if isotropic: + sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) + else: + sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) + inverse_sigma = np.linalg.inv(sigma_matrix) + kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_generalized_Gaussian(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate generalized Gaussian kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # assume beta_range[0] < 1 < beta_range[1] + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + return kernel + + +def random_bivariate_plateau(kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + beta_range, + noise_range=None, + isotropic=True): + """Randomly generate bivariate plateau kernels. + + In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. + + Args: + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi/2, math.pi/2] + beta_range (tuple): [1, 4] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' + sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) + if isotropic is False: + assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' + assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' + sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) + rotation = np.random.uniform(rotation_range[0], rotation_range[1]) + else: + sigma_y = sigma_x + rotation = 0 + + # TODO: this may be not proper + if np.random.uniform() < 0.5: + beta = np.random.uniform(beta_range[0], 1) + else: + beta = np.random.uniform(1, beta_range[1]) + + kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) + # add multiplicative noise + if noise_range is not None: + assert noise_range[0] < noise_range[1], 'Wrong noise range.' + noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) + kernel = kernel * noise + kernel = kernel / np.sum(kernel) + + return kernel + + +def random_mixed_kernels(kernel_list, + kernel_prob, + kernel_size=21, + sigma_x_range=(0.6, 5), + sigma_y_range=(0.6, 5), + rotation_range=(-math.pi, math.pi), + betag_range=(0.5, 8), + betap_range=(0.5, 8), + noise_range=None): + """Randomly generate mixed kernels. + + Args: + kernel_list (tuple): a list name of kernel types, + support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', + 'plateau_aniso'] + kernel_prob (tuple): corresponding kernel probability for each + kernel type + kernel_size (int): + sigma_x_range (tuple): [0.6, 5] + sigma_y_range (tuple): [0.6, 5] + rotation range (tuple): [-math.pi, math.pi] + beta_range (tuple): [0.5, 8] + noise_range(tuple, optional): multiplicative kernel noise, + [0.75, 1.25]. Default: None + + Returns: + kernel (ndarray): + """ + kernel_type = random.choices(kernel_list, kernel_prob)[0] + if kernel_type == 'iso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True) + elif kernel_type == 'aniso': + kernel = random_bivariate_Gaussian( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False) + elif kernel_type == 'generalized_iso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=True) + elif kernel_type == 'generalized_aniso': + kernel = random_bivariate_generalized_Gaussian( + kernel_size, + sigma_x_range, + sigma_y_range, + rotation_range, + betag_range, + noise_range=noise_range, + isotropic=False) + elif kernel_type == 'plateau_iso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True) + elif kernel_type == 'plateau_aniso': + kernel = random_bivariate_plateau( + kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False) + return kernel + + +np.seterr(divide='ignore', invalid='ignore') + + +def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0): + """2D sinc filter + + Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter + + Args: + cutoff (float): cutoff frequency in radians (pi is max) + kernel_size (int): horizontal and vertical size, must be odd. + pad_to (int): pad kernel size to desired size, must be odd or zero. + """ + assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' + kernel = np.fromfunction( + lambda x, y: cutoff * special.j1(cutoff * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt( + (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size]) + kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi) + kernel = kernel / np.sum(kernel) + if pad_to > kernel_size: + pad_size = (pad_to - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + return kernel + + +# ------------------------------------------------------------- # +# --------------------------- noise --------------------------- # +# ------------------------------------------------------------- # + +# ----------------------- Gaussian Noise ----------------------- # + + +def generate_gaussian_noise(img, sigma=10, gray_noise=False): + """Generate Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255. + noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) + else: + noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255. + return noise + + +def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False): + """Add Gaussian noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + sigma (float): Noise scale (measured in range 255). Default: 10. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_gaussian_noise(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if not isinstance(sigma, (float, int)): + sigma = sigma.view(img.size(0), 1, 1, 1) + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + + if cal_gray_noise: + noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255. + noise_gray = noise_gray.view(b, 1, h, w) + + # always calculate color noise + noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255. + + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + return noise + + +def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False): + """Add Gaussian noise (PyTorch version). + + Args: + img (Tensor): Shape (b, c, h, w), range[0, 1], float32. + scale (float | Tensor): Noise scale. Default: 1.0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_gaussian_noise_pt(img, sigma, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Gaussian Noise ----------------------- # +def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0): + sigma = np.random.uniform(sigma_range[0], sigma_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_gaussian_noise(img, sigma, gray_noise) + + +def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_gaussian_noise(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0): + sigma = torch.rand( + img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0] + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_gaussian_noise_pt(img, sigma, gray_noise) + + +def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Poisson (Shot) Noise ----------------------- # + + +def generate_poisson_noise(img, scale=1.0, gray_noise=False): + """Generate poisson noise. + + Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219 + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + if gray_noise: + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # round and clip image for counting vals correctly + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = len(np.unique(img)) + vals = 2**np.ceil(np.log2(vals)) + out = np.float32(np.random.poisson(img * vals) / float(vals)) + noise = out - img + if gray_noise: + noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2) + return noise * scale + + +def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False): + """Add poisson noise. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + scale (float): Noise scale. Default: 1.0. + gray_noise (bool): Whether generate gray noise. Default: False. + + Returns: + (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], + float32. + """ + noise = generate_poisson_noise(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): + """Generate a batch of poisson noise (PyTorch version) + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + b, _, h, w = img.size() + if isinstance(gray_noise, (float, int)): + cal_gray_noise = gray_noise > 0 + else: + gray_noise = gray_noise.view(b, 1, 1, 1) + cal_gray_noise = torch.sum(gray_noise) > 0 + if cal_gray_noise: + img_gray = rgb_to_grayscale(img, num_output_channels=1) + # round and clip image for counting vals correctly + img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img_gray * vals) / vals + noise_gray = out - img_gray + noise_gray = noise_gray.expand(b, 3, h, w) + + # always calculate color noise + # round and clip image for counting vals correctly + img = torch.clamp((img * 255.0).round(), 0, 255) / 255. + # use for-loop to get the unique values for each sample + vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] + vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] + vals = img.new_tensor(vals_list).view(b, 1, 1, 1) + out = torch.poisson(img * vals) / vals + noise = out - img + if cal_gray_noise: + noise = noise * (1 - gray_noise) + noise_gray * gray_noise + if not isinstance(scale, (float, int)): + scale = scale.view(b, 1, 1, 1) + return noise * scale + + +def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0): + """Add poisson noise to a batch of images (PyTorch version). + + Args: + img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. + scale (float | Tensor): Noise scale. Number or Tensor with shape (b). + Default: 1.0. + gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). + 0 for False, 1 for True. Default: 0. + + Returns: + (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], + float32. + """ + noise = generate_poisson_noise_pt(img, scale, gray_noise) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ----------------------- Random Poisson (Shot) Noise ----------------------- # + + +def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0): + scale = np.random.uniform(scale_range[0], scale_range[1]) + if np.random.uniform() < gray_prob: + gray_noise = True + else: + gray_noise = False + return generate_poisson_noise(img, scale, gray_noise) + + +def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = np.clip((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = np.clip(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0): + scale = torch.rand( + img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0] + gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) + gray_noise = (gray_noise < gray_prob).float() + return generate_poisson_noise_pt(img, scale, gray_noise) + + +def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): + noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob) + out = img + noise + if clip and rounds: + out = torch.clamp((out * 255.0).round(), 0, 255) / 255. + elif clip: + out = torch.clamp(out, 0, 1) + elif rounds: + out = (out * 255.0).round() / 255. + return out + + +# ------------------------------------------------------------------------ # +# --------------------------- JPEG compression --------------------------- # +# ------------------------------------------------------------------------ # + + +def add_jpg_compression(img, quality=90): + """Add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality (float): JPG compression quality. 0 for lowest quality, 100 for + best quality. Default: 90. + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + img = np.clip(img, 0, 1) + encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] + _, encimg = cv2.imencode('.jpg', img * 255., encode_param) + img = np.float32(cv2.imdecode(encimg, 1)) / 255. + return img + + +def random_add_jpg_compression(img, quality_range=(90, 100)): + """Randomly add JPG compression artifacts. + + Args: + img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. + quality_range (tuple[float] | list[float]): JPG compression quality + range. 0 for lowest quality, 100 for best quality. + Default: (90, 100). + + Returns: + (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], + float32. + """ + quality = np.random.uniform(quality_range[0], quality_range[1]) + return add_jpg_compression(img, int(quality)) diff --git a/utils/face_restoration_helper.py b/utils/face_restoration_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..0a22f25db66c9d971174df0268ded176a45f7834 --- /dev/null +++ b/utils/face_restoration_helper.py @@ -0,0 +1,379 @@ +import cv2 +import numpy as np +import os +import torch +from torchvision.transforms.functional import normalize + +from facexlib.detection import init_detection_model +from facexlib.parsing import init_parsing_model +from facexlib.utils.misc import img2tensor, imwrite + +''' +Stolen from Facexlib.utils.face_restoration_helper. +Modified to read PIL Images. +''' + + +def get_largest_face(det_faces, h, w): + + def get_location(val, length): + if val < 0: + return 0 + elif val > length: + return length + else: + return val + + face_areas = [] + for det_face in det_faces: + left = get_location(det_face[0], w) + right = get_location(det_face[2], w) + top = get_location(det_face[1], h) + bottom = get_location(det_face[3], h) + face_area = (right - left) * (bottom - top) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + return det_faces[largest_idx], largest_idx + + +def get_center_face(det_faces, h=0, w=0, center=None): + if center is not None: + center = np.array(center) + else: + center = np.array([w / 2, h / 2]) + center_dist = [] + for det_face in det_faces: + face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) + dist = np.linalg.norm(face_center - center) + center_dist.append(dist) + center_idx = center_dist.index(min(center_dist)) + return det_faces[center_idx], center_idx + + +class FaceRestoreHelper(object): + """Helper for the face restoration pipeline (base class).""" + + def __init__(self, + upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + template_3points=False, + pad_blur=False, + use_parse=False, + device=None, + model_rootpath=None): + self.template_3points = template_3points # improve robustness + self.upscale_factor = upscale_factor + # the cropped face ratio based on the square face + self.crop_ratio = crop_ratio # (h, w) + assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1' + self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) + + if self.template_3points: + self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) + else: + # standard 5 landmarks for FFHQ faces with 512 x 512 + self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], + [201.26117, 371.41043], [313.08905, 371.15118]]) + self.face_template = self.face_template * (face_size / 512.0) + if self.crop_ratio[0] > 1: + self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 + if self.crop_ratio[1] > 1: + self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 + self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False + + self.all_landmarks_5 = [] + self.det_faces = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.pad_input_imgs = [] + + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + + # init face detection model + self.face_det = init_detection_model(det_model, half=False, device=self.device, model_rootpath=model_rootpath) + + # init face parsing model + self.use_parse = use_parse + self.face_parse = init_parsing_model(model_name='parsenet', device=self.device, model_rootpath=model_rootpath) + + def set_upscale_factor(self, upscale_factor): + self.upscale_factor = upscale_factor + + def read_image(self, img): + """img can be image path or cv2 loaded image.""" + # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] + if isinstance(img, str): + img = cv2.imread(img) + + if np.max(img) > 256: # 16-bit image + img = img / 65535 * 255 + if len(img.shape) == 2: # gray image + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: # RGBA image with alpha channel + img = img[:, :, 0:3] + + self.input_img = img + + def get_face_landmarks_5(self, + only_keep_largest=False, + only_center_face=False, + resize=None, + blur_ratio=0.01, + eye_dist_threshold=None): + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = min(h, w) / resize + h, w = int(h / scale), int(w / scale) + input_img = cv2.resize(self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4) + + with torch.no_grad(): + bboxes = self.face_det.detect_faces(input_img, 0.97) * scale + for bbox in bboxes: + # remove faces with too small eye distance: side faces or too small faces + eye_dist = np.linalg.norm([bbox[5] - bbox[7], bbox[6] - bbox[8]]) + if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): + continue + + if self.template_3points: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) + else: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) + self.all_landmarks_5.append(landmark) + self.det_faces.append(bbox[0:5]) + if len(self.det_faces) == 0: + return 0 + if only_keep_largest: + h, w, _ = self.input_img.shape + self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] + elif only_center_face: + h, w, _ = self.input_img.shape + self.det_faces, center_idx = get_center_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1) + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype('float32') + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + + return len(self.all_landmarks_5) + + def align_warp_face(self, save_cropped_path=None, border_mode='constant'): + """Align and warp faces with face template. + """ + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + if border_mode == 'constant': + border_mode = cv2.BORDER_CONSTANT + elif border_mode == 'reflect101': + border_mode = cv2.BORDER_REFLECT101 + elif border_mode == 'reflect': + border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img + cropped_face = cv2.warpAffine( + input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path = os.path.splitext(save_cropped_path)[0] + save_path = f'{path}_{idx:02d}.{self.save_ext}' + imwrite(cropped_face, save_path) + + def get_inverse_affine(self, save_inverse_affine_path=None): + """Get inverse affine matrix.""" + for idx, affine_matrix in enumerate(self.affine_matrices): + inverse_affine = cv2.invertAffineTransform(affine_matrix) + inverse_affine *= self.upscale_factor + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + def add_restored_face(self, face): + self.restored_faces.append(face) + + def paste_faces_to_input_image(self, save_path=None, upsample_img=None): + h, w, _ = self.input_img.shape + h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) + + if upsample_img is None: + # simply resize the background + upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + else: + upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + + assert len(self.restored_faces) == len( + self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') + for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): + # Add an offset to inverse affine matrix, for more precise back alignment + if self.upscale_factor > 1: + extra_offset = 0.5 * self.upscale_factor + else: + extra_offset = 0 + inverse_affine[:, 2] += extra_offset + inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) + + if self.use_parse: + # inference + face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) + face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True) + normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_input = torch.unsqueeze(face_input, 0).to(self.device) + with torch.no_grad(): + out = self.face_parse(face_input)[0] + out = out.argmax(dim=1).squeeze().cpu().numpy() + + mask = np.zeros(out.shape) + MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] + for idx, color in enumerate(MASK_COLORMAP): + mask[out == idx] = color + # blur the mask + mask = cv2.GaussianBlur(mask, (101, 101), 11) + mask = cv2.GaussianBlur(mask, (101, 101), 11) + # remove the black borders + thres = 10 + mask[:thres, :] = 0 + mask[-thres:, :] = 0 + mask[:, :thres] = 0 + mask[:, -thres:] = 0 + mask = mask / 255. + + mask = cv2.resize(mask, restored_face.shape[:2]) + mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3) + inv_soft_mask = mask[:, :, None] + pasted_face = inv_restored + + else: # use square parse maps + mask = np.ones(self.face_size, dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + pasted_face = inv_mask_erosion[:, :, None] * inv_restored + total_face_area = np.sum(inv_mask_erosion) # // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + if len(upsample_img.shape) == 2: # upsample_img is gray image + upsample_img = upsample_img[:, :, None] + inv_soft_mask = inv_soft_mask[:, :, None] + + if len(upsample_img.shape) == 3 and upsample_img.shape[2] == 4: # alpha channel + alpha = upsample_img[:, :, 3:] + upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img[:, :, 0:3] + upsample_img = np.concatenate((upsample_img, alpha), axis=2) + else: + upsample_img = inv_soft_mask * pasted_face + (1 - inv_soft_mask) * upsample_img + + if np.max(upsample_img) > 256: # 16-bit image + upsample_img = upsample_img.astype(np.uint16) + else: + upsample_img = upsample_img.astype(np.uint8) + if save_path is not None: + path = os.path.splitext(save_path)[0] + save_path = f'{path}.{self.save_ext}' + imwrite(upsample_img, save_path) + return upsample_img + + def clean_all(self): + self.all_landmarks_5 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] + self.det_faces = [] + self.pad_input_imgs = [] diff --git a/utils/file.py b/utils/file.py new file mode 100644 index 0000000000000000000000000000000000000000..9d4a3958eef6766dd2061ac31c67819eed6cce53 --- /dev/null +++ b/utils/file.py @@ -0,0 +1,43 @@ +import os +from typing import List, Tuple + + +def load_file_list(file_list_path: str) -> List[str]: + files = [] + # each line in file list contains a path of an image + with open(file_list_path, "r") as fin: + for line in fin: + path = line.strip() + if path: + files.append(path) + return files + + +def list_image_files( + img_dir: str, + exts: Tuple[str]=(".jpg", ".png", ".jpeg"), + follow_links: bool=False, + log_progress: bool=False, + log_every_n_files: int=10000, + max_size: int=-1 +) -> List[str]: + files = [] + for dir_path, _, file_names in os.walk(img_dir, followlinks=follow_links): + early_stop = False + for file_name in file_names: + if os.path.splitext(file_name)[1].lower() in exts: + if max_size >= 0 and len(files) >= max_size: + early_stop = True + break + files.append(os.path.join(dir_path, file_name)) + if log_progress and len(files) % log_every_n_files == 0: + print(f"find {len(files)} images in {img_dir}") + if early_stop: + break + return files + + +def get_file_name_parts(file_path: str) -> Tuple[str, str, str]: + parent_path, file_name = os.path.split(file_path) + stem, ext = os.path.splitext(file_name) + return parent_path, stem, ext diff --git a/utils/image/__init__.py b/utils/image/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbe064d019e3f4ddc9f041d80ed4e23cbf9d643 --- /dev/null +++ b/utils/image/__init__.py @@ -0,0 +1,26 @@ +from .diffjpeg import DiffJPEG +from .usm_sharp import USMSharp +from .common import ( + random_crop_arr, center_crop_arr, augment, + filter2D, rgb2ycbcr_pt, auto_resize, pad +) +from .align_color import ( + wavelet_reconstruction, adaptive_instance_normalization +) + +__all__ = [ + "DiffJPEG", + + "USMSharp", + + "random_crop_arr", + "center_crop_arr", + "augment", + "filter2D", + "rgb2ycbcr_pt", + "auto_resize", + "pad", + + "wavelet_reconstruction", + "adaptive_instance_normalization" +] diff --git a/utils/image/align_color.py b/utils/image/align_color.py new file mode 100644 index 0000000000000000000000000000000000000000..36da591b240b1daefbcbc3ca81c202e31c74b650 --- /dev/null +++ b/utils/image/align_color.py @@ -0,0 +1,119 @@ +''' +# -------------------------------------------------------------------------------- +# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) +# -------------------------------------------------------------------------------- +''' + +import torch +from PIL import Image +from torch import Tensor +from torch.nn import functional as F +from torchvision.transforms import ToTensor, ToPILImage + + +def adain_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply adaptive instance normalization + result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def wavelet_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply wavelet reconstruction + result_tensor = wavelet_reconstruction(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def calc_mean_std(feat: Tensor, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.reshape(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().reshape(b, c, 1, 1) + feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): + """Adaptive instance normalization. + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def wavelet_blur(image: Tensor, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq \ No newline at end of file diff --git a/utils/image/common.py b/utils/image/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8ca3ce954778f0554968595446dcbe2a3455bb --- /dev/null +++ b/utils/image/common.py @@ -0,0 +1,235 @@ +import random +import math + +from PIL import Image +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py +def center_crop_arr(pil_image, image_size): + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + while min(*pil_image.size) >= 2 * image_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = image_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = (arr.shape[0] - image_size) // 2 + crop_x = (arr.shape[1] - image_size) // 2 + return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] + + +# https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/image_datasets.py +def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): + min_smaller_dim_size = math.ceil(image_size / max_crop_frac) + max_smaller_dim_size = math.ceil(image_size / min_crop_frac) + smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) + + # We are not on a new enough PIL to support the `reducing_gap` + # argument, which uses BOX downsampling at powers of two first. + # Thus, we do it by hand to improve downsample quality. + while min(*pil_image.size) >= 2 * smaller_dim_size: + pil_image = pil_image.resize( + tuple(x // 2 for x in pil_image.size), resample=Image.BOX + ) + + scale = smaller_dim_size / min(*pil_image.size) + pil_image = pil_image.resize( + tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC + ) + + arr = np.array(pil_image) + crop_y = random.randrange(arr.shape[0] - image_size + 1) + crop_x = random.randrange(arr.shape[1] - image_size + 1) + return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] + + +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/data/transforms.py +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/img_process_util.py +def filter2D(img, kernel): + """PyTorch version of cv2.filter2D + + Args: + img (Tensor): (b, c, h, w) + kernel (Tensor): (b, k, k) + """ + k = kernel.size(-1) + b, c, h, w = img.size() + if k % 2 == 1: + img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') + else: + raise ValueError('Wrong kernel size') + + ph, pw = img.size()[-2:] + + if kernel.size(0) == 1: + # apply the same kernel to all batch images + img = img.view(b * c, 1, ph, pw) + kernel = kernel.view(1, 1, k, k) + return F.conv2d(img, kernel, padding=0).view(b, c, h, w) + else: + img = img.view(1, b * c, ph, pw) + kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) + return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) + + +# https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/utils/color_util.py#L186 +def rgb2ycbcr_pt(img, y_only=False): + """Convert RGB images to YCbCr images (PyTorch version). + + It implements the ITU-R BT.601 conversion for standard-definition television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + Args: + img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. + """ + if y_only: + weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 + else: + weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) + bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) + out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias + + out_img = out_img / 255. + return out_img + + +def to_pil_image(inputs, mem_order, val_range, channel_order): + # convert inputs to numpy array + if isinstance(inputs, torch.Tensor): + inputs = inputs.cpu().numpy() + assert isinstance(inputs, np.ndarray) + + # make sure that inputs is a 4-dimension array + if mem_order in ["hwc", "chw"]: + inputs = inputs[None, ...] + mem_order = f"n{mem_order}" + # to NHWC + if mem_order == "nchw": + inputs = inputs.transpose(0, 2, 3, 1) + # to RGB + if channel_order == "bgr": + inputs = inputs[..., ::-1].copy() + else: + assert channel_order == "rgb" + + if val_range == "0,1": + inputs = inputs * 255 + elif val_range == "-1,1": + inputs = (inputs + 1) * 127.5 + else: + assert val_range == "0,255" + + inputs = inputs.clip(0, 255).astype(np.uint8) + return [inputs[i] for i in range(len(inputs))] + + +def put_text(pil_img_arr, text): + cv_img = pil_img_arr[..., ::-1].copy() + cv2.putText(cv_img, text, (10, 35), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + return cv_img[..., ::-1].copy() + + +def auto_resize(img: Image.Image, size: int) -> Image.Image: + short_edge = min(img.size) + if short_edge < size: + r = size / short_edge + img = img.resize( + tuple(math.ceil(x * r) for x in img.size), Image.BICUBIC + ) + else: + # make a deep copy of this image for safety + img = img.copy() + return img + + +def pad(img: np.ndarray, scale: int) -> np.ndarray: + h, w = img.shape[:2] + ph = 0 if h % scale == 0 else math.ceil(h / scale) * scale - h + pw = 0 if w % scale == 0 else math.ceil(w / scale) * scale - w + return np.pad( + img, pad_width=((0, ph), (0, pw), (0, 0)), mode="constant", + constant_values=0 + ) diff --git a/utils/image/diffjpeg.py b/utils/image/diffjpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..339282170b3a58e3bc9764c3cec0c4cfc9cd9816 --- /dev/null +++ b/utils/image/diffjpeg.py @@ -0,0 +1,492 @@ +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/diffjpeg.py +""" +Modified from https://github.com/mlomnitz/DiffJPEG + +For images not divisible by 8 +https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343 +""" +import itertools +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F + +# ------------------------ utils ------------------------# +y_table = np.array( + [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], + [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], + [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], + dtype=np.float32).T +y_table = nn.Parameter(torch.from_numpy(y_table)) +c_table = np.empty((8, 8), dtype=np.float32) +c_table.fill(99) +c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T +c_table = nn.Parameter(torch.from_numpy(c_table)) + + +def diff_round(x): + """ Differentiable rounding function + """ + return torch.round(x) + (x - torch.round(x))**3 + + +def quality_to_factor(quality): + """ Calculate factor corresponding to quality + + Args: + quality(float): Quality for jpeg compression. + + Returns: + float: Compression factor. + """ + if quality < 50: + quality = 5000. / quality + else: + quality = 200. - quality * 2 + return quality / 100. + + +# ------------------------ compression ------------------------# +class RGB2YCbCrJpeg(nn.Module): + """ Converts RGB image to YCbCr + """ + + def __init__(self): + super(RGB2YCbCrJpeg, self).__init__() + matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], + dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(Tensor): batch x 3 x height x width + + Returns: + Tensor: batch x height x width x 3 + """ + image = image.permute(0, 2, 3, 1) + result = torch.tensordot(image, self.matrix, dims=1) + self.shift + return result.view(image.shape) + + +class ChromaSubsampling(nn.Module): + """ Chroma subsampling on CbCr channels + """ + + def __init__(self): + super(ChromaSubsampling, self).__init__() + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + y(tensor): batch x height x width + cb(tensor): batch x height/2 x width/2 + cr(tensor): batch x height/2 x width/2 + """ + image_2 = image.permute(0, 3, 1, 2).clone() + cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) + cb = cb.permute(0, 2, 3, 1) + cr = cr.permute(0, 2, 3, 1) + return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) + + +class BlockSplitting(nn.Module): + """ Splitting image into patches + """ + + def __init__(self): + super(BlockSplitting, self).__init__() + self.k = 8 + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x h*w/64 x h x w + """ + height, _ = image.shape[1:3] + batch_size = image.shape[0] + image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) + + +class DCT8x8(nn.Module): + """ Discrete Cosine Transformation + """ + + def __init__(self): + super(DCT8x8, self).__init__() + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16) + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image - 128 + result = self.scale * torch.tensordot(image, self.tensor, dims=2) + result.view(image.shape) + return result + + +class YQuantize(nn.Module): + """ JPEG Quantization for Y channel + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(YQuantize, self).__init__() + self.rounding = rounding + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CQuantize(nn.Module): + """ JPEG Quantization for CbCr channels + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding): + super(CQuantize, self).__init__() + self.rounding = rounding + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + image = image.float() / (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + image = image.float() / table + image = self.rounding(image) + return image + + +class CompressJpeg(nn.Module): + """Full JPEG compression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(CompressJpeg, self).__init__() + self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling()) + self.l2 = nn.Sequential(BlockSplitting(), DCT8x8()) + self.c_quantize = CQuantize(rounding=rounding) + self.y_quantize = YQuantize(rounding=rounding) + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x 3 x height x width + + Returns: + dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8. + """ + y, cb, cr = self.l1(image * 255) + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + comp = self.l2(components[k]) + if k in ('cb', 'cr'): + comp = self.c_quantize(comp, factor=factor) + else: + comp = self.y_quantize(comp, factor=factor) + + components[k] = comp + + return components['y'], components['cb'], components['cr'] + + +# ------------------------ decompression ------------------------# + + +class YDequantize(nn.Module): + """Dequantize Y channel + """ + + def __init__(self): + super(YDequantize, self).__init__() + self.y_table = y_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.y_table * factor) + else: + b = factor.size(0) + table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class CDequantize(nn.Module): + """Dequantize CbCr channel + """ + + def __init__(self): + super(CDequantize, self).__init__() + self.c_table = c_table + + def forward(self, image, factor=1): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + if isinstance(factor, (int, float)): + out = image * (self.c_table * factor) + else: + b = factor.size(0) + table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) + out = image * table + return out + + +class iDCT8x8(nn.Module): + """Inverse discrete Cosine Transformation + """ + + def __init__(self): + super(iDCT8x8, self).__init__() + alpha = np.array([1. / np.sqrt(2)] + [1] * 7) + self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) + tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) + for x, y, u, v in itertools.product(range(8), repeat=4): + tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16) + self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width + + Returns: + Tensor: batch x height x width + """ + image = image * self.alpha + result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 + result.view(image.shape) + return result + + +class BlockMerging(nn.Module): + """Merge patches into image + """ + + def __init__(self): + super(BlockMerging, self).__init__() + + def forward(self, patches, height, width): + """ + Args: + patches(tensor) batch x height*width/64, height x width + height(int) + width(int) + + Returns: + Tensor: batch x height x width + """ + k = 8 + batch_size = patches.shape[0] + image_reshaped = patches.view(batch_size, height // k, width // k, k, k) + image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) + return image_transposed.contiguous().view(batch_size, height, width) + + +class ChromaUpsampling(nn.Module): + """Upsample chroma layers + """ + + def __init__(self): + super(ChromaUpsampling, self).__init__() + + def forward(self, y, cb, cr): + """ + Args: + y(tensor): y channel image + cb(tensor): cb channel + cr(tensor): cr channel + + Returns: + Tensor: batch x height x width x 3 + """ + + def repeat(x, k=2): + height, width = x.shape[1:3] + x = x.unsqueeze(-1) + x = x.repeat(1, 1, k, k) + x = x.view(-1, height * k, width * k) + return x + + cb = repeat(cb) + cr = repeat(cr) + return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) + + +class YCbCr2RGBJpeg(nn.Module): + """Converts YCbCr image to RGB JPEG + """ + + def __init__(self): + super(YCbCr2RGBJpeg, self).__init__() + + matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T + self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) + self.matrix = nn.Parameter(torch.from_numpy(matrix)) + + def forward(self, image): + """ + Args: + image(tensor): batch x height x width x 3 + + Returns: + Tensor: batch x 3 x height x width + """ + result = torch.tensordot(image + self.shift, self.matrix, dims=1) + return result.view(image.shape).permute(0, 3, 1, 2) + + +class DeCompressJpeg(nn.Module): + """Full JPEG decompression algorithm + + Args: + rounding(function): rounding function to use + """ + + def __init__(self, rounding=torch.round): + super(DeCompressJpeg, self).__init__() + self.c_dequantize = CDequantize() + self.y_dequantize = YDequantize() + self.idct = iDCT8x8() + self.merging = BlockMerging() + self.chroma = ChromaUpsampling() + self.colors = YCbCr2RGBJpeg() + + def forward(self, y, cb, cr, imgh, imgw, factor=1): + """ + Args: + compressed(dict(tensor)): batch x h*w/64 x 8 x 8 + imgh(int) + imgw(int) + factor(float) + + Returns: + Tensor: batch x 3 x height x width + """ + components = {'y': y, 'cb': cb, 'cr': cr} + for k in components.keys(): + if k in ('cb', 'cr'): + comp = self.c_dequantize(components[k], factor=factor) + height, width = int(imgh / 2), int(imgw / 2) + else: + comp = self.y_dequantize(components[k], factor=factor) + height, width = imgh, imgw + comp = self.idct(comp) + components[k] = self.merging(comp, height, width) + # + image = self.chroma(components['y'], components['cb'], components['cr']) + image = self.colors(image) + + image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)) + return image / 255 + + +# ------------------------ main DiffJPEG ------------------------ # + + +class DiffJPEG(nn.Module): + """This JPEG algorithm result is slightly different from cv2. + DiffJPEG supports batch processing. + + Args: + differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round + """ + + def __init__(self, differentiable=True): + super(DiffJPEG, self).__init__() + if differentiable: + rounding = diff_round + else: + rounding = torch.round + + self.compress = CompressJpeg(rounding=rounding) + self.decompress = DeCompressJpeg(rounding=rounding) + + def forward(self, x, quality): + """ + Args: + x (Tensor): Input image, bchw, rgb, [0, 1] + quality(float): Quality factor for jpeg compression scheme. + """ + factor = quality + if isinstance(factor, (int, float)): + factor = quality_to_factor(factor) + else: + for i in range(factor.size(0)): + factor[i] = quality_to_factor(factor[i]) + h, w = x.size()[-2:] + h_pad, w_pad = 0, 0 + # why should use 16 + if h % 16 != 0: + h_pad = 16 - h % 16 + if w % 16 != 0: + w_pad = 16 - w % 16 + x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) + + y, cb, cr = self.compress(x, factor=factor) + recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) + recovered = recovered[:, :, 0:h, 0:w] + return recovered diff --git a/utils/image/usm_sharp.py b/utils/image/usm_sharp.py new file mode 100644 index 0000000000000000000000000000000000000000..7b83532e89204e8f88cc0e120e7d9021b489be90 --- /dev/null +++ b/utils/image/usm_sharp.py @@ -0,0 +1,29 @@ +# https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/img_process_util.py +import cv2 +import numpy as np +import torch + +from .common import filter2D + + +class USMSharp(torch.nn.Module): + + def __init__(self, radius=50, sigma=0): + super(USMSharp, self).__init__() + if radius % 2 == 0: + radius += 1 + self.radius = radius + kernel = cv2.getGaussianKernel(radius, sigma) + kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) + self.register_buffer('kernel', kernel) + + def forward(self, img, weight=0.5, threshold=10): + blur = filter2D(img, self.kernel) + residual = img - blur + + mask = torch.abs(residual) * 255 > threshold + mask = mask.float() + soft_mask = filter2D(mask, self.kernel) + sharp = img + weight * residual + sharp = torch.clip(sharp, 0, 1) + return soft_mask * sharp + (1 - soft_mask) * img diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e1c3e4961d92e0d4e42522fefcddf401270a8f26 --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,66 @@ +import torch +import lpips + +from .image import rgb2ycbcr_pt +from .common import frozen_module + + +# https://github.com/XPixelGroup/BasicSR/blob/033cd6896d898fdd3dcda32e3102a792efa1b8f4/basicsr/metrics/psnr_ssim.py#L52 +def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version). + + Reference: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w). + crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: PSNR result. + """ + + assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.') + + if crop_border != 0: + img = img[:, :, crop_border:-crop_border, crop_border:-crop_border] + img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border] + + if test_y_channel: + img = rgb2ycbcr_pt(img, y_only=True) + img2 = rgb2ycbcr_pt(img2, y_only=True) + + img = img.to(torch.float64) + img2 = img2.to(torch.float64) + + mse = torch.mean((img - img2)**2, dim=[1, 2, 3]) + return 10. * torch.log10(1. / (mse + 1e-8)) + + +class LPIPS: + + def __init__(self, net: str) -> None: + self.model = lpips.LPIPS(net=net) + frozen_module(self.model) + + @torch.no_grad() + def __call__(self, img1: torch.Tensor, img2: torch.Tensor, normalize: bool) -> torch.Tensor: + """ + Compute LPIPS. + + Args: + img1 (torch.Tensor): The first image (NCHW, RGB, [-1, 1]). Specify `normalize` if input + image is range in [0, 1]. + img2 (torch.Tensor): The second image (NCHW, RGB, [-1, 1]). Specify `normalize` if input + image is range in [0, 1]. + normalize (bool): If specified, the input images will be normalized from [0, 1] to [-1, 1]. + + Returns: + lpips_values (torch.Tensor): The lpips scores of this batch. + """ + return self.model(img1, img2, normalize=normalize) + + def to(self, device: str) -> "LPIPS": + self.model.to(device) + return self