diff --git a/apps/third_party/CRM/LICENSE b/apps/third_party/CRM/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..8840910d3e0d809884ce88440d615cf493475272 --- /dev/null +++ b/apps/third_party/CRM/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 TSAIL group + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/apps/third_party/CRM/README.md b/apps/third_party/CRM/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0fb9821a04f51330e00206575bab88fce08fb786 --- /dev/null +++ b/apps/third_party/CRM/README.md @@ -0,0 +1,85 @@ +# Convolutional Reconstruction Model + +Official implementation for *CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model*. + +**CRM is a feed-forward model which can generate 3D textured mesh in 10 seconds.** + +## [Project Page](https://ml.cs.tsinghua.edu.cn/~zhengyi/CRM/) | [Arxiv](https://arxiv.org/abs/2403.05034) | [HF-Demo](https://huggingface.co/spaces/Zhengyi/CRM) | [Weights](https://huggingface.co/Zhengyi/CRM) + +https://github.com/thu-ml/CRM/assets/40787266/8b325bc0-aa74-4c26-92e8-a8f0c1079382 + +## Try CRM 🍻 +* Try CRM at [Huggingface Demo](https://huggingface.co/spaces/Zhengyi/CRM). +* Try CRM at [Replicate Demo](https://replicate.com/camenduru/crm). Thanks [@camenduru](https://github.com/camenduru)! + +## Install + +### Step 1 - Base + +Install package one by one, we use **python 3.9** + +```bash +pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117 +pip install torch-scatter==2.1.1 -f https://data.pyg.org/whl/torch-1.13.1+cu117.html +pip install kaolin==0.14.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.13.1_cu117.html +pip install -r requirements.txt +``` + +besides, one by one need to install xformers manually according to the official [doc](https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers) (**conda no need**), e.g. + +```bash +pip install ninja +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers +``` + +### Step 2 - Nvdiffrast + +Install nvdiffrast according to the official [doc](https://nvlabs.github.io/nvdiffrast/#installation), e.g. + +```bash +pip install git+https://github.com/NVlabs/nvdiffrast +``` + + + +## Inference + +We suggest gradio for a visualized inference. + +``` +gradio app.py +``` + +![image](https://github.com/thu-ml/CRM/assets/40787266/4354d22a-a641-4531-8408-c761ead8b1a2) + +For inference in command lines, simply run +```bash +CUDA_VISIBLE_DEVICES="0" python run.py --inputdir "examples/kunkun.webp" +``` +It will output the preprocessed image, generated 6-view images and CCMs and a 3D model in obj format. + +**Tips:** (1) If the result is unsatisfatory, please check whether the input image is correctly pre-processed into a grey background. Otherwise the results will be unpredictable. +(2) Different from the [Huggingface Demo](https://huggingface.co/spaces/Zhengyi/CRM), this official implementation uses UV texture instead of vertex color. It has better texture than the online demo but longer generating time owing to the UV texturing. + +## Todo List +- [x] Release inference code. +- [x] Release pretrained models. +- [ ] Optimize inference code to fit in low memery GPU. +- [ ] Upload training code. + +## Acknowledgement +- [ImageDream](https://github.com/bytedance/ImageDream) +- [nvdiffrast](https://github.com/NVlabs/nvdiffrast) +- [kiuikit](https://github.com/ashawkey/kiuikit) +- [GET3D](https://github.com/nv-tlabs/GET3D) + +## Citation + +``` +@article{wang2024crm, + title={CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model}, + author={Zhengyi Wang and Yikai Wang and Yifei Chen and Chendong Xiang and Shuo Chen and Dajiang Yu and Chongxuan Li and Hang Su and Jun Zhu}, + journal={arXiv preprint arXiv:2403.05034}, + year={2024} +} +``` diff --git a/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml b/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml new file mode 100644 index 0000000000000000000000000000000000000000..760f41f2728a94114f674e2160a75a65a8a3a656 --- /dev/null +++ b/apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml @@ -0,0 +1,21 @@ +config: +# others + seed: 1234 + num_frames: 7 + mode: pixel + offset_noise: true +# model related + models: + config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml + resume: models/pixel.pth +# sampler related + sampler: + target: libs.sample.ImageDreamDiffusion + params: + mode: pixel + num_frames: 7 + camera_views: [1, 2, 3, 4, 5, 0, 0] + ref_position: 6 + random_background: false + offset_noise: true + resize_rate: 1.0 \ No newline at end of file diff --git a/apps/third_party/CRM/configs/specs_objaverse_total.json b/apps/third_party/CRM/configs/specs_objaverse_total.json new file mode 100644 index 0000000000000000000000000000000000000000..c99ebee563a7d44859338382b197ef55963e87d0 --- /dev/null +++ b/apps/third_party/CRM/configs/specs_objaverse_total.json @@ -0,0 +1,57 @@ +{ + "Input": { + "img_num": 16, + "class": "all", + "camera_angle_num": 8, + "tet_grid_size": 80, + "validate_num": 16, + "scale": 0.95, + "radius": 3, + "resolution": [256, 256] + }, + + "Pretrain": { + "mode": null, + "sdf_threshold": 0.1, + "sdf_scale": 10, + "batch_infer": false, + "lr": 1e-4, + "radius": 0.5 + }, + + "Train": { + "mode": "rnd", + "num_epochs": 500, + "grad_acc": 1, + "warm_up": 0, + "decay": 0.000, + "learning_rate": { + "init": 1e-4, + "sdf_decay": 1, + "rgb_decay": 1 + }, + "batch_size": 4, + "eva_iter": 80, + "eva_all_epoch": 10, + "tex_sup_mode": "blender", + "exp_uv_mesh": false, + "doub": false, + "random_bg": false, + "shift": 0, + "aug_shift": 0, + "geo_type": "flex" + }, + + "ArchSpecs": { + "unet_type": "diffusers", + "use_3D_aware": false, + "fea_concat": false, + "mlp_bias": true + }, + + "DecoderSpecs": { + "c_dim": 32, + "plane_resolution": 256 + } +} + diff --git a/apps/third_party/CRM/configs/stage2-v2-snr.yaml b/apps/third_party/CRM/configs/stage2-v2-snr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e76d1a2a8ff71ba1318c9ff6ff6c59a7a9e606e --- /dev/null +++ b/apps/third_party/CRM/configs/stage2-v2-snr.yaml @@ -0,0 +1,25 @@ +config: +# others + seed: 1234 + num_frames: 6 + mode: pixel + offset_noise: true + gd_type: xyz +# model related + models: + config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml + resume: models/xyz.pth + +# eval related + sampler: + target: libs.sample.ImageDreamDiffusionStage2 + params: + mode: pixel + num_frames: 6 + camera_views: [1, 2, 3, 4, 5, 0] + ref_position: null + random_background: false + offset_noise: true + resize_rate: 1.0 + + diff --git a/apps/third_party/CRM/imagedream/.DS_Store b/apps/third_party/CRM/imagedream/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5e042f9aa5cf18c3b51e766fee10546755d5ed4d Binary files /dev/null and b/apps/third_party/CRM/imagedream/.DS_Store differ diff --git a/apps/third_party/CRM/imagedream/__init__.py b/apps/third_party/CRM/imagedream/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..326f18c2d65a018b1c214c71f1c44428a8a8089b --- /dev/null +++ b/apps/third_party/CRM/imagedream/__init__.py @@ -0,0 +1 @@ +from .model_zoo import build_model diff --git a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5514f1066fcdb863f45a135e26483e3b31c2a44d Binary files /dev/null and b/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7c150c7cb625327e1ccb59742e7c4842561eb1d Binary files /dev/null and b/apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaaf53fc96fd79461da559cd87213eac03d63e1a Binary files /dev/null and b/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..599048b52c049602f55ce7f6b569e278ae8e8dd6 Binary files /dev/null and b/apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc b/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7eaa9fbbbfe649e49709f705e513ef0f40d7690d Binary files /dev/null and b/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc b/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd45cca067b99143299e1ba9b78a2a23846e390 Binary files /dev/null and b/apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/camera_utils.py b/apps/third_party/CRM/imagedream/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6fb352745d737d96b2652f80c23778058e81f6e6 --- /dev/null +++ b/apps/third_party/CRM/imagedream/camera_utils.py @@ -0,0 +1,99 @@ +import numpy as np +import torch + + +def create_camera_to_world_matrix(elevation, azimuth): + elevation = np.radians(elevation) + azimuth = np.radians(azimuth) + # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere + x = np.cos(elevation) * np.sin(azimuth) + y = np.sin(elevation) + z = np.cos(elevation) * np.cos(azimuth) + + # Calculate camera position, target, and up vectors + camera_pos = np.array([x, y, z]) + target = np.array([0, 0, 0]) + up = np.array([0, 1, 0]) + + # Construct view matrix + forward = target - camera_pos + forward /= np.linalg.norm(forward) + right = np.cross(forward, up) + right /= np.linalg.norm(right) + new_up = np.cross(right, forward) + new_up /= np.linalg.norm(new_up) + cam2world = np.eye(4) + cam2world[:3, :3] = np.array([right, new_up, -forward]).T + cam2world[:3, 3] = camera_pos + return cam2world + + +def convert_opengl_to_blender(camera_matrix): + if isinstance(camera_matrix, np.ndarray): + # Construct transformation matrix to convert from OpenGL space to Blender space + flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) + camera_matrix_blender = np.dot(flip_yz, camera_matrix) + else: + # Construct transformation matrix to convert from OpenGL space to Blender space + flip_yz = torch.tensor( + [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]] + ) + if camera_matrix.ndim == 3: + flip_yz = flip_yz.unsqueeze(0) + camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix) + return camera_matrix_blender + + +def normalize_camera(camera_matrix): + """normalize the camera location onto a unit-sphere""" + if isinstance(camera_matrix, np.ndarray): + camera_matrix = camera_matrix.reshape(-1, 4, 4) + translation = camera_matrix[:, :3, 3] + translation = translation / ( + np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8 + ) + camera_matrix[:, :3, 3] = translation + else: + camera_matrix = camera_matrix.reshape(-1, 4, 4) + translation = camera_matrix[:, :3, 3] + translation = translation / ( + torch.norm(translation, dim=1, keepdim=True) + 1e-8 + ) + camera_matrix[:, :3, 3] = translation + return camera_matrix.reshape(-1, 16) + + +def get_camera( + num_frames, + elevation=15, + azimuth_start=0, + azimuth_span=360, + blender_coord=True, + extra_view=False, +): + angle_gap = azimuth_span / num_frames + cameras = [] + for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap): + camera_matrix = create_camera_to_world_matrix(elevation, azimuth) + if blender_coord: + camera_matrix = convert_opengl_to_blender(camera_matrix) + cameras.append(camera_matrix.flatten()) + + if extra_view: + dim = len(cameras[0]) + cameras.append(np.zeros(dim)) + return torch.tensor(np.stack(cameras, 0)).float() + + +def get_camera_for_index(data_index): + """ + 按照当前我们的数据格式, 以000为正对我们的情况: + 000是正面, ev: 0, azimuth: 0 + 001是左边, ev: 0, azimuth: -90 + 002是下面, ev: -90, azimuth: 0 + 003是背面, ev: 0, azimuth: 180 + 004是右边, ev: 0, azimuth: 90 + 005是上面, ev: 90, azimuth: 0 + """ + params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)] + return get_camera(1, *params[data_index]) \ No newline at end of file diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b4ecc2a694dbbde82d45bc3b16d1ebbb27ac552a --- /dev/null +++ b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml @@ -0,0 +1,61 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee80712395fa6323bceb437a40b4829bb497adf7 --- /dev/null +++ b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml @@ -0,0 +1,61 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 8 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce64b053fc76ded2ed85af7c5d398e2981e51c3d --- /dev/null +++ b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml @@ -0,0 +1,61 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2 + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bd9c835f06707aeffcd5cd8fd4cd7ec262646905 --- /dev/null +++ b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml @@ -0,0 +1,62 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + zero_snr: true + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2 + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b4fbe27c84d5c2b05737894dc8970f45f08ba606 --- /dev/null +++ b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml @@ -0,0 +1,62 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + ip_weight: 1.0 # adjust for similarity to image + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5824cd7f25a9d1ef2e397341a39f7c724eb1cf76 --- /dev/null +++ b/apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml @@ -0,0 +1,62 @@ +model: + target: imagedream.ldm.interface.LatentDiffusionInterface + params: + linear_start: 0.00085 + linear_end: 0.0120 + timesteps: 1000 + scale_factor: 0.18215 + parameterization: "eps" + zero_snr: true + + unet_config: + target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel + params: + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + use_checkpoint: False + legacy: False + camera_dim: 16 + with_ip: True + ip_dim: 16 # ip token length + ip_mode: "local_resample" + + vae_config: + target: imagedream.ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + clip_config: + target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + ip_mode: "local_resample" diff --git a/apps/third_party/CRM/imagedream/ldm/__init__.py b/apps/third_party/CRM/imagedream/ldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d029d6984122456fadb1fcb2eb074aa9471b82d4 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02bf8990d7b4f5cdaf387895d86be5e17bf8ed39 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01d44982e9c3c83b6c7479b28fe0604fd8f79fbd Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9f7089053acea4170588b4fc2a48adf647beaf0 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0890976472bd4db53fbbf6883313845de8002651 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e30a7fbad8ea87fce676138a6a8e481191d09dd7 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/interface.py b/apps/third_party/CRM/imagedream/ldm/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..3bbeed7921ea30efb81573c3efed42d8375cb21a --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/interface.py @@ -0,0 +1,206 @@ +from typing import List +from functools import partial + +import numpy as np +import torch +import torch.nn as nn + +from .modules.diffusionmodules.util import ( + make_beta_schedule, + extract_into_tensor, + enforce_zero_terminal_snr, + noise_like, +) +from .util import exists, default, instantiate_from_config +from .modules.distributions.distributions import DiagonalGaussianDistribution + + +class DiffusionWrapper(nn.Module): + def __init__(self, diffusion_model): + super().__init__() + self.diffusion_model = diffusion_model + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + + +class LatentDiffusionInterface(nn.Module): + """a simple interface class for LDM inference""" + + def __init__( + self, + unet_config, + clip_config, + vae_config, + parameterization="eps", + scale_factor=0.18215, + beta_schedule="linear", + timesteps=1000, + linear_start=0.00085, + linear_end=0.0120, + cosine_s=8e-3, + given_betas=None, + zero_snr=False, + *args, + **kwargs, + ): + super().__init__() + + unet = instantiate_from_config(unet_config) + self.model = DiffusionWrapper(unet) + self.clip_model = instantiate_from_config(clip_config) + self.vae_model = instantiate_from_config(vae_config) + + self.parameterization = parameterization + self.scale_factor = scale_factor + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + zero_snr=zero_snr + ) + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + zero_snr=False + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + if zero_snr: + print("--- using zero snr---") + betas = enforce_zero_terminal_snr(betas).numpy() + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + eps = 1e-8 # adding small epsilon value to avoid devide by zero error + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps))) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.v_posterior = 0 + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + * x_t + ) + + def apply_model(self, x_noisy, t, cond, **kwargs): + assert isinstance(cond, dict), "cond has to be a dictionary" + return self.model(x_noisy, t, **cond, **kwargs) + + def get_learned_conditioning(self, prompts: List[str]): + return self.clip_model(prompts) + + def get_learned_image_conditioning(self, images): + return self.clip_model.forward_image(images) + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def encode_first_stage(self, x): + return self.vae_model.encode(x) + + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + return self.vae_model.decode(z) diff --git a/apps/third_party/CRM/imagedream/ldm/models/__init__.py b/apps/third_party/CRM/imagedream/ldm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..672ea939d54368780d02dd86a921336018342e63 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e64c9b27cd3f052820627d22906cadc4b3e9dd43 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ab9a034b2a331fc87a7e50e1a44af00ff39efac Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e3a48c1e459902bdb6f0621a7ede086b9c8cd08 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py b/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..92f83096ddaf2146772f6b49b23d8f99e787fbb4 --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/models/autoencoder.py @@ -0,0 +1,270 @@ +import torch +import torch.nn.functional as F +from contextlib import contextmanager + +from ..modules.diffusionmodules.model import Encoder, Decoder +from ..modules.distributions.distributions import DiagonalGaussianDistribution + +from ..util import instantiate_from_config +from ..modules.ema import LitEma + + +class AutoencoderKL(torch.nn.Module): + def __init__( + self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False, + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0.0 < ema_decay < 1.0 + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + self.log( + "aeloss", + aeloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + optimizer_idx, + self.global_step, + last_layer=self.get_last_layer(), + split="train", + ) + + self.log( + "discloss", + discloss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False + ) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss( + inputs, + reconstructions, + posterior, + 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + discloss, log_dict_disc = self.loss( + inputs, + reconstructions, + posterior, + 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val" + postfix, + ) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = ( + list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.quant_conv.parameters()) + + list(self.post_quant_conv.parameters()) + ) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam( + self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9) + ) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode( + torch.randn_like(posterior_ema.sample()) + ) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__init__.py b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5c8f08217241b9ddc7d371105a50328a41ca897 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f083019b05ad6e286010c0cd0a21e6d9f9b2b9d Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62e7bce7e47a73aa87c94c0cfdd772df724aff7a Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0512edd43d2c8d4c576a7076c33332b30b0e0d6e Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py b/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..4c10321c39078e985f18a0ca7b388086a4aa4e2f --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py @@ -0,0 +1,430 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ...modules.diffusionmodules.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + **kwargs, + ): + """ + when inference time: all values of parameter + cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img']) + shape: (5, 4, 32, 32) + x_T: None + ddim_use_original_steps: False + timesteps: None + callback: None + quantize_denoised: False + mask: None + image_callback: None + log_every_t: 100 + temperature: 1.0 + noise_dropout: 0.0 + score_corrector: None + corrector_kwargs: None + unconditional_guidance_scale: 5 + unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img']) + kwargs: {} + """ + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94 + else: + img = x_T + + if timesteps is None: # equal with set time step in hf + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( # reversed timesteps + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps) + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO: deterministic forward pass? + img = img_orig * mask + (1.0 - mask) * img + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + dynamic_threshold=None, + **kwargs, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + elif isinstance(c[k], torch.Tensor): + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + assert c[k] == unconditional_conditioning[k] + c_in[k] = c[k] + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * ( + model_t - model_uncond + ) + + + if self.model.parameterization == "v": + print("using v!") + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", "not implemented" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + **kwargs, + ): + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + return x_dec diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..188b2eab520e6aaf281181afb6d84e1a3098a914 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b89c02f3c6827cdac3513cca77f2efd913115b0 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e77dd1a9bb9eae984eb7a6f2f33a98a54662d64a Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..378dad50ffdfa150bf000062d153d2dd6fb715b0 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6c836e51e276424666c8d22441a804bf8ae4722 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed990a3f549b22da782b4f7e7cfe07bf64d4969c Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/attention.py b/apps/third_party/CRM/imagedream/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9578027d3d9e9b8766941dc0c986d42cd93b04bf --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/modules/attention.py @@ -0,0 +1,456 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from .diffusionmodules.util import checkpoint + + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +# CrossAttn precision handling +import os + +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs): + super().__init__() + print( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads." + ) + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.with_ip = kwargs.get("with_ip", False) + if self.with_ip and (context_dim is not None): + self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) + self.ip_dim= kwargs.get("ip_dim", 16) + self.ip_weight = kwargs.get("ip_weight", 1.0) + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + + has_ip = self.with_ip and (context is not None) + if has_ip: + # context dim [(b frame_num), (77 + img_token), 1024] + token_len = context.shape[1] + context_ip = context[:, -self.ip_dim:, :] + k_ip = self.to_k_ip(context_ip) + v_ip = self.to_v_ip(context_ip) + context = context[:, :(token_len - self.ip_dim), :] + + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + if has_ip: + k_ip, v_ip = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (k_ip, v_ip), + ) + # actually compute the attention, what we cannot get enough of + out_ip = xformers.ops.memory_efficient_attention( + q, k_ip, v_ip, attn_bias=None, op=self.attention_op + ) + out = out + self.ip_weight * out_ip + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + **kwargs + ): + super().__init__() + assert XFORMERS_IS_AVAILBLE, "xformers is not available" + attn_cls = MemoryEfficientCrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + **kwargs + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = ( + self.attn1( + self.norm1(x), context=context if self.disable_self_attn else None + ) + + x + ) + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + **kwargs + ): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + **kwargs + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class BasicTransformerBlock3D(BasicTransformerBlock): + def forward(self, x, context=None, num_frames=1): + return checkpoint( + self._forward, (x, context, num_frames), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None, num_frames=1): + x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous() + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None + ) + + x + ) + x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous() + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer3D(nn.Module): + """3D self-attention""" + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + **kwargs + ): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock3D( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint, + **kwargs + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None, num_frames=1): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i], num_frames=num_frames) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87936969c3da946fe6a3d832741e28e3f8c5a465 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..709d11f26d4279dc79749621e1fadb3d3214664c Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a776d040c73ba392b8c4498ceae40daa0ac5c375 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..009edfe033ed8828bdc0ae3ae92f56076fe6752a Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/adaptors.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e42f7d66fb4493c5d29a2e00d78db7ee9459f090 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef2270fc4896075b92413258975349f6c68c9bec Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1a0a6de705d51589cf78005348fa64f9e7eb29e Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5e8589aeb17fb3c779c10c01f058ddaf3165198 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8aa9c29edf8fd31ddfd84c3b6784e57a6037b2fb Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a455c6a98a6d27b74477615cf6bf53f0899da51 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py new file mode 100644 index 0000000000000000000000000000000000000000..8d66e480728073294015cf0eb906dba471d602ca --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/adaptors.py @@ -0,0 +1,163 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +import math + +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + def __init__(self, + cross_attention_dim=1024, + clip_embeddings_dim=1024, + clip_extra_context_tokens=4): + super().__init__() + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + + # from 1024 -> 4 * 1024 + self.proj = torch.nn.Linear( + clip_embeddings_dim, + self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class SimpleReSampler(nn.Module): + def __init__(self, embedding_dim=1280, output_dim=1024): + super().__init__() + self.proj_out = nn.Linear(embedding_dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + def forward(self, latents): + """ + latents: B 256 N + """ + latents = self.proj_out(latents) + return self.norm_out(latents) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + ): + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + self.proj_in = nn.Linear(embedding_dim, dim) + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, + dim_head=dim_head, + heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + latents = self.latents.repeat(x.size(0), 1, 1) + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +if __name__ == '__main__': + resampler = Resampler(embedding_dim=1280) + resampler = SimpleReSampler(embedding_dim=1280) + tensor = torch.rand(4, 257, 1280) + embed = resampler(tensor) + # embed = (tensor) + print(embed.shape) diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..52a1d5c8e8ba62dd25133ffe76d370c637c5d25e --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,1018 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from typing import Optional, Any + +from ..attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w) + q = q.permute(0, 2, 1) # b,hw,c + k = k.reshape(b, c, h * w) # b,c,hw + w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w) + w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x + h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.attention_op = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x + out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..2f12a389584a729e1177af8683d631e5c5d77fd5 --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,1135 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from imagedream.ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, + convert_module_to_f16, + convert_module_to_f32 +) +from imagedream.ldm.modules.attention import ( + SpatialTransformer, + SpatialTransformer3D, + exists +) +from imagedream.ldm.modules.diffusionmodules.adaptors import ( + Resampler, + ImageProjModel +) + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None, num_frames=1): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer3D): + x = layer(x, context, num_frames=num_frames) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +class MultiViewUNetModel(nn.Module): + """ + The full multi-view UNet model with attention, timestep embedding and camera embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + :param camera_dim: dimensionality of camera input. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_bf16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + camera_dim=None, + with_ip=False, # wether add image prompt images + ip_dim=0, # number of extra token, 4 for global 16 for local + ip_weight=1.0, # weight for image prompt context + ip_mode="local_resample", # which mode of adaptor, global or local + ): + super().__init__() + if use_spatial_transformer: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + print( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + self.with_ip = with_ip # wether there is image prompt + self.ip_dim = ip_dim # num of extra token, 4 for global 16 for local + self.ip_weight = ip_weight + assert ip_mode in ["global", "local_resample"] + self.ip_mode = ip_mode # which mode of adaptor + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if camera_dim is not None: + time_embed_dim = model_channels * 4 + self.camera_embed = nn.Sequential( + linear(camera_dim, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + if self.with_ip and (context_dim is not None) and ip_dim > 0: + if self.ip_mode == "local_resample": + # ip-adapter-plus + hidden_dim = 1280 + self.image_embed = Resampler( + dim=context_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=ip_dim, # num token + embedding_dim=hidden_dim, + output_dim=context_dim, + ff_mult=4, + ) + elif self.ip_mode == "global": + self.image_embed = ImageProjModel( + cross_attention_dim=context_dim, + clip_extra_context_tokens=ip_dim) + else: + raise ValueError(f"{self.ip_mode} is not supported") + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or nr < num_attention_blocks[level] + ): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer3D( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + with_ip=self.with_ip, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer3D( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + with_ip=self.with_ip, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or i < num_attention_blocks[level] + ): + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer3D( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + with_ip=self.with_ip, + ip_dim=self.ip_dim, + ip_weight=self.ip_weight + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward( + self, + x, + timesteps=None, + context=None, + y=None, + camera=None, + num_frames=1, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). + :param timesteps: a 1-D batch of timesteps. + :param context: a dict conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional, default None. + :param num_frames: a integer indicating number of frames for tensor reshaping. + :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). + """ + assert ( + x.shape[0] % num_frames == 0 + ), "[UNet] input batch size must be dividable by num_frames!" + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00 + emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51 + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # Add camera embeddings + if camera is not None: + assert camera.shape[0] == emb.shape[0] + # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04 + emb = emb + self.camera_embed(camera) + ip = kwargs.get("ip", None) + ip_img = kwargs.get("ip_img", None) + + if ip_img is not None: + x[(num_frames-1)::num_frames, :, :, :] = ip_img + + if ip is not None: + ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 + context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context, num_frames=num_frames) + hs.append(h) + h = self.middle_block(h, emb, context, num_frames=num_frames) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, num_frames=num_frames) + h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58 + if self.predict_codebook_ids: # False + return self.id_predictor(h) + else: + return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93 + + + + +class MultiViewUNetModelStage2(MultiViewUNetModel): + """ + The full multi-view UNet model with attention, timestep embedding and camera embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + :param camera_dim: dimensionality of camera input. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_bf16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + camera_dim=None, + with_ip=False, # wether add image prompt images + ip_dim=0, # number of extra token, 4 for global 16 for local + ip_weight=1.0, # weight for image prompt context + ip_mode="local_resample", # which mode of adaptor, global or local + ): + super().__init__( + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout, + channel_mult, + conv_resample, + dims, + num_classes, + use_checkpoint, + use_fp16, + use_bf16, + num_heads, + num_head_channels, + num_heads_upsample, + use_scale_shift_norm, + resblock_updown, + use_new_attention_order, + use_spatial_transformer, + transformer_depth, + context_dim, + n_embed, + legacy, + disable_self_attentions, + num_attention_blocks, + disable_middle_self_attn, + use_linear_in_transformer, + adm_in_channels, + camera_dim, + with_ip, + ip_dim, + ip_weight, + ip_mode, + ) + + def forward( + self, + x, + timesteps=None, + context=None, + y=None, + camera=None, + num_frames=1, + **kwargs, + ): + """ + Apply the model to an input batch. + :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views). + :param timesteps: a 1-D batch of timesteps. + :param context: a dict conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional, default None. + :param num_frames: a integer indicating number of frames for tensor reshaping. + :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views). + """ + assert ( + x.shape[0] % num_frames == 0 + ), "[UNet] input batch size must be dividable by num_frames!" + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # shape: torch.Size([B, 320]) mean: 0.18, std: 0.68, min: -1.00, max: 1.00 + emb = self.time_embed(t_emb) # shape: torch.Size([B, 1280]) mean: 0.12, std: 0.57, min: -5.73, max: 6.51 + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + # Add camera embeddings + if camera is not None: + assert camera.shape[0] == emb.shape[0] + # camera embed: shape: torch.Size([B, 1280]) mean: -0.02, std: 0.27, min: -7.23, max: 2.04 + emb = emb + self.camera_embed(camera) + ip = kwargs.get("ip", None) + ip_img = kwargs.get("ip_img", None) + pixel_images = kwargs.get("pixel_images", None) + + if ip_img is not None: + x[(num_frames-1)::num_frames, :, :, :] = ip_img + + x = torch.cat((x, pixel_images), dim=1) + + if ip is not None: + ip_emb = self.image_embed(ip) # shape: torch.Size([B, 16, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 + context = torch.cat((context, ip_emb), 1) # shape: torch.Size([B, 93, 1024]) mean: -0.00, std: 1.00, min: -11.65, max: 7.31 + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context, num_frames=num_frames) + hs.append(h) + h = self.middle_block(h, emb, context, num_frames=num_frames) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context, num_frames=num_frames) + h = h.type(x.dtype) # shape: torch.Size([10, 320, 32, 32]) mean: -0.67, std: 3.96, min: -42.74, max: 25.58 + if self.predict_codebook_ids: # False + return self.id_predictor(h) + else: + return self.out(h) # shape: torch.Size([10, 4, 32, 32]) mean: -0.00, std: 0.91, min: -3.65, max: 3.93 + \ No newline at end of file diff --git a/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..af744261d41deab5d686aead790726efcdfaf961 --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,353 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat +import importlib + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + +def enforce_zero_terminal_snr(betas): + betas = torch.tensor(betas) if not isinstance(betas, torch.Tensor) else betas + # Convert betas to alphas_bar_sqrt + alphas =1 - betas + alphas_bar = alphas.cumprod(0) + alphas_bar_sqrt = alphas_bar.sqrt() + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + # Shift so last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + # Scale so first timestep is back to old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt ** 2 + alphas = alphas_bar[1:] / alphas_bar[:-1] + alphas = torch.cat ([alphas_bar[0:1], alphas]) + betas = 1 - alphas + return betas + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + # import pdb; pdb.set_trace() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +# dummy replace +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..312f98ad511d7f30d3483685fa70528b1910b20a Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30cc9dbbbadd654864db556064cf17f299a66b28 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c80371d54b3d6110586bfe2f53628dbae828ecde Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee436d35f6a01e2ff87c0588b6ec2bdf7211f43b Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py b/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..92f4428a3defd8fbae18fcd323c9d404036c652e --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/modules/distributions/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/apps/third_party/CRM/imagedream/ldm/modules/ema.py b/apps/third_party/CRM/imagedream/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..a073e116975f3313fb630b7d4ac115171c1fe31d --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/modules/ema.py @@ -0,0 +1,86 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__init__.py b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ad2b831f070c88d8b1b6d35c697cbb2b8466d66 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98745a81153dfd563cb689b76944075d93cd4feb Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c722f5bdac9cc12253ad9934ee36901b8468d57 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-310.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70899c6fc30bc670956be3ac51675b27beae2d02 Binary files /dev/null and b/apps/third_party/CRM/imagedream/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc differ diff --git a/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py b/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..a19d2c193d660535f101fa1cc3f1b857ce1197fc --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/modules/encoders/modules.py @@ -0,0 +1,329 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import numpy as np +import open_clip +from PIL import Image +from ...util import default, count_params + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0.0 and not disable_dropout: + mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = ( + self.n_classes - 1 + ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + ): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer( + input_ids=tokens, output_hidden_states=self.layer == "hidden" + ) + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + + LAYERS = [ + # "pooled", + "last", + "penultimate", + ] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + ip_mode=None + ): + """_summary_ + + Args: + ip_mode (str, optional): what is the image promcessing mode. Defaults to None. + + """ + super().__init__() + assert layer in self.LAYERS + model, _, preprocess = open_clip.create_model_and_transforms( + arch, device=torch.device("cpu"), pretrained=version + ) + if ip_mode is None: + del model.visual + + self.model = model + self.preprocess = preprocess + self.device = device + self.max_length = max_length + self.ip_mode = ip_mode + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def forward_image(self, pil_image): + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + if isinstance(pil_image, torch.Tensor): + pil_image = pil_image.cpu().numpy() + if isinstance(pil_image, np.ndarray): + if pil_image.ndim == 3: + pil_image = pil_image[None, :, :, :] + pil_image = [Image.fromarray(x) for x in pil_image] + + images = [] + for image in pil_image: + images.append(self.preprocess(image).to(self.device)) + + image = torch.stack(images, 0) # to [b, 3, h, w] + if self.ip_mode == "global": + image_features = self.model.encode_image(image) + image_features /= image_features.norm(dim=-1, keepdim=True) + elif "local" in self.ip_mode: + image_features = self.encode_image_with_transformer(image) + + return image_features # b, l + + def encode_image_with_transformer(self, x): + visual = self.model.visual + x = visual.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + # class embeddings and positional embeddings + x = torch.cat( + [visual.class_embedding.to(x.dtype) + \ + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), + x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + visual.positional_embedding.to(x.dtype) + + # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in + # x = visual.patch_dropout(x) + x = visual.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + hidden = self.image_transformer_forward(x) + x = hidden[-2].permute(1, 0, 2) # LND -> NLD + return x + + def image_transformer_forward(self, x): + encoder_states = () + trans = self.model.visual.transformer + for r in trans.resblocks: + if trans.grad_checkpointing and not torch.jit.is_scripting(): + # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 + x = checkpoint(r, x, None, None, None) + else: + x = r(x, attn_mask=None) + encoder_states = encoder_states + (x, ) + return encoder_states + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder( + clip_version, device, max_length=clip_max_length + ) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params." + ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] diff --git a/apps/third_party/CRM/imagedream/ldm/util.py b/apps/third_party/CRM/imagedream/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed44393d153aad633175d3afda2a1d923c7d815 --- /dev/null +++ b/apps/third_party/CRM/imagedream/ldm/util.py @@ -0,0 +1,231 @@ +import importlib + +import random +import torch +import numpy as np +from collections import abc + +import multiprocessing as mp +from threading import Thread +from queue import Queue + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join( + xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + # import pdb; pdb.set_trace() + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + # import pdb; pdb.set_trace() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + + if 'imagedream' in module: + module = 'apps.third_party.CRM.'+module + if 'lib' in module: + module = 'apps.third_party.CRM.'+module + return getattr(importlib.import_module(module, package=None), cls) + + +def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): + # create dummy dataset instance + + # run prefetching + if idx_to_fn: + res = func(data, worker_id=idx) + else: + res = func(data) + Q.put([idx, res]) + Q.put("Done") + + +def parallel_data_prefetch( + func: callable, + data, + n_proc, + target_data_type="ndarray", + cpu_intensive=True, + use_worker_id=False, +): + # if target_data_type not in ["ndarray", "list"]: + # raise ValueError( + # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." + # ) + if isinstance(data, np.ndarray) and target_data_type == "list": + raise ValueError("list expected but function got ndarray.") + elif isinstance(data, abc.Iterable): + if isinstance(data, dict): + print( + f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' + ) + data = list(data.values()) + if target_data_type == "ndarray": + data = np.asarray(data) + else: + data = list(data) + else: + raise TypeError( + f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." + ) + + if cpu_intensive: + Q = mp.Queue(1000) + proc = mp.Process + else: + Q = Queue(1000) + proc = Thread + # spawn processes + if target_data_type == "ndarray": + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate(np.array_split(data, n_proc)) + ] + else: + step = ( + int(len(data) / n_proc + 1) + if len(data) % n_proc != 0 + else int(len(data) / n_proc) + ) + arguments = [ + [func, Q, part, i, use_worker_id] + for i, part in enumerate( + [data[i : i + step] for i in range(0, len(data), step)] + ) + ] + processes = [] + for i in range(n_proc): + p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) + processes += [p] + + # start processes + print(f"Start prefetching...") + import time + + start = time.time() + gather_res = [[] for _ in range(n_proc)] + try: + for p in processes: + p.start() + + k = 0 + while k < n_proc: + # get result + res = Q.get() + if res == "Done": + k += 1 + else: + gather_res[res[0]] = res[1] + + except Exception as e: + print("Exception: ", e) + for p in processes: + p.terminate() + + raise e + finally: + for p in processes: + p.join() + print(f"Prefetching complete. [{time.time() - start} sec.]") + + if target_data_type == "ndarray": + if not isinstance(gather_res[0], np.ndarray): + return np.concatenate([np.asarray(r) for r in gather_res], axis=0) + + # order outputs + return np.concatenate(gather_res, axis=0) + elif target_data_type == "list": + out = [] + for r in gather_res: + out.extend(r) + return out + else: + return gather_res + +def set_seed(seed=None): + random.seed(seed) + np.random.seed(seed) + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +def add_random_background(image, bg_color=None): + bg_color = np.random.rand() * 255 if bg_color is None else bg_color + image = np.array(image) + rgb, alpha = image[..., :3], image[..., 3:] + alpha = alpha.astype(np.float32) / 255.0 + image_new = rgb * alpha + bg_color * (1 - alpha) + return Image.fromarray(image_new.astype(np.uint8)) \ No newline at end of file diff --git a/apps/third_party/CRM/imagedream/model_zoo.py b/apps/third_party/CRM/imagedream/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..45d6b678bdc554f5b2ad19903f8ed9976ece024e --- /dev/null +++ b/apps/third_party/CRM/imagedream/model_zoo.py @@ -0,0 +1,64 @@ +""" Utiliy functions to load pre-trained models more easily """ +import os +import pkg_resources +from omegaconf import OmegaConf + +import torch +from huggingface_hub import hf_hub_download + +from imagedream.ldm.util import instantiate_from_config + + +PRETRAINED_MODELS = { + "sd-v2.1-base-4view-ipmv": { + "config": "sd_v2_base_ipmv.yaml", + "repo_id": "Peng-Wang/ImageDream", + "filename": "sd-v2.1-base-4view-ipmv.pt", + }, + "sd-v2.1-base-4view-ipmv-local": { + "config": "sd_v2_base_ipmv_local.yaml", + "repo_id": "Peng-Wang/ImageDream", + "filename": "sd-v2.1-base-4view-ipmv-local.pt", + }, +} + + +def get_config_file(config_path): + cfg_file = pkg_resources.resource_filename( + "imagedream", os.path.join("configs", config_path) + ) + if not os.path.exists(cfg_file): + raise RuntimeError(f"Config {config_path} not available!") + return cfg_file + + +def build_model(model_name, config_path=None, ckpt_path=None, cache_dir=None): + if (config_path is not None) and (ckpt_path is not None): + config = OmegaConf.load(config_path) + model = instantiate_from_config(config.model) + model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) + return model + + if not model_name in PRETRAINED_MODELS: + raise RuntimeError( + f"Model name {model_name} is not a pre-trained model. Available models are:\n- " + + "\n- ".join(PRETRAINED_MODELS.keys()) + ) + model_info = PRETRAINED_MODELS[model_name] + + # Instiantiate the model + print(f"Loading model from config: {model_info['config']}") + config_file = get_config_file(model_info["config"]) + config = OmegaConf.load(config_file) + model = instantiate_from_config(config.model) + + # Load pre-trained checkpoint from huggingface + if not ckpt_path: + ckpt_path = hf_hub_download( + repo_id=model_info["repo_id"], + filename=model_info["filename"], + cache_dir=cache_dir, + ) + print(f"Loading model from cache file: {ckpt_path}") + model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) + return model diff --git a/apps/third_party/CRM/inference.py b/apps/third_party/CRM/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fc2a9e49d606dc44d0cb8ae0cfbf223b44dcd4 --- /dev/null +++ b/apps/third_party/CRM/inference.py @@ -0,0 +1,91 @@ +import numpy as np +import torch +import time +import nvdiffrast.torch as dr +from util.utils import get_tri +import tempfile +from mesh import Mesh +import zipfile +def generate3d(model, rgb, ccm, device): + + color_tri = torch.from_numpy(rgb)/255 + xyz_tri = torch.from_numpy(ccm[:,:,(2,1,0)])/255 + color = color_tri.permute(2,0,1) + xyz = xyz_tri.permute(2,0,1) + + + def get_imgs(color): + # color : [C, H, W*6] + color_list = [] + color_list.append(color[:,:,256*5:256*(1+5)]) + for i in range(0,5): + color_list.append(color[:,:,256*i:256*(1+i)]) + return torch.stack(color_list, dim=0)# [6, C, H, W] + + triplane_color = get_imgs(color).permute(0,2,3,1).unsqueeze(0).to(device)# [1, 6, H, W, C] + + color = get_imgs(color) + xyz = get_imgs(xyz) + + color = get_tri(color, dim=0, blender= True, scale = 1).unsqueeze(0) + xyz = get_tri(xyz, dim=0, blender= True, scale = 1, fix= True).unsqueeze(0) + + triplane = torch.cat([color,xyz],dim=1).to(device) + # 3D visualize + model.eval() + glctx = dr.RasterizeCudaContext() + + if model.denoising == True: + tnew = 20 + tnew = torch.randint(tnew, tnew+1, [triplane.shape[0]], dtype=torch.long, device=triplane.device) + noise_new = torch.randn_like(triplane) *0.5+0.5 + triplane = model.scheduler.add_noise(triplane, noise_new, tnew) + start_time = time.time() + with torch.no_grad(): + triplane_feature2 = model.unet2(triplane,tnew) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"unet takes {elapsed_time}s") + else: + triplane_feature2 = model.unet2(triplane) + + + with torch.no_grad(): + data_config = { + 'resolution': [1024, 1024], + "triview_color": triplane_color.to(device), + } + + verts, faces = model.decode(data_config, triplane_feature2) + + data_config['verts'] = verts[0] + data_config['faces'] = faces + + + from kiui.mesh_utils import clean_mesh + verts, faces = clean_mesh(data_config['verts'].squeeze().cpu().numpy().astype(np.float32), data_config['faces'].squeeze().cpu().numpy().astype(np.int32), repair = False, remesh=False, remesh_size=0.005) + data_config['verts'] = torch.from_numpy(verts).cuda().contiguous() + data_config['faces'] = torch.from_numpy(faces).cuda().contiguous() + + start_time = time.time() + with torch.no_grad(): + mesh_path_obj = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name + model.export_mesh_wt_uv(glctx, data_config, mesh_path_obj, "", device, res=(1024,1024), tri_fea_2=triplane_feature2) + + mesh = Mesh.load(mesh_path_obj+".obj", bound=0.9, front_dir="+z") + mesh_path_glb = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name + mesh.write(mesh_path_glb+".glb") + + # mesh_obj2 = trimesh.load(mesh_path_glb+".glb", file_type='glb') + # mesh_path_obj2 = tempfile.NamedTemporaryFile(suffix=f"", delete=False).name + # mesh_obj2.export(mesh_path_obj2+".obj") + + with zipfile.ZipFile(mesh_path_obj+'.zip', 'w') as myzip: + myzip.write(mesh_path_obj+'.obj', mesh_path_obj.split("/")[-1]+'.obj') + myzip.write(mesh_path_obj+'.png', mesh_path_obj.split("/")[-1]+'.png') + myzip.write(mesh_path_obj+'.mtl', mesh_path_obj.split("/")[-1]+'.mtl') + + end_time = time.time() + elapsed_time = end_time - start_time + print(f"uv takes {elapsed_time}s") + return mesh_path_glb+".glb", mesh_path_obj+'.zip' diff --git a/apps/third_party/CRM/libs/__init__.py b/apps/third_party/CRM/libs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..377a573aa67f7cc53af4c8b05123f278611521fc Binary files /dev/null and b/apps/third_party/CRM/libs/__pycache__/__init__.cpython-310.pyc differ diff --git a/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b54287998f0608201936c65b8676e9a2f8d2c9f Binary files /dev/null and b/apps/third_party/CRM/libs/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b19825c00604c13eafa9ac4db0ce70a9bf04e1df Binary files /dev/null and b/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-310.pyc differ diff --git a/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc8c696f4c72d09559bcbc1519c671680fd92a35 Binary files /dev/null and b/apps/third_party/CRM/libs/__pycache__/base_utils.cpython-38.pyc differ diff --git a/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc b/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b872ef355b5c9cac32066264ac6d3fe8e6f49d0 Binary files /dev/null and b/apps/third_party/CRM/libs/__pycache__/sample.cpython-310.pyc differ diff --git a/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc b/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..888166c9fae24f2fd70b9939ac5803ec24ab5895 Binary files /dev/null and b/apps/third_party/CRM/libs/__pycache__/sample.cpython-38.pyc differ diff --git a/apps/third_party/CRM/libs/base_utils.py b/apps/third_party/CRM/libs/base_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c90a548286e80ded1f317126d8f560f67a85e0d8 --- /dev/null +++ b/apps/third_party/CRM/libs/base_utils.py @@ -0,0 +1,84 @@ +import numpy as np +import cv2 +import torch +import numpy as np +from PIL import Image + + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + import importlib + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def tensor_detail(t): + assert type(t) == torch.Tensor + print(f"shape: {t.shape} mean: {t.mean():.2f}, std: {t.std():.2f}, min: {t.min():.2f}, max: {t.max():.2f}") + + + +def drawRoundRec(draw, color, x, y, w, h, r): + drawObject = draw + + '''Rounds''' + drawObject.ellipse((x, y, x + r, y + r), fill=color) + drawObject.ellipse((x + w - r, y, x + w, y + r), fill=color) + drawObject.ellipse((x, y + h - r, x + r, y + h), fill=color) + drawObject.ellipse((x + w - r, y + h - r, x + w, y + h), fill=color) + + '''rec.s''' + drawObject.rectangle((x + r / 2, y, x + w - (r / 2), y + h), fill=color) + drawObject.rectangle((x, y + r / 2, x + w, y + h - (r / 2)), fill=color) + + +def do_resize_content(original_image: Image, scale_rate): + # resize image content wile retain the original image size + if scale_rate != 1: + # Calculate the new size after rescaling + new_size = tuple(int(dim * scale_rate) for dim in original_image.size) + # Resize the image while maintaining the aspect ratio + resized_image = original_image.resize(new_size) + # Create a new image with the original size and black background + padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) + paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) + padded_image.paste(resized_image, paste_position) + return padded_image + else: + return original_image + +def add_stroke(img, color=(255, 255, 255), stroke_radius=3): + # color in R, G, B format + if isinstance(img, Image.Image): + assert img.mode == "RGBA" + img = cv2.cvtColor(np.array(img), cv2.COLOR_RGBA2BGRA) + else: + assert img.shape[2] == 4 + gray = img[:,:, 3] + ret, binary = cv2.threshold(gray,127,255,cv2.THRESH_BINARY) + contours, hierarchy = cv2.findContours(binary,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) + res = cv2.drawContours(img, contours,-1, tuple(color)[::-1] + (255,), stroke_radius) + return Image.fromarray(cv2.cvtColor(res,cv2.COLOR_BGRA2RGBA)) + +def make_blob(image_size=(512, 512), sigma=0.2): + """ + make 2D blob image with: + I(x, y)=1-\exp \left(-\frac{(x-H / 2)^2+(y-W / 2)^2}{2 \sigma^2 HS}\right) + """ + import numpy as np + H, W = image_size + x = np.arange(0, W, 1, float) + y = np.arange(0, H, 1, float) + x, y = np.meshgrid(x, y) + x0 = W // 2 + y0 = H // 2 + img = 1 - np.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2 * H * W)) + return (img * 255).astype(np.uint8) \ No newline at end of file diff --git a/apps/third_party/CRM/libs/sample.py b/apps/third_party/CRM/libs/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2e5cc453796c10165401bea47814113b6381ba --- /dev/null +++ b/apps/third_party/CRM/libs/sample.py @@ -0,0 +1,384 @@ +import numpy as np +import torch +from imagedream.camera_utils import get_camera_for_index +from imagedream.ldm.util import set_seed, add_random_background +# import os +# import sys +# proj_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +# sys.path.append(proj_dir) +from apps.third_party.CRM.libs.base_utils import do_resize_content +from imagedream.ldm.models.diffusion.ddim import DDIMSampler +from torchvision import transforms as T + + +class ImageDreamDiffusion: + def __init__( + self, + model, + device, + dtype, + mode, + num_frames, + camera_views, + ref_position, + random_background=False, + offset_noise=False, + resize_rate=1, + image_size=256, + seed=1234, + ) -> None: + assert mode in ["pixel", "local"] + size = image_size + self.seed = seed + batch_size = max(4, num_frames) + + neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." + uc = model.get_learned_conditioning([neg_texts]).to(device) + sampler = DDIMSampler(model) + + # pre-compute camera matrices + camera = [get_camera_for_index(i).squeeze() for i in camera_views] + camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero + camera = torch.stack(camera) + camera = camera.repeat(batch_size // num_frames, 1).to(device) + + self.image_transform = T.Compose( + [ + T.Resize((size, size)), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + self.dtype = dtype + self.ref_position = ref_position + self.mode = mode + self.random_background = random_background + self.resize_rate = resize_rate + self.num_frames = num_frames + self.size = size + self.device = device + self.batch_size = batch_size + self.model = model + self.sampler = sampler + self.uc = uc + self.camera = camera + self.offset_noise = offset_noise + + @staticmethod + def i2i( + model, + image_size, + prompt, + uc, + sampler, + ip=None, + step=20, + scale=5.0, + batch_size=8, + ddim_eta=0.0, + dtype=torch.float32, + device="cuda", + camera=None, + num_frames=4, + pixel_control=False, + transform=None, + offset_noise=False, + ): + """ The function supports additional image prompt. + Args: + model (_type_): the image dream model + image_size (_type_): size of diffusion output (standard 256) + prompt (_type_): text prompt for the image (prompt in type str) + uc (_type_): unconditional vector (tensor in shape [1, 77, 1024]) + sampler (_type_): imagedream.ldm.models.diffusion.ddim.DDIMSampler + ip (Image, optional): the image prompt. Defaults to None. + step (int, optional): _description_. Defaults to 20. + scale (float, optional): _description_. Defaults to 7.5. + batch_size (int, optional): _description_. Defaults to 8. + ddim_eta (float, optional): _description_. Defaults to 0.0. + dtype (_type_, optional): _description_. Defaults to torch.float32. + device (str, optional): _description_. Defaults to "cuda". + camera (_type_, optional): camera info in tensor, shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 + num_frames (int, optional): _num of frames (views) to generate + pixel_control: whether to use pixel conditioning. Defaults to False, True when using pixel mode + transform: Compose( + Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=warn) + ToTensor() + Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) + ) + """ + ip_raw = ip + if type(prompt) != list: + prompt = [prompt] + with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): + c = model.get_learned_conditioning(prompt).to( + device + ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 + c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size + uc_ = {"context": uc.repeat(batch_size, 1, 1)} + + if camera is not None: + c_["camera"] = uc_["camera"] = ( + camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 + ) + c_["num_frames"] = uc_["num_frames"] = num_frames + + if ip is not None: + ip_embed = model.get_learned_image_conditioning(ip).to( + device + ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 + ip_ = ip_embed.repeat(batch_size, 1, 1) + c_["ip"] = ip_ + uc_["ip"] = torch.zeros_like(ip_) + + if pixel_control: + assert camera is not None + ip = transform(ip).to( + device + ) # shape: torch.Size([3, 256, 256]) mean: 0.33, std: 0.37, min: -1.00, max: 1.00 + ip_img = model.get_first_stage_encoding( + model.encode_first_stage(ip[None, :, :, :]) + ) # shape: torch.Size([1, 4, 32, 32]) mean: 0.23, std: 0.77, min: -4.42, max: 3.55 + c_["ip_img"] = ip_img + uc_["ip_img"] = torch.zeros_like(ip_img) + + shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] + if offset_noise: + ref = transform(ip_raw).to(device) + ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) + ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) + time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) + x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) + + samples_ddim, _ = ( + sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 + S=step, + conditioning=c_, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc_, + eta=ddim_eta, + x_T=x_T if offset_noise else None, + ) + ) + + x_sample = model.decode_first_stage(samples_ddim) + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() + + return list(x_sample.astype(np.uint8)) + + def diffuse(self, t, ip, n_test=2): + set_seed(self.seed) + ip = do_resize_content(ip, self.resize_rate) + if self.random_background: + ip = add_random_background(ip) + + images = [] + for _ in range(n_test): + img = self.i2i( + self.model, + self.size, + t, + self.uc, + self.sampler, + ip=ip, + step=50, + scale=5, + batch_size=self.batch_size, + ddim_eta=0.0, + dtype=self.dtype, + device=self.device, + camera=self.camera, + num_frames=self.num_frames, + pixel_control=(self.mode == "pixel"), + transform=self.image_transform, + offset_noise=self.offset_noise, + ) + img = np.concatenate(img, 1) + img = np.concatenate((img, ip.resize((self.size, self.size))), axis=1) + images.append(img) + set_seed() # unset random and numpy seed + return images + + +class ImageDreamDiffusionStage2: + def __init__( + self, + model, + device, + dtype, + num_frames, + camera_views, + ref_position, + random_background=False, + offset_noise=False, + resize_rate=1, + mode="pixel", + image_size=256, + seed=1234, + ) -> None: + assert mode in ["pixel", "local"] + + size = image_size + self.seed = seed + batch_size = max(4, num_frames) + + neg_texts = "uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear." + uc = model.get_learned_conditioning([neg_texts]).to(device) + sampler = DDIMSampler(model) + + # pre-compute camera matrices + camera = [get_camera_for_index(i).squeeze() for i in camera_views] + if ref_position is not None: + camera[ref_position] = torch.zeros_like(camera[ref_position]) # set ref camera to zero + camera = torch.stack(camera) + camera = camera.repeat(batch_size // num_frames, 1).to(device) + + self.image_transform = T.Compose( + [ + T.Resize((size, size)), + T.ToTensor(), + T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + self.dtype = dtype + self.mode = mode + self.ref_position = ref_position + self.random_background = random_background + self.resize_rate = resize_rate + self.num_frames = num_frames + self.size = size + self.device = device + self.batch_size = batch_size + self.model = model + self.sampler = sampler + self.uc = uc + self.camera = camera + self.offset_noise = offset_noise + + @staticmethod + def i2iStage2( + model, + image_size, + prompt, + uc, + sampler, + pixel_images, + ip=None, + step=20, + scale=5.0, + batch_size=8, + ddim_eta=0.0, + dtype=torch.float32, + device="cuda", + camera=None, + num_frames=4, + pixel_control=False, + transform=None, + offset_noise=False, + ): + ip_raw = ip + if type(prompt) != list: + prompt = [prompt] + with torch.no_grad(), torch.autocast(device_type=torch.device(device).type, dtype=dtype): + c = model.get_learned_conditioning(prompt).to( + device + ) # shape: torch.Size([1, 77, 1024]) mean: -0.17, std: 1.02, min: -7.50, max: 13.05 + c_ = {"context": c.repeat(batch_size, 1, 1)} # batch_size + uc_ = {"context": uc.repeat(batch_size, 1, 1)} + + if camera is not None: + c_["camera"] = uc_["camera"] = ( + camera # shape: torch.Size([5, 16]) mean: 0.11, std: 0.49, min: -1.00, max: 1.00 + ) + c_["num_frames"] = uc_["num_frames"] = num_frames + + if ip is not None: + ip_embed = model.get_learned_image_conditioning(ip).to( + device + ) # shape: torch.Size([1, 257, 1280]) mean: 0.06, std: 0.53, min: -6.83, max: 11.12 + ip_ = ip_embed.repeat(batch_size, 1, 1) + c_["ip"] = ip_ + uc_["ip"] = torch.zeros_like(ip_) + + if pixel_control: + assert camera is not None + + transed_pixel_images = torch.stack([transform(i).to(device) for i in pixel_images]) + latent_pixel_images = model.get_first_stage_encoding(model.encode_first_stage(transed_pixel_images)) + + c_["pixel_images"] = latent_pixel_images + uc_["pixel_images"] = torch.zeros_like(latent_pixel_images) + + shape = [4, image_size // 8, image_size // 8] # [4, 32, 32] + if offset_noise: + ref = transform(ip_raw).to(device) + ref_latent = model.get_first_stage_encoding(model.encode_first_stage(ref[None, :, :, :])) + ref_mean = ref_latent.mean(dim=(-1, -2), keepdim=True) + time_steps = torch.randint(model.num_timesteps - 1, model.num_timesteps, (batch_size,), device=device) + x_T = model.q_sample(torch.ones([batch_size] + shape, device=device) * ref_mean, time_steps) + + samples_ddim, _ = ( + sampler.sample( # shape: torch.Size([5, 4, 32, 32]) mean: 0.29, std: 0.85, min: -3.38, max: 4.43 + S=step, + conditioning=c_, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc_, + eta=ddim_eta, + x_T=x_T if offset_noise else None, + ) + ) + x_sample = model.decode_first_stage(samples_ddim) + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255.0 * x_sample.permute(0, 2, 3, 1).cpu().numpy() + + return list(x_sample.astype(np.uint8)) + + @torch.no_grad() + def diffuse(self, t, ip, pixel_images, n_test=2): + set_seed(self.seed) + ip = do_resize_content(ip, self.resize_rate) + pixel_images = [do_resize_content(i, self.resize_rate) for i in pixel_images] + + if self.random_background: + bg_color = np.random.rand() * 255 + ip = add_random_background(ip, bg_color) + pixel_images = [add_random_background(i, bg_color) for i in pixel_images] + + images = [] + for _ in range(n_test): + img = self.i2iStage2( + self.model, + self.size, + t, + self.uc, + self.sampler, + pixel_images=pixel_images, + ip=ip, + step=50, + scale=5, + batch_size=self.batch_size, + ddim_eta=0.0, + dtype=self.dtype, + device=self.device, + camera=self.camera, + num_frames=self.num_frames, + pixel_control=(self.mode == "pixel"), + transform=self.image_transform, + offset_noise=self.offset_noise, + ) + img = np.concatenate(img, 1) + img = np.concatenate( + (img, ip.resize((self.size, self.size)), *[i.resize((self.size, self.size)) for i in pixel_images]), + axis=1, + ) + images.append(img) + set_seed() # unset random and numpy seed + return images diff --git a/apps/third_party/CRM/mesh.py b/apps/third_party/CRM/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..b98dea041fb41d207b6e95ed927216344854d25c --- /dev/null +++ b/apps/third_party/CRM/mesh.py @@ -0,0 +1,845 @@ +import os +import cv2 +import torch +import trimesh +import numpy as np + +from kiui.op import safe_normalize, dot +from kiui.typing import * + +class Mesh: + """ + A torch-native trimesh class, with support for ``ply/obj/glb`` formats. + + Note: + This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture). + """ + def __init__( + self, + v: Optional[Tensor] = None, + f: Optional[Tensor] = None, + vn: Optional[Tensor] = None, + fn: Optional[Tensor] = None, + vt: Optional[Tensor] = None, + ft: Optional[Tensor] = None, + vc: Optional[Tensor] = None, # vertex color + albedo: Optional[Tensor] = None, + metallicRoughness: Optional[Tensor] = None, + device: Optional[torch.device] = None, + ): + """Init a mesh directly using all attributes. + + Args: + v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None. + f (Optional[Tensor]): faces, int [M, 3]. Defaults to None. + vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None. + fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None. + vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None. + ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None. + vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None. + albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None. + metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None. + device (Optional[torch.device]): torch device. Defaults to None. + """ + self.device = device + self.v = v + self.vn = vn + self.vt = vt + self.f = f + self.fn = fn + self.ft = ft + # will first see if there is vertex color to use + self.vc = vc + # only support a single albedo image + self.albedo = albedo + # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1] + # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html + self.metallicRoughness = metallicRoughness + + self.ori_center = 0 + self.ori_scale = 1 + + @classmethod + def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs): + """load mesh from path. + + Args: + path (str): path to mesh file, supports ply, obj, glb. + clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False. + resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True. + renormal (bool, optional): re-calc the vertex normals. Defaults to True. + retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False. + bound (float, optional): bound to resize. Defaults to 0.9. + front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'. + device (torch.device, optional): torch device. Defaults to None. + + Note: + a ``device`` keyword argument can be provided to specify the torch device. + If it's not provided, we will try to use ``'cuda'`` as the device if it's available. + + Returns: + Mesh: the loaded Mesh object. + """ + # obj supports face uv + if path.endswith(".obj"): + mesh = cls.load_obj(path, **kwargs) + # trimesh only supports vertex uv, but can load more formats + else: + mesh = cls.load_trimesh(path, **kwargs) + + # clean + if clean: + from kiui.mesh_utils import clean_mesh + vertices = mesh.v.detach().cpu().numpy() + triangles = mesh.f.detach().cpu().numpy() + vertices, triangles = clean_mesh(vertices, triangles, remesh=False) + mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device) + mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device) + + print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}") + # auto-normalize + if resize: + mesh.auto_size(bound=bound) + # auto-fix normal + if renormal or mesh.vn is None: + mesh.auto_normal() + print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}") + # auto-fix texcoords + if retex or (mesh.albedo is not None and mesh.vt is None): + mesh.auto_uv(cache_path=path) + print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}") + + # rotate front dir to +z + if front_dir != "+z": + # axis switch + if "-z" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32) + elif "+x" in front_dir: + T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) + elif "-x" in front_dir: + T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) + elif "+y" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) + elif "-y" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) + else: + T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + # rotation (how many 90 degrees) + if '1' in front_dir: + T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + elif '2' in front_dir: + T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + elif '3' in front_dir: + T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + mesh.v @= T + mesh.vn @= T + + return mesh + + # load from obj file + @classmethod + def load_obj(cls, path, albedo_path=None, device=None): + """load an ``obj`` mesh. + + Args: + path (str): path to mesh. + albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None. + device (torch.device, optional): torch device. Defaults to None. + + Note: + We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension. + The `usemtl` statement is ignored, and we only use the last material path in `mtl` file. + + Returns: + Mesh: the loaded Mesh object. + """ + assert os.path.splitext(path)[-1] == ".obj" + + mesh = cls() + + # device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + mesh.device = device + + # load obj + with open(path, "r") as f: + lines = f.readlines() + + def parse_f_v(fv): + # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided) + # supported forms: + # f v1 v2 v3 + # f v1/vt1 v2/vt2 v3/vt3 + # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3 + # f v1//vn1 v2//vn2 v3//vn3 + xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")] + xs.extend([-1] * (3 - len(xs))) + return xs[0], xs[1], xs[2] + + vertices, texcoords, normals = [], [], [] + faces, tfaces, nfaces = [], [], [] + mtl_path = None + + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: + continue + prefix = split_line[0].lower() + # mtllib + if prefix == "mtllib": + mtl_path = split_line[1] + # usemtl + elif prefix == "usemtl": + pass # ignored + # v/vn/vt + elif prefix == "v": + vertices.append([float(v) for v in split_line[1:]]) + elif prefix == "vn": + normals.append([float(v) for v in split_line[1:]]) + elif prefix == "vt": + val = [float(v) for v in split_line[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == "f": + vs = split_line[1:] + nv = len(vs) + v0, t0, n0 = parse_f_v(vs[0]) + for i in range(nv - 2): # triangulate (assume vertices are ordered) + v1, t1, n1 = parse_f_v(vs[i + 1]) + v2, t2, n2 = parse_f_v(vs[i + 2]) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if len(texcoords) > 0 + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if len(normals) > 0 + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if len(texcoords) > 0 + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if len(normals) > 0 + else None + ) + + # see if there is vertex color + use_vertex_color = False + if mesh.v.shape[1] == 6: + use_vertex_color = True + mesh.vc = mesh.v[:, 3:] + mesh.v = mesh.v[:, :3] + print(f"[load_obj] use vertex color: {mesh.vc.shape}") + + # try to load texture image + if not use_vertex_color: + # try to retrieve mtl file + mtl_path_candidates = [] + if mtl_path is not None: + mtl_path_candidates.append(mtl_path) + mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path)) + mtl_path_candidates.append(path.replace(".obj", ".mtl")) + + mtl_path = None + for candidate in mtl_path_candidates: + if os.path.exists(candidate): + mtl_path = candidate + break + + # if albedo_path is not provided, try retrieve it from mtl + metallic_path = None + roughness_path = None + if mtl_path is not None and albedo_path is None: + with open(mtl_path, "r") as f: + lines = f.readlines() + + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: + continue + prefix = split_line[0] + + if "map_Kd" in prefix: + # assume relative path! + albedo_path = os.path.join(os.path.dirname(path), split_line[1]) + print(f"[load_obj] use texture from: {albedo_path}") + elif "map_Pm" in prefix: + metallic_path = os.path.join(os.path.dirname(path), split_line[1]) + elif "map_Pr" in prefix: + roughness_path = os.path.join(os.path.dirname(path), split_line[1]) + + # still not found albedo_path, or the path doesn't exist + if albedo_path is None or not os.path.exists(albedo_path): + # init an empty texture + print(f"[load_obj] init empty albedo!") + # albedo = np.random.rand(1024, 1024, 3).astype(np.float32) + albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color + else: + albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED) + albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) + albedo = albedo.astype(np.float32) / 255 + print(f"[load_obj] load texture: {albedo.shape}") + + mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device) + + # try to load metallic and roughness + if metallic_path is not None and roughness_path is not None: + print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}") + metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED) + metallic = metallic.astype(np.float32) / 255 + roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED) + roughness = roughness.astype(np.float32) / 255 + metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1) + + mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() + + return mesh + + @classmethod + def load_trimesh(cls, path, device=None): + """load a mesh using ``trimesh.load()``. + + Can load various formats like ``glb`` and serves as a fallback. + + Note: + We will try to merge all meshes if the glb contains more than one, + but **this may cause the texture to lose**, since we only support one texture image! + + Args: + path (str): path to the mesh file. + device (torch.device, optional): torch device. Defaults to None. + + Returns: + Mesh: the loaded Mesh object. + """ + mesh = cls() + + # device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + mesh.device = device + + # use trimesh to load ply/glb + _data = trimesh.load(path) + if isinstance(_data, trimesh.Scene): + if len(_data.geometry) == 1: + _mesh = list(_data.geometry.values())[0] + else: + print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.") + _concat = [] + # loop the scene graph and apply transform to each mesh + scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}} + for k, v in scene_graph.items(): + name = v['geometry'] + if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh): + transform = v['transform'] + _concat.append(_data.geometry[name].apply_transform(transform)) + _mesh = trimesh.util.concatenate(_concat) + else: + _mesh = _data + + if _mesh.visual.kind == 'vertex': + vertex_colors = _mesh.visual.vertex_colors + vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255 + mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device) + print(f"[load_trimesh] use vertex color: {mesh.vc.shape}") + elif _mesh.visual.kind == 'texture': + _material = _mesh.visual.material + if isinstance(_material, trimesh.visual.material.PBRMaterial): + texture = np.array(_material.baseColorTexture).astype(np.float32) / 255 + # load metallicRoughness if present + if _material.metallicRoughnessTexture is not None: + metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255 + mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() + elif isinstance(_material, trimesh.visual.material.SimpleMaterial): + texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255 + else: + raise NotImplementedError(f"material type {type(_material)} not supported!") + mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous() + print(f"[load_trimesh] load texture: {texture.shape}") + else: + texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) + mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device) + print(f"[load_trimesh] failed to load texture.") + + vertices = _mesh.vertices + + try: + texcoords = _mesh.visual.uv + texcoords[:, 1] = 1 - texcoords[:, 1] + except Exception as e: + texcoords = None + + try: + normals = _mesh.vertex_normals + except Exception as e: + normals = None + + # trimesh only support vertex uv... + faces = tfaces = nfaces = _mesh.faces + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if texcoords is not None + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if normals is not None + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if texcoords is not None + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if normals is not None + else None + ) + + return mesh + + # sample surface (using trimesh) + def sample_surface(self, count: int): + """sample points on the surface of the mesh. + + Args: + count (int): number of points to sample. + + Returns: + torch.Tensor: the sampled points, float [count, 3]. + """ + _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy()) + points, face_idx = trimesh.sample.sample_surface(_mesh, count) + points = torch.from_numpy(points).float().to(self.device) + return points + + # aabb + def aabb(self): + """get the axis-aligned bounding box of the mesh. + + Returns: + Tuple[torch.Tensor]: the min xyz and max xyz of the mesh. + """ + return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values + + # unit size + @torch.no_grad() + def auto_size(self, bound=0.9): + """auto resize the mesh. + + Args: + bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9. + """ + vmin, vmax = self.aabb() + self.ori_center = (vmax + vmin) / 2 + self.ori_scale = 2 * bound / torch.max(vmax - vmin).item() + self.v = (self.v - self.ori_center) * self.ori_scale + + def auto_normal(self): + """auto calculate the vertex normals. + """ + i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long() + v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + vn = torch.zeros_like(self.v) + vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + vn = torch.where( + dot(vn, vn) > 1e-20, + vn, + torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device), + ) + vn = safe_normalize(vn) + + self.vn = vn + self.fn = self.f + + def auto_uv(self, cache_path=None, vmap=True): + """auto calculate the uv coordinates. + + Args: + cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None. + vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf). + Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True. + """ + # try to load cache + if cache_path is not None: + cache_path = os.path.splitext(cache_path)[0] + "_uv.npz" + if cache_path is not None and os.path.exists(cache_path): + data = np.load(cache_path) + vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"] + else: + import xatlas + + v_np = self.v.detach().cpu().numpy() + f_np = self.f.detach().int().cpu().numpy() + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + # chart_options.max_iterations = 4 + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + # save to cache + if cache_path is not None: + np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping) + + vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device) + ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device) + self.vt = vt + self.ft = ft + + if vmap: + vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device) + self.align_v_to_vt(vmapping) + + def align_v_to_vt(self, vmapping=None): + """ remap v/f and vn/fn to vt/ft. + + Args: + vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None. + """ + if vmapping is None: + ft = self.ft.view(-1).long() + f = self.f.view(-1).long() + vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device) + vmapping[ft] = f # scatter, randomly choose one if index is not unique + + self.v = self.v[vmapping] + self.f = self.ft + + if self.vn is not None: + self.vn = self.vn[vmapping] + self.fn = self.ft + + def to(self, device): + """move all tensor attributes to device. + + Args: + device (torch.device): target device. + + Returns: + Mesh: self. + """ + self.device = device + for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]: + tensor = getattr(self, name) + if tensor is not None: + setattr(self, name, tensor.to(device)) + return self + + def write(self, path): + """write the mesh to a path. + + Args: + path (str): path to write, supports ply, obj and glb. + """ + if path.endswith(".ply"): + self.write_ply(path) + elif path.endswith(".obj"): + self.write_obj(path) + elif path.endswith(".glb") or path.endswith(".gltf"): + self.write_glb(path) + else: + raise NotImplementedError(f"format {path} not supported!") + + def write_ply(self, path): + """write the mesh in ply format. Only for geometry! + + Args: + path (str): path to write. + """ + + if self.albedo is not None: + print(f'[WARN] ply format does not support exporting texture, will ignore!') + + v_np = self.v.detach().cpu().numpy() + f_np = self.f.detach().cpu().numpy() + + _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np) + _mesh.export(path) + + + def write_glb(self, path): + """write the mesh in glb/gltf format. + This will create a scene with a single mesh. + + Args: + path (str): path to write. + """ + + # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0] + if self.vt is not None and self.v.shape[0] != self.vt.shape[0]: + self.align_v_to_vt() + + import pygltflib + + f_np = self.f.detach().cpu().numpy().astype(np.uint32) + f_np_blob = f_np.flatten().tobytes() + + v_np = self.v.detach().cpu().numpy().astype(np.float32) + v_np_blob = v_np.tobytes() + + blob = f_np_blob + v_np_blob + byteOffset = len(blob) + + # base mesh + gltf = pygltflib.GLTF2( + scene=0, + scenes=[pygltflib.Scene(nodes=[0])], + nodes=[pygltflib.Node(mesh=0)], + meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive( + # indices to accessors (0 is triangles) + attributes=pygltflib.Attributes( + POSITION=1, + ), + indices=0, + )])], + buffers=[ + pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob)) + ], + # buffer view (based on dtype) + bufferViews=[ + # triangles; as flatten (element) array + pygltflib.BufferView( + buffer=0, + byteLength=len(f_np_blob), + target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963) + ), + # positions; as vec3 array + pygltflib.BufferView( + buffer=0, + byteOffset=len(f_np_blob), + byteLength=len(v_np_blob), + byteStride=12, # vec3 + target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962) + ), + ], + accessors=[ + # 0 = triangles + pygltflib.Accessor( + bufferView=0, + componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125) + count=f_np.size, + type=pygltflib.SCALAR, + max=[int(f_np.max())], + min=[int(f_np.min())], + ), + # 1 = positions + pygltflib.Accessor( + bufferView=1, + componentType=pygltflib.FLOAT, # GL_FLOAT (5126) + count=len(v_np), + type=pygltflib.VEC3, + max=v_np.max(axis=0).tolist(), + min=v_np.min(axis=0).tolist(), + ), + ], + ) + + # append texture info + if self.vt is not None: + + vt_np = self.vt.detach().cpu().numpy().astype(np.float32) + vt_np_blob = vt_np.tobytes() + + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR) + albedo_blob = cv2.imencode('.png', albedo)[1].tobytes() + + # update primitive + gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2 + gltf.meshes[0].primitives[0].material = 0 + + # update materials + gltf.materials.append(pygltflib.Material( + pbrMetallicRoughness=pygltflib.PbrMetallicRoughness( + baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0), + metallicFactor=0.0, + roughnessFactor=1.0, + ), + alphaMode=pygltflib.OPAQUE, + alphaCutoff=None, + doubleSided=True, + )) + + gltf.textures.append(pygltflib.Texture(sampler=0, source=0)) + gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) + gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png")) + + # update buffers + gltf.bufferViews.append( + # index = 2, texcoords; as vec2 array + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(vt_np_blob), + byteStride=8, # vec2 + target=pygltflib.ARRAY_BUFFER, + ) + ) + + gltf.accessors.append( + # 2 = texcoords + pygltflib.Accessor( + bufferView=2, + componentType=pygltflib.FLOAT, + count=len(vt_np), + type=pygltflib.VEC2, + max=vt_np.max(axis=0).tolist(), + min=vt_np.min(axis=0).tolist(), + ) + ) + + blob += vt_np_blob + byteOffset += len(vt_np_blob) + + gltf.bufferViews.append( + # index = 3, albedo texture; as none target + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(albedo_blob), + ) + ) + + blob += albedo_blob + byteOffset += len(albedo_blob) + + gltf.buffers[0].byteLength = byteOffset + + # append metllic roughness + if self.metallicRoughness is not None: + metallicRoughness = self.metallicRoughness.detach().cpu().numpy() + metallicRoughness = (metallicRoughness * 255).astype(np.uint8) + metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR) + metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes() + + # update texture definition + gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0 + gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0 + gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0) + + gltf.textures.append(pygltflib.Texture(sampler=1, source=1)) + gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) + gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png")) + + # update buffers + gltf.bufferViews.append( + # index = 4, metallicRoughness texture; as none target + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(metallicRoughness_blob), + ) + ) + + blob += metallicRoughness_blob + byteOffset += len(metallicRoughness_blob) + + gltf.buffers[0].byteLength = byteOffset + + + # set actual data + gltf.set_binary_blob(blob) + + # glb = b"".join(gltf.save_to_bytes()) + gltf.save(path) + + + def write_obj(self, path): + """write the mesh in obj format. Will also write the texture and mtl files. + + Args: + path (str): path to write. + """ + + mtl_path = path.replace(".obj", ".mtl") + albedo_path = path.replace(".obj", "_albedo.png") + metallic_path = path.replace(".obj", "_metallic.png") + roughness_path = path.replace(".obj", "_roughness.png") + + v_np = self.v.detach().cpu().numpy() + vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None + vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None + f_np = self.f.detach().cpu().numpy() + ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None + fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None + + with open(path, "w") as fp: + fp.write(f"mtllib {os.path.basename(mtl_path)} \n") + + for v in v_np: + fp.write(f"v {v[0]} {v[1]} {v[2]} \n") + + if vt_np is not None: + for v in vt_np: + fp.write(f"vt {v[0]} {1 - v[1]} \n") + + if vn_np is not None: + for v in vn_np: + fp.write(f"vn {v[0]} {v[1]} {v[2]} \n") + + fp.write(f"usemtl defaultMat \n") + for i in range(len(f_np)): + fp.write( + f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \ + {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \ + {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n' + ) + + with open(mtl_path, "w") as fp: + fp.write(f"newmtl defaultMat \n") + fp.write(f"Ka 1 1 1 \n") + fp.write(f"Kd 1 1 1 \n") + fp.write(f"Ks 0 0 0 \n") + fp.write(f"Tr 1 \n") + fp.write(f"illum 1 \n") + fp.write(f"Ns 0 \n") + if self.albedo is not None: + fp.write(f"map_Kd {os.path.basename(albedo_path)} \n") + if self.metallicRoughness is not None: + # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering + fp.write(f"map_Pm {os.path.basename(metallic_path)} \n") + fp.write(f"map_Pr {os.path.basename(roughness_path)} \n") + + if self.albedo is not None: + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)) + + if self.metallicRoughness is not None: + metallicRoughness = self.metallicRoughness.detach().cpu().numpy() + metallicRoughness = (metallicRoughness * 255).astype(np.uint8) + cv2.imwrite(metallic_path, metallicRoughness[..., 2]) + cv2.imwrite(roughness_path, metallicRoughness[..., 1]) + diff --git a/apps/third_party/CRM/model/.DS_Store b/apps/third_party/CRM/model/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e2fd35f3e9054910a42dde149c88e430130d66d7 Binary files /dev/null and b/apps/third_party/CRM/model/.DS_Store differ diff --git a/apps/third_party/CRM/model/__init__.py b/apps/third_party/CRM/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b339e3ea9dac5482a6daed63b069e4d1eda000a8 --- /dev/null +++ b/apps/third_party/CRM/model/__init__.py @@ -0,0 +1 @@ +from model.crm.model import CRM \ No newline at end of file diff --git a/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de5f1cf642bb8b776bf8da5e5d582f6001996da8 Binary files /dev/null and b/apps/third_party/CRM/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/model/archs/__init__.py b/apps/third_party/CRM/model/archs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b1f2b99930a199372b91624098bf60f9baa8469 Binary files /dev/null and b/apps/third_party/CRM/model/archs/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da7fbd72afdb4465902b7f80d4dda57c8f3b9ee5 Binary files /dev/null and b/apps/third_party/CRM/model/archs/__pycache__/mlp_head.cpython-38.pyc differ diff --git a/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc b/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..932b63f9d151ae809cb092048dd560ac436567f7 Binary files /dev/null and b/apps/third_party/CRM/model/archs/__pycache__/unet.cpython-38.pyc differ diff --git a/apps/third_party/CRM/model/archs/decoders/__init__.py b/apps/third_party/CRM/model/archs/decoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f5a12faa99758192ecc4ed3fc22c9249232e86 --- /dev/null +++ b/apps/third_party/CRM/model/archs/decoders/__init__.py @@ -0,0 +1 @@ + diff --git a/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc b/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da3e1e7536c21056bb8268bd03799770646488ae Binary files /dev/null and b/apps/third_party/CRM/model/archs/decoders/__pycache__/__init__.cpython-38.pyc differ diff --git a/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc b/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fab55d21095a1a7e4841140ca2efb1c71f9730c Binary files /dev/null and b/apps/third_party/CRM/model/archs/decoders/__pycache__/shape_texture_net.cpython-38.pyc differ diff --git a/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py b/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5ddd78215b9f48b281757a91b8de6f73e4742a --- /dev/null +++ b/apps/third_party/CRM/model/archs/decoders/shape_texture_net.py @@ -0,0 +1,62 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TetTexNet(nn.Module): + def __init__(self, plane_reso=64, padding=0.1, fea_concat=True): + super().__init__() + # self.c_dim = c_dim + self.plane_reso = plane_reso + self.padding = padding + self.fea_concat = fea_concat + + def forward(self, rolled_out_feature, query): + # rolled_out_feature: rolled-out triplane feature + # query: queried xyz coordinates (should be scaled consistently to ptr cloud) + + plane_reso = self.plane_reso + + triplane_feature = dict() + triplane_feature['xy'] = rolled_out_feature[:, :, :, 0: plane_reso] + triplane_feature['yz'] = rolled_out_feature[:, :, :, plane_reso: 2 * plane_reso] + triplane_feature['zx'] = rolled_out_feature[:, :, :, 2 * plane_reso:] + + query_feature_xy = self.sample_plane_feature(query, triplane_feature['xy'], 'xy') + query_feature_yz = self.sample_plane_feature(query, triplane_feature['yz'], 'yz') + query_feature_zx = self.sample_plane_feature(query, triplane_feature['zx'], 'zx') + + if self.fea_concat: + query_feature = torch.cat((query_feature_xy, query_feature_yz, query_feature_zx), dim=1) + else: + query_feature = query_feature_xy + query_feature_yz + query_feature_zx + + output = query_feature.permute(0, 2, 1) + + return output + + # uses values from plane_feature and pixel locations from vgrid to interpolate feature + def sample_plane_feature(self, query, plane_feature, plane): + # CYF note: + # for pretraining, query are uniformly sampled positions w.i. [-scale, scale] + # for training, query are essentially tetrahedra grid vertices, which are + # also within [-scale, scale] in the current version! + # xy range [-scale, scale] + if plane == 'xy': + xy = query[:, :, [0, 1]] + elif plane == 'yz': + xy = query[:, :, [1, 2]] + elif plane == 'zx': + xy = query[:, :, [2, 0]] + else: + raise ValueError("Error! Invalid plane type!") + + xy = xy[:, :, None].float() + # not seem necessary to rescale the grid, because from + # https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html, + # it specifies sampling locations normalized by plane_feature's spatial dimension, + # which is within [-scale, scale] as specified by encoder's calling of coordinate2index() + vgrid = 1.0 * xy + sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1) + + return sampled_feat diff --git a/apps/third_party/CRM/model/archs/mlp_head.py b/apps/third_party/CRM/model/archs/mlp_head.py new file mode 100644 index 0000000000000000000000000000000000000000..33d7dcdbf58374dd036d9f3f5f0bfd3f248e845b --- /dev/null +++ b/apps/third_party/CRM/model/archs/mlp_head.py @@ -0,0 +1,40 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class SdfMlp(nn.Module): + def __init__(self, input_dim, hidden_dim=512, bias=True): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) + self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias) + self.fc3 = nn.Linear(hidden_dim, 4, bias=bias) + + + def forward(self, input): + x = F.relu(self.fc1(input)) + x = F.relu(self.fc2(x)) + out = self.fc3(x) + return out + + +class RgbMlp(nn.Module): + def __init__(self, input_dim, hidden_dim=512, bias=True): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias) + self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=bias) + self.fc3 = nn.Linear(hidden_dim, 3, bias=bias) + + def forward(self, input): + x = F.relu(self.fc1(input)) + x = F.relu(self.fc2(x)) + out = self.fc3(x) + + return out + + \ No newline at end of file diff --git a/apps/third_party/CRM/model/archs/unet.py b/apps/third_party/CRM/model/archs/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e427c18b1bed00089e87fe25b0e810042538d6b3 --- /dev/null +++ b/apps/third_party/CRM/model/archs/unet.py @@ -0,0 +1,53 @@ +''' +Codes are from: +https://github.com/jaxony/unet-pytorch/blob/master/model.py +''' + +import torch +import torch.nn as nn +from diffusers import UNet2DModel +import einops +class UNetPP(nn.Module): + ''' + Wrapper for UNet in diffusers + ''' + def __init__(self, in_channels): + super(UNetPP, self).__init__() + self.in_channels = in_channels + self.unet = UNet2DModel( + sample_size=[256, 256*3], + in_channels=in_channels, + out_channels=32, + layers_per_block=2, + block_out_channels=(64, 128, 128, 128*2, 128*2, 128*4, 128*4), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + + self.unet.enable_xformers_memory_efficient_attention() + if in_channels > 12: + self.learned_plane = torch.nn.parameter.Parameter(torch.zeros([1,in_channels-12,256,256*3])) + + def forward(self, x, t=256): + learned_plane = self.learned_plane + if x.shape[1] < self.in_channels: + learned_plane = einops.repeat(learned_plane, '1 C H W -> B C H W', B=x.shape[0]).to(x.device) + x = torch.cat([x, learned_plane], dim = 1) + return self.unet(x, t).sample + diff --git a/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc b/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93b38087a8e1a71bab18af950898769b765939ab Binary files /dev/null and b/apps/third_party/CRM/model/crm/__pycache__/model.cpython-38.pyc differ diff --git a/apps/third_party/CRM/model/crm/model.py b/apps/third_party/CRM/model/crm/model.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb164266a0c23295ab3451d99d720ec12afa468 --- /dev/null +++ b/apps/third_party/CRM/model/crm/model.py @@ -0,0 +1,213 @@ +import torch.nn as nn +import torch +import torch.nn.functional as F + +import numpy as np + + +from pathlib import Path +import cv2 +import trimesh +import nvdiffrast.torch as dr + +from model.archs.decoders.shape_texture_net import TetTexNet +from model.archs.unet import UNetPP +from util.renderer import Renderer +from model.archs.mlp_head import SdfMlp, RgbMlp +import xatlas + + +class Dummy: + pass + +class CRM(nn.Module): + def __init__(self, specs): + super(CRM, self).__init__() + + self.specs = specs + # configs + input_specs = specs["Input"] + self.input = Dummy() + self.input.scale = input_specs['scale'] + self.input.resolution = input_specs['resolution'] + self.tet_grid_size = input_specs['tet_grid_size'] + self.camera_angle_num = input_specs['camera_angle_num'] + + self.arch = Dummy() + self.arch.fea_concat = specs["ArchSpecs"]["fea_concat"] + self.arch.mlp_bias = specs["ArchSpecs"]["mlp_bias"] + + self.dec = Dummy() + self.dec.c_dim = specs["DecoderSpecs"]["c_dim"] + self.dec.plane_resolution = specs["DecoderSpecs"]["plane_resolution"] + + self.geo_type = specs["Train"].get("geo_type", "flex") # "dmtet" or "flex" + + self.unet2 = UNetPP(in_channels=self.dec.c_dim) + + mlp_chnl_s = 3 if self.arch.fea_concat else 1 # 3 for queried triplane feature concatenation + self.decoder = TetTexNet(plane_reso=self.dec.plane_resolution, fea_concat=self.arch.fea_concat) + + if self.geo_type == "flex": + self.weightMlp = nn.Sequential( + nn.Linear(mlp_chnl_s * 32 * 8, 512), + nn.SiLU(), + nn.Linear(512, 21)) + + self.sdfMlp = SdfMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) + self.rgbMlp = RgbMlp(mlp_chnl_s * 32, 512, bias=self.arch.mlp_bias) + self.renderer = Renderer(tet_grid_size=self.tet_grid_size, camera_angle_num=self.camera_angle_num, + scale=self.input.scale, geo_type = self.geo_type) + + + self.spob = True if specs['Pretrain']['mode'] is None else False # whether to add sphere + self.radius = specs['Pretrain']['radius'] # used when spob + + self.denoising = True + from diffusers import DDIMScheduler + self.scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler") + + def decode(self, data, triplane_feature2): + if self.geo_type == "flex": + tet_verts = self.renderer.flexicubes.verts.unsqueeze(0) + tet_indices = self.renderer.flexicubes.indices + + dec_verts = self.decoder(triplane_feature2, tet_verts) + out = self.sdfMlp(dec_verts) + + weight = None + if self.geo_type == "flex": + grid_feat = torch.index_select(input=dec_verts, index=self.renderer.flexicubes.indices.reshape(-1),dim=1) + grid_feat = grid_feat.reshape(dec_verts.shape[0], self.renderer.flexicubes.indices.shape[0], self.renderer.flexicubes.indices.shape[1] * dec_verts.shape[-1]) + weight = self.weightMlp(grid_feat) + weight = weight * 0.1 + + pred_sdf, deformation = out[..., 0], out[..., 1:] + if self.spob: + pred_sdf = pred_sdf + self.radius - torch.sqrt((tet_verts**2).sum(-1)) + + _, verts, faces = self.renderer(data, pred_sdf, deformation, tet_verts, tet_indices, weight= weight) + return verts[0].unsqueeze(0), faces[0].int() + + def export_mesh(self, data, out_dir, ind, device=None, tri_fea_2 = None): + verts = data['verts'] + faces = data['faces'] + + dec_verts = self.decoder(tri_fea_2, verts.unsqueeze(0)) + colors = self.rgbMlp(dec_verts).squeeze().detach().cpu().numpy() + # Expect predicted colors value range from [-1, 1] + colors = (colors * 0.5 + 0.5).clip(0, 1) + + verts = verts.squeeze().cpu().numpy() + faces = faces[..., [2, 1, 0]].squeeze().cpu().numpy() + + # export the final mesh + with torch.no_grad(): + mesh = trimesh.Trimesh(verts, faces, vertex_colors=colors, process=False) # important, process=True leads to seg fault... + mesh.export(out_dir / f'{ind}.obj') + + def export_mesh_wt_uv(self, ctx, data, out_dir, ind, device, res, tri_fea_2=None): + + mesh_v = data['verts'].squeeze().cpu().numpy() + mesh_pos_idx = data['faces'].squeeze().cpu().numpy() + + def interpolate(attr, rast, attr_idx, rast_db=None): + return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, + diff_attrs=None if rast_db is None else 'all') + + vmapping, indices, uvs = xatlas.parametrize(mesh_v, mesh_pos_idx) + + mesh_v = torch.tensor(mesh_v, dtype=torch.float32, device=device) + mesh_pos_idx = torch.tensor(mesh_pos_idx, dtype=torch.int64, device=device) + + # Convert to tensors + indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) + + uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) + mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) + # mesh_v_tex. ture + uv_clip = uvs[None, ...] * 2.0 - 1.0 + + # pad to four component coordinate + uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) + + # rasterize + rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), res) + + # Interpolate world space position + gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) + mask = rast[..., 3:4] > 0 + + # return uvs, mesh_tex_idx, gb_pos, mask + gb_pos_unsqz = gb_pos.view(-1, 3) + mask_unsqz = mask.view(-1) + tex_unsqz = torch.zeros_like(gb_pos_unsqz) + 1 + + gb_mask_pos = gb_pos_unsqz[mask_unsqz] + + gb_mask_pos = gb_mask_pos[None, ] + + with torch.no_grad(): + + dec_verts = self.decoder(tri_fea_2, gb_mask_pos) + colors = self.rgbMlp(dec_verts).squeeze() + + # Expect predicted colors value range from [-1, 1] + lo, hi = (-1, 1) + colors = (colors - lo) * (255 / (hi - lo)) + colors = colors.clip(0, 255) + + tex_unsqz[mask_unsqz] = colors + + tex = tex_unsqz.view(res + (3,)) + + verts = mesh_v.squeeze().cpu().numpy() + faces = mesh_pos_idx[..., [2, 1, 0]].squeeze().cpu().numpy() + # faces = mesh_pos_idx + # faces = faces.detach().cpu().numpy() + # faces = faces[..., [2, 1, 0]] + indices = indices[..., [2, 1, 0]] + + # xatlas.export(f"{out_dir}/{ind}.obj", verts[vmapping], indices, uvs) + matname = f'{out_dir}.mtl' + # matname = f'{out_dir}/{ind}.mtl' + fid = open(matname, 'w') + fid.write('newmtl material_0\n') + fid.write('Kd 1 1 1\n') + fid.write('Ka 1 1 1\n') + # fid.write('Ks 0 0 0\n') + fid.write('Ks 0.4 0.4 0.4\n') + fid.write('Ns 10\n') + fid.write('illum 2\n') + fid.write(f'map_Kd {out_dir.split("/")[-1]}.png\n') + fid.close() + + fid = open(f'{out_dir}.obj', 'w') + # fid = open(f'{out_dir}/{ind}.obj', 'w') + fid.write('mtllib %s.mtl\n' % out_dir.split("/")[-1]) + + for pidx, p in enumerate(verts): + pp = p + fid.write('v %f %f %f\n' % (pp[0], pp[2], - pp[1])) + + for pidx, p in enumerate(uvs): + pp = p + fid.write('vt %f %f\n' % (pp[0], 1 - pp[1])) + + fid.write('usemtl material_0\n') + for i, f in enumerate(faces): + f1 = f + 1 + f2 = indices[i] + 1 + fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + img = np.asarray(tex.data.cpu().numpy(), dtype=np.float32) + mask = np.sum(img.astype(float), axis=-1, keepdims=True) + mask = (mask <= 3.0).astype(float) + kernel = np.ones((3, 3), 'uint8') + dilate_img = cv2.dilate(img, kernel, iterations=1) + img = img * (1 - mask) + dilate_img * mask + img = img.clip(0, 255).astype(np.uint8) + + cv2.imwrite(f'{out_dir}.png', img[..., [2, 1, 0]]) + # cv2.imwrite(f'{out_dir}/{ind}.png', img[..., [2, 1, 0]]) diff --git a/apps/third_party/CRM/pipelines.py b/apps/third_party/CRM/pipelines.py new file mode 100644 index 0000000000000000000000000000000000000000..0ef19dc84dd7789197d02a29239d99b0a82558b1 --- /dev/null +++ b/apps/third_party/CRM/pipelines.py @@ -0,0 +1,205 @@ +import torch +import os +import sys +proj_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(proj_dir) +from .libs.base_utils import do_resize_content +from .imagedream.ldm.util import ( + instantiate_from_config, + get_obj_from_str, +) +from omegaconf import OmegaConf +from PIL import Image +import PIL +import rembg +class TwoStagePipeline(object): + def __init__( + self, + stage1_model_config, + stage1_sampler_config, + device="cuda", + dtype=torch.float16, + resize_rate=1, + ) -> None: + """ + only for two stage generate process. + - the first stage was condition on single pixel image, gererate multi-view pixel image, based on the v2pp config + - the second stage was condition on multiview pixel image generated by the first stage, generate the final image, based on the stage2-test config + """ + self.resize_rate = resize_rate + + self.stage1_model = instantiate_from_config(OmegaConf.load(stage1_model_config.config).model) + self.stage1_model.load_state_dict(torch.load(stage1_model_config.resume, map_location="cpu"), strict=False) + self.stage1_model = self.stage1_model.to(device).to(dtype) + + self.stage1_model.device = device + self.device = device + self.dtype = dtype + self.stage1_sampler = get_obj_from_str(stage1_sampler_config.target)( + self.stage1_model, device=device, dtype=dtype, **stage1_sampler_config.params + ) + + def stage1_sample( + self, + pixel_img, + prompt="3D assets", + neg_texts="uniform low no texture ugly, boring, bad anatomy, blurry, pixelated, obscure, unnatural colors, poor lighting, dull, and unclear.", + step=50, + scale=5, + ddim_eta=0.0, + ): + if type(pixel_img) == str: + pixel_img = Image.open(pixel_img) + + if isinstance(pixel_img, Image.Image): + if pixel_img.mode == "RGBA": + background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) + pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") + else: + pixel_img = pixel_img.convert("RGB") + else: + raise + uc = self.stage1_sampler.model.get_learned_conditioning([neg_texts]).to(self.device) + stage1_images = self.stage1_sampler.i2i( + self.stage1_sampler.model, + self.stage1_sampler.size, + prompt, + uc=uc, + sampler=self.stage1_sampler.sampler, + ip=pixel_img, + step=step, + scale=scale, + batch_size=self.stage1_sampler.batch_size, + ddim_eta=ddim_eta, + dtype=self.stage1_sampler.dtype, + device=self.stage1_sampler.device, + camera=self.stage1_sampler.camera, + num_frames=self.stage1_sampler.num_frames, + pixel_control=(self.stage1_sampler.mode == "pixel"), + transform=self.stage1_sampler.image_transform, + offset_noise=self.stage1_sampler.offset_noise, + ) + + stage1_images = [Image.fromarray(img) for img in stage1_images] + stage1_images.pop(self.stage1_sampler.ref_position) + return stage1_images + + def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50): + if type(pixel_img) == str: + pixel_img = Image.open(pixel_img) + + if isinstance(pixel_img, Image.Image): + if pixel_img.mode == "RGBA": + background = Image.new('RGBA', pixel_img.size, (0, 0, 0, 0)) + pixel_img = Image.alpha_composite(background, pixel_img).convert("RGB") + else: + pixel_img = pixel_img.convert("RGB") + else: + raise + stage2_images = self.stage2_sampler.i2iStage2( + self.stage2_sampler.model, + self.stage2_sampler.size, + "3D assets", + self.stage2_sampler.uc, + self.stage2_sampler.sampler, + pixel_images=stage1_images, + ip=pixel_img, + step=step, + scale=scale, + batch_size=self.stage2_sampler.batch_size, + ddim_eta=0.0, + dtype=self.stage2_sampler.dtype, + device=self.stage2_sampler.device, + camera=self.stage2_sampler.camera, + num_frames=self.stage2_sampler.num_frames, + pixel_control=(self.stage2_sampler.mode == "pixel"), + transform=self.stage2_sampler.image_transform, + offset_noise=self.stage2_sampler.offset_noise, + ) + stage2_images = [Image.fromarray(img) for img in stage2_images] + return stage2_images + + def set_seed(self, seed): + self.stage1_sampler.seed = seed + # self.stage2_sampler.seed = seed + + def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50): + pixel_img = do_resize_content(pixel_img, self.resize_rate) + stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step) + # stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step) + + return { + "ref_img": pixel_img, + "stage1_images": stage1_images, + # "stage2_images": stage2_images, + } + +rembg_session = rembg.new_session() + +def expand_to_square(image, bg_color=(0, 0, 0, 0)): + # expand image to 1:1 + width, height = image.size + if width == height: + return image + new_size = (max(width, height), max(width, height)) + new_image = Image.new("RGBA", new_size, bg_color) + paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) + new_image.paste(image, paste_position) + return new_image + +def remove_background( + image: PIL.Image.Image, + rembg_session = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + # explain why current do not rm bg + print("alhpa channl not enpty, skip remove background, using alpha channel as mask") + background = Image.new("RGBA", image.size, (0, 0, 0, 0)) + image = Image.alpha_composite(background, image) + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + +def do_resize_content(original_image: Image, scale_rate): + # resize image content wile retain the original image size + if scale_rate != 1: + # Calculate the new size after rescaling + new_size = tuple(int(dim * scale_rate) for dim in original_image.size) + # Resize the image while maintaining the aspect ratio + resized_image = original_image.resize(new_size) + # Create a new image with the original size and black background + padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) + paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) + padded_image.paste(resized_image, paste_position) + return padded_image + else: + return original_image + +def add_background(image, bg_color=(255, 255, 255)): + # given an RGBA image, alpha channel is used as mask to add background color + background = Image.new("RGBA", image.size, bg_color) + return Image.alpha_composite(background, image) + + +def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): + """ + input image is a pil image in RGBA, return RGB image + """ + print(background_choice) + if background_choice == "Alpha as mask": + background = Image.new("RGBA", image.size, (0, 0, 0, 0)) + image = Image.alpha_composite(background, image) + else: + image = remove_background(image, rembg_session, force_remove=True) + image = do_resize_content(image, foreground_ratio) + image = expand_to_square(image) + image = add_background(image, backgroud_color) + return image.convert("RGB") + + + diff --git a/apps/third_party/CRM/requirements.txt b/apps/third_party/CRM/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..8501f40d7ec31f64b3b6b77549fb2b7623f2d382 --- /dev/null +++ b/apps/third_party/CRM/requirements.txt @@ -0,0 +1,16 @@ +gradio +huggingface-hub +diffusers==0.24.0 +einops==0.7.0 +Pillow==10.1.0 +transformers==4.27.1 +open-clip-torch==2.7.0 +opencv-contrib-python-headless==4.9.0.80 +opencv-python-headless==4.9.0.80 +omegaconf +rembg +pygltflib +kiui +trimesh +xatlas +pymeshlab diff --git a/apps/third_party/CRM/run.py b/apps/third_party/CRM/run.py new file mode 100644 index 0000000000000000000000000000000000000000..8e14be6e7c7cb1d314d1e82a23a6d250e79ce3b7 --- /dev/null +++ b/apps/third_party/CRM/run.py @@ -0,0 +1,160 @@ +import torch +from libs.base_utils import do_resize_content +from imagedream.ldm.util import ( + instantiate_from_config, + get_obj_from_str, +) +from omegaconf import OmegaConf +from PIL import Image +import numpy as np +from inference import generate3d +from huggingface_hub import hf_hub_download +import json +import argparse +import shutil +from model import CRM +import PIL +import rembg +import os +from pipelines import TwoStagePipeline + +rembg_session = rembg.new_session() + +def expand_to_square(image, bg_color=(0, 0, 0, 0)): + # expand image to 1:1 + width, height = image.size + if width == height: + return image + new_size = (max(width, height), max(width, height)) + new_image = Image.new("RGBA", new_size, bg_color) + paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2) + new_image.paste(image, paste_position) + return new_image + +def remove_background( + image: PIL.Image.Image, + rembg_session = None, + force: bool = False, + **rembg_kwargs, +) -> PIL.Image.Image: + do_remove = True + if image.mode == "RGBA" and image.getextrema()[3][0] < 255: + # explain why current do not rm bg + print("alhpa channl not enpty, skip remove background, using alpha channel as mask") + background = Image.new("RGBA", image.size, (0, 0, 0, 0)) + image = Image.alpha_composite(background, image) + do_remove = False + do_remove = do_remove or force + if do_remove: + image = rembg.remove(image, session=rembg_session, **rembg_kwargs) + return image + +def do_resize_content(original_image: Image, scale_rate): + # resize image content wile retain the original image size + if scale_rate != 1: + # Calculate the new size after rescaling + new_size = tuple(int(dim * scale_rate) for dim in original_image.size) + # Resize the image while maintaining the aspect ratio + resized_image = original_image.resize(new_size) + # Create a new image with the original size and black background + padded_image = Image.new("RGBA", original_image.size, (0, 0, 0, 0)) + paste_position = ((original_image.width - resized_image.width) // 2, (original_image.height - resized_image.height) // 2) + padded_image.paste(resized_image, paste_position) + return padded_image + else: + return original_image + +def add_background(image, bg_color=(255, 255, 255)): + # given an RGBA image, alpha channel is used as mask to add background color + background = Image.new("RGBA", image.size, bg_color) + return Image.alpha_composite(background, image) + + +def preprocess_image(image, background_choice, foreground_ratio, backgroud_color): + """ + input image is a pil image in RGBA, return RGB image + """ + print(background_choice) + if background_choice == "Alpha as mask": + background = Image.new("RGBA", image.size, (0, 0, 0, 0)) + image = Image.alpha_composite(background, image) + else: + image = remove_background(image, rembg_session, force_remove=True) + image = do_resize_content(image, foreground_ratio) + image = expand_to_square(image) + image = add_background(image, backgroud_color) + return image.convert("RGB") + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "--inputdir", + type=str, + default="examples/kunkun.webp", + help="dir for input image", + ) + parser.add_argument( + "--scale", + type=float, + default=5.0, + ) + parser.add_argument( + "--step", + type=int, + default=50, + ) + parser.add_argument( + "--bg_choice", + type=str, + default="Auto Remove background", + help="[Auto Remove background] or [Alpha as mask]", + ) + parser.add_argument( + "--outdir", + type=str, + default="out/", + ) + args = parser.parse_args() + + + img = Image.open(args.inputdir) + img = preprocess_image(img, args.bg_choice, 1.0, (127, 127, 127)) + os.makedirs(args.outdir, exist_ok=True) + img.save(args.outdir+"preprocessed_image.png") + + crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth") + specs = json.load(open("configs/specs_objaverse_total.json")) + model = CRM(specs).to("cuda") + model.load_state_dict(torch.load(crm_path, map_location = "cuda"), strict=False) + + stage1_config = OmegaConf.load("configs/nf7_v3_SNR_rd_size_stroke.yaml").config + stage2_config = OmegaConf.load("configs/stage2-v2-snr.yaml").config + stage2_sampler_config = stage2_config.sampler + stage1_sampler_config = stage1_config.sampler + + stage1_model_config = stage1_config.models + stage2_model_config = stage2_config.models + + xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth") + pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth") + stage1_model_config.resume = pixel_path + stage2_model_config.resume = xyz_path + + pipeline = TwoStagePipeline( + stage1_model_config, + stage2_model_config, + stage1_sampler_config, + stage2_sampler_config, + ) + + rt_dict = pipeline(img, scale=args.scale, step=args.step) + stage1_images = rt_dict["stage1_images"] + stage2_images = rt_dict["stage2_images"] + np_imgs = np.concatenate(stage1_images, 1) + np_xyzs = np.concatenate(stage2_images, 1) + Image.fromarray(np_imgs).save(args.outdir+"pixel_images.png") + Image.fromarray(np_xyzs).save(args.outdir+"xyz_images.png") + + glb_path, obj_path = generate3d(model, np_imgs, np_xyzs, "cuda") + shutil.copy(obj_path, args.outdir+"output3d.zip") \ No newline at end of file