diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ed8ebf583f771da9150c35db3955987b7d757904 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..87e71e24be20c0b4ed3b26fa45612883492d4862 --- /dev/null +++ b/app.py @@ -0,0 +1,395 @@ +import spaces +import os +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +from functools import partial +import gradio as gr +from huggingface_hub import hf_hub_download + +import torch +from torchvision import transforms +import rembg +import cv2 + +from src.visualizer import CameraVisualizer +from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs +from src.pose_funcs import find_optimal_poses +from src.utils import spherical_to_cartesian, elu_to_c2w + +if torch.cuda.is_available(): + _device_ = 'cuda:0' +else: + _device_ = 'cpu' + +_config_path_ = 'src/configs/sd-objaverse-finetune-c_concat-256.yaml' + +_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/zero123-xl.ckpt', repo_type='model') +_matcher_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/indoor_ds_new.ckpt', repo_type='model') + +_config_ = OmegaConf.load(_config_path_) +_model_ = load_model_from_config(_config_, _ckpt_path_, device='cpu') +_model_ = _model_.to(_device_) +_model_.eval() + + +def rgba_to_rgb(img): + + assert img.mode == 'RGBA' + + img = np.asarray(img, dtype=np.float32) + img[:, :, :3] = img[:, :, :3] * (img[..., 3:]/255.) + (255-img[..., 3:]) + img = img.clip(0, 255).astype(np.uint8) + return Image.fromarray(img[:, :, :3]) + + +def remove_background(image, rembg_session = None, force = False, **rembg_kwargs): + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + + +def group_recenter(images, ratio=1.5, mask_thres=127, bkg_color=[255, 255, 255, 255]): + + ws = [] + hs = [] + + images = [ np.asarray(img) for img in images ] + + for img in images: + + alpha = img[:, :, 3] + + yy, xx = np.where(alpha > mask_thres) + y0, y1 = yy.min(), yy.max() + x0, x1 = xx.min(), xx.max() + + ws.append(x1 - x0) + hs.append(y1 - y0) + + sz_w = np.max(ws) + sz_h = np.max(hs) + + sz = int( max(ratio*sz_w, ratio*sz_h) ) + + out_rgbs = [] + + for rgba in images: + + rgb = rgba[:, :, :3] + alpha = rgba[:, :, 3] + + yy, xx = np.where(alpha > mask_thres) + y0, y1 = yy.min(), yy.max() + x0, x1 = xx.min(), xx.max() + + height, width, chn = rgb.shape + + cy = (y0 + y1) // 2 + cx = (x0 + x1) // 2 + + y0 = cy - int(np.floor(sz / 2)) + y1 = cy + int(np.ceil(sz / 2)) + x0 = cx - int(np.floor(sz / 2)) + x1 = cx + int(np.ceil(sz / 2)) + out = rgba[ max(y0, 0) : min(y1, height) , max(x0, 0) : min(x1, width), : ].copy() + pads = [(max(0-y0, 0), max(y1-height, 0)), (max(0-x0, 0), max(x1-width, 0)), (0, 0)] + out = np.pad(out, pads, mode='constant', constant_values=0) + + assert(out.shape[:2] == (sz, sz)) + + out[:, :, :3] = out[:, :, :3] * (out[..., 3:]/255.) + np.array(bkg_color)[None, None, :3] * (1-out[..., 3:]/255.) + out[:, :, -1] = bkg_color[-1] + + out = cv2.resize(out.astype(np.uint8), (256, 256)) + out = out[:, :, :3] + + out_rgbs.append(out) + + return out_rgbs + + +def run_preprocess(image1, image2, preprocess_chk): + + if preprocess_chk: + rembg_session = rembg.new_session() + image1 = remove_background(image1, force=True, rembg_session = rembg_session) + image2 = remove_background(image2, force=True, rembg_session = rembg_session) + + rgbs = group_recenter([image1, image2]) + + image1 = Image.fromarray(rgbs[0]) + image2 = Image.fromarray(rgbs[1]) + + return image1, image2 + + +def image_to_tensor(img, width=256, height=256): + + img = transforms.ToTensor()(img).unsqueeze(0) + img = img * 2 - 1 + img = transforms.functional.resize(img, [height, width]) + + return img + + +@spaces.GPU +def run_pose_exploration_a(cam_vis, image1, image2): + + image1 = image_to_tensor(image1).to(_device_) + image2 = image_to_tensor(image2).to(_device_) + + images = [image1, image2] + + elevs, elev_ranges = estimate_elevs( + _model_, images, + est_type='all', + matcher_ckpt_path=_matcher_ckpt_path_ + ) + + fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) + + return elevs, elev_ranges, fig + + +@spaces.GPU +def run_pose_exploration_b(cam_vis, image1, image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters): + + noise = np.random.randn(probe_bsz, 4, 32, 32) + + cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) + + image1 = image_to_tensor(image1).to(_device_) + image2 = image_to_tensor(image2).to(_device_) + + images = [image1, image2] + result_poses, aux_data = estimate_poses( + _model_, images, + seed_cand_num=8, + init_type='triangular', + optm_type='triangular', + probe_ts_range=[0.2, 0.21], + ts_range=[0.2, 0.21], + probe_bsz=probe_bsz, + adjust_factor=10.0, + adjust_iters=adj_iters, + adjust_bsz=adj_bsz, + refine_factor=1.0, + refine_iters=0, + refine_bsz=4, + noise=noise, + elevs=elevs, + elev_ranges=elev_ranges + ) + + theta, azimuth, radius = result_poses[0] + anchor_polar = aux_data['elev'][0] + if anchor_polar is None: + anchor_polar = np.pi/2 + + xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) + c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) + + xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius)) + c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) + + cam_vis._poses = [c2w0, c2w1] + fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) + + explored_sph = (theta, azimuth, radius) + + return anchor_polar, explored_sph, fig, gr.update(interactive=True) + + +@spaces.GPU +def run_pose_refinement(cam_vis, image1, image2, anchor_polar, explored_sph, refine_iters): + + cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) + + image1 = image_to_tensor(image1).to(_device_) + image2 = image_to_tensor(image2).to(_device_) + + images = [image1, image2] + images = [ img.permute(0, 2, 3, 1) for img in images ] + + out_poses, _, loss = find_optimal_poses( + _model_, images, + 1.0, + bsz=1, + n_iter=refine_iters, + init_poses={1: explored_sph}, + ts_range=[0.2, 0.21], + combinations=[(0, 1), (1, 0)], + avg_last_n=20, + print_n=100 + ) + + final_sph = out_poses[0] + theta, azimuth, radius = final_sph + + xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) + c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) + + xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius)) + c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) + + cam_vis._poses = [c2w0, c2w1] + fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) + + return final_sph, fig + + +_HEADER_ = ''' +# Official 🤗 Gradio Demo for [ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models](https://github.com/xt4d/id-pose) +- ID-Pose accepts input images with NO overlapping appearance. +- The estimation takes about 1 minute. ZeroGPU may be halted during processing due to quota restrictions. +''' + +_FOOTER_ = ''' +- Project Page: [https://xt4d.github.io/id-pose-web/](https://xt4d.github.io/id-pose-web/) +- Github: [https://github.com/xt4d/id-pose](https://github.com/xt4d/id-pose) +''' + +_CITE_ = r""" +```bibtex +@article{cheng2023id, + title={ID-Pose: Sparse-view Camera Pose Estimation by Inverting Diffusion Models}, + author={Cheng, Weihao and Cao, Yan-Pei and Shan, Ying}, + journal={arXiv preprint arXiv:2306.17140}, + year={2023} +} +``` +""" + +def run_demo(): + + demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models') + + with demo: + gr.Markdown(_HEADER_) + + with gr.Row(variant='panel'): + with gr.Column(scale=1): + + with gr.Row(): + with gr.Column(min_width=280): + input_image1 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 1', width=280) + + with gr.Column(min_width=280): + input_image2 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 2', width=280) + + with gr.Row(): + with gr.Column(min_width=280): + processed_image1 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 1', width=280, interactive=False) + with gr.Column(min_width=280): + processed_image2 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 2', width=280, interactive=False) + + + with gr.Row(): + preprocess_chk = gr.Checkbox(True, label='Remove background and recenter object') + + with gr.Accordion('Advanced options', open=False): + probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size') + adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size') + adj_iters = gr.Slider(1, 20, value=5, step=1, label='Adjust Iterations') + + with gr.Row(): + run_btn = gr.Button('Estimate', variant='primary', interactive=True) + + with gr.Row(): + refine_iters = gr.Slider(0, 1000, value=0, step=50, label='Refinement Iterations') + + with gr.Row(): + refine_btn = gr.Button('Refine', variant='primary', interactive=False) + + with gr.Row(): + gr.Markdown(_FOOTER_) + + with gr.Row(): + gr.Markdown(_CITE_) + + + with gr.Column(scale=1.4): + + with gr.Row(): + vis_output = gr.Plot(label='Camera Pose Results: anchor (red) and target (blue)') + + with gr.Row(): + + with gr.Column(min_width=200): + gr.Examples( + examples = [ + ['data/gradio_demo/duck_0.png', 'data/gradio_demo/duck_1.png'], + ['data/gradio_demo/chair_0.png', 'data/gradio_demo/chair_1.png'], + ['data/gradio_demo/foosball_0.png', 'data/gradio_demo/foosball_1.png'], + ], + inputs=[input_image1, input_image2], + label='Examples (Self-captured)', + cache_examples=False, + examples_per_page=3 + ) + + with gr.Column(min_width=200): + gr.Examples( + examples = [ + ['data/gradio_demo/bunny_0.png', 'data/gradio_demo/bunny_1.png'], + ['data/gradio_demo/bus_0.png', 'data/gradio_demo/bus_1.png'], + ['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'], + ], + inputs=[input_image1, input_image2], + label='Examples (Images from NAVI)', + cache_examples=False, + examples_per_page=3 + ) + + with gr.Column(min_width=200): + gr.Examples( + examples = [ + ['data/gradio_demo/status_0.png', 'data/gradio_demo/status_1.png'], + ['data/gradio_demo/bag_0.png', 'data/gradio_demo/bag_1.png'], + ['data/gradio_demo/cat_0.png', 'data/gradio_demo/cat_1.png'], + ], + inputs=[input_image1, input_image2], + label='Examples (Generated)', + cache_examples=False, + examples_per_page=3 + ) + + cam_vis = CameraVisualizer([np.eye(4), np.eye(4)], ['Image 1', 'Image 2'], ['red', 'blue']) + + explored_sph = gr.State() + anchor_polar = gr.State() + refined_sph = gr.State() + elevs = gr.State() + elev_ranges = gr.State() + + run_btn.click( + fn=run_preprocess, + inputs=[input_image1, input_image2, preprocess_chk], + outputs=[processed_image1, processed_image2], + ).success( + fn=partial(run_pose_exploration_a, cam_vis), + inputs=[processed_image1, processed_image2], + outputs=[elevs, elev_ranges, vis_output] + ).success( + fn=partial(run_pose_exploration_b, cam_vis), + inputs=[processed_image1, processed_image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters], + outputs=[anchor_polar, explored_sph, vis_output, refine_btn] + ) + + refine_btn.click( + fn=partial(run_pose_refinement, cam_vis), + inputs=[processed_image1, processed_image2, anchor_polar, explored_sph, refine_iters], + outputs=[refined_sph, vis_output] + ) + + demo.launch() + + +if __name__ == '__main__': + + run_demo() \ No newline at end of file diff --git a/data/gradio_demo/bag_0.png b/data/gradio_demo/bag_0.png new file mode 100644 index 0000000000000000000000000000000000000000..86a5fd395b24efef2ec0d3a5074f131f96eccfe1 Binary files /dev/null and b/data/gradio_demo/bag_0.png differ diff --git a/data/gradio_demo/bag_1.png b/data/gradio_demo/bag_1.png new file mode 100644 index 0000000000000000000000000000000000000000..9137b12fb362b9c5b752ee82084e692a02e14e3a Binary files /dev/null and b/data/gradio_demo/bag_1.png differ diff --git a/data/gradio_demo/bunny_0.png b/data/gradio_demo/bunny_0.png new file mode 100644 index 0000000000000000000000000000000000000000..8894b0d1b814f028bbd9211e2270e6d3c0317918 Binary files /dev/null and b/data/gradio_demo/bunny_0.png differ diff --git a/data/gradio_demo/bunny_1.png b/data/gradio_demo/bunny_1.png new file mode 100644 index 0000000000000000000000000000000000000000..34f6d997269a85aef01a872dc7ae51439cd0b215 Binary files /dev/null and b/data/gradio_demo/bunny_1.png differ diff --git a/data/gradio_demo/bus_0.png b/data/gradio_demo/bus_0.png new file mode 100644 index 0000000000000000000000000000000000000000..c37f589df0dafb452a5c511b2a0cb9432e82473c Binary files /dev/null and b/data/gradio_demo/bus_0.png differ diff --git a/data/gradio_demo/bus_1.png b/data/gradio_demo/bus_1.png new file mode 100644 index 0000000000000000000000000000000000000000..8a87f200393cdadc034dd5b7fed408f585443021 Binary files /dev/null and b/data/gradio_demo/bus_1.png differ diff --git a/data/gradio_demo/cat_0.png b/data/gradio_demo/cat_0.png new file mode 100644 index 0000000000000000000000000000000000000000..82019287faeb1a752af5f7d23520a764207c4771 Binary files /dev/null and b/data/gradio_demo/cat_0.png differ diff --git a/data/gradio_demo/cat_1.png b/data/gradio_demo/cat_1.png new file mode 100644 index 0000000000000000000000000000000000000000..0df4ebc804d63b6a99a69e81b06917e3b7d1b30a Binary files /dev/null and b/data/gradio_demo/cat_1.png differ diff --git a/data/gradio_demo/chair_0.png b/data/gradio_demo/chair_0.png new file mode 100644 index 0000000000000000000000000000000000000000..e6542a80f56027fd4ac8698fc216784bc0959944 Binary files /dev/null and b/data/gradio_demo/chair_0.png differ diff --git a/data/gradio_demo/chair_1.png b/data/gradio_demo/chair_1.png new file mode 100644 index 0000000000000000000000000000000000000000..2b36747c1d4bcfab3662bc934dc3b2b59953da94 Binary files /dev/null and b/data/gradio_demo/chair_1.png differ diff --git a/data/gradio_demo/circo_0.png b/data/gradio_demo/circo_0.png new file mode 100644 index 0000000000000000000000000000000000000000..b192b873a561e8abffb966a21bc2f9d6e5da788a Binary files /dev/null and b/data/gradio_demo/circo_0.png differ diff --git a/data/gradio_demo/circo_1.png b/data/gradio_demo/circo_1.png new file mode 100644 index 0000000000000000000000000000000000000000..2d093dfa98bbff7a0e818597f191f59d94dc02c7 Binary files /dev/null and b/data/gradio_demo/circo_1.png differ diff --git a/data/gradio_demo/duck_0.png b/data/gradio_demo/duck_0.png new file mode 100644 index 0000000000000000000000000000000000000000..a4d6faa4ab35280a9e040cffdaf04ea55c83da3c Binary files /dev/null and b/data/gradio_demo/duck_0.png differ diff --git a/data/gradio_demo/duck_1.png b/data/gradio_demo/duck_1.png new file mode 100644 index 0000000000000000000000000000000000000000..6ff7d8a0eb3059cd662d2734006da7049aa479e1 Binary files /dev/null and b/data/gradio_demo/duck_1.png differ diff --git a/data/gradio_demo/foosball_0.png b/data/gradio_demo/foosball_0.png new file mode 100644 index 0000000000000000000000000000000000000000..50d792f67e53de74b40ec6943f2fc00a456cb95d Binary files /dev/null and b/data/gradio_demo/foosball_0.png differ diff --git a/data/gradio_demo/foosball_1.png b/data/gradio_demo/foosball_1.png new file mode 100644 index 0000000000000000000000000000000000000000..f6a36c01d9a992d706aa34ba4fa9f5d686a45ba7 Binary files /dev/null and b/data/gradio_demo/foosball_1.png differ diff --git a/data/gradio_demo/status_0.png b/data/gradio_demo/status_0.png new file mode 100644 index 0000000000000000000000000000000000000000..cb50e674683c4127bd4cc5b5c5e53d546bb497a9 Binary files /dev/null and b/data/gradio_demo/status_0.png differ diff --git a/data/gradio_demo/status_1.png b/data/gradio_demo/status_1.png new file mode 100644 index 0000000000000000000000000000000000000000..8ca38dd7a65dc5dc12a3de4a9e09b2ee2cb32959 Binary files /dev/null and b/data/gradio_demo/status_1.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..845b43384d47a44885a4061136f5def20461bfaa --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +torch==2.0.0 +torchvision==0.15.1 +opencv-python==4.7.0.72 +pudb==2019.2 +imageio==2.9.0 +pytorch-lightning==1.4.2 +omegaconf==2.1.1 +einops==0.3.0 +kornia==0.6 +torchmetrics==0.6.0 +gradio==3.41.2 +pillow==9.5.0 +rembg==2.0.56 +plotly==5.13.1 +trimesh==3.23.5 +yacs==0.1.8 +dl-ext==1.3.4 +git+https://github.com/openai/CLIP.git +-e git+https://github.com/CompVis/taming-transformers#egg=taming-transformers +loguru==0.7.2 \ No newline at end of file diff --git a/src/configs/sd-objaverse-finetune-c_concat-256.yaml b/src/configs/sd-objaverse-finetune-c_concat-256.yaml new file mode 100755 index 0000000000000000000000000000000000000000..a2caae8d4efe4c425eece488ee198ac5eaac433f --- /dev/null +++ b/src/configs/sd-objaverse-finetune-c_concat-256.yaml @@ -0,0 +1,119 @@ +model: + base_learning_rate: 1.0e-04 + target: src.ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "image_target" + cond_stage_key: "image_cond" + image_size: 32 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid + monitor: val/loss_simple_ema + scale_factor: 0.18215 + + scheduler_config: # 10000 warmup steps + target: src.ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 100 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: src.ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: src.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: src.ldm.modules.encoders.modules.FrozenCLIPImageEmbedder + params: + clip_root: 'ckpts/' + + +data: + target: src.ldm.data.simple.ObjaverseDataModuleFromConfig + params: + root_dir: 'views_whole_sphere' + batch_size: 192 + num_workers: 16 + total_view: 4 + train: + validation: False + image_transforms: + size: 256 + + validation: + validation: True + image_transforms: + size: 256 + + +lightning: + find_unused_parameters: false + metrics_over_trainsteps_checkpoint: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 500 + max_images: 32 + increase_log_steps: False + log_first_step: True + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 32 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + + trainer: + benchmark: True + val_check_interval: 5000000 # really sorry + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/src/ldm/data/__init__.py b/src/ldm/data/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ldm/data/base.py b/src/ldm/data/base.py new file mode 100755 index 0000000000000000000000000000000000000000..742794e631081bbfa7c44f3df6f83373ca5c15c1 --- /dev/null +++ b/src/ldm/data/base.py @@ -0,0 +1,40 @@ +import os +import numpy as np +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass + + +class PRNGMixin(object): + """ + Adds a prng property which is a numpy RandomState which gets + reinitialized whenever the pid changes to avoid synchronized sampling + behavior when used in conjunction with multiprocessing. + """ + @property + def prng(self): + currentpid = os.getpid() + if getattr(self, "_initpid", None) != currentpid: + self._initpid = currentpid + self._prng = np.random.RandomState() + return self._prng diff --git a/src/ldm/data/dummy.py b/src/ldm/data/dummy.py new file mode 100755 index 0000000000000000000000000000000000000000..3b74a77fe8954686e480d28aaed19e52d3e3c9b7 --- /dev/null +++ b/src/ldm/data/dummy.py @@ -0,0 +1,34 @@ +import numpy as np +import random +import string +from torch.utils.data import Dataset, Subset + +class DummyData(Dataset): + def __init__(self, length, size): + self.length = length + self.size = size + + def __len__(self): + return self.length + + def __getitem__(self, i): + x = np.random.randn(*self.size) + letters = string.ascii_lowercase + y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) + return {"jpg": x, "txt": y} + + +class DummyDataWithEmbeddings(Dataset): + def __init__(self, length, size, emb_size): + self.length = length + self.size = size + self.emb_size = emb_size + + def __len__(self): + return self.length + + def __getitem__(self, i): + x = np.random.randn(*self.size) + y = np.random.randn(*self.emb_size).astype(np.float32) + return {"jpg": x, "txt": y} + diff --git a/src/ldm/data/simple.py b/src/ldm/data/simple.py new file mode 100755 index 0000000000000000000000000000000000000000..9374107e25db080e344bbcc984a28bd60c0269e9 --- /dev/null +++ b/src/ldm/data/simple.py @@ -0,0 +1,191 @@ +from typing import Dict +import webdataset as wds +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +import torchvision +from einops import rearrange +from ..util import instantiate_from_config +from datasets import load_dataset +import pytorch_lightning as pl +import copy +import csv +import cv2 +import random +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +import json +import os, sys +import webdataset as wds +import math +from torch.utils.data.distributed import DistributedSampler + + +class ObjaverseDataModuleFromConfig(pl.LightningDataModule): + def __init__(self, root_dir, batch_size, total_view, train=None, validation=None, + test=None, num_workers=4, **kwargs): + super().__init__(self) + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.total_view = total_view + + if train is not None: + dataset_config = train + if validation is not None: + dataset_config = validation + + if 'image_transforms' in dataset_config: + image_transforms = [torchvision.transforms.Resize(dataset_config.image_transforms.size)] + else: + image_transforms = [] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + self.image_transforms = torchvision.transforms.Compose(image_transforms) + + + def train_dataloader(self): + dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, \ + image_transforms=self.image_transforms) + sampler = DistributedSampler(dataset) + return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def val_dataloader(self): + dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, \ + image_transforms=self.image_transforms) + sampler = DistributedSampler(dataset) + return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + def test_dataloader(self): + return wds.WebLoader(ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=self.validation),\ + batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + +class ObjaverseData(Dataset): + def __init__(self, + root_dir='.objaverse/hf-objaverse-v1/views', + image_transforms=[], + ext="png", + default_trans=torch.zeros(3), + postprocess=None, + return_paths=False, + total_view=4, + validation=False + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_trans = default_trans + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + self.total_view = total_view + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + with open(os.path.join(root_dir, 'valid_paths.json')) as f: + self.paths = json.load(f) + + total_objects = len(self.paths) + if validation: + self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation + else: + self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training + print('============= length of dataset %d =============' % len(self.paths)) + self.tform = image_transforms + + def __len__(self): + return len(self.paths) + + def cartesian_to_spherical(self, xyz): + ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) + xy = xyz[:,0]**2 + xyz[:,1]**2 + z = np.sqrt(xy + xyz[:,2]**2) + theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down + #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + azimuth = np.arctan2(xyz[:,1], xyz[:,0]) + return np.array([theta, azimuth, z]) + + def get_T(self, target_RT, cond_RT): + R, T = target_RT[:3, :3], target_RT[:, -1] + T_target = -R.T @ T + + R, T = cond_RT[:3, :3], cond_RT[:, -1] + T_cond = -R.T @ T + + theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) + return d_T + + def load_im(self, path, color): + ''' + replace background pixel with random color in rendering + ''' + try: + img = plt.imread(path) + except: + print(path) + sys.exit() + img[img[:, :, -1] == 0.] = color + img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) + return img + + def __getitem__(self, index): + + data = {} + if self.paths[index][-2:] == '_1': # dirty fix for rendering dataset twice + total_view = 8 + else: + total_view = 4 + index_target, index_cond = random.sample(range(total_view), 2) # without replacement + filename = os.path.join(self.root_dir, self.paths[index]) + + # print(self.paths[index]) + + if self.return_paths: + data["path"] = str(filename) + + color = [1., 1., 1., 1.] + + try: + target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) + cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color)) + target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) + cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond)) + except: + # very hacky solution, sorry about this + filename = os.path.join(self.root_dir, '692db5f2d3a04bb286cb977a7dba903e_1') # this one we know is valid + target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) + cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color)) + target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) + cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond)) + target_im = torch.zeros_like(target_im) + cond_im = torch.zeros_like(cond_im) + + data["image_target"] = target_im + data["image_cond"] = cond_im + data["T"] = self.get_T(target_RT, cond_RT) + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) diff --git a/src/ldm/lr_scheduler.py b/src/ldm/lr_scheduler.py new file mode 100755 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/src/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/src/ldm/models/autoencoder.py b/src/ldm/models/autoencoder.py new file mode 100755 index 0000000000000000000000000000000000000000..866defe0b2dea2f8eca3cdb7b396dc46f40028d7 --- /dev/null +++ b/src/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ..modules.diffusionmodules.model import Encoder, Decoder +from ..modules.distributions.distributions import DiagonalGaussianDistribution + +from ..util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/src/ldm/models/diffusion/__init__.py b/src/ldm/models/diffusion/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ldm/models/diffusion/classifier.py b/src/ldm/models/diffusion/classifier.py new file mode 100755 index 0000000000000000000000000000000000000000..e61fbd476608b0236a78bd3f63a3d353befe0073 --- /dev/null +++ b/src/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ...modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ...util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/src/ldm/models/diffusion/ddim.py b/src/ldm/models/diffusion/ddim.py new file mode 100755 index 0000000000000000000000000000000000000000..0add14193d3aa40075f8016a70a5cb2adaf83e31 --- /dev/null +++ b/src/ldm/models/diffusion/ddim.py @@ -0,0 +1,322 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial +from einops import rearrange + +from ...modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor +from .sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def to(self, device): + """Same as to in torch module + Don't really underestand why this isn't a module in the first place""" + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + new_v = getattr(self, k).to(device) + setattr(self, k, new_v) + + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + t_start=-1): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + timesteps = timesteps[:t_start] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: + img = callback(i, img, pred_x0) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/src/ldm/models/diffusion/ddpm.py b/src/ldm/models/diffusion/ddpm.py new file mode 100755 index 0000000000000000000000000000000000000000..a0179c2bf8a1b63bb70eed209f2a1f995b1c60cb --- /dev/null +++ b/src/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1999 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig + +from ...util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ...modules.ema import LitEma +from ...modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ..autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ...modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from .ddim import DDIMSampler +from ...modules.attention import CrossAttention + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + make_it_fit=False, + ucg_training=None, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape)==len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1-p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + unet_trainable=True, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.unet_trainable = unet_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + + # construct linear projection layer for concatenating image CLIP embedding and RT + self.cc_projection = nn.Linear(772, 768) + nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768]) + nn.init.zeros_(list(self.cc_projection.parameters())[1]) + self.cc_projection.requires_grad_(True) + + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, uncond=0.05): + x = super().get_input(batch, k) + T = batch['T'].to(memory_format=torch.contiguous_format).float() + + if bs is not None: + x = x[:bs] + T = T[:bs].to(self.device) + + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + cond_key = cond_key or self.cond_stage_key + xc = super().get_input(batch, cond_key).to(self.device) + if bs is not None: + xc = xc[:bs] + cond = {} + + # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. + # random = torch.rand(x.size(0), device=x.device) + # prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1") + # input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1") + # null_prompt = self.get_learned_conditioning([""]) + + # z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768] + # print('=========== xc shape ===========', xc.shape) + with torch.enable_grad(): + clip_emb = self.get_learned_conditioning(xc).detach() + null_prompt = self.get_learned_conditioning([""]).detach() + cond["c_crossattn"] = [self.cc_projection(torch.cat([clip_emb, T[:, None, :]], dim=-1))] + cond["c_concat"] = [self.encode_first_stage((xc.to(self.device))).mode().detach()] + out = [z, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + # @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c, **kwargs) + return loss + + def forward(self, x, c, *args, **kwargs): + if 'ts' in kwargs: + t = torch.tensor(kwargs['ts']*self.num_timesteps, device=self.device).long() + kwargs.pop('ts') + else: + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + #t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + # if self.cond_stage_trainable: + # c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + adapted_cond = self.get_learned_conditioning(adapted_cond) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t.to(self.logvar.device)].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None, image_size=512): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + cond = {} + cond["c_crossattn"] = [c] + cond["c_concat"] = [torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to(self.device)] + return cond + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label, image_size=x.shape[-1]) + # uc = torch.zeros_like(c) + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = [] + if self.unet_trainable == "attn": + print("Training only unet attention layers") + for n, m in self.model.named_modules(): + if isinstance(m, CrossAttention) and n.endswith('attn2'): + params.extend(m.parameters()) + if self.unet_trainable == "conv_in": + print("Training only unet input conv layers") + params = list(self.model.diffusion_model.input_blocks[0][0].parameters()) + elif self.unet_trainable is True or self.unet_trainable == "all": + print("Training the full unet") + params = list(self.model.parameters()) + else: + raise ValueError(f"Unrecognised setting for unet_trainable: {self.unet_trainable}") + + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + + if self.cc_projection is not None: + params = params + list(self.cc_projection.parameters()) + print('========== optimizing for cc projection weight ==========') + + opt = torch.optim.AdamW([{"params": self.model.parameters(), "lr": lr}, + {"params": self.cc_projection.parameters(), "lr": 10. * lr}], lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + # c_crossattn dimension: torch.Size([8, 1, 768]) 1 + # cc dimension: torch.Size([8, 1, 768] + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + #import pudb; pu.db + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentInpaintDiffusion(LatentDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + def __init__(self, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print(f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:,self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + log["masked_image"] = rearrange(batch["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs + + +class SimpleUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.low_scale_key = low_scale_key + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + + encoder_posterior = self.encode_first_stage(x_low) + zx = self.get_first_stage_encoding(encoder_posterior).detach() + all_conds = {"c_concat": [zx], "c_crossattn": [c]} + + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + return z, all_conds, x, xrec, xc, x_low + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log + +class MultiCatFrameDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.low_scale_key = low_scale_key + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + n = 2 + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + cat_conds = batch[self.low_scale_key][:bs] + cats = [] + for i in range(n): + x_low = cat_conds[:,:,:,3*i:3*(i+1)] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + encoder_posterior = self.encode_first_stage(x_low) + zx = self.get_first_stage_encoding(encoder_posterior).detach() + cats.append(zx) + + all_conds = {"c_concat": [torch.cat(cats, dim=1)], "c_crossattn": [c]} + + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + return z, all_conds, x, xrec, xc, x_low + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log diff --git a/src/ldm/models/diffusion/plms.py b/src/ldm/models/diffusion/plms.py new file mode 100755 index 0000000000000000000000000000000000000000..b46f8cf196cdcc643586d7d78ce6b604050d186c --- /dev/null +++ b/src/ldm/models/diffusion/plms.py @@ -0,0 +1,259 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ...modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from .sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/src/ldm/models/diffusion/sampling_util.py b/src/ldm/models/diffusion/sampling_util.py new file mode 100755 index 0000000000000000000000000000000000000000..a0ae00fe86044456fc403af403be71ff15112424 --- /dev/null +++ b/src/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,50 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def renorm_thresholding(x0, value): + # renorm + pred_max = x0.max() + pred_min = x0.min() + pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 + pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 + + s = torch.quantile( + rearrange(pred_x0, 'b ... -> b (...)').abs(), + value, + dim=-1 + ) + s.clamp_(min=1.0) + s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) + + # clip by threshold + # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max + + # temporary hack: numpy on cpu + pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() + pred_x0 = torch.tensor(pred_x0).to(self.model.device) + + # re.renorm + pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 + pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range + return pred_x0 + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/src/ldm/modules/attention.py b/src/ldm/modules/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..864f4d06998da49f8c233a473ee7b5ff8414154f --- /dev/null +++ b/src/ldm/modules/attention.py @@ -0,0 +1,266 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from .diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in diff --git a/src/ldm/modules/diffusionmodules/__init__.py b/src/ldm/modules/diffusionmodules/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ldm/modules/diffusionmodules/model.py b/src/ldm/modules/diffusionmodules/model.py new file mode 100755 index 0000000000000000000000000000000000000000..2362bfcb1d7415e060a58b7f50091ce1a449486c --- /dev/null +++ b/src/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ...util import instantiate_from_config +from ..attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/src/ldm/modules/diffusionmodules/openaimodel.py b/src/ldm/modules/diffusionmodules/openaimodel.py new file mode 100755 index 0000000000000000000000000000000000000000..aa4408d3adc5fb7896b299baa2ac4e323042cd7d --- /dev/null +++ b/src/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,996 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ..attention import SpatialTransformer +from ...util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/src/ldm/modules/diffusionmodules/util.py b/src/ldm/modules/diffusionmodules/util.py new file mode 100755 index 0000000000000000000000000000000000000000..8ff2c070b8d30e7687e24cce8425d0227c17b893 --- /dev/null +++ b/src/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,267 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ...util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/src/ldm/modules/distributions/__init__.py b/src/ldm/modules/distributions/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ldm/modules/distributions/distributions.py b/src/ldm/modules/distributions/distributions.py new file mode 100755 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/src/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/src/ldm/modules/ema.py b/src/ldm/modules/ema.py new file mode 100755 index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5 --- /dev/null +++ b/src/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/src/ldm/modules/encoders/__init__.py b/src/ldm/modules/encoders/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/ldm/modules/encoders/modules.py b/src/ldm/modules/encoders/modules.py new file mode 100755 index 0000000000000000000000000000000000000000..91a6a6091842a987c2242de3c111784998256c31 --- /dev/null +++ b/src/ldm/modules/encoders/modules.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial +import kornia +import clip + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=False, + clip_root=None + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit, download_root=clip_root) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, 768, device=device) + return self.model.encode_image(self.preprocess(x)).float() + + def encode(self, im): + return self(im).unsqueeze(1) diff --git a/src/ldm/modules/losses/__init__.py b/src/ldm/modules/losses/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..9a8b4f66d63552b5b80df394eeb6b2c16564fe5d --- /dev/null +++ b/src/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from .contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/src/ldm/modules/losses/contperceptual.py b/src/ldm/modules/losses/contperceptual.py new file mode 100755 index 0000000000000000000000000000000000000000..672c1e32a1389def02461c0781339681060c540e --- /dev/null +++ b/src/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/src/ldm/modules/losses/vqperceptual.py b/src/ldm/modules/losses/vqperceptual.py new file mode 100755 index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306 --- /dev/null +++ b/src/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/src/ldm/modules/x_transformer.py b/src/ldm/modules/x_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576 --- /dev/null +++ b/src/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/src/ldm/util.py b/src/ldm/util.py new file mode 100755 index 0000000000000000000000000000000000000000..301fafff26227d33562f3e893a5dbe79570d47f1 --- /dev/null +++ b/src/ldm/util.py @@ -0,0 +1,256 @@ +import importlib + +import torchvision +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + +import os +import numpy as np +from PIL import Image +import torch +import cv2 +import PIL + +def pil_rectangle_crop(im): + width, height = im.size # Get dimensions + + if width <= height: + left = 0 + right = width + top = (height - width)/2 + bottom = (height + width)/2 + else: + + top = 0 + bottom = height + left = (width - height) / 2 + bottom = (width + height) / 2 + + # Crop the center of the image + im = im.crop((left, top, right, bottom)) + return im + +def add_margin(pil_img, color, size=256): + width, height = pil_img.size + result = Image.new(pil_img.mode, (size, size), color) + result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) + return result + + +def load_and_preprocess(interface, input_im): + ''' + :param input_im (PIL Image). + :return image (H, W, 3) array in [0, 1]. + ''' + # See https://github.com/Ir1d/image-background-remove-tool + image = input_im.convert('RGB') + + image_without_background = interface([image])[0] + image_without_background = np.array(image_without_background) + est_seg = image_without_background > 127 + image = np.array(image) + foreground = est_seg[:, : , -1].astype(np.bool_) + image[~foreground] = [255., 255., 255.] + x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) + image = image[y:y+h, x:x+w, :] + image = PIL.Image.fromarray(np.array(image)) + + # resize image such that long edge is 512 + image.thumbnail([200, 200], Image.Resampling.LANCZOS) + image = add_margin(image, (255, 255, 255), size=256) + image = np.array(image) + + return image + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/src/oee/models/loftr/__init__.py b/src/oee/models/loftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d69b9c131cf41e95c5c6ee7d389b375267b22fa --- /dev/null +++ b/src/oee/models/loftr/__init__.py @@ -0,0 +1,2 @@ +from .loftr import LoFTR +from .utils.cvpr_ds_config import default_cfg diff --git a/src/oee/models/loftr/backbone/__init__.py b/src/oee/models/loftr/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e731b3f53ab367c89ef0ea8e1cbffb0d990775 --- /dev/null +++ b/src/oee/models/loftr/backbone/__init__.py @@ -0,0 +1,11 @@ +from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 + + +def build_backbone(config): + if config['backbone_type'] == 'ResNetFPN': + if config['resolution'] == (8, 2): + return ResNetFPN_8_2(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/src/oee/models/loftr/backbone/resnet_fpn.py b/src/oee/models/loftr/backbone/resnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..985e5b3f273a51e51447a8025ca3aadbe46752eb --- /dev/null +++ b/src/oee/models/loftr/backbone/resnet_fpn.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +class ResNetFPN_16_4(nn.Module): + """ + ResNet+FPN, output resolution are 1/16 and 1/4. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out = self.layer4_outconv(x4) + + x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + return [x4_out, x2_out] diff --git a/src/oee/models/loftr/loftr.py b/src/oee/models/loftr/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..79c491ee47a4d67cb8b3fe493397349e0867accd --- /dev/null +++ b/src/oee/models/loftr/loftr.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +from .utils.position_encoding import PositionEncodingSine +from .loftr_module import LocalFeatureTransformer, FinePreprocess +from .utils.coarse_matching import CoarseMatching +from .utils.fine_matching import FineMatching + + +class LoFTR(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = build_backbone(config) + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'], + temp_bug_fix=config['coarse']['temp_bug_fix']) + self.loftr_coarse = LocalFeatureTransformer(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching() + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) + else: # handle different input shapes + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # 2. coarse-level loftr module + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + # 3. match coarse-level + self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/src/oee/models/loftr/loftr_module/__init__.py b/src/oee/models/loftr/loftr_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a --- /dev/null +++ b/src/oee/models/loftr/loftr_module/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/src/oee/models/loftr/loftr_module/fine_preprocess.py b/src/oee/models/loftr/loftr_module/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb8eefd362240a9901a335f0e6e07770ff04567 --- /dev/null +++ b/src/oee/models/loftr/loftr_module/fine_preprocess.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.W = self.config['fine_window_size'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + return feat0, feat1 + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold diff --git a/src/oee/models/loftr/loftr_module/linear_attention.py b/src/oee/models/loftr/loftr_module/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b73c5a6a6a722a44c0b68f70cb77c0988b8a5fb3 --- /dev/null +++ b/src/oee/models/loftr/loftr_module/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() diff --git a/src/oee/models/loftr/loftr_module/transformer.py b/src/oee/models/loftr/loftr_module/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d79390ca08953bbef44e98149e662a681a16e42e --- /dev/null +++ b/src/oee/models/loftr/loftr_module/transformer.py @@ -0,0 +1,101 @@ +import copy +import torch +import torch.nn as nn +from .linear_attention import LinearAttention, FullAttention + + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 diff --git a/src/oee/models/loftr/utils/coarse_matching.py b/src/oee/models/loftr/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..a97263339462dec3af9705d33d6ee634e2f46914 --- /dev/null +++ b/src/oee/models/loftr/utils/coarse_matching.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature = config['dsmax_temperature'] + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + if self.match_type == 'dual_softmax': + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) / self.temperature + if mask_c0 is not None: + sim_matrix.masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + elif self.match_type == 'sinkhorn': + # sinkhorn, dustbin included + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) + if mask_c0 is not None: + sim_matrix[:, :L, :S].masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + + # build uniform prior & use sinkhorn + log_assign_matrix = self.log_optimal_transport( + sim_matrix, self.bin_score, self.skh_iters) + assign_matrix = log_assign_matrix.exp() + conf_matrix = assign_matrix[:, :-1, :-1] + + # filter prediction with dustbin score (only in evaluation mode) + if not self.training and self.skh_prefilter: + filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] + filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] + conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 + conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 + + if self.config['sparse_spvs']: + data.update({'conf_matrix_with_bin': assign_matrix.clone()}) + + data.update({'conf_matrix': conf_matrix}) + + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches diff --git a/src/oee/models/loftr/utils/cvpr_ds_config.py b/src/oee/models/loftr/utils/cvpr_ds_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9ce70154d3a1b961d3b4f08897415720f451f8 --- /dev/null +++ b/src/oee/models/loftr/utils/cvpr_ds_config.py @@ -0,0 +1,50 @@ +from yacs.config import CfgNode as CN + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +_CN = CN() +_CN.BACKBONE_TYPE = 'ResNetFPN' +_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.FINE_CONCAT_COARSE_FEAT = True + +# 1. LoFTR-backbone (local feature CNN) config +_CN.RESNETFPN = CN() +_CN.RESNETFPN.INITIAL_DIM = 128 +_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.COARSE = CN() +_CN.COARSE.D_MODEL = 256 +_CN.COARSE.D_FFN = 256 +_CN.COARSE.NHEAD = 8 +_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.COARSE.TEMP_BUG_FIX = False + +# 3. Coarse-Matching config +_CN.MATCH_COARSE = CN() +_CN.MATCH_COARSE.THR = 0.2 +_CN.MATCH_COARSE.BORDER_RM = 2 +_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.MATCH_COARSE.SKH_ITERS = 3 +_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.MATCH_COARSE.SKH_PREFILTER = True +_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory +_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock + +# 4. LoFTR-fine module config +_CN.FINE = CN() +_CN.FINE.D_MODEL = 128 +_CN.FINE.D_FFN = 128 +_CN.FINE.NHEAD = 8 +_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 +_CN.FINE.ATTENTION = 'linear' + +default_cfg = lower_config(_CN) diff --git a/src/oee/models/loftr/utils/fine_matching.py b/src/oee/models/loftr/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..6e77aded52e1eb5c01e22c2738104f3b09d6922a --- /dev/null +++ b/src/oee/models/loftr/utils/fine_matching.py @@ -0,0 +1,74 @@ +import math +import torch +import torch.nn as nn + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self): + super().__init__() + + def forward(self, feat_f0, feat_f1, data): + """ + Args: + feat0 (torch.Tensor): [M, WW, C] + feat1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_f0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_f0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match(coords_normalized, data) + + @torch.no_grad() + def get_fine_match(self, coords_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) diff --git a/src/oee/models/loftr/utils/geometry.py b/src/oee/models/loftr/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116 --- /dev/null +++ b/src/oee/models/loftr/utils/geometry.py @@ -0,0 +1,54 @@ +import torch + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 diff --git a/src/oee/models/loftr/utils/position_encoding.py b/src/oee/models/loftr/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..732d28c814ef93bf48d338ba7554f6dcfc3b880e --- /dev/null +++ b/src/oee/models/loftr/utils/position_encoding.py @@ -0,0 +1,42 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + if temp_bug_fix: + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + else: # a buggy implementation (for backward compatability only) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] diff --git a/src/oee/models/loftr/utils/supervision.py b/src/oee/models/loftr/utils/supervision.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce6e79ec72b45fcb6b187e33bda93a47b168acd --- /dev/null +++ b/src/oee/models/loftr/utils/supervision.py @@ -0,0 +1,151 @@ +from math import log +from loguru import logger + +import torch +from einops import repeat +from kornia.utils import create_meshgrid + +from .geometry import warp_kpts + +############## ↓ Coarse-Level supervision ↓ ############## + + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['LOFTR']['RESOLUTION'][0] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt + if 'mask0' in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + + # warp kpts bi-directionally and resize them to coarse-level resolution + # (no depth consistency check, since it leads to worse results experimentally) + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) + _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + w_pt0_c = w_pt0_i / scale1 + w_pt1_c = w_pt1_i / scale0 + + # 3. check if mutual nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round().long() + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 + w_pt1_c_round = w_pt1_c[:, :, :].round().long() + nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 + nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 + + loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i + }) + + +def compute_supervision_coarse(data, config): + assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_coarse(data, config) + else: + raise ValueError(f'Unknown data source: {data_source}') + + +############## ↓ Fine-Level supervision ↓ ############## + +@torch.no_grad() +def spvs_fine(data, config): + """ + Update: + data (dict):{ + "expec_f_gt": [M, 2]} + """ + # 1. misc + # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') + w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] + scale = config['LOFTR']['RESOLUTION'][1] + radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2 + + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + + # 3. compute gt + scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale + # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later + expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] + data.update({"expec_f_gt": expec_f_gt}) + + +def compute_supervision_fine(data, config): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config) + else: + raise NotImplementedError diff --git a/src/oee/utils/elev_est_api.py b/src/oee/utils/elev_est_api.py new file mode 100644 index 0000000000000000000000000000000000000000..903e669b3eabd3d95881079ec8abed7fd27e919d --- /dev/null +++ b/src/oee/utils/elev_est_api.py @@ -0,0 +1,182 @@ +import matplotlib.pyplot as plt +import warnings + +import numpy as np +import cv2 +import os +import os.path as osp +import imageio +from copy import deepcopy + +import loguru +import torch +from ..models.loftr import LoFTR, default_cfg + +from .utils3d import rect_to_img, canonical_to_camera, calc_pose + + +class ElevEstHelper: + _feature_matcher = None + + @classmethod + def get_feature_matcher(cls, ckpt_path, device): + if cls._feature_matcher is None: + loguru.logger.info("Loading feature matcher...") + assert os.path.exists(ckpt_path) + _default_cfg = deepcopy(default_cfg) + _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt + matcher = LoFTR(config=_default_cfg) + matcher.load_state_dict(torch.load(ckpt_path)['state_dict']) + matcher = matcher.eval().to(device) + cls._feature_matcher = matcher + return cls._feature_matcher + + +def mask_out_bkgd(img): + if img.shape[-1] == 4: + fg_mask = img[:, :, :3] + else: + loguru.logger.info("Image has no alpha channel, using thresholding to mask out background") + fg_mask = ~(img > 245).all(axis=-1) + return fg_mask + + +def get_feature_matching(matcher, images): + assert len(images) == 4 + feature_matching = {} + masks = [] + for i in range(4): + mask = mask_out_bkgd(images[i]) + masks.append(mask) + for i in range(0, 4): + for j in range(i + 1, 4): + mask0 = masks[i] + mask1 = masks[j] + img0_raw = cv2.cvtColor(images[i], cv2.COLOR_RGB2GRAY) + img1_raw = cv2.cvtColor(images[j], cv2.COLOR_RGB2GRAY) + original_shape = img0_raw.shape + img0_raw_resized = cv2.resize(img0_raw, (480, 480)) + img1_raw_resized = cv2.resize(img1_raw, (480, 480)) + + img0 = torch.from_numpy(img0_raw_resized)[None][None].cuda() / 255. + img1 = torch.from_numpy(img1_raw_resized)[None][None].cuda() / 255. + batch = {'image0': img0, 'image1': img1} + + # Inference with LoFTR and get prediction + with torch.no_grad(): + matcher(batch) + mkpts0 = batch['mkpts0_f'].cpu().numpy() + mkpts1 = batch['mkpts1_f'].cpu().numpy() + mconf = batch['mconf'].cpu().numpy() + mkpts0[:, 0] = mkpts0[:, 0] * original_shape[1] / 480 + mkpts0[:, 1] = mkpts0[:, 1] * original_shape[0] / 480 + mkpts1[:, 0] = mkpts1[:, 0] * original_shape[1] / 480 + mkpts1[:, 1] = mkpts1[:, 1] * original_shape[0] / 480 + keep0 = mask0[mkpts0[:, 1].astype(int), mkpts1[:, 0].astype(int)] + keep1 = mask1[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)] + keep = np.logical_and(keep0, keep1) + mkpts0 = mkpts0[keep] + mkpts1 = mkpts1[keep] + mconf = mconf[keep] + feature_matching[f"{i}_{j}"] = np.concatenate([mkpts0, mkpts1, mconf[:, None]], axis=1) + + return feature_matching + + +def gen_pose_hypothesis(center_elevation): + elevations = np.radians( + [center_elevation, center_elevation - 10, center_elevation + 10, center_elevation, center_elevation]) # 45~120 + azimuths = np.radians([30, 30, 30, 20, 40]) + input_poses = calc_pose(elevations, azimuths, len(azimuths)) + input_poses = input_poses[1:] + input_poses[..., 1] *= -1 + input_poses[..., 2] *= -1 + return input_poses + + +def ba_error_general(K, matches, poses): + projmat0 = K @ poses[0].inverse()[:3, :4] + projmat1 = K @ poses[1].inverse()[:3, :4] + match_01 = matches[0] + pts0 = match_01[:, :2] + pts1 = match_01[:, 2:4] + Xref = cv2.triangulatePoints(projmat0.cpu().numpy(), projmat1.cpu().numpy(), + pts0.cpu().numpy().T, pts1.cpu().numpy().T) + Xref = Xref[:3] / Xref[3:] + Xref = Xref.T + Xref = torch.from_numpy(Xref).float() + reproj_error = 0 + for match, cp in zip(matches[1:], poses[2:]): + dist = (torch.norm(match_01[:, :2][:, None, :] - match[:, :2][None, :, :], dim=-1)) + if dist.numel() > 0: + # print("dist.shape", dist.shape) + m0to2_index = dist.argmin(1) + keep = dist[torch.arange(match_01.shape[0]), m0to2_index] < 1 + if keep.sum() > 0: + xref_in2 = rect_to_img(K, canonical_to_camera(Xref, cp.inverse())) + reproj_error2 = torch.norm(match[m0to2_index][keep][:, 2:4] - xref_in2[keep], dim=-1) + conf02 = match[m0to2_index][keep][:, -1] + reproj_error += (reproj_error2 * conf02).sum() / (conf02.sum()) + + return reproj_error + + +def find_optim_elev(elevs, nimgs, matches, K): + errs = [] + for elev in elevs: + err = 0 + cam_poses = gen_pose_hypothesis(elev) + for start in range(nimgs - 1): + batch_matches, batch_poses = [], [] + for i in range(start, nimgs + start): + ci = i % nimgs + batch_poses.append(cam_poses[ci]) + for j in range(nimgs - 1): + key = f"{start}_{(start + j + 1) % nimgs}" + match = matches[key] + batch_matches.append(match) + err += ba_error_general(K, batch_matches, batch_poses) + errs.append(err) + errs = torch.tensor(errs) + optim_elev = elevs[torch.argmin(errs)].item() + return optim_elev + + +def get_elev_est(feature_matching, min_elev=30, max_elev=150, K=None): + flag = True + matches = {} + for i in range(4): + for j in range(i + 1, 4): + match_ij = feature_matching[f"{i}_{j}"] + if len(match_ij) == 0: + flag = False + match_ji = np.concatenate([match_ij[:, 2:4], match_ij[:, 0:2], match_ij[:, 4:5]], axis=1) + matches[f"{i}_{j}"] = torch.from_numpy(match_ij).float() + matches[f"{j}_{i}"] = torch.from_numpy(match_ji).float() + if not flag: + loguru.logger.info("0 matches, could not estimate elevation") + return None + interval = 10 + elevs = np.arange(min_elev, max_elev, interval) + optim_elev1 = find_optim_elev(elevs, 4, matches, K) + + elevs = np.arange(optim_elev1 - 10, optim_elev1 + 10, 1) + elevs = elevs[elevs % 180 != 0] + elevs = elevs[(elevs - 10) % 180 != 0] + elevs = elevs[(elevs + 10) % 180 != 0] + + optim_elev2 = find_optim_elev(elevs, 4, matches, K) + + return optim_elev2 + + +def elev_est_api(matcher, images, min_elev=30, max_elev=150, K=None): + feature_matching = get_feature_matching(matcher, images) + if K is None: + loguru.logger.warning("K is not provided, using default K") + K = np.array([[280.0, 0, 128.0], + [0, 280.0, 128.0], + [0, 0, 1]]) + K = torch.from_numpy(K).float() + elev = get_elev_est(feature_matching, min_elev, max_elev, K) + return elev diff --git a/src/oee/utils/utils3d.py b/src/oee/utils/utils3d.py new file mode 100644 index 0000000000000000000000000000000000000000..1e8e577ef2ff95a11a1aa8b7b25032417c3e7bd8 --- /dev/null +++ b/src/oee/utils/utils3d.py @@ -0,0 +1,61 @@ +import numpy as np +import torch + + +def cart_to_hom(pts): + """ + :param pts: (N, 3 or 2) + :return pts_hom: (N, 4 or 3) + """ + if isinstance(pts, np.ndarray): + pts_hom = np.concatenate((pts, np.ones([*pts.shape[:-1], 1], dtype=np.float32)), -1) + else: + ones = torch.ones([*pts.shape[:-1], 1], dtype=torch.float32, device=pts.device) + pts_hom = torch.cat((pts, ones), dim=-1) + return pts_hom + + +def hom_to_cart(pts): + return pts[..., :-1] / pts[..., -1:] + + +def canonical_to_camera(pts, pose): + pts = cart_to_hom(pts) + pts = pts @ pose.transpose(-1, -2) + pts = hom_to_cart(pts) + return pts + + +def rect_to_img(K, pts_rect): + from dl_ext.vision_ext.datasets.kitti.structures import Calibration + pts_2d_hom = pts_rect @ K.t() + pts_img = Calibration.hom_to_cart(pts_2d_hom) + return pts_img + + +def calc_pose(phis, thetas, size, radius=1.2): + import torch + def normalize(vectors): + return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) + + thetas = torch.FloatTensor(thetas) + phis = torch.FloatTensor(phis) + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + -radius * torch.cos(thetas) * torch.sin(phis), + radius * torch.cos(phis), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = normalize(centers).squeeze(0) + up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1) + right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1)) + if right_vector.pow(2).sum() < 0.01: + right_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(size, 1) + up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + return poses diff --git a/src/pose_estimation.py b/src/pose_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..226e4375f588be141666cfb202bb3190979a85c1 --- /dev/null +++ b/src/pose_estimation.py @@ -0,0 +1,280 @@ +import numpy as np +import torch +from PIL import Image +from torchvision import transforms +from datetime import datetime + +from .ldm.util import load_and_preprocess, instantiate_from_config +from .pose_funcs import probe_pose, find_optimal_poses, get_inv_pose, add_pose, pairwise_loss + +from .oee.utils.elev_est_api import elev_est_api, ElevEstHelper +from .sampling import sample_images + + +def load_image(img_path, mask_path=None, preprocessor=None, threshold=0.9): + + img = Image.open(img_path) + + if preprocessor is not None: + img = load_and_preprocess(preprocessor, img) + else: + if img.mode == 'RGBA': + img = np.asarray(img, dtype=np.float32) / 255. + img[img[:, :, -1] <= threshold] = [1., 1., 1., 1.] # thresholding background + img = img[:, :, :3] + elif img.mode == 'RGB': + if mask_path is not None: + mask = Image.open(mask_path) + bkg = Image.new('RGB', (img.width, img.height), color=(255, 255, 255)) + img = Image.composite(img, bkg, mask) + img = np.asarray(img, dtype=np.float32) / 255. + else: + print('Wrong format:', img_path) + + return img + + +def load_model_from_config(config, ckpt, device, verbose=False): + print(f'Loading model from {ckpt}') + pl_sd = torch.load(ckpt, map_location=device) + if 'global_step' in pl_sd: + step = pl_sd['global_step'] + print(f'Global Step: {step}') + sd = pl_sd['state_dict'] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print('missing keys:') + print(m) + if len(u) > 0 and verbose: + print('unexpected keys:') + print(u) + + model.to(device) + model.eval() + return model + + +def estimate_elevs(model, images, est_type=None, matcher_ckpt_path=None): + + num = len(images) + + elevs = {i: None for i in range(num)} + elev_ranges = {i: None for i in range(num)} + + if est_type == 'all': + matcher = ElevEstHelper.get_feature_matcher(matcher_ckpt_path, model.device) + for i in range(num): + simgs = sample_surrounding_images(model, images[i]) + elev = elev_est_api(matcher, simgs, min_elev=20, max_elev=160) + elevs[i] = elev + + for i in range(num): + if elevs[i] is not None: + elevs[i] = np.deg2rad(elevs[i]) + + for i in range(1, num): + + if elevs[i] is not None and elevs[0] is not None: + elev_ranges[i] = np.array([ elevs[i] - elevs[0] ]) + elif elevs[i] is not None: + elev_ranges[i] = -make_elev_probe_range(elevs[i]) + elif elevs[0] is not None: + elev_ranges[i] = make_elev_probe_range(elevs[0]) + + elif est_type == 'simple': + matcher = ElevEstHelper.get_feature_matcher(matcher_ckpt_path, model.device) + simgs = sample_surrounding_images(model, images[0]) + elev = elev_est_api(matcher, simgs, min_elev=20, max_elev=160) + elevs[0] = np.deg2rad(elev) if elev is not None else None + ae = elevs[0] if elevs[0] is not None else np.pi/2 + for i in range(1, num): + elev_ranges[i] = np.array([np.pi/2 - ae]) + + return elevs, elev_ranges + + +def estimate_poses( + model, images, + seed_cand_num=8, + init_type='pairwise', + optm_type='pairwise', + probe_ts_range=[0.02, 0.98], ts_range=[0.02, 0.98], + probe_bsz=16, + adjust_factor=10., + adjust_iters=10, + adjust_bsz=1, + refine_factor=1., + refine_iters=600, + refine_bsz=1, + noise=None, + elevs=None, + elev_ranges=None + ): + + num = len(images) + + if elevs is None: + elevs = {i: None for i in range(num)} + if elev_ranges is None: + elev_ranges = {i: None for i in range(num)} + + if num <= 2: + init_type = 'pairwise' + + cands = {} + losses = {} + + init_poses = {i: None for i in range(num)} + pairwise_init_poses = {i: None for i in range(num)} + + print('Initialization: Probe', datetime.now()) + + + images = [ img.permute(0, 2, 3, 1) for img in images ] + + for i in range(1, num): + + print('PAIR', 0, i, datetime.now()) + + azimuth_range = np.arange(start=0.0, stop=np.pi*2, step=np.pi*2 / seed_cand_num) + + all_cands = probe_pose(model, images[0], images[i], probe_ts_range, probe_bsz, theta_range=elev_ranges[i], azimuth_range=azimuth_range, noise=noise) + all_cands = sorted(all_cands) + + print('Exploration', len(all_cands), datetime.now()) + + adjusted_cands = all_cands[:5] + if adjust_iters > 0: + adjusted_cands = [] + '''only adjust the first half''' + for cand in all_cands[:len(all_cands)//2]: + + out_poses, _, _ = find_optimal_poses( + model, [images[0], images[i]], + adjust_factor, + bsz=adjust_bsz, + n_iter=adjust_iters, + init_poses={1: cand[1]}, + ts_range=ts_range, + print_n=100, + avg_last_n=1 + ) + + loss = pairwise_loss(out_poses[0], model, images[0], images[i], probe_ts_range, probe_bsz, noise=noise) + adjusted_cands.append((loss, out_poses[0], cand[0], cand[1])) + + adjusted_cands = sorted(adjusted_cands)[:5] + + for cand in adjusted_cands: + print(cand) + + cands[i] = [ cand[:2] for cand in adjusted_cands ] + losses[i] = [loss if (init_type == 'pairwise') else 0.0 for loss, _ in cands[i]] + + pairwise_init_poses[i] = min(cands[i])[1] + + print('Selection', datetime.now()) + + if init_type == 'triangular': + + for i in range(1, num): + + for j in range(i+1, num): + + iloss = [ [None for v in range(0, len(cands[j]))] for u in range(0, len(cands[i])) ] + jloss = [ [None for u in range(0, len(cands[i]))] for v in range(0, len(cands[j])) ] + + for u in range(0, len(cands[i])): + + la, pa = cands[i][u] + + # pose i -> 0 + pa = get_inv_pose(pa) + + for v in range(0, len(cands[j])): + + # pose 0 -> j + lb, pb = cands[j][v] + + theta, azimuth, radius = add_pose(pa, pb) + lp = pairwise_loss([theta, azimuth, radius], model, images[i], images[j], probe_ts_range, probe_bsz, noise=noise) + + iloss[u][v] = la + lb + lp + jloss[v][u] = la + lb + lp + + for u in range(0, len(cands[i])): + losses[i][u] += min(min(iloss[u]), cands[i][u][0]*3) + + for v in range(0, len(cands[j])): + losses[j][v] += min(min(jloss[v]), cands[j][v][0]*3) + + for i in range(1, num): + + ranks = sorted([x for x in range(0, len(losses[i]))], key=lambda x: losses[i][x]) + + min_rank = ranks[0] + + for u in range(0, len(cands[i])): + print(cands[i][u], losses[i][u]) + print(i, 'SELECT', min_rank, losses[i][min_rank]) + + init_poses[i] = cands[i][min_rank][1] + + print('Refinement', datetime.now()) + + combinations = None + if optm_type == 'pairwise': + combinations = [ (0, i) for i in range(1, num) ] + [ (i, 0) for i in range(1, num) ] + + elif optm_type == 'triangular': + combinations = [] + for i in range(0, num): + for j in range(i+1, num): + combinations.append((i, j)) + combinations.append((j, i)) + + print('Combinations', len(combinations), combinations) + + '''Refinement''' + out_poses, _, loss = find_optimal_poses( + model, images, + refine_factor, + bsz=refine_bsz, + n_iter=(num-1)*refine_iters, + init_poses=init_poses, + ts_range=ts_range, + combinations=combinations, + avg_last_n=20, + print_n=100 + ) + + print('Done', datetime.now()) + + aux_data = { + 'tri_init_sph': init_poses, + 'pw_init_sph': pairwise_init_poses, + 'elev': elevs + } + + return out_poses, aux_data + + +def make_elev_probe_range(elev, interval=np.pi/4): + + up_range = np.arange(elev, 0, -interval) + down_range = np.arange(elev+interval, np.pi, interval) + probe_range = np.concatenate([up_range, down_range]) + probe_range -= elev + + return probe_range + + +def sample_surrounding_images(model, image): + + s0 = sample_images(model, image, float(np.deg2rad(-10)), 0, 0, n_samples=1) + s1 = sample_images(model, image, float(np.deg2rad(+10)), 0, 0, n_samples=1) + s2 = sample_images(model, image, 0, float(np.deg2rad(-10)), 0, n_samples=1) + s3 = sample_images(model, image, 0, float(np.deg2rad(+10)), 0, n_samples=1) + + return s0 + s1 + s2 + s3 \ No newline at end of file diff --git a/src/pose_funcs.py b/src/pose_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..1a25f5c3983f3a9a1e2148b8242d0da34a86f2a0 --- /dev/null +++ b/src/pose_funcs.py @@ -0,0 +1,222 @@ +import numpy as np +import torch + +class PoseT(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, pose): + + p1 = pose[..., 0:1] + p2 = torch.sin(pose[..., 1:2]) + p3 = torch.cos(pose[..., 1:2]) + p4 = pose[..., 2:] + + return torch.cat([p1, p2, p3, p4], dim=-1) + + +@torch.no_grad() +def noise_loss(model, cond_image, target_image, pose, ts_range, bsz, noise=None): + + mx = ts_range[1] + mn = ts_range[0] + + pose_layer = PoseT() + + batch = {} + batch['image_target'] = target_image.repeat(bsz, 1, 1, 1) + batch['image_cond'] = cond_image.repeat(bsz, 1, 1, 1) + batch['T'] = pose_layer(pose.detach()).repeat(bsz, 1) + + if noise is not None: + noise = torch.tensor(noise, dtype=model.dtype, device=model.device) + + loss, _ = model.shared_step(batch, ts=np.arange(mn, mx, (mx-mn) / bsz), noise=noise[:bsz]) + + return loss.item() + + +@torch.no_grad() +def pairwise_loss(pose, model, cond_image, target_image, ts_range, probe_bsz, noise=None): + + theta, azimuth, radius = pose + + pose1 = torch.tensor([[theta, azimuth, radius]], device=model.device, dtype=torch.float32) + pose2 = torch.tensor([[-theta, np.pi*2-azimuth, -radius]], device=model.device, dtype=torch.float32) + loss1 = noise_loss(model, cond_image, target_image, pose1, ts_range, probe_bsz, noise=noise) + loss2 = noise_loss(model, target_image, cond_image, pose2, ts_range, probe_bsz, noise=noise) + + return loss1 + loss2 + + +@torch.no_grad() +def probe_pose(model, cond_image, target_image, ts_range, probe_bsz, theta_range=None, azimuth_range=None, radius_range=None, noise=None): + + eps = 1e-5 + + if theta_range is None: + theta_range = np.arange(start=-np.pi*2/3, stop=np.pi*2/3+eps, step=np.pi/3) + if azimuth_range is None: + azimuth_range = np.arange(start=0.0, stop=np.pi*2, step=np.pi/4) + if radius_range is None: + radius_range = np.arange(start=0.0, stop=0.0+eps, step=0.1) + + cands = [] + + for radius in radius_range: + for azimuth in azimuth_range: + for theta in theta_range: + + loss = pairwise_loss([theta, azimuth, radius], model, cond_image, target_image, ts_range, probe_bsz, noise=noise) + + '''convert numpy.float to float''' + cands.append((loss, [float(theta), float(azimuth), float(radius)])) + + return cands + + +def create_random_pose(): + + theta = np.random.rand() * np.pi - np.pi / 2 + azimuth = np.random.rand() * np.pi * 2 + radius = np.random.rand() - 0.5 + + return [theta, azimuth, radius] + + +def get_inv_pose(pose): + + return [-pose[0], np.pi*2 - pose[1], -pose[2]] + + +def add_pose(pose1, pose2): + + theta = pose1[0] + pose2[0] + azimuth = pose1[1] + pose2[1] + azimuth = azimuth % (np.pi*2) + + return [ theta, azimuth, (pose1[2] + pose2[2]) ] + + +def create_pose_params(pose, device): + + theta = torch.tensor([pose[0]], requires_grad=True, device=device) + azimuth = torch.tensor([pose[1]], requires_grad=True, device=device) + radius = torch.tensor([pose[2]], requires_grad=True, device=device) + + return [theta, azimuth, radius] + + +def find_optimal_poses(model, images, learning_rate, bsz=1, n_iter=1000, init_poses={}, ts_range=[0.02, 0.92], combinations=None, print_n=50, avg_last_n=1): + + layer = PoseT() + + num = len(images) + + batch = {} + + pose_params = { i:None for i in range(1, num)} + pose_trajs = { i:[] for i in range(1, num) } + + for i in range(1, num): + + if i in init_poses: + init_pose = init_poses[i] + else: + init_pose = create_random_pose() + + pose = create_pose_params(init_pose, model.device) + pose_params[i] = pose + + if combinations is None: + combinations = [] + for i in range(0, num): + for j in range(i+1, num): + combinations.append((i, j)) + combinations.append((j, i)) + + param_list = [] + for i in pose_params: + param_list += pose_params[i] + + optimizer = torch.optim.SGD(param_list, lr = learning_rate) + + loss_traj = [] + select_indces = set([]) + + for iter in range(0, n_iter): + + if print_n > 0 and iter % print_n == 0 and iter > 0: + print(iter, np.mean(loss_traj[-avg_last_n:]), flush=True) + for i in range(1, num): + print(0, i, np.mean(pose_trajs[i][-avg_last_n:], axis=0).tolist()) + + '''record poses''' + for i in select_indces: + pose = pose_params[i] + pose_trajs[i].append([pose[0].item(), pose[1].item(), pose[2].item()]) + + select_indces = set([]) + + conds = [] + targets = [] + rts = [] + + choices = [ iter % len(combinations) ] + + if bsz > 1: + choices = np.random.choice(len(combinations), size=bsz, replace=True) + + for cho in choices: + + i, j = combinations[cho] + + conds.append(images[i]) + targets.append(images[j]) + if i == 0: + pose = pose_params[j] + select_indces.add(j) + + elif j == 0: + pose = get_inv_pose(pose_params[i]) + select_indces.add(i) + + else: + pose0j = pose_params[j] + posei0 = get_inv_pose(pose_params[i]) + + if np.random.rand() < 0.5: + posei0 = [a.item() for a in posei0] + select_indces.add(j) + else: + pose0j = [b.item() for b in pose0j] + select_indces.add(i) + + #pose = [ torch.remainder(a+b+2*np.pi, 2*np.pi) - np.pi for a, b in zip(posei0, pose0j) ] + pose = [ a+b for a, b in zip(posei0, pose0j) ] + + rts.append(torch.cat(pose)[None, ...]) + + batch['image_cond'] = torch.cat(conds, dim=0) + batch['image_target'] = torch.cat(targets, dim=0) + batch['T'] = layer(torch.cat(rts, dim=0)) + ts = np.arange(ts_range[0], ts_range[1], (ts_range[1]-ts_range[0]) / len(conds)) + + optimizer.zero_grad() + loss, loss_dict = model.shared_step(batch, ts=ts) + loss.backward() + + optimizer.step() + + loss_traj.append(loss.item()) + + if n_iter > 0: + result_poses = [ np.mean(pose_trajs[i][-avg_last_n:], axis=0).tolist() for i in range(1, num) ] + result_loss = np.mean(loss_traj[-avg_last_n:]) + else: + result_poses = [ init_poses[i] for i in range(1, num) ] + result_loss = None + + return result_poses, [ init_poses[i] for i in range(1, num) ], result_loss + diff --git a/src/sampling.py b/src/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d32066942232870ca269704ba72a2d5b114d4e14 --- /dev/null +++ b/src/sampling.py @@ -0,0 +1,75 @@ +import math + +import numpy as np + +from contextlib import nullcontext +from PIL import Image +from einops import rearrange + +import torch +from torch import autocast + +from .ldm.models.diffusion.ddim import DDIMSampler + +@torch.no_grad() +def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale, \ + ddim_eta, x, y, z): + precision_scope = autocast if precision=='autocast' else nullcontext + with precision_scope('cuda'): + with model.ema_scope(): + c = model.get_learned_conditioning(input_im).tile(n_samples,1,1) + T = torch.tensor([x, math.sin(y), math.cos(y), z]) + T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device) + c = torch.cat([c, T], dim=-1) + c = model.cc_projection(c) + cond = {} + cond['c_crossattn'] = [c] + c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach() + cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach()\ + .repeat(n_samples, 1, 1, 1)] + if scale != 1.0: + uc = {} + uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] + uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] + else: + uc = None + + shape = [4, h // 8, w // 8] + samples_ddim, _ = sampler.sample(S=ddim_steps, + conditioning=cond, + batch_size=n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + x_T=None) + # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False) + x_samples_ddim = model.decode_first_stage(samples_ddim) + return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() + + +def sample_images( + model, + input_im, + x=0., + y=0., + z=0., + scale=3.0, + n_samples=4, + ddim_steps=50, + ddim_eta=1.0, + precision='fp32', + h=256, + w=256, + ): + + sampler = DDIMSampler(model) + + x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w,\ + ddim_steps, n_samples, scale, ddim_eta, x, y, z) + output_ims = [] + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + output_ims.append(x_sample.astype(np.uint8)) + return output_ims \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a58804129d2d9f40cb1ecb713964fadf9473836 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,166 @@ +import os, sys +import numpy as np +import math + + +def spherical_to_cartesian(sph): + + theta, azimuth, radius = sph + + return np.array([ + radius * np.sin(theta) * np.cos(azimuth), + radius * np.sin(theta) * np.sin(azimuth), + radius * np.cos(theta), + ]) + + +def cartesian_to_spherical(xyz): + + xy = xyz[0]**2 + xyz[1]**2 + radius = np.sqrt(xy + xyz[2]**2) + theta = np.arctan2(np.sqrt(xy), xyz[2]) + azimuth = np.arctan2(xyz[1], xyz[0]) + + return np.array([theta, azimuth, radius]) + + +def relative_spherical(xyz_target, xyz_cond): + + sp_target = cartesian_to_spherical(xyz_target) + sp_cond = cartesian_to_spherical(xyz_cond) + + theta_cond, azimuth_cond, z_cond = sp_cond + theta_target, azimuth_target, z_target = sp_target + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + return np.array([d_theta, d_azimuth, d_z]) + + +def elu_to_c2w(eye, lookat, up): + + if isinstance(eye, list): + eye = np.array(eye) + if isinstance(lookat, list): + lookat = np.array(lookat) + if isinstance(up, list): + up = np.array(up) + + l = eye - lookat + l = l / np.linalg.norm(l) + s = np.cross(l, up) + s = s / np.linalg.norm(s) + uu = np.cross(s, l) + + rot = np.eye(3) + rot[0, :] = -s + rot[1, :] = uu + rot[2, :] = l + + c2w = np.eye(4) + c2w[:3, :3] = rot.T + c2w[:3, 3] = eye + + return c2w + + +def c2w_to_elu(c2w): + + w2c = np.linalg.inv(c2w) + eye = c2w[:3, 3] + lookat_dir = -w2c[2, :3] + lookat = eye + lookat_dir + up = w2c[1, :3] + + return eye, lookat, up + + + '''save pose output''' + jdata = { + 'anchor_vid': anchor_vid, + 'obs': {} + } + + jdata['obs'][anchor_vid] = { + 'img_path': f'{anchor_vid:03d}.png' + } + + for key in aux_data: + jdata['obs'][anchor_vid][key] = aux_data[key][0] + + anchor_sph = None + + anchor_rt = None + + if os.path.exists(os.path.join(obj_root, 'poses', f'{anchor_vid:03d}.npy')): + anchor_rt = np.load(os.path.join(obj_root, 'poses', f'{anchor_vid:03d}.npy')) + elif os.path.exists(os.path.join(obj_root, 'poses', f'{anchor_vid:03d}.txt')): + anchor_rt = np.loadtxt(os.path.join(obj_root, 'poses', f'{anchor_vid:03d}.txt')) + + if anchor_rt is not None: + + anchor_xyz = anchor_rt[:3, -1] + + anchor_sph = cartesian_to_spherical(anchor_xyz) + jdata['obs'][anchor_vid]['sph'] = anchor_sph.tolist() + + if export_xyz: + jdata['obs'][anchor_vid]['xyz'] = { + 'x': anchor_xyz[0], + 'y': anchor_xyz[1], + 'z': anchor_xyz[2] + } + + + for i in range(0, len(target_vids)): + + target_vid = target_vids[i] + + rel_sph = np.array(pred_sphs[i]) + + opack = { + 'img_path': f'{target_vid:03d}.png', + 'rel_sph': rel_sph.tolist() + } + + for key in aux_data: + opack[key] = aux_data[key][i+1] + + if anchor_sph is not None: + + target_sph = anchor_sph + rel_sph + + if export_xyz: + target_xyz = spherical_to_cartesian(target_sph) + + opack['xyz'] = { + 'x': target_xyz[0], + 'y': target_xyz[1], + 'z': target_xyz[2] + } + + target_rt = None + + if os.path.exists(os.path.join(obj_root, 'poses', f'{target_vid:03d}.npy')): + target_rt = np.load(os.path.join(obj_root, 'poses', f'{target_vid:03d}.npy')) + elif os.path.exists(os.path.join(obj_root, 'poses', f'{target_vid:03d}.txt')): + target_rt = np.loadtxt(os.path.join(obj_root, 'poses', f'{target_vid:03d}.txt')) + + if target_rt is not None: + + if export_xyz: + opack['gt_xyz'] = { + 'x': target_rt[0, -1], + 'y': target_rt[1, -1], + 'z': target_rt[2, -1] + } + + gt_rel_sph = relative_spherical(target_rt[:3, -1], anchor_rt[:3, -1]) + + opack['gt_rel_sph'] = gt_rel_sph.tolist() + + jdata['obs'][target_vid] = opack + + return jdata \ No newline at end of file diff --git a/src/visualizer.py b/src/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3f742e13f120a775660010f4ee1559de72f05a7 --- /dev/null +++ b/src/visualizer.py @@ -0,0 +1,253 @@ +import os + +from PIL import Image +import plotly.graph_objects as go +import numpy as np + + +def calc_cam_cone_pts_3d(c2w, fov_deg, zoom = 1.0): + + fov_rad = np.deg2rad(fov_deg) + + cam_x = c2w[0, -1] + cam_y = c2w[1, -1] + cam_z = c2w[2, -1] + + corn1 = [np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0), -1.0] + corn2 = [-np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0), -1.0] + corn3 = [-np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0), -1.0] + corn4 = [np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0), -1.0] + corn5 = [0, np.tan(fov_rad / 2.0), -1.0] + + corn1 = np.dot(c2w[:3, :3], corn1) + corn2 = np.dot(c2w[:3, :3], corn2) + corn3 = np.dot(c2w[:3, :3], corn3) + corn4 = np.dot(c2w[:3, :3], corn4) + corn5 = np.dot(c2w[:3, :3], corn5) + + # Now attach as offset to actual 3D camera position: + corn1 = np.array(corn1) / np.linalg.norm(corn1, ord=2) * zoom + corn_x1 = cam_x + corn1[0] + corn_y1 = cam_y + corn1[1] + corn_z1 = cam_z + corn1[2] + corn2 = np.array(corn2) / np.linalg.norm(corn2, ord=2) * zoom + corn_x2 = cam_x + corn2[0] + corn_y2 = cam_y + corn2[1] + corn_z2 = cam_z + corn2[2] + corn3 = np.array(corn3) / np.linalg.norm(corn3, ord=2) * zoom + corn_x3 = cam_x + corn3[0] + corn_y3 = cam_y + corn3[1] + corn_z3 = cam_z + corn3[2] + corn4 = np.array(corn4) / np.linalg.norm(corn4, ord=2) * zoom + corn_x4 = cam_x + corn4[0] + corn_y4 = cam_y + corn4[1] + corn_z4 = cam_z + corn4[2] + corn5 = np.array(corn5) / np.linalg.norm(corn5, ord=2) * zoom + corn_x5 = cam_x + corn5[0] + corn_y5 = cam_y + corn5[1] + corn_z5 = cam_z + corn5[2] + + xs = [cam_x, corn_x1, corn_x2, corn_x3, corn_x4, corn_x5] + ys = [cam_y, corn_y1, corn_y2, corn_y3, corn_y4, corn_y5] + zs = [cam_z, corn_z1, corn_z2, corn_z3, corn_z4, corn_z5] + + return np.array([xs, ys, zs]).T + + +class CameraVisualizer: + + def __init__(self, poses, legends, colors, images=None, mesh_path=None, camera_x=1.0): + self._fig = None + + self._camera_x = camera_x + + self._poses = poses + self._legends = legends + self._colors = colors + + self._raw_images = None + self._bit_images = None + self._image_colorscale = None + self.set_images(images) + + self._mesh = None + if mesh_path is not None and os.path.exists(mesh_path): + import trimesh + self._mesh = trimesh.load(mesh_path, force='mesh') + + + def set_images(self, images): + + if images is not None: + self._raw_images = images + self._bit_images = [] + self._image_colorscale = [] + for img in images: + if img is None: + self._bit_images.append(None) + self._image_colorscale.append(None) + continue + + bit_img, colorscale = self.encode_image(img) + self._bit_images.append(bit_img) + self._image_colorscale.append(colorscale) + + + def encode_image(self, raw_image): + ''' + :param raw_image (H, W, 3) array of uint8 in [0, 255]. + ''' + # https://stackoverflow.com/questions/60685749/python-plotly-how-to-add-an-image-to-a-3d-scatter-plot + + dum_img = Image.fromarray(np.ones((3, 3, 3), dtype='uint8')).convert('P', palette='WEB') + idx_to_color = np.array(dum_img.getpalette()).reshape((-1, 3)) + + bit_image = Image.fromarray(raw_image).convert('P', palette='WEB', dither=None) + # bit_image = Image.fromarray(raw_image.clip(0, 254)).convert( + # 'P', palette='WEB', dither=None) + colorscale = [[i / 255.0, 'rgb({}, {}, {})'.format(*rgb)] for i, rgb in enumerate(idx_to_color)] + + return bit_image, colorscale + + + def update_figure( + self, scene_bounds, + base_radius=0.0, zoom_scale=1.0, fov_deg=50., + mesh_z_shift=0.0, mesh_scale=1.0, + font_size=10, + show_background=False, show_grid=False, show_ticklabels=False + ): + + fig = go.Figure() + + if self._mesh is not None: + fig.add_trace( + go.Mesh3d( + x=self._mesh.vertices[:, 0] * mesh_scale, + y=self._mesh.vertices[:, 2] * -mesh_scale, + z=(self._mesh.vertices[:, 1] + mesh_z_shift) * mesh_scale, + i=self._mesh.faces[:, 0], + j=self._mesh.faces[:, 1], + k=self._mesh.faces[:, 2], + color=None, + facecolor=None, + opacity=0.8, + lighting={'ambient': 1}, + ) + ) + + for i in range(len(self._poses)): + + pose = self._poses[i] + clr = self._colors[i] + legend = self._legends[i] + + edges = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4), (4, 1), (0, 5)] + + cone = calc_cam_cone_pts_3d(pose, fov_deg) + radius = np.linalg.norm(pose[:3, -1]) + + if self._bit_images and self._bit_images[i]: + + raw_image = self._raw_images[i] + bit_image = self._bit_images[i] + colorscale = self._image_colorscale[i] + + (H, W, C) = raw_image.shape + + z = np.zeros((H, W)) + base_radius + (x, y) = np.meshgrid(np.linspace(-1.0 * self._camera_x, 1.0 * self._camera_x, W), np.linspace(1.0, -1.0, H) * H / W) + + xyz = np.concatenate([x[..., None], y[..., None], z[..., None]], axis=-1) + + rot_xyz = np.matmul(xyz, pose[:3, :3].T) + pose[:3, -1] + + x, y, z = rot_xyz[:, :, 0], rot_xyz[:, :, 1], rot_xyz[:, :, 2] + + fig.add_trace(go.Surface( + x=x, y=y, z=z, + surfacecolor=bit_image, + cmin=0, + cmax=255, + colorscale=colorscale, + showscale=False, + lighting_diffuse=1.0, + lighting_ambient=1.0, + lighting_fresnel=1.0, + lighting_roughness=1.0, + lighting_specular=0.3)) + + for (i, edge) in enumerate(edges): + (x1, x2) = (cone[edge[0], 0], cone[edge[1], 0]) + (y1, y2) = (cone[edge[0], 1], cone[edge[1], 1]) + (z1, z2) = (cone[edge[0], 2], cone[edge[1], 2]) + fig.add_trace(go.Scatter3d( + x=[x1, x2], y=[y1, y2], z=[z1, z2], mode='lines', + line=dict(color=clr, width=3), + name=legend, showlegend=(i == 0))) + + # Add label. + if cone[0, 2] < 0: + fig.add_trace(go.Scatter3d( + x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] - 0.05], showlegend=False, + mode='text', text=legend, textfont=dict(color=clr, size=font_size), textposition='bottom center')) + else: + fig.add_trace(go.Scatter3d( + x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] + 0.05], showlegend=False, + mode='text', text=legend, textfont=dict(color=clr, size=font_size), textposition='top center')) + + # look at the center of scene + fig.update_layout( + height=720, + autosize=True, + hovermode=False, + margin=go.layout.Margin(l=0, r=0, b=0, t=0), + showlegend=True, + legend=dict( + yanchor='bottom', + y=0.01, + xanchor='right', + x=0.99, + ), + scene=dict( + aspectmode='manual', + aspectratio=dict(x=1, y=1, z=1), + camera=dict( + eye=dict(x=1.5, y=1.5, z=1.0), + center=dict(x=0.0, y=0.0, z=0.0), + up=dict(x=0.0, y=0.0, z=1.0)), + xaxis_title='x', + yaxis_title='y', + zaxis_title='z', + xaxis=dict( + range=[-scene_bounds, scene_bounds], + showticklabels=show_ticklabels, + showgrid=show_grid, + zeroline=False, + showbackground=show_background, + showspikes=False, + showline=False, + ticks=''), + yaxis=dict( + range=[-scene_bounds, scene_bounds], + showticklabels=show_ticklabels, + showgrid=show_grid, + zeroline=False, + showbackground=show_background, + showspikes=False, + showline=False, + ticks=''), + zaxis=dict( + range=[-scene_bounds, scene_bounds], + showticklabels=show_ticklabels, + showgrid=show_grid, + zeroline=False, + showbackground=show_background, + showspikes=False, + showline=False, + ticks='') + ) + ) + + self._fig = fig + return fig \ No newline at end of file