wyysf commited on
Commit
d758270
1 Parent(s): c5f7475

Upload 107 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. apps/third_party/CRM/LICENSE +21 -0
  2. apps/third_party/CRM/README.md +85 -0
  3. apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml +21 -0
  4. apps/third_party/CRM/configs/specs_objaverse_total.json +57 -0
  5. apps/third_party/CRM/configs/stage2-v2-snr.yaml +25 -0
  6. apps/third_party/CRM/imagedream/.DS_Store +0 -0
  7. apps/third_party/CRM/imagedream/__init__.py +1 -0
  8. apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc +0 -0
  9. apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc +0 -0
  10. apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc +0 -0
  11. apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc +0 -0
  12. apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc +0 -0
  13. apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc +0 -0
  14. apps/third_party/CRM/imagedream/camera_utils.py +99 -0
  15. apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml +61 -0
  16. apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml +61 -0
  17. apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml +61 -0
  18. apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml +62 -0
  19. apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml +62 -0
  20. apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml +62 -0
  21. apps/third_party/CRM/imagedream/ldm/__init__.py +0 -0
  22. apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc +0 -0
  23. apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc +0 -0
  24. apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc +0 -0
  25. apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc +0 -0
  26. apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc +0 -0
  27. apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc +0 -0
  28. apps/third_party/CRM/imagedream/ldm/interface.py +206 -0
  29. apps/third_party/CRM/imagedream/ldm/models/__init__.py +0 -0
  30. apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc +0 -0
  31. apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc +0 -0
  32. apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc +0 -0
  33. apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  34. apps/third_party/CRM/imagedream/ldm/models/autoencoder.py +270 -0
  35. apps/third_party/CRM/imagedream/ldm/models/diffusion/__init__.py +0 -0
  36. apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  37. apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  38. apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc +0 -0
  39. apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  40. apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py +430 -0
  41. apps/third_party/CRM/imagedream/ldm/modules/__init__.py +0 -0
  42. apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  43. apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc +0 -0
  44. apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc +0 -0
  45. apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc +0 -0
  46. apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc +0 -0
  47. apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc +0 -0
  48. apps/third_party/CRM/imagedream/ldm/modules/attention.py +456 -0
  49. apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__init__.py +0 -0
  50. apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc +0 -0
apps/third_party/CRM/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 TSAIL group
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
apps/third_party/CRM/README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Convolutional Reconstruction Model
2
+
3
+ Official implementation for *CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model*.
4
+
5
+ **CRM is a feed-forward model which can generate 3D textured mesh in 10 seconds.**
6
+
7
+ ## [Project Page](https://ml.cs.tsinghua.edu.cn/~zhengyi/CRM/) | [Arxiv](https://arxiv.org/abs/2403.05034) | [HF-Demo](https://huggingface.co/spaces/Zhengyi/CRM) | [Weights](https://huggingface.co/Zhengyi/CRM)
8
+
9
+ https://github.com/thu-ml/CRM/assets/40787266/8b325bc0-aa74-4c26-92e8-a8f0c1079382
10
+
11
+ ## Try CRM 🍻
12
+ * Try CRM at [Huggingface Demo](https://huggingface.co/spaces/Zhengyi/CRM).
13
+ * Try CRM at [Replicate Demo](https://replicate.com/camenduru/crm). Thanks [@camenduru](https://github.com/camenduru)!
14
+
15
+ ## Install
16
+
17
+ ### Step 1 - Base
18
+
19
+ Install package one by one, we use **python 3.9**
20
+
21
+ ```bash
22
+ pip install torch==1.13.0+cu117 torchvision==0.14.0+cu117 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
23
+ pip install torch-scatter==2.1.1 -f https://data.pyg.org/whl/torch-1.13.1+cu117.html
24
+ pip install kaolin==0.14.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-1.13.1_cu117.html
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ besides, one by one need to install xformers manually according to the official [doc](https://github.com/facebookresearch/xformers?tab=readme-ov-file#installing-xformers) (**conda no need**), e.g.
29
+
30
+ ```bash
31
+ pip install ninja
32
+ pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
33
+ ```
34
+
35
+ ### Step 2 - Nvdiffrast
36
+
37
+ Install nvdiffrast according to the official [doc](https://nvlabs.github.io/nvdiffrast/#installation), e.g.
38
+
39
+ ```bash
40
+ pip install git+https://github.com/NVlabs/nvdiffrast
41
+ ```
42
+
43
+
44
+
45
+ ## Inference
46
+
47
+ We suggest gradio for a visualized inference.
48
+
49
+ ```
50
+ gradio app.py
51
+ ```
52
+
53
+ ![image](https://github.com/thu-ml/CRM/assets/40787266/4354d22a-a641-4531-8408-c761ead8b1a2)
54
+
55
+ For inference in command lines, simply run
56
+ ```bash
57
+ CUDA_VISIBLE_DEVICES="0" python run.py --inputdir "examples/kunkun.webp"
58
+ ```
59
+ It will output the preprocessed image, generated 6-view images and CCMs and a 3D model in obj format.
60
+
61
+ **Tips:** (1) If the result is unsatisfatory, please check whether the input image is correctly pre-processed into a grey background. Otherwise the results will be unpredictable.
62
+ (2) Different from the [Huggingface Demo](https://huggingface.co/spaces/Zhengyi/CRM), this official implementation uses UV texture instead of vertex color. It has better texture than the online demo but longer generating time owing to the UV texturing.
63
+
64
+ ## Todo List
65
+ - [x] Release inference code.
66
+ - [x] Release pretrained models.
67
+ - [ ] Optimize inference code to fit in low memery GPU.
68
+ - [ ] Upload training code.
69
+
70
+ ## Acknowledgement
71
+ - [ImageDream](https://github.com/bytedance/ImageDream)
72
+ - [nvdiffrast](https://github.com/NVlabs/nvdiffrast)
73
+ - [kiuikit](https://github.com/ashawkey/kiuikit)
74
+ - [GET3D](https://github.com/nv-tlabs/GET3D)
75
+
76
+ ## Citation
77
+
78
+ ```
79
+ @article{wang2024crm,
80
+ title={CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model},
81
+ author={Zhengyi Wang and Yikai Wang and Yifei Chen and Chendong Xiang and Shuo Chen and Dajiang Yu and Chongxuan Li and Hang Su and Jun Zhu},
82
+ journal={arXiv preprint arXiv:2403.05034},
83
+ year={2024}
84
+ }
85
+ ```
apps/third_party/CRM/configs/nf7_v3_SNR_rd_size_stroke.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config:
2
+ # others
3
+ seed: 1234
4
+ num_frames: 7
5
+ mode: pixel
6
+ offset_noise: true
7
+ # model related
8
+ models:
9
+ config: imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml
10
+ resume: models/pixel.pth
11
+ # sampler related
12
+ sampler:
13
+ target: libs.sample.ImageDreamDiffusion
14
+ params:
15
+ mode: pixel
16
+ num_frames: 7
17
+ camera_views: [1, 2, 3, 4, 5, 0, 0]
18
+ ref_position: 6
19
+ random_background: false
20
+ offset_noise: true
21
+ resize_rate: 1.0
apps/third_party/CRM/configs/specs_objaverse_total.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Input": {
3
+ "img_num": 16,
4
+ "class": "all",
5
+ "camera_angle_num": 8,
6
+ "tet_grid_size": 80,
7
+ "validate_num": 16,
8
+ "scale": 0.95,
9
+ "radius": 3,
10
+ "resolution": [256, 256]
11
+ },
12
+
13
+ "Pretrain": {
14
+ "mode": null,
15
+ "sdf_threshold": 0.1,
16
+ "sdf_scale": 10,
17
+ "batch_infer": false,
18
+ "lr": 1e-4,
19
+ "radius": 0.5
20
+ },
21
+
22
+ "Train": {
23
+ "mode": "rnd",
24
+ "num_epochs": 500,
25
+ "grad_acc": 1,
26
+ "warm_up": 0,
27
+ "decay": 0.000,
28
+ "learning_rate": {
29
+ "init": 1e-4,
30
+ "sdf_decay": 1,
31
+ "rgb_decay": 1
32
+ },
33
+ "batch_size": 4,
34
+ "eva_iter": 80,
35
+ "eva_all_epoch": 10,
36
+ "tex_sup_mode": "blender",
37
+ "exp_uv_mesh": false,
38
+ "doub": false,
39
+ "random_bg": false,
40
+ "shift": 0,
41
+ "aug_shift": 0,
42
+ "geo_type": "flex"
43
+ },
44
+
45
+ "ArchSpecs": {
46
+ "unet_type": "diffusers",
47
+ "use_3D_aware": false,
48
+ "fea_concat": false,
49
+ "mlp_bias": true
50
+ },
51
+
52
+ "DecoderSpecs": {
53
+ "c_dim": 32,
54
+ "plane_resolution": 256
55
+ }
56
+ }
57
+
apps/third_party/CRM/configs/stage2-v2-snr.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config:
2
+ # others
3
+ seed: 1234
4
+ num_frames: 6
5
+ mode: pixel
6
+ offset_noise: true
7
+ gd_type: xyz
8
+ # model related
9
+ models:
10
+ config: imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml
11
+ resume: models/xyz.pth
12
+
13
+ # eval related
14
+ sampler:
15
+ target: libs.sample.ImageDreamDiffusionStage2
16
+ params:
17
+ mode: pixel
18
+ num_frames: 6
19
+ camera_views: [1, 2, 3, 4, 5, 0]
20
+ ref_position: null
21
+ random_background: false
22
+ offset_noise: true
23
+ resize_rate: 1.0
24
+
25
+
apps/third_party/CRM/imagedream/.DS_Store ADDED
Binary file (6.15 kB). View file
 
apps/third_party/CRM/imagedream/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_zoo import build_model
apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (219 Bytes). View file
 
apps/third_party/CRM/imagedream/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (216 Bytes). View file
 
apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-310.pyc ADDED
Binary file (2.83 kB). View file
 
apps/third_party/CRM/imagedream/__pycache__/camera_utils.cpython-38.pyc ADDED
Binary file (2.75 kB). View file
 
apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-310.pyc ADDED
Binary file (1.79 kB). View file
 
apps/third_party/CRM/imagedream/__pycache__/model_zoo.cpython-38.pyc ADDED
Binary file (1.79 kB). View file
 
apps/third_party/CRM/imagedream/camera_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def create_camera_to_world_matrix(elevation, azimuth):
6
+ elevation = np.radians(elevation)
7
+ azimuth = np.radians(azimuth)
8
+ # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
9
+ x = np.cos(elevation) * np.sin(azimuth)
10
+ y = np.sin(elevation)
11
+ z = np.cos(elevation) * np.cos(azimuth)
12
+
13
+ # Calculate camera position, target, and up vectors
14
+ camera_pos = np.array([x, y, z])
15
+ target = np.array([0, 0, 0])
16
+ up = np.array([0, 1, 0])
17
+
18
+ # Construct view matrix
19
+ forward = target - camera_pos
20
+ forward /= np.linalg.norm(forward)
21
+ right = np.cross(forward, up)
22
+ right /= np.linalg.norm(right)
23
+ new_up = np.cross(right, forward)
24
+ new_up /= np.linalg.norm(new_up)
25
+ cam2world = np.eye(4)
26
+ cam2world[:3, :3] = np.array([right, new_up, -forward]).T
27
+ cam2world[:3, 3] = camera_pos
28
+ return cam2world
29
+
30
+
31
+ def convert_opengl_to_blender(camera_matrix):
32
+ if isinstance(camera_matrix, np.ndarray):
33
+ # Construct transformation matrix to convert from OpenGL space to Blender space
34
+ flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
35
+ camera_matrix_blender = np.dot(flip_yz, camera_matrix)
36
+ else:
37
+ # Construct transformation matrix to convert from OpenGL space to Blender space
38
+ flip_yz = torch.tensor(
39
+ [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
40
+ )
41
+ if camera_matrix.ndim == 3:
42
+ flip_yz = flip_yz.unsqueeze(0)
43
+ camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
44
+ return camera_matrix_blender
45
+
46
+
47
+ def normalize_camera(camera_matrix):
48
+ """normalize the camera location onto a unit-sphere"""
49
+ if isinstance(camera_matrix, np.ndarray):
50
+ camera_matrix = camera_matrix.reshape(-1, 4, 4)
51
+ translation = camera_matrix[:, :3, 3]
52
+ translation = translation / (
53
+ np.linalg.norm(translation, axis=1, keepdims=True) + 1e-8
54
+ )
55
+ camera_matrix[:, :3, 3] = translation
56
+ else:
57
+ camera_matrix = camera_matrix.reshape(-1, 4, 4)
58
+ translation = camera_matrix[:, :3, 3]
59
+ translation = translation / (
60
+ torch.norm(translation, dim=1, keepdim=True) + 1e-8
61
+ )
62
+ camera_matrix[:, :3, 3] = translation
63
+ return camera_matrix.reshape(-1, 16)
64
+
65
+
66
+ def get_camera(
67
+ num_frames,
68
+ elevation=15,
69
+ azimuth_start=0,
70
+ azimuth_span=360,
71
+ blender_coord=True,
72
+ extra_view=False,
73
+ ):
74
+ angle_gap = azimuth_span / num_frames
75
+ cameras = []
76
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
77
+ camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
78
+ if blender_coord:
79
+ camera_matrix = convert_opengl_to_blender(camera_matrix)
80
+ cameras.append(camera_matrix.flatten())
81
+
82
+ if extra_view:
83
+ dim = len(cameras[0])
84
+ cameras.append(np.zeros(dim))
85
+ return torch.tensor(np.stack(cameras, 0)).float()
86
+
87
+
88
+ def get_camera_for_index(data_index):
89
+ """
90
+ 按照当前我们的数据格式, 以000为正对我们的情况:
91
+ 000是正面, ev: 0, azimuth: 0
92
+ 001是左边, ev: 0, azimuth: -90
93
+ 002是下面, ev: -90, azimuth: 0
94
+ 003是背面, ev: 0, azimuth: 180
95
+ 004是右边, ev: 0, azimuth: 90
96
+ 005是上面, ev: 90, azimuth: 0
97
+ """
98
+ params = [(0, 0), (0, -90), (-90, 0), (0, 180), (0, 90), (90, 0)]
99
+ return get_camera(1, *params[data_index])
apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+
10
+ unet_config:
11
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
12
+ params:
13
+ image_size: 32 # unused
14
+ in_channels: 4
15
+ out_channels: 4
16
+ model_channels: 320
17
+ attention_resolutions: [ 4, 2, 1 ]
18
+ num_res_blocks: 2
19
+ channel_mult: [ 1, 2, 4, 4 ]
20
+ num_head_channels: 64 # need to fix for flash-attn
21
+ use_spatial_transformer: True
22
+ use_linear_in_transformer: True
23
+ transformer_depth: 1
24
+ context_dim: 1024
25
+ use_checkpoint: False
26
+ legacy: False
27
+ camera_dim: 16
28
+ with_ip: True
29
+ ip_dim: 16 # ip token length
30
+ ip_mode: "local_resample"
31
+
32
+ vae_config:
33
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
34
+ params:
35
+ embed_dim: 4
36
+ monitor: val/rec_loss
37
+ ddconfig:
38
+ #attn_type: "vanilla-xformers"
39
+ double_z: true
40
+ z_channels: 4
41
+ resolution: 256
42
+ in_channels: 3
43
+ out_ch: 3
44
+ ch: 128
45
+ ch_mult:
46
+ - 1
47
+ - 2
48
+ - 4
49
+ - 4
50
+ num_res_blocks: 2
51
+ attn_resolutions: []
52
+ dropout: 0.0
53
+ lossconfig:
54
+ target: torch.nn.Identity
55
+
56
+ clip_config:
57
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
58
+ params:
59
+ freeze: True
60
+ layer: "penultimate"
61
+ ip_mode: "local_resample"
apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_ch8.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+
10
+ unet_config:
11
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
12
+ params:
13
+ image_size: 32 # unused
14
+ in_channels: 8
15
+ out_channels: 8
16
+ model_channels: 320
17
+ attention_resolutions: [ 4, 2, 1 ]
18
+ num_res_blocks: 2
19
+ channel_mult: [ 1, 2, 4, 4 ]
20
+ num_head_channels: 64 # need to fix for flash-attn
21
+ use_spatial_transformer: True
22
+ use_linear_in_transformer: True
23
+ transformer_depth: 1
24
+ context_dim: 1024
25
+ use_checkpoint: False
26
+ legacy: False
27
+ camera_dim: 16
28
+ with_ip: True
29
+ ip_dim: 16 # ip token length
30
+ ip_mode: "local_resample"
31
+
32
+ vae_config:
33
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
34
+ params:
35
+ embed_dim: 4
36
+ monitor: val/rec_loss
37
+ ddconfig:
38
+ #attn_type: "vanilla-xformers"
39
+ double_z: true
40
+ z_channels: 4
41
+ resolution: 256
42
+ in_channels: 3
43
+ out_ch: 3
44
+ ch: 128
45
+ ch_mult:
46
+ - 1
47
+ - 2
48
+ - 4
49
+ - 4
50
+ num_res_blocks: 2
51
+ attn_resolutions: []
52
+ dropout: 0.0
53
+ lossconfig:
54
+ target: torch.nn.Identity
55
+
56
+ clip_config:
57
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
58
+ params:
59
+ freeze: True
60
+ layer: "penultimate"
61
+ ip_mode: "local_resample"
apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8.yaml ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+
10
+ unet_config:
11
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
12
+ params:
13
+ image_size: 32 # unused
14
+ in_channels: 8
15
+ out_channels: 4
16
+ model_channels: 320
17
+ attention_resolutions: [ 4, 2, 1 ]
18
+ num_res_blocks: 2
19
+ channel_mult: [ 1, 2, 4, 4 ]
20
+ num_head_channels: 64 # need to fix for flash-attn
21
+ use_spatial_transformer: True
22
+ use_linear_in_transformer: True
23
+ transformer_depth: 1
24
+ context_dim: 1024
25
+ use_checkpoint: False
26
+ legacy: False
27
+ camera_dim: 16
28
+ with_ip: True
29
+ ip_dim: 16 # ip token length
30
+ ip_mode: "local_resample"
31
+
32
+ vae_config:
33
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
34
+ params:
35
+ embed_dim: 4
36
+ monitor: val/rec_loss
37
+ ddconfig:
38
+ #attn_type: "vanilla-xformers"
39
+ double_z: true
40
+ z_channels: 4
41
+ resolution: 256
42
+ in_channels: 3
43
+ out_ch: 3
44
+ ch: 128
45
+ ch_mult:
46
+ - 1
47
+ - 2
48
+ - 4
49
+ - 4
50
+ num_res_blocks: 2
51
+ attn_resolutions: []
52
+ dropout: 0.0
53
+ lossconfig:
54
+ target: torch.nn.Identity
55
+
56
+ clip_config:
57
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
58
+ params:
59
+ freeze: True
60
+ layer: "penultimate"
61
+ ip_mode: "local_resample"
apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_chin8_zero_snr.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+ zero_snr: true
10
+
11
+ unet_config:
12
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModelStage2
13
+ params:
14
+ image_size: 32 # unused
15
+ in_channels: 8
16
+ out_channels: 4
17
+ model_channels: 320
18
+ attention_resolutions: [ 4, 2, 1 ]
19
+ num_res_blocks: 2
20
+ channel_mult: [ 1, 2, 4, 4 ]
21
+ num_head_channels: 64 # need to fix for flash-attn
22
+ use_spatial_transformer: True
23
+ use_linear_in_transformer: True
24
+ transformer_depth: 1
25
+ context_dim: 1024
26
+ use_checkpoint: False
27
+ legacy: False
28
+ camera_dim: 16
29
+ with_ip: True
30
+ ip_dim: 16 # ip token length
31
+ ip_mode: "local_resample"
32
+
33
+ vae_config:
34
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
35
+ params:
36
+ embed_dim: 4
37
+ monitor: val/rec_loss
38
+ ddconfig:
39
+ #attn_type: "vanilla-xformers"
40
+ double_z: true
41
+ z_channels: 4
42
+ resolution: 256
43
+ in_channels: 3
44
+ out_ch: 3
45
+ ch: 128
46
+ ch_mult:
47
+ - 1
48
+ - 2
49
+ - 4
50
+ - 4
51
+ num_res_blocks: 2
52
+ attn_resolutions: []
53
+ dropout: 0.0
54
+ lossconfig:
55
+ target: torch.nn.Identity
56
+
57
+ clip_config:
58
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
59
+ params:
60
+ freeze: True
61
+ layer: "penultimate"
62
+ ip_mode: "local_resample"
apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_local.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+
10
+ unet_config:
11
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
12
+ params:
13
+ image_size: 32 # unused
14
+ in_channels: 4
15
+ out_channels: 4
16
+ model_channels: 320
17
+ attention_resolutions: [ 4, 2, 1 ]
18
+ num_res_blocks: 2
19
+ channel_mult: [ 1, 2, 4, 4 ]
20
+ num_head_channels: 64 # need to fix for flash-attn
21
+ use_spatial_transformer: True
22
+ use_linear_in_transformer: True
23
+ transformer_depth: 1
24
+ context_dim: 1024
25
+ use_checkpoint: False
26
+ legacy: False
27
+ camera_dim: 16
28
+ with_ip: True
29
+ ip_dim: 16 # ip token length
30
+ ip_mode: "local_resample"
31
+ ip_weight: 1.0 # adjust for similarity to image
32
+
33
+ vae_config:
34
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
35
+ params:
36
+ embed_dim: 4
37
+ monitor: val/rec_loss
38
+ ddconfig:
39
+ #attn_type: "vanilla-xformers"
40
+ double_z: true
41
+ z_channels: 4
42
+ resolution: 256
43
+ in_channels: 3
44
+ out_ch: 3
45
+ ch: 128
46
+ ch_mult:
47
+ - 1
48
+ - 2
49
+ - 4
50
+ - 4
51
+ num_res_blocks: 2
52
+ attn_resolutions: []
53
+ dropout: 0.0
54
+ lossconfig:
55
+ target: torch.nn.Identity
56
+
57
+ clip_config:
58
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
59
+ params:
60
+ freeze: True
61
+ layer: "penultimate"
62
+ ip_mode: "local_resample"
apps/third_party/CRM/imagedream/configs/sd_v2_base_ipmv_zero_SNR.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: imagedream.ldm.interface.LatentDiffusionInterface
3
+ params:
4
+ linear_start: 0.00085
5
+ linear_end: 0.0120
6
+ timesteps: 1000
7
+ scale_factor: 0.18215
8
+ parameterization: "eps"
9
+ zero_snr: true
10
+
11
+ unet_config:
12
+ target: imagedream.ldm.modules.diffusionmodules.openaimodel.MultiViewUNetModel
13
+ params:
14
+ image_size: 32 # unused
15
+ in_channels: 4
16
+ out_channels: 4
17
+ model_channels: 320
18
+ attention_resolutions: [ 4, 2, 1 ]
19
+ num_res_blocks: 2
20
+ channel_mult: [ 1, 2, 4, 4 ]
21
+ num_head_channels: 64 # need to fix for flash-attn
22
+ use_spatial_transformer: True
23
+ use_linear_in_transformer: True
24
+ transformer_depth: 1
25
+ context_dim: 1024
26
+ use_checkpoint: False
27
+ legacy: False
28
+ camera_dim: 16
29
+ with_ip: True
30
+ ip_dim: 16 # ip token length
31
+ ip_mode: "local_resample"
32
+
33
+ vae_config:
34
+ target: imagedream.ldm.models.autoencoder.AutoencoderKL
35
+ params:
36
+ embed_dim: 4
37
+ monitor: val/rec_loss
38
+ ddconfig:
39
+ #attn_type: "vanilla-xformers"
40
+ double_z: true
41
+ z_channels: 4
42
+ resolution: 256
43
+ in_channels: 3
44
+ out_ch: 3
45
+ ch: 128
46
+ ch_mult:
47
+ - 1
48
+ - 2
49
+ - 4
50
+ - 4
51
+ num_res_blocks: 2
52
+ attn_resolutions: []
53
+ dropout: 0.0
54
+ lossconfig:
55
+ target: torch.nn.Identity
56
+
57
+ clip_config:
58
+ target: imagedream.ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
59
+ params:
60
+ freeze: True
61
+ layer: "penultimate"
62
+ ip_mode: "local_resample"
apps/third_party/CRM/imagedream/ldm/__init__.py ADDED
File without changes
apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (178 Bytes). View file
 
apps/third_party/CRM/imagedream/ldm/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (175 Bytes). View file
 
apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-310.pyc ADDED
Binary file (6.27 kB). View file
 
apps/third_party/CRM/imagedream/ldm/__pycache__/interface.cpython-38.pyc ADDED
Binary file (6.33 kB). View file
 
apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-310.pyc ADDED
Binary file (6.75 kB). View file
 
apps/third_party/CRM/imagedream/ldm/__pycache__/util.cpython-38.pyc ADDED
Binary file (6.73 kB). View file
 
apps/third_party/CRM/imagedream/ldm/interface.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .modules.diffusionmodules.util import (
9
+ make_beta_schedule,
10
+ extract_into_tensor,
11
+ enforce_zero_terminal_snr,
12
+ noise_like,
13
+ )
14
+ from .util import exists, default, instantiate_from_config
15
+ from .modules.distributions.distributions import DiagonalGaussianDistribution
16
+
17
+
18
+ class DiffusionWrapper(nn.Module):
19
+ def __init__(self, diffusion_model):
20
+ super().__init__()
21
+ self.diffusion_model = diffusion_model
22
+
23
+ def forward(self, *args, **kwargs):
24
+ return self.diffusion_model(*args, **kwargs)
25
+
26
+
27
+ class LatentDiffusionInterface(nn.Module):
28
+ """a simple interface class for LDM inference"""
29
+
30
+ def __init__(
31
+ self,
32
+ unet_config,
33
+ clip_config,
34
+ vae_config,
35
+ parameterization="eps",
36
+ scale_factor=0.18215,
37
+ beta_schedule="linear",
38
+ timesteps=1000,
39
+ linear_start=0.00085,
40
+ linear_end=0.0120,
41
+ cosine_s=8e-3,
42
+ given_betas=None,
43
+ zero_snr=False,
44
+ *args,
45
+ **kwargs,
46
+ ):
47
+ super().__init__()
48
+
49
+ unet = instantiate_from_config(unet_config)
50
+ self.model = DiffusionWrapper(unet)
51
+ self.clip_model = instantiate_from_config(clip_config)
52
+ self.vae_model = instantiate_from_config(vae_config)
53
+
54
+ self.parameterization = parameterization
55
+ self.scale_factor = scale_factor
56
+ self.register_schedule(
57
+ given_betas=given_betas,
58
+ beta_schedule=beta_schedule,
59
+ timesteps=timesteps,
60
+ linear_start=linear_start,
61
+ linear_end=linear_end,
62
+ cosine_s=cosine_s,
63
+ zero_snr=zero_snr
64
+ )
65
+
66
+ def register_schedule(
67
+ self,
68
+ given_betas=None,
69
+ beta_schedule="linear",
70
+ timesteps=1000,
71
+ linear_start=1e-4,
72
+ linear_end=2e-2,
73
+ cosine_s=8e-3,
74
+ zero_snr=False
75
+ ):
76
+ if exists(given_betas):
77
+ betas = given_betas
78
+ else:
79
+ betas = make_beta_schedule(
80
+ beta_schedule,
81
+ timesteps,
82
+ linear_start=linear_start,
83
+ linear_end=linear_end,
84
+ cosine_s=cosine_s,
85
+ )
86
+ if zero_snr:
87
+ print("--- using zero snr---")
88
+ betas = enforce_zero_terminal_snr(betas).numpy()
89
+ alphas = 1.0 - betas
90
+ alphas_cumprod = np.cumprod(alphas, axis=0)
91
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
92
+
93
+ (timesteps,) = betas.shape
94
+ self.num_timesteps = int(timesteps)
95
+ self.linear_start = linear_start
96
+ self.linear_end = linear_end
97
+ assert (
98
+ alphas_cumprod.shape[0] == self.num_timesteps
99
+ ), "alphas have to be defined for each timestep"
100
+
101
+ to_torch = partial(torch.tensor, dtype=torch.float32)
102
+
103
+ self.register_buffer("betas", to_torch(betas))
104
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
105
+ self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
106
+
107
+ # calculations for diffusion q(x_t | x_{t-1}) and others
108
+ self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
109
+ self.register_buffer(
110
+ "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
111
+ )
112
+ self.register_buffer(
113
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
114
+ )
115
+ eps = 1e-8 # adding small epsilon value to avoid devide by zero error
116
+ self.register_buffer(
117
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps)))
118
+ )
119
+ self.register_buffer(
120
+ "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1))
121
+ )
122
+
123
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
124
+ self.v_posterior = 0
125
+ posterior_variance = (1 - self.v_posterior) * betas * (
126
+ 1.0 - alphas_cumprod_prev
127
+ ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
128
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
129
+ self.register_buffer("posterior_variance", to_torch(posterior_variance))
130
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
131
+ self.register_buffer(
132
+ "posterior_log_variance_clipped",
133
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
134
+ )
135
+ self.register_buffer(
136
+ "posterior_mean_coef1",
137
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
138
+ )
139
+ self.register_buffer(
140
+ "posterior_mean_coef2",
141
+ to_torch(
142
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
143
+ ),
144
+ )
145
+
146
+ def q_sample(self, x_start, t, noise=None):
147
+ noise = default(noise, lambda: torch.randn_like(x_start))
148
+ return (
149
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
150
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
151
+ * noise
152
+ )
153
+
154
+ def get_v(self, x, noise, t):
155
+ return (
156
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
157
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
158
+ )
159
+
160
+ def predict_start_from_noise(self, x_t, t, noise):
161
+ return (
162
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
163
+ - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
164
+ * noise
165
+ )
166
+
167
+ def predict_start_from_z_and_v(self, x_t, t, v):
168
+ return (
169
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
170
+ - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
171
+ )
172
+
173
+ def predict_eps_from_z_and_v(self, x_t, t, v):
174
+ return (
175
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
176
+ + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
177
+ * x_t
178
+ )
179
+
180
+ def apply_model(self, x_noisy, t, cond, **kwargs):
181
+ assert isinstance(cond, dict), "cond has to be a dictionary"
182
+ return self.model(x_noisy, t, **cond, **kwargs)
183
+
184
+ def get_learned_conditioning(self, prompts: List[str]):
185
+ return self.clip_model(prompts)
186
+
187
+ def get_learned_image_conditioning(self, images):
188
+ return self.clip_model.forward_image(images)
189
+
190
+ def get_first_stage_encoding(self, encoder_posterior):
191
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
192
+ z = encoder_posterior.sample()
193
+ elif isinstance(encoder_posterior, torch.Tensor):
194
+ z = encoder_posterior
195
+ else:
196
+ raise NotImplementedError(
197
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
198
+ )
199
+ return self.scale_factor * z
200
+
201
+ def encode_first_stage(self, x):
202
+ return self.vae_model.encode(x)
203
+
204
+ def decode_first_stage(self, z):
205
+ z = 1.0 / self.scale_factor * z
206
+ return self.vae_model.decode(z)
apps/third_party/CRM/imagedream/ldm/models/__init__.py ADDED
File without changes
apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (185 Bytes). View file
 
apps/third_party/CRM/imagedream/ldm/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (182 Bytes). View file
 
apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-310.pyc ADDED
Binary file (7.79 kB). View file
 
apps/third_party/CRM/imagedream/ldm/models/__pycache__/autoencoder.cpython-38.pyc ADDED
Binary file (7.68 kB). View file
 
apps/third_party/CRM/imagedream/ldm/models/autoencoder.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from contextlib import contextmanager
4
+
5
+ from ..modules.diffusionmodules.model import Encoder, Decoder
6
+ from ..modules.distributions.distributions import DiagonalGaussianDistribution
7
+
8
+ from ..util import instantiate_from_config
9
+ from ..modules.ema import LitEma
10
+
11
+
12
+ class AutoencoderKL(torch.nn.Module):
13
+ def __init__(
14
+ self,
15
+ ddconfig,
16
+ lossconfig,
17
+ embed_dim,
18
+ ckpt_path=None,
19
+ ignore_keys=[],
20
+ image_key="image",
21
+ colorize_nlabels=None,
22
+ monitor=None,
23
+ ema_decay=None,
24
+ learn_logvar=False,
25
+ ):
26
+ super().__init__()
27
+ self.learn_logvar = learn_logvar
28
+ self.image_key = image_key
29
+ self.encoder = Encoder(**ddconfig)
30
+ self.decoder = Decoder(**ddconfig)
31
+ self.loss = instantiate_from_config(lossconfig)
32
+ assert ddconfig["double_z"]
33
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
34
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
+ self.embed_dim = embed_dim
36
+ if colorize_nlabels is not None:
37
+ assert type(colorize_nlabels) == int
38
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
+ if monitor is not None:
40
+ self.monitor = monitor
41
+
42
+ self.use_ema = ema_decay is not None
43
+ if self.use_ema:
44
+ self.ema_decay = ema_decay
45
+ assert 0.0 < ema_decay < 1.0
46
+ self.model_ema = LitEma(self, decay=ema_decay)
47
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
+
49
+ if ckpt_path is not None:
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
+
52
+ def init_from_ckpt(self, path, ignore_keys=list()):
53
+ sd = torch.load(path, map_location="cpu")["state_dict"]
54
+ keys = list(sd.keys())
55
+ for k in keys:
56
+ for ik in ignore_keys:
57
+ if k.startswith(ik):
58
+ print("Deleting key {} from state_dict.".format(k))
59
+ del sd[k]
60
+ self.load_state_dict(sd, strict=False)
61
+ print(f"Restored from {path}")
62
+
63
+ @contextmanager
64
+ def ema_scope(self, context=None):
65
+ if self.use_ema:
66
+ self.model_ema.store(self.parameters())
67
+ self.model_ema.copy_to(self)
68
+ if context is not None:
69
+ print(f"{context}: Switched to EMA weights")
70
+ try:
71
+ yield None
72
+ finally:
73
+ if self.use_ema:
74
+ self.model_ema.restore(self.parameters())
75
+ if context is not None:
76
+ print(f"{context}: Restored training weights")
77
+
78
+ def on_train_batch_end(self, *args, **kwargs):
79
+ if self.use_ema:
80
+ self.model_ema(self)
81
+
82
+ def encode(self, x):
83
+ h = self.encoder(x)
84
+ moments = self.quant_conv(h)
85
+ posterior = DiagonalGaussianDistribution(moments)
86
+ return posterior
87
+
88
+ def decode(self, z):
89
+ z = self.post_quant_conv(z)
90
+ dec = self.decoder(z)
91
+ return dec
92
+
93
+ def forward(self, input, sample_posterior=True):
94
+ posterior = self.encode(input)
95
+ if sample_posterior:
96
+ z = posterior.sample()
97
+ else:
98
+ z = posterior.mode()
99
+ dec = self.decode(z)
100
+ return dec, posterior
101
+
102
+ def get_input(self, batch, k):
103
+ x = batch[k]
104
+ if len(x.shape) == 3:
105
+ x = x[..., None]
106
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
+ return x
108
+
109
+ def training_step(self, batch, batch_idx, optimizer_idx):
110
+ inputs = self.get_input(batch, self.image_key)
111
+ reconstructions, posterior = self(inputs)
112
+
113
+ if optimizer_idx == 0:
114
+ # train encoder+decoder+logvar
115
+ aeloss, log_dict_ae = self.loss(
116
+ inputs,
117
+ reconstructions,
118
+ posterior,
119
+ optimizer_idx,
120
+ self.global_step,
121
+ last_layer=self.get_last_layer(),
122
+ split="train",
123
+ )
124
+ self.log(
125
+ "aeloss",
126
+ aeloss,
127
+ prog_bar=True,
128
+ logger=True,
129
+ on_step=True,
130
+ on_epoch=True,
131
+ )
132
+ self.log_dict(
133
+ log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False
134
+ )
135
+ return aeloss
136
+
137
+ if optimizer_idx == 1:
138
+ # train the discriminator
139
+ discloss, log_dict_disc = self.loss(
140
+ inputs,
141
+ reconstructions,
142
+ posterior,
143
+ optimizer_idx,
144
+ self.global_step,
145
+ last_layer=self.get_last_layer(),
146
+ split="train",
147
+ )
148
+
149
+ self.log(
150
+ "discloss",
151
+ discloss,
152
+ prog_bar=True,
153
+ logger=True,
154
+ on_step=True,
155
+ on_epoch=True,
156
+ )
157
+ self.log_dict(
158
+ log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False
159
+ )
160
+ return discloss
161
+
162
+ def validation_step(self, batch, batch_idx):
163
+ log_dict = self._validation_step(batch, batch_idx)
164
+ with self.ema_scope():
165
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
166
+ return log_dict
167
+
168
+ def _validation_step(self, batch, batch_idx, postfix=""):
169
+ inputs = self.get_input(batch, self.image_key)
170
+ reconstructions, posterior = self(inputs)
171
+ aeloss, log_dict_ae = self.loss(
172
+ inputs,
173
+ reconstructions,
174
+ posterior,
175
+ 0,
176
+ self.global_step,
177
+ last_layer=self.get_last_layer(),
178
+ split="val" + postfix,
179
+ )
180
+
181
+ discloss, log_dict_disc = self.loss(
182
+ inputs,
183
+ reconstructions,
184
+ posterior,
185
+ 1,
186
+ self.global_step,
187
+ last_layer=self.get_last_layer(),
188
+ split="val" + postfix,
189
+ )
190
+
191
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
192
+ self.log_dict(log_dict_ae)
193
+ self.log_dict(log_dict_disc)
194
+ return self.log_dict
195
+
196
+ def configure_optimizers(self):
197
+ lr = self.learning_rate
198
+ ae_params_list = (
199
+ list(self.encoder.parameters())
200
+ + list(self.decoder.parameters())
201
+ + list(self.quant_conv.parameters())
202
+ + list(self.post_quant_conv.parameters())
203
+ )
204
+ if self.learn_logvar:
205
+ print(f"{self.__class__.__name__}: Learning logvar")
206
+ ae_params_list.append(self.loss.logvar)
207
+ opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
208
+ opt_disc = torch.optim.Adam(
209
+ self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)
210
+ )
211
+ return [opt_ae, opt_disc], []
212
+
213
+ def get_last_layer(self):
214
+ return self.decoder.conv_out.weight
215
+
216
+ @torch.no_grad()
217
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
218
+ log = dict()
219
+ x = self.get_input(batch, self.image_key)
220
+ x = x.to(self.device)
221
+ if not only_inputs:
222
+ xrec, posterior = self(x)
223
+ if x.shape[1] > 3:
224
+ # colorize with random projection
225
+ assert xrec.shape[1] > 3
226
+ x = self.to_rgb(x)
227
+ xrec = self.to_rgb(xrec)
228
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
229
+ log["reconstructions"] = xrec
230
+ if log_ema or self.use_ema:
231
+ with self.ema_scope():
232
+ xrec_ema, posterior_ema = self(x)
233
+ if x.shape[1] > 3:
234
+ # colorize with random projection
235
+ assert xrec_ema.shape[1] > 3
236
+ xrec_ema = self.to_rgb(xrec_ema)
237
+ log["samples_ema"] = self.decode(
238
+ torch.randn_like(posterior_ema.sample())
239
+ )
240
+ log["reconstructions_ema"] = xrec_ema
241
+ log["inputs"] = x
242
+ return log
243
+
244
+ def to_rgb(self, x):
245
+ assert self.image_key == "segmentation"
246
+ if not hasattr(self, "colorize"):
247
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
248
+ x = F.conv2d(x, weight=self.colorize)
249
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
250
+ return x
251
+
252
+
253
+ class IdentityFirstStage(torch.nn.Module):
254
+ def __init__(self, *args, vq_interface=False, **kwargs):
255
+ self.vq_interface = vq_interface
256
+ super().__init__()
257
+
258
+ def encode(self, x, *args, **kwargs):
259
+ return x
260
+
261
+ def decode(self, x, *args, **kwargs):
262
+ return x
263
+
264
+ def quantize(self, x, *args, **kwargs):
265
+ if self.vq_interface:
266
+ return x, None, [None, None, None]
267
+ return x
268
+
269
+ def forward(self, x, *args, **kwargs):
270
+ return x
apps/third_party/CRM/imagedream/ldm/models/diffusion/__init__.py ADDED
File without changes
apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (195 Bytes). View file
 
apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (192 Bytes). View file
 
apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc ADDED
Binary file (8.8 kB). View file
 
apps/third_party/CRM/imagedream/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc ADDED
Binary file (8.77 kB). View file
 
apps/third_party/CRM/imagedream/ldm/models/diffusion/ddim.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+
3
+ import torch
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from ...modules.diffusionmodules.util import (
9
+ make_ddim_sampling_parameters,
10
+ make_ddim_timesteps,
11
+ noise_like,
12
+ extract_into_tensor,
13
+ )
14
+
15
+
16
+ class DDIMSampler(object):
17
+ def __init__(self, model, schedule="linear", **kwargs):
18
+ super().__init__()
19
+ self.model = model
20
+ self.ddpm_num_timesteps = model.num_timesteps
21
+ self.schedule = schedule
22
+
23
+ def register_buffer(self, name, attr):
24
+ if type(attr) == torch.Tensor:
25
+ if attr.device != torch.device("cuda"):
26
+ attr = attr.to(torch.device("cuda"))
27
+ setattr(self, name, attr)
28
+
29
+ def make_schedule(
30
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
31
+ ):
32
+ self.ddim_timesteps = make_ddim_timesteps(
33
+ ddim_discr_method=ddim_discretize,
34
+ num_ddim_timesteps=ddim_num_steps,
35
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
36
+ verbose=verbose,
37
+ )
38
+ alphas_cumprod = self.model.alphas_cumprod
39
+ assert (
40
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
41
+ ), "alphas have to be defined for each timestep"
42
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
43
+
44
+ self.register_buffer("betas", to_torch(self.model.betas))
45
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
46
+ self.register_buffer(
47
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
48
+ )
49
+
50
+ # calculations for diffusion q(x_t | x_{t-1}) and others
51
+ self.register_buffer(
52
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
53
+ )
54
+ self.register_buffer(
55
+ "sqrt_one_minus_alphas_cumprod",
56
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
57
+ )
58
+ self.register_buffer(
59
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
60
+ )
61
+ self.register_buffer(
62
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
63
+ )
64
+ self.register_buffer(
65
+ "sqrt_recipm1_alphas_cumprod",
66
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
67
+ )
68
+
69
+ # ddim sampling parameters
70
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
71
+ alphacums=alphas_cumprod.cpu(),
72
+ ddim_timesteps=self.ddim_timesteps,
73
+ eta=ddim_eta,
74
+ verbose=verbose,
75
+ )
76
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
77
+ self.register_buffer("ddim_alphas", ddim_alphas)
78
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
79
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
80
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
81
+ (1 - self.alphas_cumprod_prev)
82
+ / (1 - self.alphas_cumprod)
83
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
84
+ )
85
+ self.register_buffer(
86
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
87
+ )
88
+
89
+ @torch.no_grad()
90
+ def sample(
91
+ self,
92
+ S,
93
+ batch_size,
94
+ shape,
95
+ conditioning=None,
96
+ callback=None,
97
+ normals_sequence=None,
98
+ img_callback=None,
99
+ quantize_x0=False,
100
+ eta=0.0,
101
+ mask=None,
102
+ x0=None,
103
+ temperature=1.0,
104
+ noise_dropout=0.0,
105
+ score_corrector=None,
106
+ corrector_kwargs=None,
107
+ verbose=True,
108
+ x_T=None,
109
+ log_every_t=100,
110
+ unconditional_guidance_scale=1.0,
111
+ unconditional_conditioning=None,
112
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
113
+ **kwargs,
114
+ ):
115
+ if conditioning is not None:
116
+ if isinstance(conditioning, dict):
117
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
118
+ if cbs != batch_size:
119
+ print(
120
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
121
+ )
122
+ else:
123
+ if conditioning.shape[0] != batch_size:
124
+ print(
125
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
126
+ )
127
+
128
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
129
+ # sampling
130
+ C, H, W = shape
131
+ size = (batch_size, C, H, W)
132
+
133
+ samples, intermediates = self.ddim_sampling(
134
+ conditioning,
135
+ size,
136
+ callback=callback,
137
+ img_callback=img_callback,
138
+ quantize_denoised=quantize_x0,
139
+ mask=mask,
140
+ x0=x0,
141
+ ddim_use_original_steps=False,
142
+ noise_dropout=noise_dropout,
143
+ temperature=temperature,
144
+ score_corrector=score_corrector,
145
+ corrector_kwargs=corrector_kwargs,
146
+ x_T=x_T,
147
+ log_every_t=log_every_t,
148
+ unconditional_guidance_scale=unconditional_guidance_scale,
149
+ unconditional_conditioning=unconditional_conditioning,
150
+ **kwargs,
151
+ )
152
+ return samples, intermediates
153
+
154
+ @torch.no_grad()
155
+ def ddim_sampling(
156
+ self,
157
+ cond,
158
+ shape,
159
+ x_T=None,
160
+ ddim_use_original_steps=False,
161
+ callback=None,
162
+ timesteps=None,
163
+ quantize_denoised=False,
164
+ mask=None,
165
+ x0=None,
166
+ img_callback=None,
167
+ log_every_t=100,
168
+ temperature=1.0,
169
+ noise_dropout=0.0,
170
+ score_corrector=None,
171
+ corrector_kwargs=None,
172
+ unconditional_guidance_scale=1.0,
173
+ unconditional_conditioning=None,
174
+ **kwargs,
175
+ ):
176
+ """
177
+ when inference time: all values of parameter
178
+ cond.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
179
+ shape: (5, 4, 32, 32)
180
+ x_T: None
181
+ ddim_use_original_steps: False
182
+ timesteps: None
183
+ callback: None
184
+ quantize_denoised: False
185
+ mask: None
186
+ image_callback: None
187
+ log_every_t: 100
188
+ temperature: 1.0
189
+ noise_dropout: 0.0
190
+ score_corrector: None
191
+ corrector_kwargs: None
192
+ unconditional_guidance_scale: 5
193
+ unconditional_conditioning.keys(): dict_keys(['context', 'camera', 'num_frames', 'ip', 'ip_img'])
194
+ kwargs: {}
195
+ """
196
+ device = self.model.betas.device
197
+ b = shape[0]
198
+ if x_T is None:
199
+ img = torch.randn(shape, device=device) # shape: torch.Size([5, 4, 32, 32]) mean: -0.00, std: 1.00, min: -3.64, max: 3.94
200
+ else:
201
+ img = x_T
202
+
203
+ if timesteps is None: # equal with set time step in hf
204
+ timesteps = (
205
+ self.ddpm_num_timesteps
206
+ if ddim_use_original_steps
207
+ else self.ddim_timesteps
208
+ )
209
+ elif timesteps is not None and not ddim_use_original_steps:
210
+ subset_end = (
211
+ int(
212
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
213
+ * self.ddim_timesteps.shape[0]
214
+ )
215
+ - 1
216
+ )
217
+ timesteps = self.ddim_timesteps[:subset_end]
218
+
219
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
220
+ time_range = ( # reversed timesteps
221
+ reversed(range(0, timesteps))
222
+ if ddim_use_original_steps
223
+ else np.flip(timesteps)
224
+ )
225
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
226
+ iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
227
+ for i, step in enumerate(iterator):
228
+ index = total_steps - i - 1
229
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
230
+
231
+ if mask is not None:
232
+ assert x0 is not None
233
+ img_orig = self.model.q_sample(
234
+ x0, ts
235
+ ) # TODO: deterministic forward pass?
236
+ img = img_orig * mask + (1.0 - mask) * img
237
+
238
+ outs = self.p_sample_ddim(
239
+ img,
240
+ cond,
241
+ ts,
242
+ index=index,
243
+ use_original_steps=ddim_use_original_steps,
244
+ quantize_denoised=quantize_denoised,
245
+ temperature=temperature,
246
+ noise_dropout=noise_dropout,
247
+ score_corrector=score_corrector,
248
+ corrector_kwargs=corrector_kwargs,
249
+ unconditional_guidance_scale=unconditional_guidance_scale,
250
+ unconditional_conditioning=unconditional_conditioning,
251
+ **kwargs,
252
+ )
253
+ img, pred_x0 = outs
254
+ if callback:
255
+ callback(i)
256
+ if img_callback:
257
+ img_callback(pred_x0, i)
258
+
259
+ if index % log_every_t == 0 or index == total_steps - 1:
260
+ intermediates["x_inter"].append(img)
261
+ intermediates["pred_x0"].append(pred_x0)
262
+
263
+ return img, intermediates
264
+
265
+ @torch.no_grad()
266
+ def p_sample_ddim(
267
+ self,
268
+ x,
269
+ c,
270
+ t,
271
+ index,
272
+ repeat_noise=False,
273
+ use_original_steps=False,
274
+ quantize_denoised=False,
275
+ temperature=1.0,
276
+ noise_dropout=0.0,
277
+ score_corrector=None,
278
+ corrector_kwargs=None,
279
+ unconditional_guidance_scale=1.0,
280
+ unconditional_conditioning=None,
281
+ dynamic_threshold=None,
282
+ **kwargs,
283
+ ):
284
+ b, *_, device = *x.shape, x.device
285
+
286
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
287
+ model_output = self.model.apply_model(x, t, c)
288
+ else:
289
+ x_in = torch.cat([x] * 2)
290
+ t_in = torch.cat([t] * 2)
291
+ if isinstance(c, dict):
292
+ assert isinstance(unconditional_conditioning, dict)
293
+ c_in = dict()
294
+ for k in c:
295
+ if isinstance(c[k], list):
296
+ c_in[k] = [
297
+ torch.cat([unconditional_conditioning[k][i], c[k][i]])
298
+ for i in range(len(c[k]))
299
+ ]
300
+ elif isinstance(c[k], torch.Tensor):
301
+ c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
302
+ else:
303
+ assert c[k] == unconditional_conditioning[k]
304
+ c_in[k] = c[k]
305
+ elif isinstance(c, list):
306
+ c_in = list()
307
+ assert isinstance(unconditional_conditioning, list)
308
+ for i in range(len(c)):
309
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
310
+ else:
311
+ c_in = torch.cat([unconditional_conditioning, c])
312
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
313
+ model_output = model_uncond + unconditional_guidance_scale * (
314
+ model_t - model_uncond
315
+ )
316
+
317
+
318
+ if self.model.parameterization == "v":
319
+ print("using v!")
320
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
321
+ else:
322
+ e_t = model_output
323
+
324
+ if score_corrector is not None:
325
+ assert self.model.parameterization == "eps", "not implemented"
326
+ e_t = score_corrector.modify_score(
327
+ self.model, e_t, x, t, c, **corrector_kwargs
328
+ )
329
+
330
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
331
+ alphas_prev = (
332
+ self.model.alphas_cumprod_prev
333
+ if use_original_steps
334
+ else self.ddim_alphas_prev
335
+ )
336
+ sqrt_one_minus_alphas = (
337
+ self.model.sqrt_one_minus_alphas_cumprod
338
+ if use_original_steps
339
+ else self.ddim_sqrt_one_minus_alphas
340
+ )
341
+ sigmas = (
342
+ self.model.ddim_sigmas_for_original_num_steps
343
+ if use_original_steps
344
+ else self.ddim_sigmas
345
+ )
346
+ # select parameters corresponding to the currently considered timestep
347
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
348
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
349
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
350
+ sqrt_one_minus_at = torch.full(
351
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
352
+ )
353
+
354
+ # current prediction for x_0
355
+ if self.model.parameterization != "v":
356
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
357
+ else:
358
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
359
+
360
+ if quantize_denoised:
361
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
362
+
363
+ if dynamic_threshold is not None:
364
+ raise NotImplementedError()
365
+
366
+ # direction pointing to x_t
367
+ dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
368
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
369
+ if noise_dropout > 0.0:
370
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
371
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
372
+ return x_prev, pred_x0
373
+
374
+ @torch.no_grad()
375
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
376
+ # fast, but does not allow for exact reconstruction
377
+ # t serves as an index to gather the correct alphas
378
+ if use_original_steps:
379
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
380
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
381
+ else:
382
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
383
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
384
+
385
+ if noise is None:
386
+ noise = torch.randn_like(x0)
387
+ return (
388
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
389
+ + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
390
+ )
391
+
392
+ @torch.no_grad()
393
+ def decode(
394
+ self,
395
+ x_latent,
396
+ cond,
397
+ t_start,
398
+ unconditional_guidance_scale=1.0,
399
+ unconditional_conditioning=None,
400
+ use_original_steps=False,
401
+ **kwargs,
402
+ ):
403
+ timesteps = (
404
+ np.arange(self.ddpm_num_timesteps)
405
+ if use_original_steps
406
+ else self.ddim_timesteps
407
+ )
408
+ timesteps = timesteps[:t_start]
409
+
410
+ time_range = np.flip(timesteps)
411
+ total_steps = timesteps.shape[0]
412
+
413
+ iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
414
+ x_dec = x_latent
415
+ for i, step in enumerate(iterator):
416
+ index = total_steps - i - 1
417
+ ts = torch.full(
418
+ (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
419
+ )
420
+ x_dec, _ = self.p_sample_ddim(
421
+ x_dec,
422
+ cond,
423
+ ts,
424
+ index=index,
425
+ use_original_steps=use_original_steps,
426
+ unconditional_guidance_scale=unconditional_guidance_scale,
427
+ unconditional_conditioning=unconditional_conditioning,
428
+ **kwargs,
429
+ )
430
+ return x_dec
apps/third_party/CRM/imagedream/ldm/modules/__init__.py ADDED
File without changes
apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
apps/third_party/CRM/imagedream/ldm/modules/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (183 Bytes). View file
 
apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
apps/third_party/CRM/imagedream/ldm/modules/__pycache__/attention.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-310.pyc ADDED
Binary file (3.24 kB). View file
 
apps/third_party/CRM/imagedream/ldm/modules/__pycache__/ema.cpython-38.pyc ADDED
Binary file (3.23 kB). View file
 
apps/third_party/CRM/imagedream/ldm/modules/attention.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+ from typing import Optional, Any
8
+
9
+ from .diffusionmodules.util import checkpoint
10
+
11
+
12
+ try:
13
+ import xformers
14
+ import xformers.ops
15
+
16
+ XFORMERS_IS_AVAILBLE = True
17
+ except:
18
+ XFORMERS_IS_AVAILBLE = False
19
+
20
+ # CrossAttn precision handling
21
+ import os
22
+
23
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def uniq(arr):
31
+ return {el: True for el in arr}.keys()
32
+
33
+
34
+ def default(val, d):
35
+ if exists(val):
36
+ return val
37
+ return d() if isfunction(d) else d
38
+
39
+
40
+ def max_neg_value(t):
41
+ return -torch.finfo(t.dtype).max
42
+
43
+
44
+ def init_(tensor):
45
+ dim = tensor.shape[-1]
46
+ std = 1 / math.sqrt(dim)
47
+ tensor.uniform_(-std, std)
48
+ return tensor
49
+
50
+
51
+ # feedforward
52
+ class GEGLU(nn.Module):
53
+ def __init__(self, dim_in, dim_out):
54
+ super().__init__()
55
+ self.proj = nn.Linear(dim_in, dim_out * 2)
56
+
57
+ def forward(self, x):
58
+ x, gate = self.proj(x).chunk(2, dim=-1)
59
+ return x * F.gelu(gate)
60
+
61
+
62
+ class FeedForward(nn.Module):
63
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
64
+ super().__init__()
65
+ inner_dim = int(dim * mult)
66
+ dim_out = default(dim_out, dim)
67
+ project_in = (
68
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
69
+ if not glu
70
+ else GEGLU(dim, inner_dim)
71
+ )
72
+
73
+ self.net = nn.Sequential(
74
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.net(x)
79
+
80
+
81
+ def zero_module(module):
82
+ """
83
+ Zero out the parameters of a module and return it.
84
+ """
85
+ for p in module.parameters():
86
+ p.detach().zero_()
87
+ return module
88
+
89
+
90
+ def Normalize(in_channels):
91
+ return torch.nn.GroupNorm(
92
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
93
+ )
94
+
95
+
96
+ class SpatialSelfAttention(nn.Module):
97
+ def __init__(self, in_channels):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+
101
+ self.norm = Normalize(in_channels)
102
+ self.q = torch.nn.Conv2d(
103
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
104
+ )
105
+ self.k = torch.nn.Conv2d(
106
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
107
+ )
108
+ self.v = torch.nn.Conv2d(
109
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
110
+ )
111
+ self.proj_out = torch.nn.Conv2d(
112
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
113
+ )
114
+
115
+ def forward(self, x):
116
+ h_ = x
117
+ h_ = self.norm(h_)
118
+ q = self.q(h_)
119
+ k = self.k(h_)
120
+ v = self.v(h_)
121
+
122
+ # compute attention
123
+ b, c, h, w = q.shape
124
+ q = rearrange(q, "b c h w -> b (h w) c")
125
+ k = rearrange(k, "b c h w -> b c (h w)")
126
+ w_ = torch.einsum("bij,bjk->bik", q, k)
127
+
128
+ w_ = w_ * (int(c) ** (-0.5))
129
+ w_ = torch.nn.functional.softmax(w_, dim=2)
130
+
131
+ # attend to values
132
+ v = rearrange(v, "b c h w -> b c (h w)")
133
+ w_ = rearrange(w_, "b i j -> b j i")
134
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
135
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
136
+ h_ = self.proj_out(h_)
137
+
138
+ return x + h_
139
+
140
+
141
+ class MemoryEfficientCrossAttention(nn.Module):
142
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
143
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs):
144
+ super().__init__()
145
+ print(
146
+ f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
147
+ f"{heads} heads."
148
+ )
149
+ inner_dim = dim_head * heads
150
+ context_dim = default(context_dim, query_dim)
151
+
152
+ self.heads = heads
153
+ self.dim_head = dim_head
154
+
155
+ self.with_ip = kwargs.get("with_ip", False)
156
+ if self.with_ip and (context_dim is not None):
157
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
158
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
159
+ self.ip_dim= kwargs.get("ip_dim", 16)
160
+ self.ip_weight = kwargs.get("ip_weight", 1.0)
161
+
162
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
163
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
164
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
165
+
166
+ self.to_out = nn.Sequential(
167
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
168
+ )
169
+ self.attention_op = None
170
+
171
+ def forward(self, x, context=None, mask=None):
172
+ q = self.to_q(x)
173
+
174
+ has_ip = self.with_ip and (context is not None)
175
+ if has_ip:
176
+ # context dim [(b frame_num), (77 + img_token), 1024]
177
+ token_len = context.shape[1]
178
+ context_ip = context[:, -self.ip_dim:, :]
179
+ k_ip = self.to_k_ip(context_ip)
180
+ v_ip = self.to_v_ip(context_ip)
181
+ context = context[:, :(token_len - self.ip_dim), :]
182
+
183
+ context = default(context, x)
184
+ k = self.to_k(context)
185
+ v = self.to_v(context)
186
+
187
+ b, _, _ = q.shape
188
+ q, k, v = map(
189
+ lambda t: t.unsqueeze(3)
190
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
191
+ .permute(0, 2, 1, 3)
192
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
193
+ .contiguous(),
194
+ (q, k, v),
195
+ )
196
+
197
+ # actually compute the attention, what we cannot get enough of
198
+ out = xformers.ops.memory_efficient_attention(
199
+ q, k, v, attn_bias=None, op=self.attention_op
200
+ )
201
+
202
+ if has_ip:
203
+ k_ip, v_ip = map(
204
+ lambda t: t.unsqueeze(3)
205
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
206
+ .permute(0, 2, 1, 3)
207
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
208
+ .contiguous(),
209
+ (k_ip, v_ip),
210
+ )
211
+ # actually compute the attention, what we cannot get enough of
212
+ out_ip = xformers.ops.memory_efficient_attention(
213
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
214
+ )
215
+ out = out + self.ip_weight * out_ip
216
+
217
+ if exists(mask):
218
+ raise NotImplementedError
219
+ out = (
220
+ out.unsqueeze(0)
221
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
222
+ .permute(0, 2, 1, 3)
223
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
224
+ )
225
+ return self.to_out(out)
226
+
227
+
228
+ class BasicTransformerBlock(nn.Module):
229
+ def __init__(
230
+ self,
231
+ dim,
232
+ n_heads,
233
+ d_head,
234
+ dropout=0.0,
235
+ context_dim=None,
236
+ gated_ff=True,
237
+ checkpoint=True,
238
+ disable_self_attn=False,
239
+ **kwargs
240
+ ):
241
+ super().__init__()
242
+ assert XFORMERS_IS_AVAILBLE, "xformers is not available"
243
+ attn_cls = MemoryEfficientCrossAttention
244
+ self.disable_self_attn = disable_self_attn
245
+ self.attn1 = attn_cls(
246
+ query_dim=dim,
247
+ heads=n_heads,
248
+ dim_head=d_head,
249
+ dropout=dropout,
250
+ context_dim=context_dim if self.disable_self_attn else None,
251
+ ) # is a self-attention if not self.disable_self_attn
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = attn_cls(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ **kwargs
260
+ ) # is self-attn if context is none
261
+ self.norm1 = nn.LayerNorm(dim)
262
+ self.norm2 = nn.LayerNorm(dim)
263
+ self.norm3 = nn.LayerNorm(dim)
264
+ self.checkpoint = checkpoint
265
+
266
+ def forward(self, x, context=None):
267
+ return checkpoint(
268
+ self._forward, (x, context), self.parameters(), self.checkpoint
269
+ )
270
+
271
+ def _forward(self, x, context=None):
272
+ x = (
273
+ self.attn1(
274
+ self.norm1(x), context=context if self.disable_self_attn else None
275
+ )
276
+ + x
277
+ )
278
+ x = self.attn2(self.norm2(x), context=context) + x
279
+ x = self.ff(self.norm3(x)) + x
280
+ return x
281
+
282
+
283
+ class SpatialTransformer(nn.Module):
284
+ """
285
+ Transformer block for image-like data.
286
+ First, project the input (aka embedding)
287
+ and reshape to b, t, d.
288
+ Then apply standard transformer action.
289
+ Finally, reshape to image
290
+ NEW: use_linear for more efficiency instead of the 1x1 convs
291
+ """
292
+
293
+ def __init__(
294
+ self,
295
+ in_channels,
296
+ n_heads,
297
+ d_head,
298
+ depth=1,
299
+ dropout=0.0,
300
+ context_dim=None,
301
+ disable_self_attn=False,
302
+ use_linear=False,
303
+ use_checkpoint=True,
304
+ **kwargs
305
+ ):
306
+ super().__init__()
307
+ if exists(context_dim) and not isinstance(context_dim, list):
308
+ context_dim = [context_dim]
309
+ self.in_channels = in_channels
310
+ inner_dim = n_heads * d_head
311
+ self.norm = Normalize(in_channels)
312
+ if not use_linear:
313
+ self.proj_in = nn.Conv2d(
314
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
315
+ )
316
+ else:
317
+ self.proj_in = nn.Linear(in_channels, inner_dim)
318
+
319
+ self.transformer_blocks = nn.ModuleList(
320
+ [
321
+ BasicTransformerBlock(
322
+ inner_dim,
323
+ n_heads,
324
+ d_head,
325
+ dropout=dropout,
326
+ context_dim=context_dim[d],
327
+ disable_self_attn=disable_self_attn,
328
+ checkpoint=use_checkpoint,
329
+ **kwargs
330
+ )
331
+ for d in range(depth)
332
+ ]
333
+ )
334
+ if not use_linear:
335
+ self.proj_out = zero_module(
336
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
337
+ )
338
+ else:
339
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
340
+ self.use_linear = use_linear
341
+
342
+ def forward(self, x, context=None):
343
+ # note: if no context is given, cross-attention defaults to self-attention
344
+ if not isinstance(context, list):
345
+ context = [context]
346
+ b, c, h, w = x.shape
347
+ x_in = x
348
+ x = self.norm(x)
349
+ if not self.use_linear:
350
+ x = self.proj_in(x)
351
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
352
+ if self.use_linear:
353
+ x = self.proj_in(x)
354
+ for i, block in enumerate(self.transformer_blocks):
355
+ x = block(x, context=context[i])
356
+ if self.use_linear:
357
+ x = self.proj_out(x)
358
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
359
+ if not self.use_linear:
360
+ x = self.proj_out(x)
361
+ return x + x_in
362
+
363
+
364
+ class BasicTransformerBlock3D(BasicTransformerBlock):
365
+ def forward(self, x, context=None, num_frames=1):
366
+ return checkpoint(
367
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
368
+ )
369
+
370
+ def _forward(self, x, context=None, num_frames=1):
371
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
372
+ x = (
373
+ self.attn1(
374
+ self.norm1(x),
375
+ context=context if self.disable_self_attn else None
376
+ )
377
+ + x
378
+ )
379
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
380
+ x = self.attn2(self.norm2(x), context=context) + x
381
+ x = self.ff(self.norm3(x)) + x
382
+ return x
383
+
384
+
385
+ class SpatialTransformer3D(nn.Module):
386
+ """3D self-attention"""
387
+
388
+ def __init__(
389
+ self,
390
+ in_channels,
391
+ n_heads,
392
+ d_head,
393
+ depth=1,
394
+ dropout=0.0,
395
+ context_dim=None,
396
+ disable_self_attn=False,
397
+ use_linear=False,
398
+ use_checkpoint=True,
399
+ **kwargs
400
+ ):
401
+ super().__init__()
402
+ if exists(context_dim) and not isinstance(context_dim, list):
403
+ context_dim = [context_dim]
404
+ self.in_channels = in_channels
405
+ inner_dim = n_heads * d_head
406
+ self.norm = Normalize(in_channels)
407
+ if not use_linear:
408
+ self.proj_in = nn.Conv2d(
409
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
410
+ )
411
+ else:
412
+ self.proj_in = nn.Linear(in_channels, inner_dim)
413
+
414
+ self.transformer_blocks = nn.ModuleList(
415
+ [
416
+ BasicTransformerBlock3D(
417
+ inner_dim,
418
+ n_heads,
419
+ d_head,
420
+ dropout=dropout,
421
+ context_dim=context_dim[d],
422
+ disable_self_attn=disable_self_attn,
423
+ checkpoint=use_checkpoint,
424
+ **kwargs
425
+ )
426
+ for d in range(depth)
427
+ ]
428
+ )
429
+ if not use_linear:
430
+ self.proj_out = zero_module(
431
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
432
+ )
433
+ else:
434
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
435
+ self.use_linear = use_linear
436
+
437
+ def forward(self, x, context=None, num_frames=1):
438
+ # note: if no context is given, cross-attention defaults to self-attention
439
+ if not isinstance(context, list):
440
+ context = [context]
441
+ b, c, h, w = x.shape
442
+ x_in = x
443
+ x = self.norm(x)
444
+ if not self.use_linear:
445
+ x = self.proj_in(x)
446
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
447
+ if self.use_linear:
448
+ x = self.proj_in(x)
449
+ for i, block in enumerate(self.transformer_blocks):
450
+ x = block(x, context=context[i], num_frames=num_frames)
451
+ if self.use_linear:
452
+ x = self.proj_out(x)
453
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
454
+ if not self.use_linear:
455
+ x = self.proj_out(x)
456
+ return x + x_in
apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__init__.py ADDED
File without changes
apps/third_party/CRM/imagedream/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (203 Bytes). View file