flamehaze1115 commited on
Commit
b582ef0
1 Parent(s): 9729d73

first commit

Browse files
Files changed (46) hide show
  1. .gitignore +189 -0
  2. README.md +69 -13
  3. app.py +331 -0
  4. configs/mvdiffusion-joint-ortho-6views.yaml +42 -0
  5. example_images/14_10_29_489_Tiger_1__1.png +0 -0
  6. example_images/box.png +0 -0
  7. example_images/bread.png +0 -0
  8. example_images/cat.png +0 -0
  9. example_images/cat_head.png +0 -0
  10. example_images/chili.png +0 -0
  11. example_images/duola.png +0 -0
  12. example_images/halloween.png +0 -0
  13. example_images/head.png +0 -0
  14. example_images/kettle.png +0 -0
  15. example_images/kunkun.png +0 -0
  16. example_images/milk.png +0 -0
  17. example_images/owl.png +0 -0
  18. example_images/poro.png +0 -0
  19. example_images/pumpkin.png +0 -0
  20. example_images/skull.png +0 -0
  21. example_images/stone.png +0 -0
  22. example_images/teapot.png +0 -0
  23. example_images/tiger-head-3d-model-obj-stl.png +0 -0
  24. mvdiffusion/data/fixed_poses/four_views/000_back_RT.txt +3 -0
  25. mvdiffusion/data/fixed_poses/four_views/000_front_RT.txt +3 -0
  26. mvdiffusion/data/fixed_poses/four_views/000_left_RT.txt +3 -0
  27. mvdiffusion/data/fixed_poses/four_views/000_right_RT.txt +3 -0
  28. mvdiffusion/data/fixed_poses/nine_views/000_back_RT.txt +3 -0
  29. mvdiffusion/data/fixed_poses/nine_views/000_back_left_RT.txt +3 -0
  30. mvdiffusion/data/fixed_poses/nine_views/000_back_right_RT.txt +3 -0
  31. mvdiffusion/data/fixed_poses/nine_views/000_front_RT.txt +3 -0
  32. mvdiffusion/data/fixed_poses/nine_views/000_front_left_RT.txt +3 -0
  33. mvdiffusion/data/fixed_poses/nine_views/000_front_right_RT.txt +3 -0
  34. mvdiffusion/data/fixed_poses/nine_views/000_left_RT.txt +3 -0
  35. mvdiffusion/data/fixed_poses/nine_views/000_right_RT.txt +3 -0
  36. mvdiffusion/data/fixed_poses/nine_views/000_top_RT.txt +3 -0
  37. mvdiffusion/data/normal_utils.py +45 -0
  38. mvdiffusion/data/objaverse_dataset.py +608 -0
  39. mvdiffusion/data/single_image_dataset.py +321 -0
  40. mvdiffusion/models/transformer_mv2d.py +1005 -0
  41. mvdiffusion/models/unet_mv2d_blocks.py +880 -0
  42. mvdiffusion/models/unet_mv2d_condition.py +1462 -0
  43. mvdiffusion/pipelines/pipeline_mvdiffusion_image.py +485 -0
  44. requirements.txt +30 -0
  45. run_test.sh +1 -0
  46. utils/misc.py +54 -0
.gitignore ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Initially taken from Github's Python gitignore file
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # tests and logs
12
+ tests/fixtures/cached_*_text.txt
13
+ logs/
14
+ lightning_logs/
15
+ lang_code_data/
16
+
17
+ # Distribution / packaging
18
+ .Python
19
+ build/
20
+ develop-eggs/
21
+ dist/
22
+ downloads/
23
+ eggs/
24
+ .eggs/
25
+ lib/
26
+ lib64/
27
+ parts/
28
+ sdist/
29
+ var/
30
+ wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ .python-version
90
+
91
+ # celery beat schedule file
92
+ celerybeat-schedule
93
+
94
+ # SageMath parsed files
95
+ *.sage.py
96
+
97
+ # Environments
98
+ .env
99
+ .venv
100
+ env/
101
+ venv/
102
+ ENV/
103
+ env.bak/
104
+ venv.bak/
105
+
106
+ # Spyder project settings
107
+ .spyderproject
108
+ .spyproject
109
+
110
+ # Rope project settings
111
+ .ropeproject
112
+
113
+ # mkdocs documentation
114
+ /site
115
+
116
+ # mypy
117
+ .mypy_cache/
118
+ .dmypy.json
119
+ dmypy.json
120
+
121
+ # Pyre type checker
122
+ .pyre/
123
+
124
+ # vscode
125
+ .vs
126
+ .vscode
127
+
128
+ # Pycharm
129
+ .idea
130
+
131
+ # TF code
132
+ tensorflow_code
133
+
134
+ # Models
135
+ proc_data
136
+
137
+ # examples
138
+ runs
139
+ /runs_old
140
+ /wandb
141
+ /examples/runs
142
+ /examples/**/*.args
143
+ /examples/rag/sweep
144
+
145
+ # data
146
+ /data
147
+ serialization_dir
148
+
149
+ # emacs
150
+ *.*~
151
+ debug.env
152
+
153
+ # vim
154
+ .*.swp
155
+
156
+ #ctags
157
+ tags
158
+
159
+ # pre-commit
160
+ .pre-commit*
161
+
162
+ # .lock
163
+ *.lock
164
+
165
+ # DS_Store (MacOS)
166
+ .DS_Store
167
+ # RL pipelines may produce mp4 outputs
168
+ *.mp4
169
+
170
+ # dependencies
171
+ /transformers
172
+
173
+ # ruff
174
+ .ruff_cache
175
+
176
+ # ckpts
177
+ *.ckpt
178
+
179
+ outputs/*
180
+
181
+ NeuS/exp/*
182
+ NeuS/test_scenes/*
183
+ NeuS/mesh2tex/*
184
+ neus_configs
185
+ vast/*
186
+ render_results
187
+ experiments/*
188
+ neus/*
189
+ ckpts/*
README.md CHANGED
@@ -1,13 +1,69 @@
1
- ---
2
- title: Wonder3D
3
- emoji: 👁
4
- colorFrom: indigo
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 3.50.2
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-sa-4.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Wonder3D
2
+ Single Image to 3D using Cross-Domain Diffusion
3
+ ## [Paper](https://arxiv.org/abs/2310.15008) | [Project page](https://www.xxlong.site/Wonder3D/)
4
+
5
+ ![](assets/fig_teaser.png)
6
+
7
+ Wonder3D reconstructs highly-detailed textured meshes from a single-view image in only 2 ∼ 3 minutes. Wonder3D first generates consistent multi-view normal maps with corresponding color images via a cross-domain diffusion model, and then leverages a novel normal fusion method to achieve fast and high-quality reconstruction.
8
+
9
+ ## Schedule
10
+ - [x] Inference code and pretrained models.
11
+ - [ ] Huggingface demo.
12
+ - [ ] Training code.
13
+ - [ ] Rendering code for data prepare.
14
+
15
+
16
+ ### Preparation for inference
17
+ 1. Install packages in `requirements.txt`.
18
+ ```angular2html
19
+ conda create -n wonder3d
20
+ conda activate wonder3d
21
+ pip install -r requirements.txt
22
+ ```
23
+ 2. Download the [checkpoints](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/xxlong_connect_hku_hk/EgSHPyJAtaJFpV_BjXM3zXwB-UMIrT4v-sQwGgw-coPtIA) into the root folder.
24
+
25
+ ### Inference
26
+ 1. Make sure you have the following models.
27
+ ```bash
28
+ Wonder3D
29
+ |-- ckpts
30
+ |-- unet
31
+ |-- scheduler.bin
32
+ ...
33
+ ```
34
+ 2. Predict foreground mask as the alpha channel. We use [Clipdrop](https://clipdrop.co/remove-background) to segment the foreground object interactively.
35
+ You may also use `rembg` to remove the backgrounds.
36
+ ```bash
37
+ # !pip install rembg
38
+ import rembg
39
+ result = rembg.remove(result)
40
+ result.show()
41
+ ```
42
+ 3. Run Wonder3d to produce multiview-consistent normal maps and color images. Then you can check the results in the folder `./outputs`. (we use rembg to remove backgrounds of the results, but the segmemtations are not always perfect.)
43
+ ```bash
44
+ accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py \
45
+ --config mvdiffusion-joint-ortho-6views.yaml
46
+ ```
47
+ or
48
+ ```bash
49
+ bash run_test.sh
50
+ ```
51
+
52
+ 4. Mesh Extraction
53
+ ```bash
54
+ cd ./instant-nsr-pl
55
+ bash run.sh output_folder_path scene_name
56
+ ```
57
+
58
+ ## Citation
59
+ If you find this repository useful in your project, please cite the following work. :)
60
+ ```
61
+ @misc{long2023wonder3d,
62
+ title={Wonder3D: Single Image to 3D using Cross-Domain Diffusion},
63
+ author={Xiaoxiao Long and Yuan-Chen Guo and Cheng Lin and Yuan Liu and Zhiyang Dou and Lingjie Liu and Yuexin Ma and Song-Hai Zhang and Marc Habermann and Christian Theobalt and Wenping Wang},
64
+ year={2023},
65
+ eprint={2310.15008},
66
+ archivePrefix={arXiv},
67
+ primaryClass={cs.CV}
68
+ }
69
+ ```
app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy
4
+ import torch
5
+ import rembg
6
+ import threading
7
+ import urllib.request
8
+ from PIL import Image
9
+ from typing import Dict, Optional, Tuple, List
10
+ from dataclasses import dataclass
11
+ import streamlit as st
12
+ import huggingface_hub
13
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
14
+ from mvdiffusion.models.unet_mv2d_condition import UNetMV2DConditionModel
15
+ from mvdiffusion.data.single_image_dataset import SingleImageDataset as MVDiffusionDataset
16
+ from mvdiffusion.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline
17
+ from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler
18
+
19
+ @dataclass
20
+ class TestConfig:
21
+ pretrained_model_name_or_path: str
22
+ pretrained_unet_path:str
23
+ revision: Optional[str]
24
+ validation_dataset: Dict
25
+ save_dir: str
26
+ seed: Optional[int]
27
+ validation_batch_size: int
28
+ dataloader_num_workers: int
29
+
30
+ local_rank: int
31
+
32
+ pipe_kwargs: Dict
33
+ pipe_validation_kwargs: Dict
34
+ unet_from_pretrained_kwargs: Dict
35
+ validation_guidance_scales: List[float]
36
+ validation_grid_nrow: int
37
+ camera_embedding_lr_mult: float
38
+
39
+ num_views: int
40
+ camera_embedding_type: str
41
+
42
+ pred_type: str # joint, or ablation
43
+
44
+ enable_xformers_memory_efficient_attention: bool
45
+
46
+ cond_on_normals: bool
47
+ cond_on_colors: bool
48
+
49
+ img_example_counter = 0
50
+ iret_base = 'example_images'
51
+ iret = [
52
+ dict(rimageinput=os.path.join(iret_base, x), dispi=os.path.join(iret_base, x))
53
+ for x in sorted(os.listdir(iret_base))
54
+ ]
55
+
56
+
57
+ class SAMAPI:
58
+ predictor = None
59
+
60
+ @staticmethod
61
+ @st.cache_resource
62
+ def get_instance(sam_checkpoint=None):
63
+ if SAMAPI.predictor is None:
64
+ if sam_checkpoint is None:
65
+ sam_checkpoint = "tmp/sam_vit_h_4b8939.pth"
66
+ if not os.path.exists(sam_checkpoint):
67
+ os.makedirs('tmp', exist_ok=True)
68
+ urllib.request.urlretrieve(
69
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
70
+ sam_checkpoint
71
+ )
72
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
73
+ model_type = "default"
74
+
75
+ from segment_anything import sam_model_registry, SamPredictor
76
+
77
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
78
+ sam.to(device=device)
79
+
80
+ predictor = SamPredictor(sam)
81
+ SAMAPI.predictor = predictor
82
+ return SAMAPI.predictor
83
+
84
+ @staticmethod
85
+ def segment_api(rgb, mask=None, bbox=None, sam_checkpoint=None):
86
+ """
87
+
88
+ Parameters
89
+ ----------
90
+ rgb : np.ndarray h,w,3 uint8
91
+ mask: np.ndarray h,w bool
92
+
93
+ Returns
94
+ -------
95
+
96
+ """
97
+ np = numpy
98
+ predictor = SAMAPI.get_instance(sam_checkpoint)
99
+ predictor.set_image(rgb)
100
+ if mask is None and bbox is None:
101
+ box_input = None
102
+ else:
103
+ # mask to bbox
104
+ if bbox is None:
105
+ y1, y2, x1, x2 = np.nonzero(mask)[0].min(), np.nonzero(mask)[0].max(), np.nonzero(mask)[1].min(), \
106
+ np.nonzero(mask)[1].max()
107
+ else:
108
+ x1, y1, x2, y2 = bbox
109
+ box_input = np.array([[x1, y1, x2, y2]])
110
+ masks, scores, logits = predictor.predict(
111
+ box=box_input,
112
+ multimask_output=True,
113
+ return_logits=False,
114
+ )
115
+ mask = masks[-1]
116
+ return mask
117
+
118
+
119
+ def image_examples(samples, ncols, return_key=None, example_text="Examples"):
120
+ global img_example_counter
121
+ trigger = False
122
+ with st.expander(example_text, True):
123
+ for i in range(len(samples) // ncols):
124
+ cols = st.columns(ncols)
125
+ for j in range(ncols):
126
+ idx = i * ncols + j
127
+ if idx >= len(samples):
128
+ continue
129
+ entry = samples[idx]
130
+ with cols[j]:
131
+ st.image(entry['dispi'])
132
+ img_example_counter += 1
133
+ with st.columns(5)[2]:
134
+ this_trigger = st.button('\+', key='imgexuse%d' % img_example_counter)
135
+ trigger = trigger or this_trigger
136
+ if this_trigger:
137
+ trigger = entry[return_key]
138
+ return trigger
139
+
140
+
141
+ def segment_img(img: Image):
142
+ output = rembg.remove(img)
143
+ mask = numpy.array(output)[:, :, 3] > 0
144
+ sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
145
+ segmented_img = Image.new("RGBA", img.size, (0, 0, 0, 0))
146
+ segmented_img.paste(img, mask=Image.fromarray(sam_mask))
147
+ return segmented_img
148
+
149
+
150
+ def segment_6imgs(imgs):
151
+ segmented_imgs = []
152
+ for i, img in enumerate(imgs):
153
+ output = rembg.remove(img)
154
+ mask = numpy.array(output)[:, :, 3]
155
+ mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
156
+ data = numpy.array(img)[:,:,:3]
157
+ data[mask == 0] = [255, 255, 255]
158
+ segmented_imgs.append(data)
159
+ result = numpy.concatenate([
160
+ numpy.concatenate([segmented_imgs[0], segmented_imgs[1]], axis=1),
161
+ numpy.concatenate([segmented_imgs[2], segmented_imgs[3]], axis=1),
162
+ numpy.concatenate([segmented_imgs[4], segmented_imgs[5]], axis=1)
163
+ ])
164
+ return Image.fromarray(result)
165
+
166
+ def pack_6imgs(imgs):
167
+ result = numpy.concatenate([
168
+ numpy.concatenate([imgs[0], imgs[1]], axis=1),
169
+ numpy.concatenate([imgs[2], imgs[3]], axis=1),
170
+ numpy.concatenate([imgs[4], imgs[5]], axis=1)
171
+ ])
172
+ return Image.fromarray(result)
173
+
174
+
175
+ def expand2square(pil_img, background_color):
176
+ width, height = pil_img.size
177
+ if width == height:
178
+ return pil_img
179
+ elif width > height:
180
+ result = Image.new(pil_img.mode, (width, width), background_color)
181
+ result.paste(pil_img, (0, (width - height) // 2))
182
+ return result
183
+ else:
184
+ result = Image.new(pil_img.mode, (height, height), background_color)
185
+ result.paste(pil_img, ((height - width) // 2, 0))
186
+ return result
187
+
188
+
189
+ @st.cache_data
190
+ def check_dependencies():
191
+ reqs = []
192
+ try:
193
+ import diffusers
194
+ except ImportError:
195
+ import traceback
196
+ traceback.print_exc()
197
+ print("Error: `diffusers` not found.", file=sys.stderr)
198
+ reqs.append("diffusers==0.20.2")
199
+ else:
200
+ if not diffusers.__version__.startswith("0.20"):
201
+ print(
202
+ f"Warning: You are using an unsupported version of diffusers ({diffusers.__version__}), which may lead to performance issues.",
203
+ file=sys.stderr
204
+ )
205
+ print("Recommended version is `diffusers==0.20.2`.", file=sys.stderr)
206
+ try:
207
+ import transformers
208
+ except ImportError:
209
+ import traceback
210
+ traceback.print_exc()
211
+ print("Error: `transformers` not found.", file=sys.stderr)
212
+ reqs.append("transformers==4.29.2")
213
+ if torch.__version__ < '2.0':
214
+ try:
215
+ import xformers
216
+ except ImportError:
217
+ print("Warning: You are using PyTorch 1.x without a working `xformers` installation.", file=sys.stderr)
218
+ print("You may see a significant memory overhead when running the model.", file=sys.stderr)
219
+ if len(reqs):
220
+ print(f"Info: Fix all dependency errors with `pip install {' '.join(reqs)}`.")
221
+
222
+
223
+ @st.cache_resource
224
+ def load_wonder3d_pipeline(cfg):
225
+ # Load scheduler, tokenizer and models.
226
+ # noise_scheduler = DDPMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
227
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="image_encoder", revision=cfg.revision)
228
+ feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="feature_extractor", revision=cfg.revision)
229
+ vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
230
+ unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
231
+
232
+ weight_dtype = torch.float16
233
+ # Move text_encode and vae to gpu and cast to weight_dtype
234
+ image_encoder.to(dtype=weight_dtype)
235
+ vae.to(dtype=weight_dtype)
236
+ unet.to(dtype=weight_dtype)
237
+
238
+ pipeline = MVDiffusionImagePipeline(
239
+ image_encoder=image_encoder, feature_extractor=feature_extractor, vae=vae, unet=unet, safety_checker=None,
240
+ scheduler=DDIMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler"),
241
+ **cfg.pipe_kwargs
242
+ )
243
+
244
+ if torch.cuda.is_available():
245
+ pipeline.to('cuda:0')
246
+ sys.main_lock = threading.Lock()
247
+ return pipeline
248
+
249
+
250
+ from utils.misc import load_config
251
+ from omegaconf import OmegaConf
252
+ # parse YAML config to OmegaConf
253
+ cfg = load_config("./configs/mvdiffusion-joint-ortho-6views.yaml")
254
+ # print(cfg)
255
+ schema = OmegaConf.structured(TestConfig)
256
+ # cfg = OmegaConf.load(args.config)
257
+ cfg = OmegaConf.merge(schema, cfg)
258
+
259
+ check_dependencies()
260
+ pipeline = load_wonder3d_pipeline(cfg)
261
+ SAMAPI.get_instance()
262
+ torch.set_grad_enabled(False)
263
+
264
+ st.title("Wonder3D Demo")
265
+ # st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
266
+ prog = st.progress(0.0, "Idle")
267
+ pic = st.file_uploader("Upload an Image", key='imageinput', type=['png', 'jpg', 'webp'])
268
+ left, right = st.columns(2)
269
+ with left:
270
+ rem_input_bg = st.checkbox("Remove Input Background")
271
+ with right:
272
+ rem_output_bg = st.checkbox("Remove Output Background")
273
+ num_inference_steps = st.slider("Number of Inference Steps", 15, 100, 75)
274
+ st.caption("Diffusion Steps. For general real or synthetic objects, around 28 is enough. For objects with delicate details such as faces (either realistic or illustration), you may need 75 or more steps.")
275
+ cfg_scale = st.slider("Classifier Free Guidance Scale", 1.0, 10.0, 4.0)
276
+ seed = st.text_input("Seed", "42")
277
+ submit = False
278
+ if st.button("Submit"):
279
+ submit = True
280
+ results_container = st.container()
281
+ sample_got = image_examples(iret, 4, 'rimageinput')
282
+ if sample_got:
283
+ pic = sample_got
284
+ with results_container:
285
+ if sample_got or submit:
286
+ prog.progress(0.03, "Waiting in Queue...")
287
+ with sys.main_lock:
288
+ seed = int(seed)
289
+ torch.manual_seed(seed)
290
+ img = Image.open(pic)
291
+ if max(img.size) > 1280:
292
+ w, h = img.size
293
+ w = round(1280 / max(img.size) * w)
294
+ h = round(1280 / max(img.size) * h)
295
+ img = img.resize((w, h))
296
+ left, right = st.columns(2)
297
+ with left:
298
+ st.image(img)
299
+ st.caption("Input Image")
300
+ prog.progress(0.1, "Preparing Inputs")
301
+ if rem_input_bg:
302
+ with right:
303
+ img = segment_img(img)
304
+ st.image(img)
305
+ st.caption("Input (Background Removed)")
306
+ img = expand2square(img, (127, 127, 127, 0))
307
+ pipeline.set_progress_bar_config(disable=True)
308
+ result = pipeline(
309
+ img,
310
+ num_inference_steps=num_inference_steps,
311
+ guidance_scale=cfg_scale,
312
+ generator=torch.Generator(pipeline.device).manual_seed(seed),
313
+ callback=lambda i, t, latents: prog.progress(0.1 + 0.8 * i / num_inference_steps, "Diffusion Step %d" % i)
314
+ ).images
315
+ bsz = result.shape[0] // 2
316
+ normals_pred = result[:bsz]
317
+ images_pred = result[bsz:]
318
+ prog.progress(0.9, "Post Processing")
319
+ left, right = st.columns(2)
320
+ with left:
321
+ st.image(pack_6imgs(normals_pred))
322
+ st.image(pack_6imgs(images_pred))
323
+ st.caption("Result")
324
+ if rem_output_bg:
325
+ normals_pred = segment_6imgs(normals_pred)
326
+ images_pred = segment_6imgs(images_pred)
327
+ with right:
328
+ st.image(normals_pred)
329
+ st.image(images_pred)
330
+ st.caption("Result (Background Removed)")
331
+ prog.progress(1.0, "Idle")
configs/mvdiffusion-joint-ortho-6views.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers'
2
+ pretrained_unet_path: './ckpts/'
3
+ revision: null
4
+ validation_dataset:
5
+ root_dir: "./example_images" # the folder path stores testing images
6
+ num_views: 6
7
+ bg_color: 'white'
8
+ img_wh: [256, 256]
9
+ num_validation_samples: 1000
10
+ crop_size: 192
11
+ filepaths: ['owl.png'] # the test image names. leave it empty, test all images in the folder
12
+
13
+ save_dir: 'outputs/'
14
+
15
+ pred_type: 'joint'
16
+ seed: 42
17
+ validation_batch_size: 1
18
+ dataloader_num_workers: 64
19
+
20
+ local_rank: -1
21
+
22
+ pipe_kwargs:
23
+ camera_embedding_type: 'e_de_da_sincos'
24
+ num_views: 6
25
+
26
+ validation_guidance_scales: [3.0]
27
+ pipe_validation_kwargs:
28
+ eta: 1.0
29
+ validation_grid_nrow: 6
30
+
31
+ unet_from_pretrained_kwargs:
32
+ camera_embedding_type: 'e_de_da_sincos'
33
+ projection_class_embeddings_input_dim: 10
34
+ num_views: 6
35
+ sample_size: 32
36
+ zero_init_conv_in: false
37
+ zero_init_camera_projection: false
38
+
39
+ num_views: 6
40
+ camera_embedding_type: 'e_de_da_sincos'
41
+
42
+ enable_xformers_memory_efficient_attention: true
example_images/14_10_29_489_Tiger_1__1.png ADDED
example_images/box.png ADDED
example_images/bread.png ADDED
example_images/cat.png ADDED
example_images/cat_head.png ADDED
example_images/chili.png ADDED
example_images/duola.png ADDED
example_images/halloween.png ADDED
example_images/head.png ADDED
example_images/kettle.png ADDED
example_images/kunkun.png ADDED
example_images/milk.png ADDED
example_images/owl.png ADDED
example_images/poro.png ADDED
example_images/pumpkin.png ADDED
example_images/skull.png ADDED
example_images/stone.png ADDED
example_images/teapot.png ADDED
example_images/tiger-head-3d-model-obj-stl.png ADDED
mvdiffusion/data/fixed_poses/four_views/000_back_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -1.000000238418579102e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
2
+ 0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 1.746665105883948854e-07
3
+ 0.000000000000000000e+00 1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
mvdiffusion/data/fixed_poses/four_views/000_front_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
2
+ 0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 -1.746665105883948854e-07
3
+ 0.000000000000000000e+00 -1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
mvdiffusion/data/fixed_poses/four_views/000_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -2.220446049250313081e-16 -1.000000000000000000e+00 0.000000000000000000e+00 -2.886579758146288598e-16
2
+ 0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
3
+ -1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
mvdiffusion/data/fixed_poses/four_views/000_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 2.886579758146288598e-16
2
+ 0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
3
+ 1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
mvdiffusion/data/fixed_poses/nine_views/000_back_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -5.266582965850830078e-01 7.410295009613037109e-01 -4.165407419204711914e-01 -5.960464477539062500e-08
2
+ 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 -9.462351613365171943e-08
3
+ 8.500770330429077148e-01 4.590988159179687500e-01 -2.580644786357879639e-01 -1.300000071525573730e+00
mvdiffusion/data/fixed_poses/nine_views/000_back_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -9.734988808631896973e-01 1.993551850318908691e-01 -1.120596975088119507e-01 -1.713633537292480469e-07
2
+ 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 1.772203575001185527e-07
3
+ 2.286916375160217285e-01 8.486189246177673340e-01 -4.770178496837615967e-01 -1.838477611541748047e+00
mvdiffusion/data/fixed_poses/nine_views/000_back_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 2.286914736032485962e-01 8.486190438270568848e-01 -4.770178198814392090e-01 1.564621925354003906e-07
2
+ -3.417914484771245043e-08 4.900034070014953613e-01 8.717205524444580078e-01 -7.293811421504869941e-08
3
+ 9.734990000724792480e-01 -1.993550658226013184e-01 1.120596155524253845e-01 -1.838477969169616699e+00
mvdiffusion/data/fixed_poses/nine_views/000_front_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 5.266583561897277832e-01 -7.410295009613037109e-01 4.165407419204711914e-01 0.000000000000000000e+00
2
+ 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 9.462351613365171943e-08
3
+ -8.500770330429077148e-01 -4.590988159179687500e-01 2.580645382404327393e-01 -1.300000071525573730e+00
mvdiffusion/data/fixed_poses/nine_views/000_front_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -2.286916971206665039e-01 -8.486189842224121094e-01 4.770179092884063721e-01 -2.458691596984863281e-07
2
+ 9.085837859856837895e-09 4.900034666061401367e-01 8.717205524444580078e-01 1.205695667749751010e-07
3
+ -9.734990000724792480e-01 1.993551701307296753e-01 -1.120597645640373230e-01 -1.838477969169616699e+00
mvdiffusion/data/fixed_poses/nine_views/000_front_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 9.734989404678344727e-01 -1.993551850318908691e-01 1.120596975088119507e-01 -1.415610313415527344e-07
2
+ 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 -1.772203575001185527e-07
3
+ -2.286916375160217285e-01 -8.486189246177673340e-01 4.770178794860839844e-01 -1.838477611541748047e+00
mvdiffusion/data/fixed_poses/nine_views/000_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -8.500771522521972656e-01 -4.590989053249359131e-01 2.580644488334655762e-01 0.000000000000000000e+00
2
+ -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 9.006067358541258727e-08
3
+ -5.266583561897277832e-01 7.410295605659484863e-01 -4.165408313274383545e-01 -1.300000071525573730e+00
mvdiffusion/data/fixed_poses/nine_views/000_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 8.500770330429077148e-01 4.590989053249359131e-01 -2.580644488334655762e-01 5.960464477539062500e-08
2
+ -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 -9.006067358541258727e-08
3
+ 5.266583561897277832e-01 -7.410295605659484863e-01 4.165407419204711914e-01 -1.300000071525573730e+00
mvdiffusion/data/fixed_poses/nine_views/000_top_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 9.958608150482177734e-01 7.923202216625213623e-02 -4.453715682029724121e-02 -3.098167056236889039e-09
2
+ -9.089154005050659180e-02 8.681122064590454102e-01 -4.879753291606903076e-01 5.784738377201392723e-08
3
+ -2.028124157504862524e-08 4.900035560131072998e-01 8.717204332351684570e-01 -1.300000071525573730e+00
mvdiffusion/data/normal_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def camNormal2worldNormal(rot_c2w, camNormal):
4
+ H,W,_ = camNormal.shape
5
+ normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
6
+
7
+ return normal_img
8
+
9
+ def worldNormal2camNormal(rot_w2c, normal_map_world):
10
+ H,W,_ = normal_map_world.shape
11
+ # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
12
+
13
+ # faster version
14
+ # Reshape the normal map into a 2D array where each row represents a normal vector
15
+ normal_map_flat = normal_map_world.reshape(-1, 3)
16
+
17
+ # Transform the normal vectors using the transformation matrix
18
+ normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T)
19
+
20
+ # Reshape the transformed normal map back to its original shape
21
+ normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape)
22
+
23
+ return normal_map_camera
24
+
25
+ def trans_normal(normal, RT_w2c, RT_w2c_target):
26
+
27
+ # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
28
+ # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
29
+
30
+ relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
31
+ normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal)
32
+
33
+ return normal_target_cam
34
+
35
+ def img2normal(img):
36
+ return (img/255.)*2-1
37
+
38
+ def normal2img(normal):
39
+ return np.uint8((normal*0.5+0.5)*255)
40
+
41
+ def norm_normalize(normal, dim=-1):
42
+
43
+ normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
44
+
45
+ return normal
mvdiffusion/data/objaverse_dataset.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ import PIL.Image
20
+ from .normal_utils import trans_normal, normal2img, img2normal
21
+ import pdb
22
+
23
+ def shift_list(lst, n):
24
+ length = len(lst)
25
+ n = n % length # Ensure n is within the range of the list length
26
+ return lst[-n:] + lst[:-n]
27
+
28
+
29
+ class ObjaverseDataset(Dataset):
30
+ def __init__(self,
31
+ root_dir: str,
32
+ num_views: int,
33
+ bg_color: Any,
34
+ img_wh: Tuple[int, int],
35
+ object_list: str,
36
+ groups_num: int=1,
37
+ validation: bool = False,
38
+ random_views: bool = False,
39
+ num_validation_samples: int = 64,
40
+ num_samples: Optional[int] = None,
41
+ invalid_list: Optional[str] = None,
42
+ trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view
43
+ augment_data: bool = False,
44
+ read_normal: bool = True,
45
+ read_color: bool = False,
46
+ read_depth: bool = False,
47
+ mix_color_normal: bool = False,
48
+ random_view_and_domain: bool = False
49
+ ) -> None:
50
+ """Create a dataset from a folder of images.
51
+ If you pass in a root directory it will be searched for images
52
+ ending in ext (ext can be a list)
53
+ """
54
+ self.root_dir = Path(root_dir)
55
+ self.num_views = num_views
56
+ self.bg_color = bg_color
57
+ self.validation = validation
58
+ self.num_samples = num_samples
59
+ self.trans_norm_system = trans_norm_system
60
+ self.augment_data = augment_data
61
+ self.invalid_list = invalid_list
62
+ self.groups_num = groups_num
63
+ print("augment data: ", self.augment_data)
64
+ self.img_wh = img_wh
65
+ self.read_normal = read_normal
66
+ self.read_color = read_color
67
+ self.read_depth = read_depth
68
+ self.mix_color_normal = mix_color_normal # mix load color and normal maps
69
+ self.random_view_and_domain = random_view_and_domain # load normal or rgb of a single view
70
+ self.random_views = random_views
71
+ if not self.random_views:
72
+ if self.num_views == 4:
73
+ self.view_types = ['front', 'right', 'back', 'left']
74
+ elif self.num_views == 5:
75
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left']
76
+ elif self.num_views == 6 or self.num_views==1:
77
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
78
+ else:
79
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
80
+
81
+ self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views"
82
+
83
+ self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
84
+
85
+ if object_list is not None:
86
+ with open(object_list) as f:
87
+ self.objects = json.load(f)
88
+ self.objects = [os.path.basename(o).replace(".glb", "") for o in self.objects]
89
+ else:
90
+ self.objects = os.listdir(self.root_dir)
91
+ self.objects = sorted(self.objects)
92
+
93
+ if self.invalid_list is not None:
94
+ with open(self.invalid_list) as f:
95
+ self.invalid_objects = json.load(f)
96
+ self.invalid_objects = [os.path.basename(o).replace(".glb", "") for o in self.invalid_objects]
97
+ else:
98
+ self.invalid_objects = []
99
+
100
+
101
+ self.all_objects = set(self.objects) - (set(self.invalid_objects) & set(self.objects))
102
+ self.all_objects = list(self.all_objects)
103
+
104
+ if not validation:
105
+ self.all_objects = self.all_objects[:-num_validation_samples]
106
+ else:
107
+ self.all_objects = self.all_objects[-num_validation_samples:]
108
+ if num_samples is not None:
109
+ self.all_objects = self.all_objects[:num_samples]
110
+
111
+ print("loading ", len(self.all_objects), " objects in the dataset")
112
+
113
+ if self.mix_color_normal:
114
+ self.backup_data = self.__getitem_mix__(0, "9438abf986c7453a9f4df7c34aa2e65b")
115
+ elif self.random_view_and_domain:
116
+ self.backup_data = self.__getitem_random_viewanddomain__(0, "9438abf986c7453a9f4df7c34aa2e65b")
117
+ else:
118
+ self.backup_data = self.__getitem_norm__(0, "9438abf986c7453a9f4df7c34aa2e65b") # "66b2134b7e3645b29d7c349645291f78")
119
+
120
+ def __len__(self):
121
+ return len(self.objects)*self.total_view
122
+
123
+ def load_fixed_poses(self):
124
+ poses = {}
125
+ for face in self.view_types:
126
+ RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
127
+ poses[face] = RT
128
+
129
+ return poses
130
+
131
+ def cartesian_to_spherical(self, xyz):
132
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
133
+ xy = xyz[:,0]**2 + xyz[:,1]**2
134
+ z = np.sqrt(xy + xyz[:,2]**2)
135
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
136
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
137
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
138
+ return np.array([theta, azimuth, z])
139
+
140
+ def get_T(self, target_RT, cond_RT):
141
+ R, T = target_RT[:3, :3], target_RT[:, -1]
142
+ T_target = -R.T @ T # change to cam2world
143
+
144
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
145
+ T_cond = -R.T @ T
146
+
147
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
148
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
149
+
150
+ d_theta = theta_target - theta_cond
151
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
152
+ d_z = z_target - z_cond
153
+
154
+ # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
155
+ return d_theta, d_azimuth
156
+
157
+ def get_bg_color(self):
158
+ if self.bg_color == 'white':
159
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
160
+ elif self.bg_color == 'black':
161
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
162
+ elif self.bg_color == 'gray':
163
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
164
+ elif self.bg_color == 'random':
165
+ bg_color = np.random.rand(3)
166
+ elif self.bg_color == 'three_choices':
167
+ white = np.array([1., 1., 1.], dtype=np.float32)
168
+ black = np.array([0., 0., 0.], dtype=np.float32)
169
+ gray = np.array([0.5, 0.5, 0.5], dtype=np.float32)
170
+ bg_color = random.choice([white, black, gray])
171
+ elif isinstance(self.bg_color, float):
172
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
173
+ else:
174
+ raise NotImplementedError
175
+ return bg_color
176
+
177
+
178
+
179
+ def load_mask(self, img_path, return_type='np'):
180
+ # not using cv2 as may load in uint16 format
181
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
182
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
183
+ # pil always returns uint8
184
+ img = np.array(Image.open(img_path).resize(self.img_wh))
185
+ img = np.float32(img > 0)
186
+
187
+ assert len(np.shape(img)) == 2
188
+
189
+ if return_type == "np":
190
+ pass
191
+ elif return_type == "pt":
192
+ img = torch.from_numpy(img)
193
+ else:
194
+ raise NotImplementedError
195
+
196
+ return img
197
+
198
+ def load_image(self, img_path, bg_color, alpha, return_type='np'):
199
+ # not using cv2 as may load in uint16 format
200
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
201
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
202
+ # pil always returns uint8
203
+ img = np.array(Image.open(img_path).resize(self.img_wh))
204
+ img = img.astype(np.float32) / 255. # [0, 1]
205
+ assert img.shape[-1] == 3 # RGB
206
+
207
+ if alpha.shape[-1] != 1:
208
+ alpha = alpha[:, :, None]
209
+
210
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
211
+
212
+ if return_type == "np":
213
+ pass
214
+ elif return_type == "pt":
215
+ img = torch.from_numpy(img)
216
+ else:
217
+ raise NotImplementedError
218
+
219
+ return img
220
+
221
+ def load_depth(self, img_path, bg_color, alpha, return_type='np'):
222
+ # not using cv2 as may load in uint16 format
223
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
224
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
225
+ # pil always returns uint8
226
+ img = np.array(Image.open(img_path).resize(self.img_wh))
227
+ img = img.astype(np.float32) / 65535. # [0, 1]
228
+
229
+ img[img > 0.4] = 0
230
+ img = img / 0.4
231
+
232
+ assert img.ndim == 2 # depth
233
+ img = np.stack([img]*3, axis=-1)
234
+
235
+ if alpha.shape[-1] != 1:
236
+ alpha = alpha[:, :, None]
237
+
238
+ # print(np.max(img[:, :, 0]))
239
+
240
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
241
+
242
+ if return_type == "np":
243
+ pass
244
+ elif return_type == "pt":
245
+ img = torch.from_numpy(img)
246
+ else:
247
+ raise NotImplementedError
248
+
249
+ return img
250
+
251
+ def load_normal(self, img_path, bg_color, alpha, RT_w2c=None, RT_w2c_cond=None, return_type='np'):
252
+ # not using cv2 as may load in uint16 format
253
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
254
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
255
+ # pil always returns uint8
256
+ normal = np.array(Image.open(img_path).resize(self.img_wh))
257
+
258
+ assert normal.shape[-1] == 3 # RGB
259
+
260
+ normal = trans_normal(img2normal(normal), RT_w2c, RT_w2c_cond)
261
+
262
+ img = (normal*0.5 + 0.5).astype(np.float32) # [0, 1]
263
+
264
+ if alpha.shape[-1] != 1:
265
+ alpha = alpha[:, :, None]
266
+
267
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
268
+
269
+ if return_type == "np":
270
+ pass
271
+ elif return_type == "pt":
272
+ img = torch.from_numpy(img)
273
+ else:
274
+ raise NotImplementedError
275
+
276
+ return img
277
+
278
+ def __len__(self):
279
+ return len(self.all_objects)
280
+
281
+ def __getitem_mix__(self, index, debug_object=None):
282
+ if debug_object is not None:
283
+ object_name = debug_object #
284
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
285
+ else:
286
+ object_name = self.all_objects[index%len(self.all_objects)]
287
+ set_idx = 0
288
+
289
+ if self.augment_data:
290
+ cond_view = random.sample(self.view_types, k=1)[0]
291
+ else:
292
+ cond_view = 'front'
293
+
294
+ if random.random() < 0.5:
295
+ read_color, read_normal, read_depth = True, False, False
296
+ else:
297
+ read_color, read_normal, read_depth = False, True, True
298
+
299
+ read_normal = read_normal & self.read_normal
300
+ read_depth = read_depth & self.read_depth
301
+
302
+ assert (read_color and (read_normal or read_depth)) is False
303
+
304
+ view_types = self.view_types
305
+
306
+ cond_w2c = self.fix_cam_poses[cond_view]
307
+
308
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
309
+
310
+ elevations = []
311
+ azimuths = []
312
+
313
+ # get the bg color
314
+ bg_color = self.get_bg_color()
315
+
316
+ cond_alpha = self.load_mask(os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, cond_view)), return_type='np')
317
+ img_tensors_in = [
318
+ self.load_image(os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, cond_view)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
319
+ ] * self.num_views
320
+ img_tensors_out = []
321
+
322
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
323
+ img_path = os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, view))
324
+ mask_path = os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, view))
325
+ normal_path = os.path.join(self.root_dir, object_name[:3], object_name, "normals_%03d_%s.png" % (set_idx, view))
326
+ depth_path = os.path.join(self.root_dir, object_name[:3], object_name, "depth_%03d_%s.png" % (set_idx, view))
327
+ alpha = self.load_mask(mask_path, return_type='np')
328
+
329
+ if read_color:
330
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt")
331
+ img_tensor = img_tensor.permute(2, 0, 1)
332
+ img_tensors_out.append(img_tensor)
333
+
334
+ if read_normal:
335
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt").permute(2, 0, 1)
336
+ img_tensors_out.append(normal_tensor)
337
+ if read_depth:
338
+ depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt").permute(2, 0, 1)
339
+ img_tensors_out.append(depth_tensor)
340
+
341
+ # evelations, azimuths
342
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
343
+ elevations.append(elevation)
344
+ azimuths.append(azimuth)
345
+
346
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
347
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
348
+
349
+
350
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
351
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
352
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
353
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
354
+
355
+ normal_class = torch.tensor([1, 0]).float()
356
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
357
+ color_class = torch.tensor([0, 1]).float()
358
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
359
+ if read_normal or read_depth:
360
+ task_embeddings = normal_task_embeddings
361
+ if read_color:
362
+ task_embeddings = color_task_embeddings
363
+
364
+ return {
365
+ 'elevations_cond': elevations_cond,
366
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
367
+ 'elevations': elevations,
368
+ 'azimuths': azimuths,
369
+ 'elevations_deg': torch.rad2deg(elevations),
370
+ 'azimuths_deg': torch.rad2deg(azimuths),
371
+ 'imgs_in': img_tensors_in,
372
+ 'imgs_out': img_tensors_out,
373
+ 'camera_embeddings': camera_embeddings,
374
+ 'task_embeddings': task_embeddings
375
+ }
376
+
377
+
378
+ def __getitem_random_viewanddomain__(self, index, debug_object=None):
379
+ if debug_object is not None:
380
+ object_name = debug_object #
381
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
382
+ else:
383
+ object_name = self.all_objects[index%len(self.all_objects)]
384
+ set_idx = 0
385
+
386
+ if self.augment_data:
387
+ cond_view = random.sample(self.view_types, k=1)[0]
388
+ else:
389
+ cond_view = 'front'
390
+
391
+ if random.random() < 0.5:
392
+ read_color, read_normal, read_depth = True, False, False
393
+ else:
394
+ read_color, read_normal, read_depth = False, True, True
395
+
396
+ read_normal = read_normal & self.read_normal
397
+ read_depth = read_depth & self.read_depth
398
+
399
+ assert (read_color and (read_normal or read_depth)) is False
400
+
401
+ view_types = self.view_types
402
+
403
+ cond_w2c = self.fix_cam_poses[cond_view]
404
+
405
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
406
+
407
+ elevations = []
408
+ azimuths = []
409
+
410
+ # get the bg color
411
+ bg_color = self.get_bg_color()
412
+
413
+ cond_alpha = self.load_mask(os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, cond_view)), return_type='np')
414
+ img_tensors_in = [
415
+ self.load_image(os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, cond_view)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
416
+ ] * self.num_views
417
+ img_tensors_out = []
418
+
419
+ random_viewidx = random.randint(0, len(view_types)-1)
420
+
421
+ for view, tgt_w2c in zip([view_types[random_viewidx]], [tgt_w2cs[random_viewidx]]):
422
+ img_path = os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, view))
423
+ mask_path = os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, view))
424
+ normal_path = os.path.join(self.root_dir, object_name[:3], object_name, "normals_%03d_%s.png" % (set_idx, view))
425
+ depth_path = os.path.join(self.root_dir, object_name[:3], object_name, "depth_%03d_%s.png" % (set_idx, view))
426
+ alpha = self.load_mask(mask_path, return_type='np')
427
+
428
+ if read_color:
429
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt")
430
+ img_tensor = img_tensor.permute(2, 0, 1)
431
+ img_tensors_out.append(img_tensor)
432
+
433
+ if read_normal:
434
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt").permute(2, 0, 1)
435
+ img_tensors_out.append(normal_tensor)
436
+ if read_depth:
437
+ depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt").permute(2, 0, 1)
438
+ img_tensors_out.append(depth_tensor)
439
+
440
+ # evelations, azimuths
441
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
442
+ elevations.append(elevation)
443
+ azimuths.append(azimuth)
444
+
445
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
446
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
447
+
448
+
449
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
450
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
451
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
452
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
453
+
454
+ normal_class = torch.tensor([1, 0]).float()
455
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
456
+ color_class = torch.tensor([0, 1]).float()
457
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
458
+ if read_normal or read_depth:
459
+ task_embeddings = normal_task_embeddings
460
+ if read_color:
461
+ task_embeddings = color_task_embeddings
462
+
463
+ return {
464
+ 'elevations_cond': elevations_cond,
465
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
466
+ 'elevations': elevations,
467
+ 'azimuths': azimuths,
468
+ 'elevations_deg': torch.rad2deg(elevations),
469
+ 'azimuths_deg': torch.rad2deg(azimuths),
470
+ 'imgs_in': img_tensors_in,
471
+ 'imgs_out': img_tensors_out,
472
+ 'camera_embeddings': camera_embeddings,
473
+ 'task_embeddings': task_embeddings
474
+ }
475
+
476
+
477
+ def __getitem_norm__(self, index, debug_object=None):
478
+ if debug_object is not None:
479
+ object_name = debug_object #
480
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
481
+ else:
482
+ object_name = self.all_objects[index%len(self.all_objects)]
483
+ set_idx = 0
484
+
485
+ if self.augment_data:
486
+ cond_view = random.sample(self.view_types, k=1)[0]
487
+ else:
488
+ cond_view = 'front'
489
+
490
+ # if self.random_views:
491
+ # view_types = ['front']+random.sample(self.view_types[1:], 3)
492
+ # else:
493
+ # view_types = self.view_types
494
+
495
+ view_types = self.view_types
496
+
497
+ cond_w2c = self.fix_cam_poses[cond_view]
498
+
499
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
500
+
501
+ elevations = []
502
+ azimuths = []
503
+
504
+ # get the bg color
505
+ bg_color = self.get_bg_color()
506
+
507
+ cond_alpha = self.load_mask(os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, cond_view)), return_type='np')
508
+ img_tensors_in = [
509
+ self.load_image(os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, cond_view)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
510
+ ] * self.num_views
511
+ img_tensors_out = []
512
+ normal_tensors_out = []
513
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
514
+ img_path = os.path.join(self.root_dir, object_name[:3], object_name, "rgb_%03d_%s.png" % (set_idx, view))
515
+ mask_path = os.path.join(self.root_dir, object_name[:3], object_name, "mask_%03d_%s.png" % (set_idx, view))
516
+ alpha = self.load_mask(mask_path, return_type='np')
517
+
518
+ if self.read_color:
519
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt")
520
+ img_tensor = img_tensor.permute(2, 0, 1)
521
+ img_tensors_out.append(img_tensor)
522
+
523
+ if self.read_normal:
524
+ normal_path = os.path.join(self.root_dir, object_name[:3], object_name, "normals_%03d_%s.png" % (set_idx, view))
525
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt").permute(2, 0, 1)
526
+ normal_tensors_out.append(normal_tensor)
527
+
528
+ # evelations, azimuths
529
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
530
+ elevations.append(elevation)
531
+ azimuths.append(azimuth)
532
+
533
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
534
+ if self.read_color:
535
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
536
+ if self.read_normal:
537
+ normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
538
+
539
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
540
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
541
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
542
+
543
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
544
+
545
+ normal_class = torch.tensor([1, 0]).float()
546
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
547
+ color_class = torch.tensor([0, 1]).float()
548
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
549
+
550
+ return {
551
+ 'elevations_cond': elevations_cond,
552
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
553
+ 'elevations': elevations,
554
+ 'azimuths': azimuths,
555
+ 'elevations_deg': torch.rad2deg(elevations),
556
+ 'azimuths_deg': torch.rad2deg(azimuths),
557
+ 'imgs_in': img_tensors_in,
558
+ 'imgs_out': img_tensors_out,
559
+ 'normals_out': normal_tensors_out,
560
+ 'camera_embeddings': camera_embeddings,
561
+ 'normal_task_embeddings': normal_task_embeddings,
562
+ 'color_task_embeddings': color_task_embeddings
563
+ }
564
+
565
+ def __getitem__(self, index):
566
+
567
+ try:
568
+ if self.mix_color_normal:
569
+ data = self.__getitem_mix__(index)
570
+ elif self.random_view_and_domain:
571
+ data = self.__getitem_random_viewanddomain__(index)
572
+ else:
573
+ data = self.__getitem_norm__(index)
574
+ return data
575
+ except:
576
+ print("load error ", self.all_objects[index%len(self.all_objects)] )
577
+ return self.backup_data
578
+
579
+
580
+ class ConcatDataset(torch.utils.data.Dataset):
581
+ def __init__(self, datasets, weights):
582
+ self.datasets = datasets
583
+ self.weights = weights
584
+ self.num_datasets = len(datasets)
585
+
586
+ def __getitem__(self, i):
587
+
588
+ chosen = random.choices(self.datasets, self.weights, k=1)[0]
589
+ return chosen[i]
590
+
591
+ def __len__(self):
592
+ return max(len(d) for d in self.datasets)
593
+
594
+ if __name__ == "__main__":
595
+ train_dataset = ObjaverseDataset(
596
+ root_dir="/ghome/l5/xxlong/.objaverse/hf-objaverse-v1/renderings",
597
+ size=(128, 128),
598
+ ext="hdf5",
599
+ default_trans=torch.zeros(3),
600
+ return_paths=False,
601
+ total_view=8,
602
+ validation=False,
603
+ object_list=None,
604
+ views_mode='fourviews'
605
+ )
606
+ data0 = train_dataset[0]
607
+ data1 = train_dataset[50]
608
+ # print(data)
mvdiffusion/data/single_image_dataset.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ from glob import glob
20
+
21
+ import PIL.Image
22
+ from .normal_utils import trans_normal, normal2img, img2normal
23
+ import pdb
24
+
25
+
26
+ import cv2
27
+ import numpy as np
28
+
29
+ def add_margin(pil_img, color=0, size=256):
30
+ width, height = pil_img.size
31
+ result = Image.new(pil_img.mode, (size, size), color)
32
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
33
+ return result
34
+
35
+ def scale_and_place_object(image, scale_factor):
36
+ assert np.shape(image)[-1]==4 # RGBA
37
+
38
+ # Extract the alpha channel (transparency) and the object (RGB channels)
39
+ alpha_channel = image[:, :, 3]
40
+
41
+ # Find the bounding box coordinates of the object
42
+ coords = cv2.findNonZero(alpha_channel)
43
+ x, y, width, height = cv2.boundingRect(coords)
44
+
45
+ # Calculate the scale factor for resizing
46
+ original_height, original_width = image.shape[:2]
47
+
48
+ if width > height:
49
+ size = width
50
+ original_size = original_width
51
+ else:
52
+ size = height
53
+ original_size = original_height
54
+
55
+ scale_factor = min(scale_factor, size / (original_size+0.0))
56
+
57
+ new_size = scale_factor * original_size
58
+ scale_factor = new_size / size
59
+
60
+ # Calculate the new size based on the scale factor
61
+ new_width = int(width * scale_factor)
62
+ new_height = int(height * scale_factor)
63
+
64
+ center_x = original_width // 2
65
+ center_y = original_height // 2
66
+
67
+ paste_x = center_x - (new_width // 2)
68
+ paste_y = center_y - (new_height // 2)
69
+
70
+ # Resize the object (RGB channels) to the new size
71
+ rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height))
72
+
73
+ # Create a new RGBA image with the resized image
74
+ new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8)
75
+
76
+ new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object
77
+
78
+ return new_image
79
+
80
+ class SingleImageDataset(Dataset):
81
+ def __init__(self,
82
+ root_dir: str,
83
+ num_views: int,
84
+ img_wh: Tuple[int, int],
85
+ bg_color: str,
86
+ crop_size: int = 224,
87
+ num_validation_samples: Optional[int] = None,
88
+ filepaths: Optional[list] = None,
89
+ cond_type: Optional[str] = None
90
+ ) -> None:
91
+ """Create a dataset from a folder of images.
92
+ If you pass in a root directory it will be searched for images
93
+ ending in ext (ext can be a list)
94
+ """
95
+ self.root_dir = Path(root_dir)
96
+ self.num_views = num_views
97
+ self.img_wh = img_wh
98
+ self.crop_size = crop_size
99
+ self.bg_color = bg_color
100
+ self.cond_type = cond_type
101
+
102
+ if self.num_views == 4:
103
+ self.view_types = ['front', 'right', 'back', 'left']
104
+ elif self.num_views == 5:
105
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left']
106
+ elif self.num_views == 6:
107
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
108
+
109
+ self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views"
110
+
111
+ self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
112
+
113
+ if filepaths is None:
114
+ # Get a list of all files in the directory
115
+ file_list = os.listdir(self.root_dir)
116
+ else:
117
+ file_list = filepaths
118
+
119
+ if self.cond_type == None:
120
+ # Filter the files that end with .png or .jpg
121
+ self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg'))]
122
+ self.cond_dirs = None
123
+ else:
124
+ self.file_list = []
125
+ self.cond_dirs = []
126
+ for scene in file_list:
127
+ self.file_list.append(os.path.join(scene, f"{scene}.png"))
128
+ if self.cond_type == 'normals':
129
+ self.cond_dirs.append(os.path.join(self.root_dir, scene, 'outs'))
130
+ else:
131
+ self.cond_dirs.append(os.path.join(self.root_dir, scene))
132
+
133
+ # load all images
134
+ self.all_images = []
135
+ self.all_alphas = []
136
+ bg_color = self.get_bg_color()
137
+ for file in self.file_list:
138
+ image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt')
139
+ self.all_images.append(image)
140
+ self.all_alphas.append(alpha)
141
+
142
+ self.all_images = self.all_images[:num_validation_samples]
143
+ self.all_alphas = self.all_alphas[:num_validation_samples]
144
+
145
+
146
+ def __len__(self):
147
+ return len(self.all_images)
148
+
149
+ def load_fixed_poses(self):
150
+ poses = {}
151
+ for face in self.view_types:
152
+ RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
153
+ poses[face] = RT
154
+
155
+ return poses
156
+
157
+ def cartesian_to_spherical(self, xyz):
158
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
159
+ xy = xyz[:,0]**2 + xyz[:,1]**2
160
+ z = np.sqrt(xy + xyz[:,2]**2)
161
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
162
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
163
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
164
+ return np.array([theta, azimuth, z])
165
+
166
+ def get_T(self, target_RT, cond_RT):
167
+ R, T = target_RT[:3, :3], target_RT[:, -1]
168
+ T_target = -R.T @ T # change to cam2world
169
+
170
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
171
+ T_cond = -R.T @ T
172
+
173
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
174
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
175
+
176
+ d_theta = theta_target - theta_cond
177
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
178
+ d_z = z_target - z_cond
179
+
180
+ # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
181
+ return d_theta, d_azimuth
182
+
183
+ def get_bg_color(self):
184
+ if self.bg_color == 'white':
185
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
186
+ elif self.bg_color == 'black':
187
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
188
+ elif self.bg_color == 'gray':
189
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
190
+ elif self.bg_color == 'random':
191
+ bg_color = np.random.rand(3)
192
+ elif isinstance(self.bg_color, float):
193
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
194
+ else:
195
+ raise NotImplementedError
196
+ return bg_color
197
+
198
+
199
+ def load_image(self, img_path, bg_color, return_type='np'):
200
+ # pil always returns uint8
201
+ image_input = Image.open(img_path)
202
+ image_size = self.img_wh[0]
203
+
204
+ if self.crop_size!=-1:
205
+ alpha_np = np.asarray(image_input)[:, :, 3]
206
+ coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
207
+ min_x, min_y = np.min(coords, 0)
208
+ max_x, max_y = np.max(coords, 0)
209
+ ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
210
+ h, w = ref_img_.height, ref_img_.width
211
+ scale = self.crop_size / max(h, w)
212
+ h_, w_ = int(scale * h), int(scale * w)
213
+ ref_img_ = ref_img_.resize((w_, h_), resample=Image.BICUBIC)
214
+ image_input = add_margin(ref_img_, size=image_size)
215
+ else:
216
+ image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
217
+ image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
218
+
219
+ # img = scale_and_place_object(img, self.scale_ratio)
220
+ img = np.array(image_input)
221
+ img = img.astype(np.float32) / 255. # [0, 1]
222
+ assert img.shape[-1] == 4 # RGBA
223
+
224
+ alpha = img[...,3:4]
225
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
226
+
227
+ if return_type == "np":
228
+ pass
229
+ elif return_type == "pt":
230
+ img = torch.from_numpy(img)
231
+ alpha = torch.from_numpy(alpha)
232
+ else:
233
+ raise NotImplementedError
234
+
235
+ return img, alpha
236
+
237
+ def load_conds(self, directory):
238
+ assert self.crop_size == -1
239
+ image_size = self.img_wh[0]
240
+ conds = []
241
+ for view in self.view_types:
242
+ cond_file = f"{self.cond_type}_000_{view}.png"
243
+ image_input = Image.open(os.path.join(directory, cond_file))
244
+ image_input = image_input.resize((image_size, image_size), resample=Image.BICUBIC)
245
+ image_input = np.array(image_input)[:, :, :3] / 255.
246
+ conds.append(image_input)
247
+
248
+ conds = np.stack(conds, axis=0)
249
+ conds = torch.from_numpy(conds).permute(0, 3, 1, 2) # B, 3, H, W
250
+ return conds
251
+
252
+ def __len__(self):
253
+ return len(self.all_images)
254
+
255
+ def __getitem__(self, index):
256
+
257
+ image = self.all_images[index%len(self.all_images)]
258
+ alpha = self.all_alphas[index%len(self.all_images)]
259
+ filename = self.file_list[index%len(self.all_images)].replace(".png", "")
260
+
261
+ if self.cond_type != None:
262
+ conds = self.load_conds(self.cond_dirs[index%len(self.all_images)])
263
+ else:
264
+ conds = None
265
+
266
+ cond_w2c = self.fix_cam_poses['front']
267
+
268
+ tgt_w2cs = [self.fix_cam_poses[view] for view in self.view_types]
269
+
270
+ elevations = []
271
+ azimuths = []
272
+
273
+ img_tensors_in = [
274
+ image.permute(2, 0, 1)
275
+ ] * self.num_views
276
+
277
+ alpha_tensors_in = [
278
+ alpha.permute(2, 0, 1)
279
+ ] * self.num_views
280
+
281
+ for view, tgt_w2c in zip(self.view_types, tgt_w2cs):
282
+ # evelations, azimuths
283
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
284
+ elevations.append(elevation)
285
+ azimuths.append(azimuth)
286
+
287
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
288
+ alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W)
289
+
290
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
291
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
292
+ elevations_cond = torch.as_tensor([0] * self.num_views).float()
293
+
294
+ normal_class = torch.tensor([1, 0]).float()
295
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
296
+ color_class = torch.tensor([0, 1]).float()
297
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
298
+
299
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
300
+
301
+ out = {
302
+ 'elevations_cond': elevations_cond,
303
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
304
+ 'elevations': elevations,
305
+ 'azimuths': azimuths,
306
+ 'elevations_deg': torch.rad2deg(elevations),
307
+ 'azimuths_deg': torch.rad2deg(azimuths),
308
+ 'imgs_in': img_tensors_in,
309
+ 'alphas': alpha_tensors_in,
310
+ 'camera_embeddings': camera_embeddings,
311
+ 'normal_task_embeddings': normal_task_embeddings,
312
+ 'color_task_embeddings': color_task_embeddings,
313
+ 'filename': filename,
314
+ }
315
+
316
+ if conds is not None:
317
+ out['conds'] = conds
318
+
319
+ return out
320
+
321
+
mvdiffusion/models/transformer_mv2d.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
23
+ from diffusers.utils import BaseOutput, deprecate, maybe_allow_in_graph
24
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
25
+ from diffusers.models.embeddings import PatchEmbed
26
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.utils.import_utils import is_xformers_available
29
+
30
+ from einops import rearrange
31
+ import pdb
32
+ import random
33
+
34
+
35
+ if is_xformers_available():
36
+ import xformers
37
+ import xformers.ops
38
+ else:
39
+ xformers = None
40
+
41
+
42
+ @dataclass
43
+ class TransformerMV2DModelOutput(BaseOutput):
44
+ """
45
+ The output of [`Transformer2DModel`].
46
+
47
+ Args:
48
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
49
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
50
+ distributions for the unnoised latent pixels.
51
+ """
52
+
53
+ sample: torch.FloatTensor
54
+
55
+
56
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
57
+ """
58
+ A 2D Transformer model for image-like data.
59
+
60
+ Parameters:
61
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
62
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
63
+ in_channels (`int`, *optional*):
64
+ The number of channels in the input and output (specify if the input is **continuous**).
65
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
66
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
67
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
68
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
69
+ This is fixed during training since it is used to learn a number of position embeddings.
70
+ num_vector_embeds (`int`, *optional*):
71
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
72
+ Includes the class for the masked latent pixel.
73
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
74
+ num_embeds_ada_norm ( `int`, *optional*):
75
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
76
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
77
+ added to the hidden states.
78
+
79
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
80
+ attention_bias (`bool`, *optional*):
81
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
82
+ """
83
+
84
+ @register_to_config
85
+ def __init__(
86
+ self,
87
+ num_attention_heads: int = 16,
88
+ attention_head_dim: int = 88,
89
+ in_channels: Optional[int] = None,
90
+ out_channels: Optional[int] = None,
91
+ num_layers: int = 1,
92
+ dropout: float = 0.0,
93
+ norm_num_groups: int = 32,
94
+ cross_attention_dim: Optional[int] = None,
95
+ attention_bias: bool = False,
96
+ sample_size: Optional[int] = None,
97
+ num_vector_embeds: Optional[int] = None,
98
+ patch_size: Optional[int] = None,
99
+ activation_fn: str = "geglu",
100
+ num_embeds_ada_norm: Optional[int] = None,
101
+ use_linear_projection: bool = False,
102
+ only_cross_attention: bool = False,
103
+ upcast_attention: bool = False,
104
+ norm_type: str = "layer_norm",
105
+ norm_elementwise_affine: bool = True,
106
+ num_views: int = 1,
107
+ joint_attention: bool=False,
108
+ joint_attention_twice: bool=False,
109
+ multiview_attention: bool=True,
110
+ cross_domain_attention: bool=False
111
+ ):
112
+ super().__init__()
113
+ self.use_linear_projection = use_linear_projection
114
+ self.num_attention_heads = num_attention_heads
115
+ self.attention_head_dim = attention_head_dim
116
+ inner_dim = num_attention_heads * attention_head_dim
117
+
118
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
119
+ # Define whether input is continuous or discrete depending on configuration
120
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
121
+ self.is_input_vectorized = num_vector_embeds is not None
122
+ self.is_input_patches = in_channels is not None and patch_size is not None
123
+
124
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
125
+ deprecation_message = (
126
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
127
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
128
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
129
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
130
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
131
+ )
132
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
133
+ norm_type = "ada_norm"
134
+
135
+ if self.is_input_continuous and self.is_input_vectorized:
136
+ raise ValueError(
137
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
138
+ " sure that either `in_channels` or `num_vector_embeds` is None."
139
+ )
140
+ elif self.is_input_vectorized and self.is_input_patches:
141
+ raise ValueError(
142
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
143
+ " sure that either `num_vector_embeds` or `num_patches` is None."
144
+ )
145
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
146
+ raise ValueError(
147
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
148
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
149
+ )
150
+
151
+ # 2. Define input layers
152
+ if self.is_input_continuous:
153
+ self.in_channels = in_channels
154
+
155
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
156
+ if use_linear_projection:
157
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
158
+ else:
159
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
160
+ elif self.is_input_vectorized:
161
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
162
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
163
+
164
+ self.height = sample_size
165
+ self.width = sample_size
166
+ self.num_vector_embeds = num_vector_embeds
167
+ self.num_latent_pixels = self.height * self.width
168
+
169
+ self.latent_image_embedding = ImagePositionalEmbeddings(
170
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
171
+ )
172
+ elif self.is_input_patches:
173
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
174
+
175
+ self.height = sample_size
176
+ self.width = sample_size
177
+
178
+ self.patch_size = patch_size
179
+ self.pos_embed = PatchEmbed(
180
+ height=sample_size,
181
+ width=sample_size,
182
+ patch_size=patch_size,
183
+ in_channels=in_channels,
184
+ embed_dim=inner_dim,
185
+ )
186
+
187
+ # 3. Define transformers blocks
188
+ self.transformer_blocks = nn.ModuleList(
189
+ [
190
+ BasicMVTransformerBlock(
191
+ inner_dim,
192
+ num_attention_heads,
193
+ attention_head_dim,
194
+ dropout=dropout,
195
+ cross_attention_dim=cross_attention_dim,
196
+ activation_fn=activation_fn,
197
+ num_embeds_ada_norm=num_embeds_ada_norm,
198
+ attention_bias=attention_bias,
199
+ only_cross_attention=only_cross_attention,
200
+ upcast_attention=upcast_attention,
201
+ norm_type=norm_type,
202
+ norm_elementwise_affine=norm_elementwise_affine,
203
+ num_views=num_views,
204
+ joint_attention=joint_attention,
205
+ joint_attention_twice=joint_attention_twice,
206
+ multiview_attention=multiview_attention,
207
+ cross_domain_attention=cross_domain_attention
208
+ )
209
+ for d in range(num_layers)
210
+ ]
211
+ )
212
+
213
+ # 4. Define output layers
214
+ self.out_channels = in_channels if out_channels is None else out_channels
215
+ if self.is_input_continuous:
216
+ # TODO: should use out_channels for continuous projections
217
+ if use_linear_projection:
218
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
219
+ else:
220
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
221
+ elif self.is_input_vectorized:
222
+ self.norm_out = nn.LayerNorm(inner_dim)
223
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
224
+ elif self.is_input_patches:
225
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
226
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
227
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
228
+
229
+ def forward(
230
+ self,
231
+ hidden_states: torch.Tensor,
232
+ encoder_hidden_states: Optional[torch.Tensor] = None,
233
+ timestep: Optional[torch.LongTensor] = None,
234
+ class_labels: Optional[torch.LongTensor] = None,
235
+ cross_attention_kwargs: Dict[str, Any] = None,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ encoder_attention_mask: Optional[torch.Tensor] = None,
238
+ return_dict: bool = True,
239
+ ):
240
+ """
241
+ The [`Transformer2DModel`] forward method.
242
+
243
+ Args:
244
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
245
+ Input `hidden_states`.
246
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
247
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
248
+ self-attention.
249
+ timestep ( `torch.LongTensor`, *optional*):
250
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
251
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
252
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
253
+ `AdaLayerZeroNorm`.
254
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
255
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
256
+
257
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
258
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
259
+
260
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
261
+ above. This bias will be added to the cross-attention scores.
262
+ return_dict (`bool`, *optional*, defaults to `True`):
263
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
264
+ tuple.
265
+
266
+ Returns:
267
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
268
+ `tuple` where the first element is the sample tensor.
269
+ """
270
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
271
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
272
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
273
+ # expects mask of shape:
274
+ # [batch, key_tokens]
275
+ # adds singleton query_tokens dimension:
276
+ # [batch, 1, key_tokens]
277
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
278
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
279
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
280
+ if attention_mask is not None and attention_mask.ndim == 2:
281
+ # assume that mask is expressed as:
282
+ # (1 = keep, 0 = discard)
283
+ # convert mask into a bias that can be added to attention scores:
284
+ # (keep = +0, discard = -10000.0)
285
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
286
+ attention_mask = attention_mask.unsqueeze(1)
287
+
288
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
289
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
290
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
291
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
292
+
293
+ # 1. Input
294
+ if self.is_input_continuous:
295
+ batch, _, height, width = hidden_states.shape
296
+ residual = hidden_states
297
+
298
+ hidden_states = self.norm(hidden_states)
299
+ if not self.use_linear_projection:
300
+ hidden_states = self.proj_in(hidden_states)
301
+ inner_dim = hidden_states.shape[1]
302
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
303
+ else:
304
+ inner_dim = hidden_states.shape[1]
305
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
306
+ hidden_states = self.proj_in(hidden_states)
307
+ elif self.is_input_vectorized:
308
+ hidden_states = self.latent_image_embedding(hidden_states)
309
+ elif self.is_input_patches:
310
+ hidden_states = self.pos_embed(hidden_states)
311
+
312
+ # 2. Blocks
313
+ for block in self.transformer_blocks:
314
+ hidden_states = block(
315
+ hidden_states,
316
+ attention_mask=attention_mask,
317
+ encoder_hidden_states=encoder_hidden_states,
318
+ encoder_attention_mask=encoder_attention_mask,
319
+ timestep=timestep,
320
+ cross_attention_kwargs=cross_attention_kwargs,
321
+ class_labels=class_labels,
322
+ )
323
+
324
+ # 3. Output
325
+ if self.is_input_continuous:
326
+ if not self.use_linear_projection:
327
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
328
+ hidden_states = self.proj_out(hidden_states)
329
+ else:
330
+ hidden_states = self.proj_out(hidden_states)
331
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
332
+
333
+ output = hidden_states + residual
334
+ elif self.is_input_vectorized:
335
+ hidden_states = self.norm_out(hidden_states)
336
+ logits = self.out(hidden_states)
337
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
338
+ logits = logits.permute(0, 2, 1)
339
+
340
+ # log(p(x_0))
341
+ output = F.log_softmax(logits.double(), dim=1).float()
342
+ elif self.is_input_patches:
343
+ # TODO: cleanup!
344
+ conditioning = self.transformer_blocks[0].norm1.emb(
345
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
346
+ )
347
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
348
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
349
+ hidden_states = self.proj_out_2(hidden_states)
350
+
351
+ # unpatchify
352
+ height = width = int(hidden_states.shape[1] ** 0.5)
353
+ hidden_states = hidden_states.reshape(
354
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
355
+ )
356
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
357
+ output = hidden_states.reshape(
358
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
359
+ )
360
+
361
+ if not return_dict:
362
+ return (output,)
363
+
364
+ return TransformerMV2DModelOutput(sample=output)
365
+
366
+
367
+ @maybe_allow_in_graph
368
+ class BasicMVTransformerBlock(nn.Module):
369
+ r"""
370
+ A basic Transformer block.
371
+
372
+ Parameters:
373
+ dim (`int`): The number of channels in the input and output.
374
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
375
+ attention_head_dim (`int`): The number of channels in each head.
376
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
377
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
378
+ only_cross_attention (`bool`, *optional*):
379
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
380
+ double_self_attention (`bool`, *optional*):
381
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
382
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
383
+ num_embeds_ada_norm (:
384
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
385
+ attention_bias (:
386
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
387
+ """
388
+
389
+ def __init__(
390
+ self,
391
+ dim: int,
392
+ num_attention_heads: int,
393
+ attention_head_dim: int,
394
+ dropout=0.0,
395
+ cross_attention_dim: Optional[int] = None,
396
+ activation_fn: str = "geglu",
397
+ num_embeds_ada_norm: Optional[int] = None,
398
+ attention_bias: bool = False,
399
+ only_cross_attention: bool = False,
400
+ double_self_attention: bool = False,
401
+ upcast_attention: bool = False,
402
+ norm_elementwise_affine: bool = True,
403
+ norm_type: str = "layer_norm",
404
+ final_dropout: bool = False,
405
+ num_views: int = 1,
406
+ joint_attention: bool = False,
407
+ joint_attention_twice: bool = False,
408
+ multiview_attention: bool = True,
409
+ cross_domain_attention: bool = False
410
+ ):
411
+ super().__init__()
412
+ self.only_cross_attention = only_cross_attention
413
+
414
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
415
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
416
+
417
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
418
+ raise ValueError(
419
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
420
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
421
+ )
422
+
423
+ # Define 3 blocks. Each block has its own normalization layer.
424
+ # 1. Self-Attn
425
+ if self.use_ada_layer_norm:
426
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
427
+ elif self.use_ada_layer_norm_zero:
428
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
429
+ else:
430
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
431
+
432
+ self.multiview_attention = multiview_attention
433
+ self.cross_domain_attention = cross_domain_attention
434
+
435
+ self.attn1 = CustomAttention(
436
+ query_dim=dim,
437
+ heads=num_attention_heads,
438
+ dim_head=attention_head_dim,
439
+ dropout=dropout,
440
+ bias=attention_bias,
441
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
442
+ upcast_attention=upcast_attention,
443
+ processor=MVAttnProcessor()
444
+ )
445
+
446
+ # 2. Cross-Attn
447
+ if cross_attention_dim is not None or double_self_attention:
448
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
449
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
450
+ # the second cross attention block.
451
+ self.norm2 = (
452
+ AdaLayerNorm(dim, num_embeds_ada_norm)
453
+ if self.use_ada_layer_norm
454
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
455
+ )
456
+ self.attn2 = Attention(
457
+ query_dim=dim,
458
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
459
+ heads=num_attention_heads,
460
+ dim_head=attention_head_dim,
461
+ dropout=dropout,
462
+ bias=attention_bias,
463
+ upcast_attention=upcast_attention,
464
+ ) # is self-attn if encoder_hidden_states is none
465
+ else:
466
+ self.norm2 = None
467
+ self.attn2 = None
468
+
469
+ # 3. Feed-forward
470
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
471
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
472
+
473
+ # let chunk size default to None
474
+ self._chunk_size = None
475
+ self._chunk_dim = 0
476
+
477
+ self.num_views = num_views
478
+
479
+ self.joint_attention = joint_attention
480
+
481
+ if self.joint_attention:
482
+ # Joint task -Attn
483
+ self.attn_joint = CustomJointAttention(
484
+ query_dim=dim,
485
+ heads=num_attention_heads,
486
+ dim_head=attention_head_dim,
487
+ dropout=dropout,
488
+ bias=attention_bias,
489
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
490
+ upcast_attention=upcast_attention,
491
+ processor=JointAttnProcessor()
492
+ )
493
+ nn.init.zeros_(self.attn_joint.to_out[0].weight.data)
494
+ self.norm_joint = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
495
+
496
+
497
+ self.joint_attention_twice = joint_attention_twice
498
+
499
+ if self.joint_attention_twice:
500
+ print("joint twice")
501
+ # Joint task -Attn
502
+ self.attn_joint_twice = CustomJointAttention(
503
+ query_dim=dim,
504
+ heads=num_attention_heads,
505
+ dim_head=attention_head_dim,
506
+ dropout=dropout,
507
+ bias=attention_bias,
508
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
509
+ upcast_attention=upcast_attention,
510
+ processor=JointAttnProcessor()
511
+ )
512
+ nn.init.zeros_(self.attn_joint_twice.to_out[0].weight.data)
513
+ self.norm_joint_twice = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
514
+
515
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
516
+ # Sets chunk feed-forward
517
+ self._chunk_size = chunk_size
518
+ self._chunk_dim = dim
519
+
520
+ def forward(
521
+ self,
522
+ hidden_states: torch.FloatTensor,
523
+ attention_mask: Optional[torch.FloatTensor] = None,
524
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
525
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
526
+ timestep: Optional[torch.LongTensor] = None,
527
+ cross_attention_kwargs: Dict[str, Any] = None,
528
+ class_labels: Optional[torch.LongTensor] = None,
529
+ ):
530
+ assert attention_mask is None # not supported yet
531
+ # Notice that normalization is always applied before the real computation in the following blocks.
532
+ # 1. Self-Attention
533
+ if self.use_ada_layer_norm:
534
+ norm_hidden_states = self.norm1(hidden_states, timestep)
535
+ elif self.use_ada_layer_norm_zero:
536
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
537
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
538
+ )
539
+ else:
540
+ norm_hidden_states = self.norm1(hidden_states)
541
+
542
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
543
+
544
+ attn_output = self.attn1(
545
+ norm_hidden_states,
546
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
547
+ attention_mask=attention_mask,
548
+ num_views=self.num_views,
549
+ multiview_attention=self.multiview_attention,
550
+ cross_domain_attention=self.cross_domain_attention,
551
+ **cross_attention_kwargs,
552
+ )
553
+
554
+
555
+ if self.use_ada_layer_norm_zero:
556
+ attn_output = gate_msa.unsqueeze(1) * attn_output
557
+ hidden_states = attn_output + hidden_states
558
+
559
+ # joint attention twice
560
+ if self.joint_attention_twice:
561
+ norm_hidden_states = (
562
+ self.norm_joint_twice(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_twice(hidden_states)
563
+ )
564
+ hidden_states = self.attn_joint_twice(norm_hidden_states) + hidden_states
565
+
566
+ # 2. Cross-Attention
567
+ if self.attn2 is not None:
568
+ norm_hidden_states = (
569
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
570
+ )
571
+
572
+ attn_output = self.attn2(
573
+ norm_hidden_states,
574
+ encoder_hidden_states=encoder_hidden_states,
575
+ attention_mask=encoder_attention_mask,
576
+ **cross_attention_kwargs,
577
+ )
578
+ hidden_states = attn_output + hidden_states
579
+
580
+ # 3. Feed-forward
581
+ norm_hidden_states = self.norm3(hidden_states)
582
+
583
+ if self.use_ada_layer_norm_zero:
584
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
585
+
586
+ if self._chunk_size is not None:
587
+ # "feed_forward_chunk_size" can be used to save memory
588
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
589
+ raise ValueError(
590
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
591
+ )
592
+
593
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
594
+ ff_output = torch.cat(
595
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
596
+ dim=self._chunk_dim,
597
+ )
598
+ else:
599
+ ff_output = self.ff(norm_hidden_states)
600
+
601
+ if self.use_ada_layer_norm_zero:
602
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
603
+
604
+ hidden_states = ff_output + hidden_states
605
+
606
+ if self.joint_attention:
607
+ norm_hidden_states = (
608
+ self.norm_joint(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint(hidden_states)
609
+ )
610
+ hidden_states = self.attn_joint(norm_hidden_states) + hidden_states
611
+
612
+ return hidden_states
613
+
614
+
615
+ class CustomAttention(Attention):
616
+ def set_use_memory_efficient_attention_xformers(
617
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
618
+ ):
619
+ processor = XFormersMVAttnProcessor()
620
+ self.set_processor(processor)
621
+ # print("using xformers attention processor")
622
+
623
+
624
+ class CustomJointAttention(Attention):
625
+ def set_use_memory_efficient_attention_xformers(
626
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
627
+ ):
628
+ processor = XFormersJointAttnProcessor()
629
+ self.set_processor(processor)
630
+ # print("using xformers attention processor")
631
+
632
+ class MVAttnProcessor:
633
+ r"""
634
+ Default processor for performing attention-related computations.
635
+ """
636
+
637
+ def __call__(
638
+ self,
639
+ attn: Attention,
640
+ hidden_states,
641
+ encoder_hidden_states=None,
642
+ attention_mask=None,
643
+ temb=None,
644
+ num_views=1,
645
+ multiview_attention=True
646
+ ):
647
+ residual = hidden_states
648
+
649
+ if attn.spatial_norm is not None:
650
+ hidden_states = attn.spatial_norm(hidden_states, temb)
651
+
652
+ input_ndim = hidden_states.ndim
653
+
654
+ if input_ndim == 4:
655
+ batch_size, channel, height, width = hidden_states.shape
656
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
657
+
658
+ batch_size, sequence_length, _ = (
659
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
660
+ )
661
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
662
+
663
+ if attn.group_norm is not None:
664
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
665
+
666
+ query = attn.to_q(hidden_states)
667
+
668
+ if encoder_hidden_states is None:
669
+ encoder_hidden_states = hidden_states
670
+ elif attn.norm_cross:
671
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
672
+
673
+ key = attn.to_k(encoder_hidden_states)
674
+ value = attn.to_v(encoder_hidden_states)
675
+
676
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
677
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
678
+ # pdb.set_trace()
679
+ # multi-view self-attention
680
+ if multiview_attention:
681
+ if num_views <= 6:
682
+ # after use xformer; possible to train with 6 views
683
+ key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
684
+ value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
685
+ else:# apply sparse attention
686
+ pass
687
+ # print("use sparse attention")
688
+ # # seems that the sparse random sampling cause problems
689
+ # # don't use random sampling, just fix the indexes
690
+ # onekey = rearrange(key, "(b t) d c -> b t d c", t=num_views)
691
+ # onevalue = rearrange(value, "(b t) d c -> b t d c", t=num_views)
692
+ # allkeys = []
693
+ # allvalues = []
694
+ # all_indexes = {
695
+ # 0 : [0, 2, 3, 4],
696
+ # 1: [0, 1, 3, 5],
697
+ # 2: [0, 2, 3, 4],
698
+ # 3: [0, 2, 3, 4],
699
+ # 4: [0, 2, 3, 4],
700
+ # 5: [0, 1, 3, 5]
701
+ # }
702
+ # for jj in range(num_views):
703
+ # # valid_index = [x for x in range(0, num_views) if x!= jj]
704
+ # # indexes = random.sample(valid_index, 3) + [jj] + [0]
705
+ # indexes = all_indexes[jj]
706
+
707
+ # indexes = torch.tensor(indexes).long().to(key.device)
708
+ # allkeys.append(onekey[:, indexes])
709
+ # allvalues.append(onevalue[:, indexes])
710
+ # keys = torch.stack(allkeys, dim=1) # checked, should be dim=1
711
+ # values = torch.stack(allvalues, dim=1)
712
+ # key = rearrange(keys, 'b t f d c -> (b t) (f d) c')
713
+ # value = rearrange(values, 'b t f d c -> (b t) (f d) c')
714
+
715
+
716
+ query = attn.head_to_batch_dim(query).contiguous()
717
+ key = attn.head_to_batch_dim(key).contiguous()
718
+ value = attn.head_to_batch_dim(value).contiguous()
719
+
720
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
721
+ hidden_states = torch.bmm(attention_probs, value)
722
+ hidden_states = attn.batch_to_head_dim(hidden_states)
723
+
724
+ # linear proj
725
+ hidden_states = attn.to_out[0](hidden_states)
726
+ # dropout
727
+ hidden_states = attn.to_out[1](hidden_states)
728
+
729
+ if input_ndim == 4:
730
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
731
+
732
+ if attn.residual_connection:
733
+ hidden_states = hidden_states + residual
734
+
735
+ hidden_states = hidden_states / attn.rescale_output_factor
736
+
737
+ return hidden_states
738
+
739
+
740
+ class XFormersMVAttnProcessor:
741
+ r"""
742
+ Default processor for performing attention-related computations.
743
+ """
744
+
745
+ def __call__(
746
+ self,
747
+ attn: Attention,
748
+ hidden_states,
749
+ encoder_hidden_states=None,
750
+ attention_mask=None,
751
+ temb=None,
752
+ num_views=1.,
753
+ multiview_attention=True,
754
+ cross_domain_attention=False,
755
+ ):
756
+ residual = hidden_states
757
+
758
+ if attn.spatial_norm is not None:
759
+ hidden_states = attn.spatial_norm(hidden_states, temb)
760
+
761
+ input_ndim = hidden_states.ndim
762
+
763
+ if input_ndim == 4:
764
+ batch_size, channel, height, width = hidden_states.shape
765
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
766
+
767
+ batch_size, sequence_length, _ = (
768
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
769
+ )
770
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
771
+
772
+ # from yuancheng; here attention_mask is None
773
+ if attention_mask is not None:
774
+ # expand our mask's singleton query_tokens dimension:
775
+ # [batch*heads, 1, key_tokens] ->
776
+ # [batch*heads, query_tokens, key_tokens]
777
+ # so that it can be added as a bias onto the attention scores that xformers computes:
778
+ # [batch*heads, query_tokens, key_tokens]
779
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
780
+ _, query_tokens, _ = hidden_states.shape
781
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
782
+
783
+ if attn.group_norm is not None:
784
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
785
+
786
+ query = attn.to_q(hidden_states)
787
+
788
+ if encoder_hidden_states is None:
789
+ encoder_hidden_states = hidden_states
790
+ elif attn.norm_cross:
791
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
792
+
793
+ key_raw = attn.to_k(encoder_hidden_states)
794
+ value_raw = attn.to_v(encoder_hidden_states)
795
+
796
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
797
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
798
+ # pdb.set_trace()
799
+ # multi-view self-attention
800
+ if multiview_attention:
801
+ key = rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
802
+ value = rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
803
+
804
+ if cross_domain_attention:
805
+ # memory efficient, cross domain attention
806
+ key_0, key_1 = torch.chunk(key_raw, dim=0, chunks=2) # keys shape (b t) d c
807
+ value_0, value_1 = torch.chunk(value_raw, dim=0, chunks=2)
808
+ key_cross = torch.concat([key_1, key_0], dim=0)
809
+ value_cross = torch.concat([value_1, value_0], dim=0) # shape (b t) d c
810
+ key = torch.cat([key, key_cross], dim=1)
811
+ value = torch.cat([value, value_cross], dim=1) # shape (b t) (t+1 d) c
812
+ else:
813
+ # print("don't use multiview attention.")
814
+ key = key_raw
815
+ value = value_raw
816
+
817
+ query = attn.head_to_batch_dim(query)
818
+ key = attn.head_to_batch_dim(key)
819
+ value = attn.head_to_batch_dim(value)
820
+
821
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
822
+ hidden_states = attn.batch_to_head_dim(hidden_states)
823
+
824
+ # linear proj
825
+ hidden_states = attn.to_out[0](hidden_states)
826
+ # dropout
827
+ hidden_states = attn.to_out[1](hidden_states)
828
+
829
+ if input_ndim == 4:
830
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
831
+
832
+ if attn.residual_connection:
833
+ hidden_states = hidden_states + residual
834
+
835
+ hidden_states = hidden_states / attn.rescale_output_factor
836
+
837
+ return hidden_states
838
+
839
+
840
+
841
+ class XFormersJointAttnProcessor:
842
+ r"""
843
+ Default processor for performing attention-related computations.
844
+ """
845
+
846
+ def __call__(
847
+ self,
848
+ attn: Attention,
849
+ hidden_states,
850
+ encoder_hidden_states=None,
851
+ attention_mask=None,
852
+ temb=None,
853
+ num_tasks=2
854
+ ):
855
+
856
+ residual = hidden_states
857
+
858
+ if attn.spatial_norm is not None:
859
+ hidden_states = attn.spatial_norm(hidden_states, temb)
860
+
861
+ input_ndim = hidden_states.ndim
862
+
863
+ if input_ndim == 4:
864
+ batch_size, channel, height, width = hidden_states.shape
865
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
866
+
867
+ batch_size, sequence_length, _ = (
868
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
869
+ )
870
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
871
+
872
+ # from yuancheng; here attention_mask is None
873
+ if attention_mask is not None:
874
+ # expand our mask's singleton query_tokens dimension:
875
+ # [batch*heads, 1, key_tokens] ->
876
+ # [batch*heads, query_tokens, key_tokens]
877
+ # so that it can be added as a bias onto the attention scores that xformers computes:
878
+ # [batch*heads, query_tokens, key_tokens]
879
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
880
+ _, query_tokens, _ = hidden_states.shape
881
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
882
+
883
+ if attn.group_norm is not None:
884
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
885
+
886
+ query = attn.to_q(hidden_states)
887
+
888
+ if encoder_hidden_states is None:
889
+ encoder_hidden_states = hidden_states
890
+ elif attn.norm_cross:
891
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
892
+
893
+ key = attn.to_k(encoder_hidden_states)
894
+ value = attn.to_v(encoder_hidden_states)
895
+
896
+ assert num_tasks == 2 # only support two tasks now
897
+
898
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
899
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
900
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
901
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
902
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
903
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
904
+
905
+
906
+ query = attn.head_to_batch_dim(query).contiguous()
907
+ key = attn.head_to_batch_dim(key).contiguous()
908
+ value = attn.head_to_batch_dim(value).contiguous()
909
+
910
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
911
+ hidden_states = attn.batch_to_head_dim(hidden_states)
912
+
913
+ # linear proj
914
+ hidden_states = attn.to_out[0](hidden_states)
915
+ # dropout
916
+ hidden_states = attn.to_out[1](hidden_states)
917
+
918
+ if input_ndim == 4:
919
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
920
+
921
+ if attn.residual_connection:
922
+ hidden_states = hidden_states + residual
923
+
924
+ hidden_states = hidden_states / attn.rescale_output_factor
925
+
926
+ return hidden_states
927
+
928
+
929
+ class JointAttnProcessor:
930
+ r"""
931
+ Default processor for performing attention-related computations.
932
+ """
933
+
934
+ def __call__(
935
+ self,
936
+ attn: Attention,
937
+ hidden_states,
938
+ encoder_hidden_states=None,
939
+ attention_mask=None,
940
+ temb=None,
941
+ num_tasks=2
942
+ ):
943
+
944
+ residual = hidden_states
945
+
946
+ if attn.spatial_norm is not None:
947
+ hidden_states = attn.spatial_norm(hidden_states, temb)
948
+
949
+ input_ndim = hidden_states.ndim
950
+
951
+ if input_ndim == 4:
952
+ batch_size, channel, height, width = hidden_states.shape
953
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
954
+
955
+ batch_size, sequence_length, _ = (
956
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
957
+ )
958
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
959
+
960
+
961
+ if attn.group_norm is not None:
962
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
963
+
964
+ query = attn.to_q(hidden_states)
965
+
966
+ if encoder_hidden_states is None:
967
+ encoder_hidden_states = hidden_states
968
+ elif attn.norm_cross:
969
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
970
+
971
+ key = attn.to_k(encoder_hidden_states)
972
+ value = attn.to_v(encoder_hidden_states)
973
+
974
+ assert num_tasks == 2 # only support two tasks now
975
+
976
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
977
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
978
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
979
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
980
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
981
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
982
+
983
+
984
+ query = attn.head_to_batch_dim(query).contiguous()
985
+ key = attn.head_to_batch_dim(key).contiguous()
986
+ value = attn.head_to_batch_dim(value).contiguous()
987
+
988
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
989
+ hidden_states = torch.bmm(attention_probs, value)
990
+ hidden_states = attn.batch_to_head_dim(hidden_states)
991
+
992
+ # linear proj
993
+ hidden_states = attn.to_out[0](hidden_states)
994
+ # dropout
995
+ hidden_states = attn.to_out[1](hidden_states)
996
+
997
+ if input_ndim == 4:
998
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
999
+
1000
+ if attn.residual_connection:
1001
+ hidden_states = hidden_states + residual
1002
+
1003
+ hidden_states = hidden_states / attn.rescale_output_factor
1004
+
1005
+ return hidden_states
mvdiffusion/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,880 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ from diffusers.models.attention import AdaGroupNorm
23
+ from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+ from mvdiffusion.models.transformer_mv2d import TransformerMV2DModel
27
+
28
+ from diffusers.models.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
29
+ from diffusers.models.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ def get_down_block(
36
+ down_block_type,
37
+ num_layers,
38
+ in_channels,
39
+ out_channels,
40
+ temb_channels,
41
+ add_downsample,
42
+ resnet_eps,
43
+ resnet_act_fn,
44
+ transformer_layers_per_block=1,
45
+ num_attention_heads=None,
46
+ resnet_groups=None,
47
+ cross_attention_dim=None,
48
+ downsample_padding=None,
49
+ dual_cross_attention=False,
50
+ use_linear_projection=False,
51
+ only_cross_attention=False,
52
+ upcast_attention=False,
53
+ resnet_time_scale_shift="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ num_views=1
60
+ ):
61
+ # If attn head dim is not defined, we default it to the number of heads
62
+ if attention_head_dim is None:
63
+ logger.warn(
64
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
65
+ )
66
+ attention_head_dim = num_attention_heads
67
+
68
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
69
+ if down_block_type == "DownBlock2D":
70
+ return DownBlock2D(
71
+ num_layers=num_layers,
72
+ in_channels=in_channels,
73
+ out_channels=out_channels,
74
+ temb_channels=temb_channels,
75
+ add_downsample=add_downsample,
76
+ resnet_eps=resnet_eps,
77
+ resnet_act_fn=resnet_act_fn,
78
+ resnet_groups=resnet_groups,
79
+ downsample_padding=downsample_padding,
80
+ resnet_time_scale_shift=resnet_time_scale_shift,
81
+ )
82
+ elif down_block_type == "ResnetDownsampleBlock2D":
83
+ return ResnetDownsampleBlock2D(
84
+ num_layers=num_layers,
85
+ in_channels=in_channels,
86
+ out_channels=out_channels,
87
+ temb_channels=temb_channels,
88
+ add_downsample=add_downsample,
89
+ resnet_eps=resnet_eps,
90
+ resnet_act_fn=resnet_act_fn,
91
+ resnet_groups=resnet_groups,
92
+ resnet_time_scale_shift=resnet_time_scale_shift,
93
+ skip_time_act=resnet_skip_time_act,
94
+ output_scale_factor=resnet_out_scale_factor,
95
+ )
96
+ elif down_block_type == "AttnDownBlock2D":
97
+ if add_downsample is False:
98
+ downsample_type = None
99
+ else:
100
+ downsample_type = downsample_type or "conv" # default to 'conv'
101
+ return AttnDownBlock2D(
102
+ num_layers=num_layers,
103
+ in_channels=in_channels,
104
+ out_channels=out_channels,
105
+ temb_channels=temb_channels,
106
+ resnet_eps=resnet_eps,
107
+ resnet_act_fn=resnet_act_fn,
108
+ resnet_groups=resnet_groups,
109
+ downsample_padding=downsample_padding,
110
+ attention_head_dim=attention_head_dim,
111
+ resnet_time_scale_shift=resnet_time_scale_shift,
112
+ downsample_type=downsample_type,
113
+ )
114
+ elif down_block_type == "CrossAttnDownBlock2D":
115
+ if cross_attention_dim is None:
116
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
117
+ return CrossAttnDownBlock2D(
118
+ num_layers=num_layers,
119
+ transformer_layers_per_block=transformer_layers_per_block,
120
+ in_channels=in_channels,
121
+ out_channels=out_channels,
122
+ temb_channels=temb_channels,
123
+ add_downsample=add_downsample,
124
+ resnet_eps=resnet_eps,
125
+ resnet_act_fn=resnet_act_fn,
126
+ resnet_groups=resnet_groups,
127
+ downsample_padding=downsample_padding,
128
+ cross_attention_dim=cross_attention_dim,
129
+ num_attention_heads=num_attention_heads,
130
+ dual_cross_attention=dual_cross_attention,
131
+ use_linear_projection=use_linear_projection,
132
+ only_cross_attention=only_cross_attention,
133
+ upcast_attention=upcast_attention,
134
+ resnet_time_scale_shift=resnet_time_scale_shift,
135
+ )
136
+ # custom MV2D attention block
137
+ elif down_block_type == "CrossAttnDownBlockMV2D":
138
+ if cross_attention_dim is None:
139
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
140
+ return CrossAttnDownBlockMV2D(
141
+ num_layers=num_layers,
142
+ transformer_layers_per_block=transformer_layers_per_block,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ temb_channels=temb_channels,
146
+ add_downsample=add_downsample,
147
+ resnet_eps=resnet_eps,
148
+ resnet_act_fn=resnet_act_fn,
149
+ resnet_groups=resnet_groups,
150
+ downsample_padding=downsample_padding,
151
+ cross_attention_dim=cross_attention_dim,
152
+ num_attention_heads=num_attention_heads,
153
+ dual_cross_attention=dual_cross_attention,
154
+ use_linear_projection=use_linear_projection,
155
+ only_cross_attention=only_cross_attention,
156
+ upcast_attention=upcast_attention,
157
+ resnet_time_scale_shift=resnet_time_scale_shift,
158
+ num_views=num_views
159
+ )
160
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
161
+ if cross_attention_dim is None:
162
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
163
+ return SimpleCrossAttnDownBlock2D(
164
+ num_layers=num_layers,
165
+ in_channels=in_channels,
166
+ out_channels=out_channels,
167
+ temb_channels=temb_channels,
168
+ add_downsample=add_downsample,
169
+ resnet_eps=resnet_eps,
170
+ resnet_act_fn=resnet_act_fn,
171
+ resnet_groups=resnet_groups,
172
+ cross_attention_dim=cross_attention_dim,
173
+ attention_head_dim=attention_head_dim,
174
+ resnet_time_scale_shift=resnet_time_scale_shift,
175
+ skip_time_act=resnet_skip_time_act,
176
+ output_scale_factor=resnet_out_scale_factor,
177
+ only_cross_attention=only_cross_attention,
178
+ cross_attention_norm=cross_attention_norm,
179
+ )
180
+ elif down_block_type == "SkipDownBlock2D":
181
+ return SkipDownBlock2D(
182
+ num_layers=num_layers,
183
+ in_channels=in_channels,
184
+ out_channels=out_channels,
185
+ temb_channels=temb_channels,
186
+ add_downsample=add_downsample,
187
+ resnet_eps=resnet_eps,
188
+ resnet_act_fn=resnet_act_fn,
189
+ downsample_padding=downsample_padding,
190
+ resnet_time_scale_shift=resnet_time_scale_shift,
191
+ )
192
+ elif down_block_type == "AttnSkipDownBlock2D":
193
+ return AttnSkipDownBlock2D(
194
+ num_layers=num_layers,
195
+ in_channels=in_channels,
196
+ out_channels=out_channels,
197
+ temb_channels=temb_channels,
198
+ add_downsample=add_downsample,
199
+ resnet_eps=resnet_eps,
200
+ resnet_act_fn=resnet_act_fn,
201
+ attention_head_dim=attention_head_dim,
202
+ resnet_time_scale_shift=resnet_time_scale_shift,
203
+ )
204
+ elif down_block_type == "DownEncoderBlock2D":
205
+ return DownEncoderBlock2D(
206
+ num_layers=num_layers,
207
+ in_channels=in_channels,
208
+ out_channels=out_channels,
209
+ add_downsample=add_downsample,
210
+ resnet_eps=resnet_eps,
211
+ resnet_act_fn=resnet_act_fn,
212
+ resnet_groups=resnet_groups,
213
+ downsample_padding=downsample_padding,
214
+ resnet_time_scale_shift=resnet_time_scale_shift,
215
+ )
216
+ elif down_block_type == "AttnDownEncoderBlock2D":
217
+ return AttnDownEncoderBlock2D(
218
+ num_layers=num_layers,
219
+ in_channels=in_channels,
220
+ out_channels=out_channels,
221
+ add_downsample=add_downsample,
222
+ resnet_eps=resnet_eps,
223
+ resnet_act_fn=resnet_act_fn,
224
+ resnet_groups=resnet_groups,
225
+ downsample_padding=downsample_padding,
226
+ attention_head_dim=attention_head_dim,
227
+ resnet_time_scale_shift=resnet_time_scale_shift,
228
+ )
229
+ elif down_block_type == "KDownBlock2D":
230
+ return KDownBlock2D(
231
+ num_layers=num_layers,
232
+ in_channels=in_channels,
233
+ out_channels=out_channels,
234
+ temb_channels=temb_channels,
235
+ add_downsample=add_downsample,
236
+ resnet_eps=resnet_eps,
237
+ resnet_act_fn=resnet_act_fn,
238
+ )
239
+ elif down_block_type == "KCrossAttnDownBlock2D":
240
+ return KCrossAttnDownBlock2D(
241
+ num_layers=num_layers,
242
+ in_channels=in_channels,
243
+ out_channels=out_channels,
244
+ temb_channels=temb_channels,
245
+ add_downsample=add_downsample,
246
+ resnet_eps=resnet_eps,
247
+ resnet_act_fn=resnet_act_fn,
248
+ cross_attention_dim=cross_attention_dim,
249
+ attention_head_dim=attention_head_dim,
250
+ add_self_attention=True if not add_downsample else False,
251
+ )
252
+ raise ValueError(f"{down_block_type} does not exist.")
253
+
254
+
255
+ def get_up_block(
256
+ up_block_type,
257
+ num_layers,
258
+ in_channels,
259
+ out_channels,
260
+ prev_output_channel,
261
+ temb_channels,
262
+ add_upsample,
263
+ resnet_eps,
264
+ resnet_act_fn,
265
+ transformer_layers_per_block=1,
266
+ num_attention_heads=None,
267
+ resnet_groups=None,
268
+ cross_attention_dim=None,
269
+ dual_cross_attention=False,
270
+ use_linear_projection=False,
271
+ only_cross_attention=False,
272
+ upcast_attention=False,
273
+ resnet_time_scale_shift="default",
274
+ resnet_skip_time_act=False,
275
+ resnet_out_scale_factor=1.0,
276
+ cross_attention_norm=None,
277
+ attention_head_dim=None,
278
+ upsample_type=None,
279
+ num_views=1
280
+ ):
281
+ # If attn head dim is not defined, we default it to the number of heads
282
+ if attention_head_dim is None:
283
+ logger.warn(
284
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
285
+ )
286
+ attention_head_dim = num_attention_heads
287
+
288
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
289
+ if up_block_type == "UpBlock2D":
290
+ return UpBlock2D(
291
+ num_layers=num_layers,
292
+ in_channels=in_channels,
293
+ out_channels=out_channels,
294
+ prev_output_channel=prev_output_channel,
295
+ temb_channels=temb_channels,
296
+ add_upsample=add_upsample,
297
+ resnet_eps=resnet_eps,
298
+ resnet_act_fn=resnet_act_fn,
299
+ resnet_groups=resnet_groups,
300
+ resnet_time_scale_shift=resnet_time_scale_shift,
301
+ )
302
+ elif up_block_type == "ResnetUpsampleBlock2D":
303
+ return ResnetUpsampleBlock2D(
304
+ num_layers=num_layers,
305
+ in_channels=in_channels,
306
+ out_channels=out_channels,
307
+ prev_output_channel=prev_output_channel,
308
+ temb_channels=temb_channels,
309
+ add_upsample=add_upsample,
310
+ resnet_eps=resnet_eps,
311
+ resnet_act_fn=resnet_act_fn,
312
+ resnet_groups=resnet_groups,
313
+ resnet_time_scale_shift=resnet_time_scale_shift,
314
+ skip_time_act=resnet_skip_time_act,
315
+ output_scale_factor=resnet_out_scale_factor,
316
+ )
317
+ elif up_block_type == "CrossAttnUpBlock2D":
318
+ if cross_attention_dim is None:
319
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
320
+ return CrossAttnUpBlock2D(
321
+ num_layers=num_layers,
322
+ transformer_layers_per_block=transformer_layers_per_block,
323
+ in_channels=in_channels,
324
+ out_channels=out_channels,
325
+ prev_output_channel=prev_output_channel,
326
+ temb_channels=temb_channels,
327
+ add_upsample=add_upsample,
328
+ resnet_eps=resnet_eps,
329
+ resnet_act_fn=resnet_act_fn,
330
+ resnet_groups=resnet_groups,
331
+ cross_attention_dim=cross_attention_dim,
332
+ num_attention_heads=num_attention_heads,
333
+ dual_cross_attention=dual_cross_attention,
334
+ use_linear_projection=use_linear_projection,
335
+ only_cross_attention=only_cross_attention,
336
+ upcast_attention=upcast_attention,
337
+ resnet_time_scale_shift=resnet_time_scale_shift,
338
+ )
339
+ # custom MV2D attention block
340
+ elif up_block_type == "CrossAttnUpBlockMV2D":
341
+ if cross_attention_dim is None:
342
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
343
+ return CrossAttnUpBlockMV2D(
344
+ num_layers=num_layers,
345
+ transformer_layers_per_block=transformer_layers_per_block,
346
+ in_channels=in_channels,
347
+ out_channels=out_channels,
348
+ prev_output_channel=prev_output_channel,
349
+ temb_channels=temb_channels,
350
+ add_upsample=add_upsample,
351
+ resnet_eps=resnet_eps,
352
+ resnet_act_fn=resnet_act_fn,
353
+ resnet_groups=resnet_groups,
354
+ cross_attention_dim=cross_attention_dim,
355
+ num_attention_heads=num_attention_heads,
356
+ dual_cross_attention=dual_cross_attention,
357
+ use_linear_projection=use_linear_projection,
358
+ only_cross_attention=only_cross_attention,
359
+ upcast_attention=upcast_attention,
360
+ resnet_time_scale_shift=resnet_time_scale_shift,
361
+ num_views=num_views
362
+ )
363
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
364
+ if cross_attention_dim is None:
365
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
366
+ return SimpleCrossAttnUpBlock2D(
367
+ num_layers=num_layers,
368
+ in_channels=in_channels,
369
+ out_channels=out_channels,
370
+ prev_output_channel=prev_output_channel,
371
+ temb_channels=temb_channels,
372
+ add_upsample=add_upsample,
373
+ resnet_eps=resnet_eps,
374
+ resnet_act_fn=resnet_act_fn,
375
+ resnet_groups=resnet_groups,
376
+ cross_attention_dim=cross_attention_dim,
377
+ attention_head_dim=attention_head_dim,
378
+ resnet_time_scale_shift=resnet_time_scale_shift,
379
+ skip_time_act=resnet_skip_time_act,
380
+ output_scale_factor=resnet_out_scale_factor,
381
+ only_cross_attention=only_cross_attention,
382
+ cross_attention_norm=cross_attention_norm,
383
+ )
384
+ elif up_block_type == "AttnUpBlock2D":
385
+ if add_upsample is False:
386
+ upsample_type = None
387
+ else:
388
+ upsample_type = upsample_type or "conv" # default to 'conv'
389
+
390
+ return AttnUpBlock2D(
391
+ num_layers=num_layers,
392
+ in_channels=in_channels,
393
+ out_channels=out_channels,
394
+ prev_output_channel=prev_output_channel,
395
+ temb_channels=temb_channels,
396
+ resnet_eps=resnet_eps,
397
+ resnet_act_fn=resnet_act_fn,
398
+ resnet_groups=resnet_groups,
399
+ attention_head_dim=attention_head_dim,
400
+ resnet_time_scale_shift=resnet_time_scale_shift,
401
+ upsample_type=upsample_type,
402
+ )
403
+ elif up_block_type == "SkipUpBlock2D":
404
+ return SkipUpBlock2D(
405
+ num_layers=num_layers,
406
+ in_channels=in_channels,
407
+ out_channels=out_channels,
408
+ prev_output_channel=prev_output_channel,
409
+ temb_channels=temb_channels,
410
+ add_upsample=add_upsample,
411
+ resnet_eps=resnet_eps,
412
+ resnet_act_fn=resnet_act_fn,
413
+ resnet_time_scale_shift=resnet_time_scale_shift,
414
+ )
415
+ elif up_block_type == "AttnSkipUpBlock2D":
416
+ return AttnSkipUpBlock2D(
417
+ num_layers=num_layers,
418
+ in_channels=in_channels,
419
+ out_channels=out_channels,
420
+ prev_output_channel=prev_output_channel,
421
+ temb_channels=temb_channels,
422
+ add_upsample=add_upsample,
423
+ resnet_eps=resnet_eps,
424
+ resnet_act_fn=resnet_act_fn,
425
+ attention_head_dim=attention_head_dim,
426
+ resnet_time_scale_shift=resnet_time_scale_shift,
427
+ )
428
+ elif up_block_type == "UpDecoderBlock2D":
429
+ return UpDecoderBlock2D(
430
+ num_layers=num_layers,
431
+ in_channels=in_channels,
432
+ out_channels=out_channels,
433
+ add_upsample=add_upsample,
434
+ resnet_eps=resnet_eps,
435
+ resnet_act_fn=resnet_act_fn,
436
+ resnet_groups=resnet_groups,
437
+ resnet_time_scale_shift=resnet_time_scale_shift,
438
+ temb_channels=temb_channels,
439
+ )
440
+ elif up_block_type == "AttnUpDecoderBlock2D":
441
+ return AttnUpDecoderBlock2D(
442
+ num_layers=num_layers,
443
+ in_channels=in_channels,
444
+ out_channels=out_channels,
445
+ add_upsample=add_upsample,
446
+ resnet_eps=resnet_eps,
447
+ resnet_act_fn=resnet_act_fn,
448
+ resnet_groups=resnet_groups,
449
+ attention_head_dim=attention_head_dim,
450
+ resnet_time_scale_shift=resnet_time_scale_shift,
451
+ temb_channels=temb_channels,
452
+ )
453
+ elif up_block_type == "KUpBlock2D":
454
+ return KUpBlock2D(
455
+ num_layers=num_layers,
456
+ in_channels=in_channels,
457
+ out_channels=out_channels,
458
+ temb_channels=temb_channels,
459
+ add_upsample=add_upsample,
460
+ resnet_eps=resnet_eps,
461
+ resnet_act_fn=resnet_act_fn,
462
+ )
463
+ elif up_block_type == "KCrossAttnUpBlock2D":
464
+ return KCrossAttnUpBlock2D(
465
+ num_layers=num_layers,
466
+ in_channels=in_channels,
467
+ out_channels=out_channels,
468
+ temb_channels=temb_channels,
469
+ add_upsample=add_upsample,
470
+ resnet_eps=resnet_eps,
471
+ resnet_act_fn=resnet_act_fn,
472
+ cross_attention_dim=cross_attention_dim,
473
+ attention_head_dim=attention_head_dim,
474
+ )
475
+
476
+ raise ValueError(f"{up_block_type} does not exist.")
477
+
478
+
479
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
480
+ def __init__(
481
+ self,
482
+ in_channels: int,
483
+ temb_channels: int,
484
+ dropout: float = 0.0,
485
+ num_layers: int = 1,
486
+ transformer_layers_per_block: int = 1,
487
+ resnet_eps: float = 1e-6,
488
+ resnet_time_scale_shift: str = "default",
489
+ resnet_act_fn: str = "swish",
490
+ resnet_groups: int = 32,
491
+ resnet_pre_norm: bool = True,
492
+ num_attention_heads=1,
493
+ output_scale_factor=1.0,
494
+ cross_attention_dim=1280,
495
+ dual_cross_attention=False,
496
+ use_linear_projection=False,
497
+ upcast_attention=False,
498
+ num_views: int = 1,
499
+ joint_attention: bool = False,
500
+ joint_attention_twice: bool = False,
501
+ multiview_attention: bool = True,
502
+ cross_domain_attention: bool=False
503
+ ):
504
+ super().__init__()
505
+
506
+ self.has_cross_attention = True
507
+ self.num_attention_heads = num_attention_heads
508
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
509
+
510
+ # there is always at least one resnet
511
+ resnets = [
512
+ ResnetBlock2D(
513
+ in_channels=in_channels,
514
+ out_channels=in_channels,
515
+ temb_channels=temb_channels,
516
+ eps=resnet_eps,
517
+ groups=resnet_groups,
518
+ dropout=dropout,
519
+ time_embedding_norm=resnet_time_scale_shift,
520
+ non_linearity=resnet_act_fn,
521
+ output_scale_factor=output_scale_factor,
522
+ pre_norm=resnet_pre_norm,
523
+ )
524
+ ]
525
+ attentions = []
526
+
527
+ for _ in range(num_layers):
528
+ if not dual_cross_attention:
529
+ attentions.append(
530
+ TransformerMV2DModel(
531
+ num_attention_heads,
532
+ in_channels // num_attention_heads,
533
+ in_channels=in_channels,
534
+ num_layers=transformer_layers_per_block,
535
+ cross_attention_dim=cross_attention_dim,
536
+ norm_num_groups=resnet_groups,
537
+ use_linear_projection=use_linear_projection,
538
+ upcast_attention=upcast_attention,
539
+ num_views=num_views,
540
+ joint_attention=joint_attention,
541
+ joint_attention_twice=joint_attention_twice,
542
+ multiview_attention=multiview_attention,
543
+ cross_domain_attention=cross_domain_attention
544
+ )
545
+ )
546
+ else:
547
+ raise NotImplementedError
548
+ resnets.append(
549
+ ResnetBlock2D(
550
+ in_channels=in_channels,
551
+ out_channels=in_channels,
552
+ temb_channels=temb_channels,
553
+ eps=resnet_eps,
554
+ groups=resnet_groups,
555
+ dropout=dropout,
556
+ time_embedding_norm=resnet_time_scale_shift,
557
+ non_linearity=resnet_act_fn,
558
+ output_scale_factor=output_scale_factor,
559
+ pre_norm=resnet_pre_norm,
560
+ )
561
+ )
562
+
563
+ self.attentions = nn.ModuleList(attentions)
564
+ self.resnets = nn.ModuleList(resnets)
565
+
566
+ def forward(
567
+ self,
568
+ hidden_states: torch.FloatTensor,
569
+ temb: Optional[torch.FloatTensor] = None,
570
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
571
+ attention_mask: Optional[torch.FloatTensor] = None,
572
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
573
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
574
+ ) -> torch.FloatTensor:
575
+ hidden_states = self.resnets[0](hidden_states, temb)
576
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
577
+ hidden_states = attn(
578
+ hidden_states,
579
+ encoder_hidden_states=encoder_hidden_states,
580
+ cross_attention_kwargs=cross_attention_kwargs,
581
+ attention_mask=attention_mask,
582
+ encoder_attention_mask=encoder_attention_mask,
583
+ return_dict=False,
584
+ )[0]
585
+ hidden_states = resnet(hidden_states, temb)
586
+
587
+ return hidden_states
588
+
589
+
590
+ class CrossAttnUpBlockMV2D(nn.Module):
591
+ def __init__(
592
+ self,
593
+ in_channels: int,
594
+ out_channels: int,
595
+ prev_output_channel: int,
596
+ temb_channels: int,
597
+ dropout: float = 0.0,
598
+ num_layers: int = 1,
599
+ transformer_layers_per_block: int = 1,
600
+ resnet_eps: float = 1e-6,
601
+ resnet_time_scale_shift: str = "default",
602
+ resnet_act_fn: str = "swish",
603
+ resnet_groups: int = 32,
604
+ resnet_pre_norm: bool = True,
605
+ num_attention_heads=1,
606
+ cross_attention_dim=1280,
607
+ output_scale_factor=1.0,
608
+ add_upsample=True,
609
+ dual_cross_attention=False,
610
+ use_linear_projection=False,
611
+ only_cross_attention=False,
612
+ upcast_attention=False,
613
+ num_views: int = 1
614
+ ):
615
+ super().__init__()
616
+ resnets = []
617
+ attentions = []
618
+
619
+ self.has_cross_attention = True
620
+ self.num_attention_heads = num_attention_heads
621
+
622
+ for i in range(num_layers):
623
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
624
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
625
+
626
+ resnets.append(
627
+ ResnetBlock2D(
628
+ in_channels=resnet_in_channels + res_skip_channels,
629
+ out_channels=out_channels,
630
+ temb_channels=temb_channels,
631
+ eps=resnet_eps,
632
+ groups=resnet_groups,
633
+ dropout=dropout,
634
+ time_embedding_norm=resnet_time_scale_shift,
635
+ non_linearity=resnet_act_fn,
636
+ output_scale_factor=output_scale_factor,
637
+ pre_norm=resnet_pre_norm,
638
+ )
639
+ )
640
+ if not dual_cross_attention:
641
+ attentions.append(
642
+ TransformerMV2DModel(
643
+ num_attention_heads,
644
+ out_channels // num_attention_heads,
645
+ in_channels=out_channels,
646
+ num_layers=transformer_layers_per_block,
647
+ cross_attention_dim=cross_attention_dim,
648
+ norm_num_groups=resnet_groups,
649
+ use_linear_projection=use_linear_projection,
650
+ only_cross_attention=only_cross_attention,
651
+ upcast_attention=upcast_attention,
652
+ num_views=num_views
653
+ )
654
+ )
655
+ else:
656
+ raise NotImplementedError
657
+ self.attentions = nn.ModuleList(attentions)
658
+ self.resnets = nn.ModuleList(resnets)
659
+
660
+ if add_upsample:
661
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
662
+ else:
663
+ self.upsamplers = None
664
+
665
+ self.gradient_checkpointing = False
666
+
667
+ def forward(
668
+ self,
669
+ hidden_states: torch.FloatTensor,
670
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
671
+ temb: Optional[torch.FloatTensor] = None,
672
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
673
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
674
+ upsample_size: Optional[int] = None,
675
+ attention_mask: Optional[torch.FloatTensor] = None,
676
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
677
+ ):
678
+ for resnet, attn in zip(self.resnets, self.attentions):
679
+ # pop res hidden states
680
+ res_hidden_states = res_hidden_states_tuple[-1]
681
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
682
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
683
+
684
+ if self.training and self.gradient_checkpointing:
685
+
686
+ def create_custom_forward(module, return_dict=None):
687
+ def custom_forward(*inputs):
688
+ if return_dict is not None:
689
+ return module(*inputs, return_dict=return_dict)
690
+ else:
691
+ return module(*inputs)
692
+
693
+ return custom_forward
694
+
695
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
696
+ hidden_states = torch.utils.checkpoint.checkpoint(
697
+ create_custom_forward(resnet),
698
+ hidden_states,
699
+ temb,
700
+ **ckpt_kwargs,
701
+ )
702
+ hidden_states = torch.utils.checkpoint.checkpoint(
703
+ create_custom_forward(attn, return_dict=False),
704
+ hidden_states,
705
+ encoder_hidden_states,
706
+ None, # timestep
707
+ None, # class_labels
708
+ cross_attention_kwargs,
709
+ attention_mask,
710
+ encoder_attention_mask,
711
+ **ckpt_kwargs,
712
+ )[0]
713
+ else:
714
+ hidden_states = resnet(hidden_states, temb)
715
+ hidden_states = attn(
716
+ hidden_states,
717
+ encoder_hidden_states=encoder_hidden_states,
718
+ cross_attention_kwargs=cross_attention_kwargs,
719
+ attention_mask=attention_mask,
720
+ encoder_attention_mask=encoder_attention_mask,
721
+ return_dict=False,
722
+ )[0]
723
+
724
+ if self.upsamplers is not None:
725
+ for upsampler in self.upsamplers:
726
+ hidden_states = upsampler(hidden_states, upsample_size)
727
+
728
+ return hidden_states
729
+
730
+
731
+ class CrossAttnDownBlockMV2D(nn.Module):
732
+ def __init__(
733
+ self,
734
+ in_channels: int,
735
+ out_channels: int,
736
+ temb_channels: int,
737
+ dropout: float = 0.0,
738
+ num_layers: int = 1,
739
+ transformer_layers_per_block: int = 1,
740
+ resnet_eps: float = 1e-6,
741
+ resnet_time_scale_shift: str = "default",
742
+ resnet_act_fn: str = "swish",
743
+ resnet_groups: int = 32,
744
+ resnet_pre_norm: bool = True,
745
+ num_attention_heads=1,
746
+ cross_attention_dim=1280,
747
+ output_scale_factor=1.0,
748
+ downsample_padding=1,
749
+ add_downsample=True,
750
+ dual_cross_attention=False,
751
+ use_linear_projection=False,
752
+ only_cross_attention=False,
753
+ upcast_attention=False,
754
+ num_views: int = 1
755
+ ):
756
+ super().__init__()
757
+ resnets = []
758
+ attentions = []
759
+
760
+ self.has_cross_attention = True
761
+ self.num_attention_heads = num_attention_heads
762
+
763
+ for i in range(num_layers):
764
+ in_channels = in_channels if i == 0 else out_channels
765
+ resnets.append(
766
+ ResnetBlock2D(
767
+ in_channels=in_channels,
768
+ out_channels=out_channels,
769
+ temb_channels=temb_channels,
770
+ eps=resnet_eps,
771
+ groups=resnet_groups,
772
+ dropout=dropout,
773
+ time_embedding_norm=resnet_time_scale_shift,
774
+ non_linearity=resnet_act_fn,
775
+ output_scale_factor=output_scale_factor,
776
+ pre_norm=resnet_pre_norm,
777
+ )
778
+ )
779
+ if not dual_cross_attention:
780
+ attentions.append(
781
+ TransformerMV2DModel(
782
+ num_attention_heads,
783
+ out_channels // num_attention_heads,
784
+ in_channels=out_channels,
785
+ num_layers=transformer_layers_per_block,
786
+ cross_attention_dim=cross_attention_dim,
787
+ norm_num_groups=resnet_groups,
788
+ use_linear_projection=use_linear_projection,
789
+ only_cross_attention=only_cross_attention,
790
+ upcast_attention=upcast_attention,
791
+ num_views=num_views
792
+ )
793
+ )
794
+ else:
795
+ raise NotImplementedError
796
+ self.attentions = nn.ModuleList(attentions)
797
+ self.resnets = nn.ModuleList(resnets)
798
+
799
+ if add_downsample:
800
+ self.downsamplers = nn.ModuleList(
801
+ [
802
+ Downsample2D(
803
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
804
+ )
805
+ ]
806
+ )
807
+ else:
808
+ self.downsamplers = None
809
+
810
+ self.gradient_checkpointing = False
811
+
812
+ def forward(
813
+ self,
814
+ hidden_states: torch.FloatTensor,
815
+ temb: Optional[torch.FloatTensor] = None,
816
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
817
+ attention_mask: Optional[torch.FloatTensor] = None,
818
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
819
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
820
+ additional_residuals=None,
821
+ ):
822
+ output_states = ()
823
+
824
+ blocks = list(zip(self.resnets, self.attentions))
825
+
826
+ for i, (resnet, attn) in enumerate(blocks):
827
+ if self.training and self.gradient_checkpointing:
828
+
829
+ def create_custom_forward(module, return_dict=None):
830
+ def custom_forward(*inputs):
831
+ if return_dict is not None:
832
+ return module(*inputs, return_dict=return_dict)
833
+ else:
834
+ return module(*inputs)
835
+
836
+ return custom_forward
837
+
838
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
839
+ hidden_states = torch.utils.checkpoint.checkpoint(
840
+ create_custom_forward(resnet),
841
+ hidden_states,
842
+ temb,
843
+ **ckpt_kwargs,
844
+ )
845
+ hidden_states = torch.utils.checkpoint.checkpoint(
846
+ create_custom_forward(attn, return_dict=False),
847
+ hidden_states,
848
+ encoder_hidden_states,
849
+ None, # timestep
850
+ None, # class_labels
851
+ cross_attention_kwargs,
852
+ attention_mask,
853
+ encoder_attention_mask,
854
+ **ckpt_kwargs,
855
+ )[0]
856
+ else:
857
+ hidden_states = resnet(hidden_states, temb)
858
+ hidden_states = attn(
859
+ hidden_states,
860
+ encoder_hidden_states=encoder_hidden_states,
861
+ cross_attention_kwargs=cross_attention_kwargs,
862
+ attention_mask=attention_mask,
863
+ encoder_attention_mask=encoder_attention_mask,
864
+ return_dict=False,
865
+ )[0]
866
+
867
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
868
+ if i == len(blocks) - 1 and additional_residuals is not None:
869
+ hidden_states = hidden_states + additional_residuals
870
+
871
+ output_states = output_states + (hidden_states,)
872
+
873
+ if self.downsamplers is not None:
874
+ for downsampler in self.downsamplers:
875
+ hidden_states = downsampler(hidden_states)
876
+
877
+ output_states = output_states + (hidden_states,)
878
+
879
+ return hidden_states, output_states
880
+
mvdiffusion/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
+ from diffusers.models.embeddings import (
28
+ GaussianFourierProjection,
29
+ ImageHintTimeEmbedding,
30
+ ImageProjection,
31
+ ImageTimeEmbedding,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
+ from diffusers.models.unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ CrossAttnUpBlock2D,
42
+ DownBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ UNetMidBlock2DSimpleCrossAttn,
45
+ UpBlock2D,
46
+ )
47
+ from diffusers.utils import (
48
+ CONFIG_NAME,
49
+ DIFFUSERS_CACHE,
50
+ FLAX_WEIGHTS_NAME,
51
+ HF_HUB_OFFLINE,
52
+ SAFETENSORS_WEIGHTS_NAME,
53
+ WEIGHTS_NAME,
54
+ _add_variant,
55
+ _get_model_file,
56
+ deprecate,
57
+ is_accelerate_available,
58
+ is_safetensors_available,
59
+ is_torch_version,
60
+ logging,
61
+ )
62
+ from diffusers import __version__
63
+ from mvdiffusion.models.unet_mv2d_blocks import (
64
+ CrossAttnDownBlockMV2D,
65
+ CrossAttnUpBlockMV2D,
66
+ UNetMidBlockMV2DCrossAttn,
67
+ get_down_block,
68
+ get_up_block,
69
+ )
70
+
71
+
72
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
73
+
74
+
75
+ @dataclass
76
+ class UNetMV2DConditionOutput(BaseOutput):
77
+ """
78
+ The output of [`UNet2DConditionModel`].
79
+
80
+ Args:
81
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
82
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
83
+ """
84
+
85
+ sample: torch.FloatTensor = None
86
+
87
+
88
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
89
+ r"""
90
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
91
+ shaped output.
92
+
93
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
94
+ for all models (such as downloading or saving).
95
+
96
+ Parameters:
97
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
98
+ Height and width of input/output sample.
99
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
100
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
101
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
102
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
103
+ Whether to flip the sin to cos in the time embedding.
104
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
105
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
106
+ The tuple of downsample blocks to use.
107
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
108
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
109
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
110
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
111
+ The tuple of upsample blocks to use.
112
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
113
+ Whether to include self-attention in the basic transformer blocks, see
114
+ [`~models.attention.BasicTransformerBlock`].
115
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
116
+ The tuple of output channels for each block.
117
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
118
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
119
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
120
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
121
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
122
+ If `None`, normalization and activation layers is skipped in post-processing.
123
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
124
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
125
+ The dimension of the cross attention features.
126
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
127
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
128
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
129
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
130
+ encoder_hid_dim (`int`, *optional*, defaults to None):
131
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
132
+ dimension to `cross_attention_dim`.
133
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
134
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
135
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
136
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
137
+ num_attention_heads (`int`, *optional*):
138
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
139
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
140
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
141
+ class_embed_type (`str`, *optional*, defaults to `None`):
142
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
143
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
144
+ addition_embed_type (`str`, *optional*, defaults to `None`):
145
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
146
+ "text". "text" will use the `TextTimeEmbedding` layer.
147
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
148
+ Dimension for the timestep embeddings.
149
+ num_class_embeds (`int`, *optional*, defaults to `None`):
150
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
151
+ class conditioning with `class_embed_type` equal to `None`.
152
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
153
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
154
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
155
+ An optional override for the dimension of the projected time embedding.
156
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
157
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
158
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
159
+ timestep_post_act (`str`, *optional*, defaults to `None`):
160
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
161
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
162
+ The dimension of `cond_proj` layer in the timestep embedding.
163
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
164
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
165
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
166
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
167
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
168
+ embeddings with the class embeddings.
169
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
170
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
171
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
172
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
173
+ otherwise.
174
+ """
175
+
176
+ _supports_gradient_checkpointing = True
177
+
178
+ @register_to_config
179
+ def __init__(
180
+ self,
181
+ sample_size: Optional[int] = None,
182
+ in_channels: int = 4,
183
+ out_channels: int = 4,
184
+ center_input_sample: bool = False,
185
+ flip_sin_to_cos: bool = True,
186
+ freq_shift: int = 0,
187
+ down_block_types: Tuple[str] = (
188
+ "CrossAttnDownBlockMV2D",
189
+ "CrossAttnDownBlockMV2D",
190
+ "CrossAttnDownBlockMV2D",
191
+ "DownBlock2D",
192
+ ),
193
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
194
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
195
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
196
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
197
+ layers_per_block: Union[int, Tuple[int]] = 2,
198
+ downsample_padding: int = 1,
199
+ mid_block_scale_factor: float = 1,
200
+ act_fn: str = "silu",
201
+ norm_num_groups: Optional[int] = 32,
202
+ norm_eps: float = 1e-5,
203
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
204
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
205
+ encoder_hid_dim: Optional[int] = None,
206
+ encoder_hid_dim_type: Optional[str] = None,
207
+ attention_head_dim: Union[int, Tuple[int]] = 8,
208
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
209
+ dual_cross_attention: bool = False,
210
+ use_linear_projection: bool = False,
211
+ class_embed_type: Optional[str] = None,
212
+ addition_embed_type: Optional[str] = None,
213
+ addition_time_embed_dim: Optional[int] = None,
214
+ num_class_embeds: Optional[int] = None,
215
+ upcast_attention: bool = False,
216
+ resnet_time_scale_shift: str = "default",
217
+ resnet_skip_time_act: bool = False,
218
+ resnet_out_scale_factor: int = 1.0,
219
+ time_embedding_type: str = "positional",
220
+ time_embedding_dim: Optional[int] = None,
221
+ time_embedding_act_fn: Optional[str] = None,
222
+ timestep_post_act: Optional[str] = None,
223
+ time_cond_proj_dim: Optional[int] = None,
224
+ conv_in_kernel: int = 3,
225
+ conv_out_kernel: int = 3,
226
+ projection_class_embeddings_input_dim: Optional[int] = None,
227
+ class_embeddings_concat: bool = False,
228
+ mid_block_only_cross_attention: Optional[bool] = None,
229
+ cross_attention_norm: Optional[str] = None,
230
+ addition_embed_type_num_heads=64,
231
+ num_views: int = 1,
232
+ joint_attention: bool = False,
233
+ joint_attention_twice: bool = False,
234
+ multiview_attention: bool = True,
235
+ cross_domain_attention: bool = False
236
+ ):
237
+ super().__init__()
238
+
239
+ self.sample_size = sample_size
240
+
241
+ if num_attention_heads is not None:
242
+ raise ValueError(
243
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
244
+ )
245
+
246
+ # If `num_attention_heads` is not defined (which is the case for most models)
247
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
248
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
249
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
250
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
251
+ # which is why we correct for the naming here.
252
+ num_attention_heads = num_attention_heads or attention_head_dim
253
+
254
+ # Check inputs
255
+ if len(down_block_types) != len(up_block_types):
256
+ raise ValueError(
257
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
258
+ )
259
+
260
+ if len(block_out_channels) != len(down_block_types):
261
+ raise ValueError(
262
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
263
+ )
264
+
265
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
266
+ raise ValueError(
267
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
268
+ )
269
+
270
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
271
+ raise ValueError(
272
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
273
+ )
274
+
275
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
276
+ raise ValueError(
277
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
278
+ )
279
+
280
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
281
+ raise ValueError(
282
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
283
+ )
284
+
285
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
286
+ raise ValueError(
287
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
288
+ )
289
+
290
+ # input
291
+ conv_in_padding = (conv_in_kernel - 1) // 2
292
+ self.conv_in = nn.Conv2d(
293
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
294
+ )
295
+
296
+ # time
297
+ if time_embedding_type == "fourier":
298
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
299
+ if time_embed_dim % 2 != 0:
300
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
301
+ self.time_proj = GaussianFourierProjection(
302
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
303
+ )
304
+ timestep_input_dim = time_embed_dim
305
+ elif time_embedding_type == "positional":
306
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
307
+
308
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
309
+ timestep_input_dim = block_out_channels[0]
310
+ else:
311
+ raise ValueError(
312
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
313
+ )
314
+
315
+ self.time_embedding = TimestepEmbedding(
316
+ timestep_input_dim,
317
+ time_embed_dim,
318
+ act_fn=act_fn,
319
+ post_act_fn=timestep_post_act,
320
+ cond_proj_dim=time_cond_proj_dim,
321
+ )
322
+
323
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
324
+ encoder_hid_dim_type = "text_proj"
325
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
326
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
327
+
328
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
329
+ raise ValueError(
330
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
331
+ )
332
+
333
+ if encoder_hid_dim_type == "text_proj":
334
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
335
+ elif encoder_hid_dim_type == "text_image_proj":
336
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
337
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
338
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
339
+ self.encoder_hid_proj = TextImageProjection(
340
+ text_embed_dim=encoder_hid_dim,
341
+ image_embed_dim=cross_attention_dim,
342
+ cross_attention_dim=cross_attention_dim,
343
+ )
344
+ elif encoder_hid_dim_type == "image_proj":
345
+ # Kandinsky 2.2
346
+ self.encoder_hid_proj = ImageProjection(
347
+ image_embed_dim=encoder_hid_dim,
348
+ cross_attention_dim=cross_attention_dim,
349
+ )
350
+ elif encoder_hid_dim_type is not None:
351
+ raise ValueError(
352
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
353
+ )
354
+ else:
355
+ self.encoder_hid_proj = None
356
+
357
+ # class embedding
358
+ if class_embed_type is None and num_class_embeds is not None:
359
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
360
+ elif class_embed_type == "timestep":
361
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
362
+ elif class_embed_type == "identity":
363
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
364
+ elif class_embed_type == "projection":
365
+ if projection_class_embeddings_input_dim is None:
366
+ raise ValueError(
367
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
368
+ )
369
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
370
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
371
+ # 2. it projects from an arbitrary input dimension.
372
+ #
373
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
374
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
375
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
376
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
377
+ elif class_embed_type == "simple_projection":
378
+ if projection_class_embeddings_input_dim is None:
379
+ raise ValueError(
380
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
381
+ )
382
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
383
+ else:
384
+ self.class_embedding = None
385
+
386
+ if addition_embed_type == "text":
387
+ if encoder_hid_dim is not None:
388
+ text_time_embedding_from_dim = encoder_hid_dim
389
+ else:
390
+ text_time_embedding_from_dim = cross_attention_dim
391
+
392
+ self.add_embedding = TextTimeEmbedding(
393
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
394
+ )
395
+ elif addition_embed_type == "text_image":
396
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
397
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
398
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
399
+ self.add_embedding = TextImageTimeEmbedding(
400
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
401
+ )
402
+ elif addition_embed_type == "text_time":
403
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
404
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
405
+ elif addition_embed_type == "image":
406
+ # Kandinsky 2.2
407
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
408
+ elif addition_embed_type == "image_hint":
409
+ # Kandinsky 2.2 ControlNet
410
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
411
+ elif addition_embed_type is not None:
412
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
413
+
414
+ if time_embedding_act_fn is None:
415
+ self.time_embed_act = None
416
+ else:
417
+ self.time_embed_act = get_activation(time_embedding_act_fn)
418
+
419
+ self.down_blocks = nn.ModuleList([])
420
+ self.up_blocks = nn.ModuleList([])
421
+
422
+ if isinstance(only_cross_attention, bool):
423
+ if mid_block_only_cross_attention is None:
424
+ mid_block_only_cross_attention = only_cross_attention
425
+
426
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
427
+
428
+ if mid_block_only_cross_attention is None:
429
+ mid_block_only_cross_attention = False
430
+
431
+ if isinstance(num_attention_heads, int):
432
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
433
+
434
+ if isinstance(attention_head_dim, int):
435
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
436
+
437
+ if isinstance(cross_attention_dim, int):
438
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
439
+
440
+ if isinstance(layers_per_block, int):
441
+ layers_per_block = [layers_per_block] * len(down_block_types)
442
+
443
+ if isinstance(transformer_layers_per_block, int):
444
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
445
+
446
+ if class_embeddings_concat:
447
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
448
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
449
+ # regular time embeddings
450
+ blocks_time_embed_dim = time_embed_dim * 2
451
+ else:
452
+ blocks_time_embed_dim = time_embed_dim
453
+
454
+ # down
455
+ output_channel = block_out_channels[0]
456
+ for i, down_block_type in enumerate(down_block_types):
457
+ input_channel = output_channel
458
+ output_channel = block_out_channels[i]
459
+ is_final_block = i == len(block_out_channels) - 1
460
+
461
+ down_block = get_down_block(
462
+ down_block_type,
463
+ num_layers=layers_per_block[i],
464
+ transformer_layers_per_block=transformer_layers_per_block[i],
465
+ in_channels=input_channel,
466
+ out_channels=output_channel,
467
+ temb_channels=blocks_time_embed_dim,
468
+ add_downsample=not is_final_block,
469
+ resnet_eps=norm_eps,
470
+ resnet_act_fn=act_fn,
471
+ resnet_groups=norm_num_groups,
472
+ cross_attention_dim=cross_attention_dim[i],
473
+ num_attention_heads=num_attention_heads[i],
474
+ downsample_padding=downsample_padding,
475
+ dual_cross_attention=dual_cross_attention,
476
+ use_linear_projection=use_linear_projection,
477
+ only_cross_attention=only_cross_attention[i],
478
+ upcast_attention=upcast_attention,
479
+ resnet_time_scale_shift=resnet_time_scale_shift,
480
+ resnet_skip_time_act=resnet_skip_time_act,
481
+ resnet_out_scale_factor=resnet_out_scale_factor,
482
+ cross_attention_norm=cross_attention_norm,
483
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
484
+ num_views=num_views
485
+ )
486
+ self.down_blocks.append(down_block)
487
+
488
+ # mid
489
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
490
+ self.mid_block = UNetMidBlock2DCrossAttn(
491
+ transformer_layers_per_block=transformer_layers_per_block[-1],
492
+ in_channels=block_out_channels[-1],
493
+ temb_channels=blocks_time_embed_dim,
494
+ resnet_eps=norm_eps,
495
+ resnet_act_fn=act_fn,
496
+ output_scale_factor=mid_block_scale_factor,
497
+ resnet_time_scale_shift=resnet_time_scale_shift,
498
+ cross_attention_dim=cross_attention_dim[-1],
499
+ num_attention_heads=num_attention_heads[-1],
500
+ resnet_groups=norm_num_groups,
501
+ dual_cross_attention=dual_cross_attention,
502
+ use_linear_projection=use_linear_projection,
503
+ upcast_attention=upcast_attention,
504
+ )
505
+ # custom MV2D attention block
506
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
507
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
508
+ transformer_layers_per_block=transformer_layers_per_block[-1],
509
+ in_channels=block_out_channels[-1],
510
+ temb_channels=blocks_time_embed_dim,
511
+ resnet_eps=norm_eps,
512
+ resnet_act_fn=act_fn,
513
+ output_scale_factor=mid_block_scale_factor,
514
+ resnet_time_scale_shift=resnet_time_scale_shift,
515
+ cross_attention_dim=cross_attention_dim[-1],
516
+ num_attention_heads=num_attention_heads[-1],
517
+ resnet_groups=norm_num_groups,
518
+ dual_cross_attention=dual_cross_attention,
519
+ use_linear_projection=use_linear_projection,
520
+ upcast_attention=upcast_attention,
521
+ num_views=num_views,
522
+ joint_attention=joint_attention,
523
+ joint_attention_twice=joint_attention_twice,
524
+ multiview_attention=multiview_attention,
525
+ cross_domain_attention=cross_domain_attention
526
+ )
527
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
528
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
529
+ in_channels=block_out_channels[-1],
530
+ temb_channels=blocks_time_embed_dim,
531
+ resnet_eps=norm_eps,
532
+ resnet_act_fn=act_fn,
533
+ output_scale_factor=mid_block_scale_factor,
534
+ cross_attention_dim=cross_attention_dim[-1],
535
+ attention_head_dim=attention_head_dim[-1],
536
+ resnet_groups=norm_num_groups,
537
+ resnet_time_scale_shift=resnet_time_scale_shift,
538
+ skip_time_act=resnet_skip_time_act,
539
+ only_cross_attention=mid_block_only_cross_attention,
540
+ cross_attention_norm=cross_attention_norm,
541
+ )
542
+ elif mid_block_type is None:
543
+ self.mid_block = None
544
+ else:
545
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
546
+
547
+ # count how many layers upsample the images
548
+ self.num_upsamplers = 0
549
+
550
+ # up
551
+ reversed_block_out_channels = list(reversed(block_out_channels))
552
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
553
+ reversed_layers_per_block = list(reversed(layers_per_block))
554
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
555
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
556
+ only_cross_attention = list(reversed(only_cross_attention))
557
+
558
+ output_channel = reversed_block_out_channels[0]
559
+ for i, up_block_type in enumerate(up_block_types):
560
+ is_final_block = i == len(block_out_channels) - 1
561
+
562
+ prev_output_channel = output_channel
563
+ output_channel = reversed_block_out_channels[i]
564
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
565
+
566
+ # add upsample block for all BUT final layer
567
+ if not is_final_block:
568
+ add_upsample = True
569
+ self.num_upsamplers += 1
570
+ else:
571
+ add_upsample = False
572
+
573
+ up_block = get_up_block(
574
+ up_block_type,
575
+ num_layers=reversed_layers_per_block[i] + 1,
576
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
577
+ in_channels=input_channel,
578
+ out_channels=output_channel,
579
+ prev_output_channel=prev_output_channel,
580
+ temb_channels=blocks_time_embed_dim,
581
+ add_upsample=add_upsample,
582
+ resnet_eps=norm_eps,
583
+ resnet_act_fn=act_fn,
584
+ resnet_groups=norm_num_groups,
585
+ cross_attention_dim=reversed_cross_attention_dim[i],
586
+ num_attention_heads=reversed_num_attention_heads[i],
587
+ dual_cross_attention=dual_cross_attention,
588
+ use_linear_projection=use_linear_projection,
589
+ only_cross_attention=only_cross_attention[i],
590
+ upcast_attention=upcast_attention,
591
+ resnet_time_scale_shift=resnet_time_scale_shift,
592
+ resnet_skip_time_act=resnet_skip_time_act,
593
+ resnet_out_scale_factor=resnet_out_scale_factor,
594
+ cross_attention_norm=cross_attention_norm,
595
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
596
+ num_views=num_views
597
+ )
598
+ self.up_blocks.append(up_block)
599
+ prev_output_channel = output_channel
600
+
601
+ # out
602
+ if norm_num_groups is not None:
603
+ self.conv_norm_out = nn.GroupNorm(
604
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
605
+ )
606
+
607
+ self.conv_act = get_activation(act_fn)
608
+
609
+ else:
610
+ self.conv_norm_out = None
611
+ self.conv_act = None
612
+
613
+ conv_out_padding = (conv_out_kernel - 1) // 2
614
+ self.conv_out = nn.Conv2d(
615
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
616
+ )
617
+
618
+ @property
619
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
620
+ r"""
621
+ Returns:
622
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
623
+ indexed by its weight name.
624
+ """
625
+ # set recursively
626
+ processors = {}
627
+
628
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
629
+ if hasattr(module, "set_processor"):
630
+ processors[f"{name}.processor"] = module.processor
631
+
632
+ for sub_name, child in module.named_children():
633
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
634
+
635
+ return processors
636
+
637
+ for name, module in self.named_children():
638
+ fn_recursive_add_processors(name, module, processors)
639
+
640
+ return processors
641
+
642
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
643
+ r"""
644
+ Sets the attention processor to use to compute attention.
645
+
646
+ Parameters:
647
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
648
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
649
+ for **all** `Attention` layers.
650
+
651
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
652
+ processor. This is strongly recommended when setting trainable attention processors.
653
+
654
+ """
655
+ count = len(self.attn_processors.keys())
656
+
657
+ if isinstance(processor, dict) and len(processor) != count:
658
+ raise ValueError(
659
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
660
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
661
+ )
662
+
663
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
664
+ if hasattr(module, "set_processor"):
665
+ if not isinstance(processor, dict):
666
+ module.set_processor(processor)
667
+ else:
668
+ module.set_processor(processor.pop(f"{name}.processor"))
669
+
670
+ for sub_name, child in module.named_children():
671
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
672
+
673
+ for name, module in self.named_children():
674
+ fn_recursive_attn_processor(name, module, processor)
675
+
676
+ def set_default_attn_processor(self):
677
+ """
678
+ Disables custom attention processors and sets the default attention implementation.
679
+ """
680
+ self.set_attn_processor(AttnProcessor())
681
+
682
+ def set_attention_slice(self, slice_size):
683
+ r"""
684
+ Enable sliced attention computation.
685
+
686
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
687
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
688
+
689
+ Args:
690
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
691
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
692
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
693
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
694
+ must be a multiple of `slice_size`.
695
+ """
696
+ sliceable_head_dims = []
697
+
698
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
699
+ if hasattr(module, "set_attention_slice"):
700
+ sliceable_head_dims.append(module.sliceable_head_dim)
701
+
702
+ for child in module.children():
703
+ fn_recursive_retrieve_sliceable_dims(child)
704
+
705
+ # retrieve number of attention layers
706
+ for module in self.children():
707
+ fn_recursive_retrieve_sliceable_dims(module)
708
+
709
+ num_sliceable_layers = len(sliceable_head_dims)
710
+
711
+ if slice_size == "auto":
712
+ # half the attention head size is usually a good trade-off between
713
+ # speed and memory
714
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
715
+ elif slice_size == "max":
716
+ # make smallest slice possible
717
+ slice_size = num_sliceable_layers * [1]
718
+
719
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
720
+
721
+ if len(slice_size) != len(sliceable_head_dims):
722
+ raise ValueError(
723
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
724
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
725
+ )
726
+
727
+ for i in range(len(slice_size)):
728
+ size = slice_size[i]
729
+ dim = sliceable_head_dims[i]
730
+ if size is not None and size > dim:
731
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
732
+
733
+ # Recursively walk through all the children.
734
+ # Any children which exposes the set_attention_slice method
735
+ # gets the message
736
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
737
+ if hasattr(module, "set_attention_slice"):
738
+ module.set_attention_slice(slice_size.pop())
739
+
740
+ for child in module.children():
741
+ fn_recursive_set_attention_slice(child, slice_size)
742
+
743
+ reversed_slice_size = list(reversed(slice_size))
744
+ for module in self.children():
745
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
746
+
747
+ def _set_gradient_checkpointing(self, module, value=False):
748
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
749
+ module.gradient_checkpointing = value
750
+
751
+ def forward(
752
+ self,
753
+ sample: torch.FloatTensor,
754
+ timestep: Union[torch.Tensor, float, int],
755
+ encoder_hidden_states: torch.Tensor,
756
+ class_labels: Optional[torch.Tensor] = None,
757
+ timestep_cond: Optional[torch.Tensor] = None,
758
+ attention_mask: Optional[torch.Tensor] = None,
759
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
760
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
761
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
762
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
763
+ encoder_attention_mask: Optional[torch.Tensor] = None,
764
+ return_dict: bool = True,
765
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
766
+ r"""
767
+ The [`UNet2DConditionModel`] forward method.
768
+
769
+ Args:
770
+ sample (`torch.FloatTensor`):
771
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
772
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
773
+ encoder_hidden_states (`torch.FloatTensor`):
774
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
775
+ encoder_attention_mask (`torch.Tensor`):
776
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
777
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
778
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
779
+ return_dict (`bool`, *optional*, defaults to `True`):
780
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
781
+ tuple.
782
+ cross_attention_kwargs (`dict`, *optional*):
783
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
784
+ added_cond_kwargs: (`dict`, *optional*):
785
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
786
+ are passed along to the UNet blocks.
787
+
788
+ Returns:
789
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
790
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
791
+ a `tuple` is returned where the first element is the sample tensor.
792
+ """
793
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
794
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
795
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
796
+ # on the fly if necessary.
797
+ default_overall_up_factor = 2**self.num_upsamplers
798
+
799
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
800
+ forward_upsample_size = False
801
+ upsample_size = None
802
+
803
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
804
+ logger.info("Forward upsample size to force interpolation output size.")
805
+ forward_upsample_size = True
806
+
807
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
808
+ # expects mask of shape:
809
+ # [batch, key_tokens]
810
+ # adds singleton query_tokens dimension:
811
+ # [batch, 1, key_tokens]
812
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
813
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
814
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
815
+ if attention_mask is not None:
816
+ # assume that mask is expressed as:
817
+ # (1 = keep, 0 = discard)
818
+ # convert mask into a bias that can be added to attention scores:
819
+ # (keep = +0, discard = -10000.0)
820
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
821
+ attention_mask = attention_mask.unsqueeze(1)
822
+
823
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
824
+ if encoder_attention_mask is not None:
825
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
826
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
827
+
828
+ # 0. center input if necessary
829
+ if self.config.center_input_sample:
830
+ sample = 2 * sample - 1.0
831
+
832
+ # 1. time
833
+ timesteps = timestep
834
+ if not torch.is_tensor(timesteps):
835
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
836
+ # This would be a good case for the `match` statement (Python 3.10+)
837
+ is_mps = sample.device.type == "mps"
838
+ if isinstance(timestep, float):
839
+ dtype = torch.float32 if is_mps else torch.float64
840
+ else:
841
+ dtype = torch.int32 if is_mps else torch.int64
842
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
843
+ elif len(timesteps.shape) == 0:
844
+ timesteps = timesteps[None].to(sample.device)
845
+
846
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
847
+ timesteps = timesteps.expand(sample.shape[0])
848
+
849
+ t_emb = self.time_proj(timesteps)
850
+
851
+ # `Timesteps` does not contain any weights and will always return f32 tensors
852
+ # but time_embedding might actually be running in fp16. so we need to cast here.
853
+ # there might be better ways to encapsulate this.
854
+ t_emb = t_emb.to(dtype=sample.dtype)
855
+
856
+ emb = self.time_embedding(t_emb, timestep_cond)
857
+ aug_emb = None
858
+
859
+ if self.class_embedding is not None:
860
+ if class_labels is None:
861
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
862
+
863
+ if self.config.class_embed_type == "timestep":
864
+ class_labels = self.time_proj(class_labels)
865
+
866
+ # `Timesteps` does not contain any weights and will always return f32 tensors
867
+ # there might be better ways to encapsulate this.
868
+ class_labels = class_labels.to(dtype=sample.dtype)
869
+
870
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
871
+
872
+ if self.config.class_embeddings_concat:
873
+ emb = torch.cat([emb, class_emb], dim=-1)
874
+ else:
875
+ emb = emb + class_emb
876
+
877
+ if self.config.addition_embed_type == "text":
878
+ aug_emb = self.add_embedding(encoder_hidden_states)
879
+ elif self.config.addition_embed_type == "text_image":
880
+ # Kandinsky 2.1 - style
881
+ if "image_embeds" not in added_cond_kwargs:
882
+ raise ValueError(
883
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
884
+ )
885
+
886
+ image_embs = added_cond_kwargs.get("image_embeds")
887
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
888
+ aug_emb = self.add_embedding(text_embs, image_embs)
889
+ elif self.config.addition_embed_type == "text_time":
890
+ # SDXL - style
891
+ if "text_embeds" not in added_cond_kwargs:
892
+ raise ValueError(
893
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
894
+ )
895
+ text_embeds = added_cond_kwargs.get("text_embeds")
896
+ if "time_ids" not in added_cond_kwargs:
897
+ raise ValueError(
898
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
899
+ )
900
+ time_ids = added_cond_kwargs.get("time_ids")
901
+ time_embeds = self.add_time_proj(time_ids.flatten())
902
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
903
+
904
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
905
+ add_embeds = add_embeds.to(emb.dtype)
906
+ aug_emb = self.add_embedding(add_embeds)
907
+ elif self.config.addition_embed_type == "image":
908
+ # Kandinsky 2.2 - style
909
+ if "image_embeds" not in added_cond_kwargs:
910
+ raise ValueError(
911
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
912
+ )
913
+ image_embs = added_cond_kwargs.get("image_embeds")
914
+ aug_emb = self.add_embedding(image_embs)
915
+ elif self.config.addition_embed_type == "image_hint":
916
+ # Kandinsky 2.2 - style
917
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
918
+ raise ValueError(
919
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
920
+ )
921
+ image_embs = added_cond_kwargs.get("image_embeds")
922
+ hint = added_cond_kwargs.get("hint")
923
+ aug_emb, hint = self.add_embedding(image_embs, hint)
924
+ sample = torch.cat([sample, hint], dim=1)
925
+
926
+ emb = emb + aug_emb if aug_emb is not None else emb
927
+
928
+ if self.time_embed_act is not None:
929
+ emb = self.time_embed_act(emb)
930
+
931
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
932
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
933
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
934
+ # Kadinsky 2.1 - style
935
+ if "image_embeds" not in added_cond_kwargs:
936
+ raise ValueError(
937
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
938
+ )
939
+
940
+ image_embeds = added_cond_kwargs.get("image_embeds")
941
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
942
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
943
+ # Kandinsky 2.2 - style
944
+ if "image_embeds" not in added_cond_kwargs:
945
+ raise ValueError(
946
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
947
+ )
948
+ image_embeds = added_cond_kwargs.get("image_embeds")
949
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
950
+ # 2. pre-process
951
+ sample = self.conv_in(sample)
952
+
953
+ # 3. down
954
+
955
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
956
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
957
+
958
+ down_block_res_samples = (sample,)
959
+ for downsample_block in self.down_blocks:
960
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
961
+ # For t2i-adapter CrossAttnDownBlock2D
962
+ additional_residuals = {}
963
+ if is_adapter and len(down_block_additional_residuals) > 0:
964
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
965
+
966
+ sample, res_samples = downsample_block(
967
+ hidden_states=sample,
968
+ temb=emb,
969
+ encoder_hidden_states=encoder_hidden_states,
970
+ attention_mask=attention_mask,
971
+ cross_attention_kwargs=cross_attention_kwargs,
972
+ encoder_attention_mask=encoder_attention_mask,
973
+ **additional_residuals,
974
+ )
975
+ else:
976
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
977
+
978
+ if is_adapter and len(down_block_additional_residuals) > 0:
979
+ sample += down_block_additional_residuals.pop(0)
980
+
981
+ down_block_res_samples += res_samples
982
+
983
+ if is_controlnet:
984
+ new_down_block_res_samples = ()
985
+
986
+ for down_block_res_sample, down_block_additional_residual in zip(
987
+ down_block_res_samples, down_block_additional_residuals
988
+ ):
989
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
990
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
991
+
992
+ down_block_res_samples = new_down_block_res_samples
993
+
994
+ # 4. mid
995
+ if self.mid_block is not None:
996
+ sample = self.mid_block(
997
+ sample,
998
+ emb,
999
+ encoder_hidden_states=encoder_hidden_states,
1000
+ attention_mask=attention_mask,
1001
+ cross_attention_kwargs=cross_attention_kwargs,
1002
+ encoder_attention_mask=encoder_attention_mask,
1003
+ )
1004
+
1005
+ if is_controlnet:
1006
+ sample = sample + mid_block_additional_residual
1007
+
1008
+ # 5. up
1009
+ for i, upsample_block in enumerate(self.up_blocks):
1010
+ is_final_block = i == len(self.up_blocks) - 1
1011
+
1012
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1013
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1014
+
1015
+ # if we have not reached the final block and need to forward the
1016
+ # upsample size, we do it here
1017
+ if not is_final_block and forward_upsample_size:
1018
+ upsample_size = down_block_res_samples[-1].shape[2:]
1019
+
1020
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1021
+ sample = upsample_block(
1022
+ hidden_states=sample,
1023
+ temb=emb,
1024
+ res_hidden_states_tuple=res_samples,
1025
+ encoder_hidden_states=encoder_hidden_states,
1026
+ cross_attention_kwargs=cross_attention_kwargs,
1027
+ upsample_size=upsample_size,
1028
+ attention_mask=attention_mask,
1029
+ encoder_attention_mask=encoder_attention_mask,
1030
+ )
1031
+ else:
1032
+ sample = upsample_block(
1033
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1034
+ )
1035
+
1036
+ # 6. post-process
1037
+ if self.conv_norm_out:
1038
+ sample = self.conv_norm_out(sample)
1039
+ sample = self.conv_act(sample)
1040
+ sample = self.conv_out(sample)
1041
+
1042
+ if not return_dict:
1043
+ return (sample,)
1044
+
1045
+ return UNetMV2DConditionOutput(sample=sample)
1046
+
1047
+ @classmethod
1048
+ def from_pretrained_2d(
1049
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1050
+ camera_embedding_type: str, num_views: int, sample_size: int,
1051
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1052
+ projection_class_embeddings_input_dim: int=6, joint_attention: bool = False,
1053
+ joint_attention_twice: bool = False, multiview_attention: bool = True,
1054
+ cross_domain_attention: bool = False,
1055
+ in_channels: int = 8, out_channels: int = 4,
1056
+ **kwargs
1057
+ ):
1058
+ r"""
1059
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1060
+
1061
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1062
+ train the model, set it back in training mode with `model.train()`.
1063
+
1064
+ Parameters:
1065
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1066
+ Can be either:
1067
+
1068
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1069
+ the Hub.
1070
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1071
+ with [`~ModelMixin.save_pretrained`].
1072
+
1073
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1074
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1075
+ is not used.
1076
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1077
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1078
+ dtype is automatically derived from the model's weights.
1079
+ force_download (`bool`, *optional*, defaults to `False`):
1080
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1081
+ cached versions if they exist.
1082
+ resume_download (`bool`, *optional*, defaults to `False`):
1083
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1084
+ incompletely downloaded files are deleted.
1085
+ proxies (`Dict[str, str]`, *optional*):
1086
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1087
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1088
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1089
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1090
+ local_files_only(`bool`, *optional*, defaults to `False`):
1091
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1092
+ won't be downloaded from the Hub.
1093
+ use_auth_token (`str` or *bool*, *optional*):
1094
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1095
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1096
+ revision (`str`, *optional*, defaults to `"main"`):
1097
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1098
+ allowed by Git.
1099
+ from_flax (`bool`, *optional*, defaults to `False`):
1100
+ Load the model weights from a Flax checkpoint save file.
1101
+ subfolder (`str`, *optional*, defaults to `""`):
1102
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1103
+ mirror (`str`, *optional*):
1104
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1105
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1106
+ information.
1107
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1108
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1109
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1110
+ same device.
1111
+
1112
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1113
+ more information about each option see [designing a device
1114
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1115
+ max_memory (`Dict`, *optional*):
1116
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1117
+ each GPU and the available CPU RAM if unset.
1118
+ offload_folder (`str` or `os.PathLike`, *optional*):
1119
+ The path to offload weights if `device_map` contains the value `"disk"`.
1120
+ offload_state_dict (`bool`, *optional*):
1121
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1122
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1123
+ when there is some disk offload.
1124
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1125
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1126
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1127
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1128
+ argument to `True` will raise an error.
1129
+ variant (`str`, *optional*):
1130
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1131
+ loading `from_flax`.
1132
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1133
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1134
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1135
+ weights. If set to `False`, `safetensors` weights are not loaded.
1136
+
1137
+ <Tip>
1138
+
1139
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1140
+ `huggingface-cli login`. You can also activate the special
1141
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1142
+ firewalled environment.
1143
+
1144
+ </Tip>
1145
+
1146
+ Example:
1147
+
1148
+ ```py
1149
+ from diffusers import UNet2DConditionModel
1150
+
1151
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1152
+ ```
1153
+
1154
+ If you get the error message below, you need to finetune the weights for your downstream task:
1155
+
1156
+ ```bash
1157
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1158
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1159
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1160
+ ```
1161
+ """
1162
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1163
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1164
+ force_download = kwargs.pop("force_download", False)
1165
+ from_flax = kwargs.pop("from_flax", False)
1166
+ resume_download = kwargs.pop("resume_download", False)
1167
+ proxies = kwargs.pop("proxies", None)
1168
+ output_loading_info = kwargs.pop("output_loading_info", False)
1169
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1170
+ use_auth_token = kwargs.pop("use_auth_token", None)
1171
+ revision = kwargs.pop("revision", None)
1172
+ torch_dtype = kwargs.pop("torch_dtype", None)
1173
+ subfolder = kwargs.pop("subfolder", None)
1174
+ device_map = kwargs.pop("device_map", None)
1175
+ max_memory = kwargs.pop("max_memory", None)
1176
+ offload_folder = kwargs.pop("offload_folder", None)
1177
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1178
+ variant = kwargs.pop("variant", None)
1179
+ use_safetensors = kwargs.pop("use_safetensors", None)
1180
+
1181
+ if use_safetensors and not is_safetensors_available():
1182
+ raise ValueError(
1183
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1184
+ )
1185
+
1186
+ allow_pickle = False
1187
+ if use_safetensors is None:
1188
+ use_safetensors = is_safetensors_available()
1189
+ allow_pickle = True
1190
+
1191
+ if device_map is not None and not is_accelerate_available():
1192
+ raise NotImplementedError(
1193
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1194
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1195
+ )
1196
+
1197
+ # Check if we can handle device_map and dispatching the weights
1198
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1199
+ raise NotImplementedError(
1200
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1201
+ " `device_map=None`."
1202
+ )
1203
+
1204
+ # Load config if we don't provide a configuration
1205
+ config_path = pretrained_model_name_or_path
1206
+
1207
+ user_agent = {
1208
+ "diffusers": __version__,
1209
+ "file_type": "model",
1210
+ "framework": "pytorch",
1211
+ }
1212
+
1213
+ # load config
1214
+ config, unused_kwargs, commit_hash = cls.load_config(
1215
+ config_path,
1216
+ cache_dir=cache_dir,
1217
+ return_unused_kwargs=True,
1218
+ return_commit_hash=True,
1219
+ force_download=force_download,
1220
+ resume_download=resume_download,
1221
+ proxies=proxies,
1222
+ local_files_only=local_files_only,
1223
+ use_auth_token=use_auth_token,
1224
+ revision=revision,
1225
+ subfolder=subfolder,
1226
+ device_map=device_map,
1227
+ max_memory=max_memory,
1228
+ offload_folder=offload_folder,
1229
+ offload_state_dict=offload_state_dict,
1230
+ user_agent=user_agent,
1231
+ **kwargs,
1232
+ )
1233
+
1234
+ # modify config
1235
+ config["_class_name"] = cls.__name__
1236
+ config['in_channels'] = in_channels
1237
+ config['out_channels'] = out_channels
1238
+ config['sample_size'] = sample_size # training resolution
1239
+ config['num_views'] = num_views
1240
+ config['joint_attention'] = joint_attention
1241
+ config['joint_attention_twice'] = joint_attention_twice
1242
+ config['multiview_attention'] = multiview_attention
1243
+ config['cross_domain_attention'] = cross_domain_attention
1244
+ config["down_block_types"] = [
1245
+ "CrossAttnDownBlockMV2D",
1246
+ "CrossAttnDownBlockMV2D",
1247
+ "CrossAttnDownBlockMV2D",
1248
+ "DownBlock2D"
1249
+ ]
1250
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1251
+ config["up_block_types"] = [
1252
+ "UpBlock2D",
1253
+ "CrossAttnUpBlockMV2D",
1254
+ "CrossAttnUpBlockMV2D",
1255
+ "CrossAttnUpBlockMV2D"
1256
+ ]
1257
+ config['class_embed_type'] = 'projection'
1258
+ if camera_embedding_type == 'e_de_da_sincos':
1259
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1260
+ else:
1261
+ raise NotImplementedError
1262
+
1263
+ # load model
1264
+ model_file = None
1265
+ if from_flax:
1266
+ raise NotImplementedError
1267
+ else:
1268
+ if use_safetensors:
1269
+ try:
1270
+ model_file = _get_model_file(
1271
+ pretrained_model_name_or_path,
1272
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1273
+ cache_dir=cache_dir,
1274
+ force_download=force_download,
1275
+ resume_download=resume_download,
1276
+ proxies=proxies,
1277
+ local_files_only=local_files_only,
1278
+ use_auth_token=use_auth_token,
1279
+ revision=revision,
1280
+ subfolder=subfolder,
1281
+ user_agent=user_agent,
1282
+ commit_hash=commit_hash,
1283
+ )
1284
+ except IOError as e:
1285
+ if not allow_pickle:
1286
+ raise e
1287
+ pass
1288
+ if model_file is None:
1289
+ model_file = _get_model_file(
1290
+ pretrained_model_name_or_path,
1291
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1292
+ cache_dir=cache_dir,
1293
+ force_download=force_download,
1294
+ resume_download=resume_download,
1295
+ proxies=proxies,
1296
+ local_files_only=local_files_only,
1297
+ use_auth_token=use_auth_token,
1298
+ revision=revision,
1299
+ subfolder=subfolder,
1300
+ user_agent=user_agent,
1301
+ commit_hash=commit_hash,
1302
+ )
1303
+
1304
+ model = cls.from_config(config, **unused_kwargs)
1305
+
1306
+ state_dict = load_state_dict(model_file, variant=variant)
1307
+ model._convert_deprecated_attention_blocks(state_dict)
1308
+
1309
+ conv_in_weight = state_dict['conv_in.weight']
1310
+ conv_out_weight = state_dict['conv_out.weight']
1311
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1312
+ model,
1313
+ state_dict,
1314
+ model_file,
1315
+ pretrained_model_name_or_path,
1316
+ ignore_mismatched_sizes=True,
1317
+ )
1318
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1319
+ # initialize from the original SD structure
1320
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1321
+
1322
+ # whether to place all zero to new layers?
1323
+ if zero_init_conv_in:
1324
+ model.conv_in.weight.data[:,4:] = 0.
1325
+
1326
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1327
+ # initialize from the original SD structure
1328
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1329
+ if out_channels == 8: # copy for the last 4 channels
1330
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1331
+
1332
+ if zero_init_camera_projection:
1333
+ for p in model.class_embedding.parameters():
1334
+ torch.nn.init.zeros_(p)
1335
+
1336
+ loading_info = {
1337
+ "missing_keys": missing_keys,
1338
+ "unexpected_keys": unexpected_keys,
1339
+ "mismatched_keys": mismatched_keys,
1340
+ "error_msgs": error_msgs,
1341
+ }
1342
+
1343
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1344
+ raise ValueError(
1345
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1346
+ )
1347
+ elif torch_dtype is not None:
1348
+ model = model.to(torch_dtype)
1349
+
1350
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1351
+
1352
+ # Set model in evaluation mode to deactivate DropOut modules by default
1353
+ model.eval()
1354
+ if output_loading_info:
1355
+ return model, loading_info
1356
+
1357
+ return model
1358
+
1359
+ @classmethod
1360
+ def _load_pretrained_model_2d(
1361
+ cls,
1362
+ model,
1363
+ state_dict,
1364
+ resolved_archive_file,
1365
+ pretrained_model_name_or_path,
1366
+ ignore_mismatched_sizes=False,
1367
+ ):
1368
+ # Retrieve missing & unexpected_keys
1369
+ model_state_dict = model.state_dict()
1370
+ loaded_keys = list(state_dict.keys())
1371
+
1372
+ expected_keys = list(model_state_dict.keys())
1373
+
1374
+ original_loaded_keys = loaded_keys
1375
+
1376
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1377
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1378
+
1379
+ # Make sure we are able to load base models as well as derived models (with heads)
1380
+ model_to_load = model
1381
+
1382
+ def _find_mismatched_keys(
1383
+ state_dict,
1384
+ model_state_dict,
1385
+ loaded_keys,
1386
+ ignore_mismatched_sizes,
1387
+ ):
1388
+ mismatched_keys = []
1389
+ if ignore_mismatched_sizes:
1390
+ for checkpoint_key in loaded_keys:
1391
+ model_key = checkpoint_key
1392
+
1393
+ if (
1394
+ model_key in model_state_dict
1395
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1396
+ ):
1397
+ mismatched_keys.append(
1398
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1399
+ )
1400
+ del state_dict[checkpoint_key]
1401
+ return mismatched_keys
1402
+
1403
+ if state_dict is not None:
1404
+ # Whole checkpoint
1405
+ mismatched_keys = _find_mismatched_keys(
1406
+ state_dict,
1407
+ model_state_dict,
1408
+ original_loaded_keys,
1409
+ ignore_mismatched_sizes,
1410
+ )
1411
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1412
+
1413
+ if len(error_msgs) > 0:
1414
+ error_msg = "\n\t".join(error_msgs)
1415
+ if "size mismatch" in error_msg:
1416
+ error_msg += (
1417
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1418
+ )
1419
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1420
+
1421
+ if len(unexpected_keys) > 0:
1422
+ logger.warning(
1423
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1424
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1425
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1426
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1427
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1428
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1429
+ " identical (initializing a BertForSequenceClassification model from a"
1430
+ " BertForSequenceClassification model)."
1431
+ )
1432
+ else:
1433
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1434
+ if len(missing_keys) > 0:
1435
+ logger.warning(
1436
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1437
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1438
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1439
+ )
1440
+ elif len(mismatched_keys) == 0:
1441
+ logger.info(
1442
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1443
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1444
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1445
+ " without further training."
1446
+ )
1447
+ if len(mismatched_keys) > 0:
1448
+ mismatched_warning = "\n".join(
1449
+ [
1450
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1451
+ for key, shape1, shape2 in mismatched_keys
1452
+ ]
1453
+ )
1454
+ logger.warning(
1455
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1456
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1457
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1458
+ " able to use it for predictions and inference."
1459
+ )
1460
+
1461
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1462
+
mvdiffusion/pipelines/pipeline_mvdiffusion_image.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import warnings
17
+ from typing import Callable, List, Optional, Union
18
+
19
+ import PIL
20
+ import torch
21
+ import torchvision.transforms.functional as TF
22
+ from packaging import version
23
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
24
+
25
+ from diffusers.configuration_utils import FrozenDict
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.schedulers import KarrasDiffusionSchedulers
29
+ from diffusers.utils import deprecate, logging, randn_tensor
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
32
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class MVDiffusionImagePipeline(DiffusionPipeline):
39
+ r"""
40
+ Pipeline to generate image variations from an input image using Stable Diffusion.
41
+
42
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
43
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
44
+
45
+ Args:
46
+ vae ([`AutoencoderKL`]):
47
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
48
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
49
+ Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
50
+ text_encoder ([`~transformers.CLIPTextModel`]):
51
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
52
+ tokenizer ([`~transformers.CLIPTokenizer`]):
53
+ A `CLIPTokenizer` to tokenize text.
54
+ unet ([`UNet2DConditionModel`]):
55
+ A `UNet2DConditionModel` to denoise the encoded image latents.
56
+ scheduler ([`SchedulerMixin`]):
57
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
58
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
59
+ safety_checker ([`StableDiffusionSafetyChecker`]):
60
+ Classification module that estimates whether generated images could be considered offensive or harmful.
61
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
62
+ about a model's potential harms.
63
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
64
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
65
+ """
66
+ # TODO: feature_extractor is required to encode images (if they are in PIL format),
67
+ # we should give a descriptive message if the pipeline doesn't have one.
68
+ _optional_components = ["safety_checker"]
69
+
70
+ def __init__(
71
+ self,
72
+ vae: AutoencoderKL,
73
+ image_encoder: CLIPVisionModelWithProjection,
74
+ unet: UNet2DConditionModel,
75
+ scheduler: KarrasDiffusionSchedulers,
76
+ safety_checker: StableDiffusionSafetyChecker,
77
+ feature_extractor: CLIPImageProcessor,
78
+ requires_safety_checker: bool = True,
79
+ camera_embedding_type: str = 'e_de_da_sincos',
80
+ num_views: int = 4
81
+ ):
82
+ super().__init__()
83
+
84
+ if safety_checker is None and requires_safety_checker:
85
+ logger.warn(
86
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
87
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
88
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
89
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
90
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
91
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
92
+ )
93
+
94
+ if safety_checker is not None and feature_extractor is None:
95
+ raise ValueError(
96
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
97
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
98
+ )
99
+
100
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
101
+ version.parse(unet.config._diffusers_version).base_version
102
+ ) < version.parse("0.9.0.dev0")
103
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
104
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
105
+ deprecation_message = (
106
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
107
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
108
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
109
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
110
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
111
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
112
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
113
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
114
+ " the `unet/config.json` file"
115
+ )
116
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
117
+ new_config = dict(unet.config)
118
+ new_config["sample_size"] = 64
119
+ unet._internal_dict = FrozenDict(new_config)
120
+
121
+ self.register_modules(
122
+ vae=vae,
123
+ image_encoder=image_encoder,
124
+ unet=unet,
125
+ scheduler=scheduler,
126
+ safety_checker=safety_checker,
127
+ feature_extractor=feature_extractor,
128
+ )
129
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
130
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
131
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
132
+
133
+ self.camera_embedding_type: str = camera_embedding_type
134
+ self.num_views: int = num_views
135
+
136
+ def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance):
137
+ dtype = next(self.image_encoder.parameters()).dtype
138
+
139
+ image_pt = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
140
+ image_pt = image_pt.to(device=device, dtype=dtype)
141
+ image_embeddings = self.image_encoder(image_pt).image_embeds
142
+ image_embeddings = image_embeddings.unsqueeze(1)
143
+
144
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
145
+ # Note: repeat differently from official pipelines
146
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
147
+ bs_embed, seq_len, _ = image_embeddings.shape
148
+ image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1)
149
+
150
+ if do_classifier_free_guidance:
151
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
152
+
153
+ # For classifier free guidance, we need to do two forward passes.
154
+ # Here we concatenate the unconditional and text embeddings into a single batch
155
+ # to avoid doing two forward passes
156
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
157
+
158
+ image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device)
159
+ image_pt = image_pt * 2.0 - 1.0
160
+ image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
161
+ # Note: repeat differently from official pipelines
162
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
163
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
164
+
165
+ if do_classifier_free_guidance:
166
+ image_latents = torch.cat([torch.zeros_like(image_latents), image_latents])
167
+
168
+ return image_embeddings, image_latents
169
+
170
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
171
+ def run_safety_checker(self, image, device, dtype):
172
+ if self.safety_checker is None:
173
+ has_nsfw_concept = None
174
+ else:
175
+ if torch.is_tensor(image):
176
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
177
+ else:
178
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
179
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
180
+ image, has_nsfw_concept = self.safety_checker(
181
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
182
+ )
183
+ return image, has_nsfw_concept
184
+
185
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
186
+ def decode_latents(self, latents):
187
+ warnings.warn(
188
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
189
+ " use VaeImageProcessor instead",
190
+ FutureWarning,
191
+ )
192
+ latents = 1 / self.vae.config.scaling_factor * latents
193
+ image = self.vae.decode(latents, return_dict=False)[0]
194
+ image = (image / 2 + 0.5).clamp(0, 1)
195
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
196
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
197
+ return image
198
+
199
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
200
+ def prepare_extra_step_kwargs(self, generator, eta):
201
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
202
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
203
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
204
+ # and should be between [0, 1]
205
+
206
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
207
+ extra_step_kwargs = {}
208
+ if accepts_eta:
209
+ extra_step_kwargs["eta"] = eta
210
+
211
+ # check if the scheduler accepts generator
212
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
213
+ if accepts_generator:
214
+ extra_step_kwargs["generator"] = generator
215
+ return extra_step_kwargs
216
+
217
+ def check_inputs(self, image, height, width, callback_steps):
218
+ if (
219
+ not isinstance(image, torch.Tensor)
220
+ and not isinstance(image, PIL.Image.Image)
221
+ and not isinstance(image, list)
222
+ ):
223
+ raise ValueError(
224
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
225
+ f" {type(image)}"
226
+ )
227
+
228
+ if height % 8 != 0 or width % 8 != 0:
229
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
230
+
231
+ if (callback_steps is None) or (
232
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
233
+ ):
234
+ raise ValueError(
235
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
236
+ f" {type(callback_steps)}."
237
+ )
238
+
239
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
240
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
241
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
242
+ if isinstance(generator, list) and len(generator) != batch_size:
243
+ raise ValueError(
244
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
245
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
246
+ )
247
+
248
+ if latents is None:
249
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
250
+ else:
251
+ latents = latents.to(device)
252
+
253
+ # scale the initial noise by the standard deviation required by the scheduler
254
+ latents = latents * self.scheduler.init_noise_sigma
255
+ return latents
256
+
257
+ def prepare_camera_embedding(self, camera_embedding: Union[float, torch.Tensor], do_classifier_free_guidance, num_images_per_prompt=1):
258
+ # (B, 3)
259
+ camera_embedding = camera_embedding.to(dtype=self.unet.dtype, device=self.unet.device)
260
+
261
+ if self.camera_embedding_type == 'e_de_da_sincos':
262
+ # (B, 6)
263
+ camera_embedding = torch.cat([
264
+ torch.sin(camera_embedding),
265
+ torch.cos(camera_embedding)
266
+ ], dim=-1)
267
+ assert self.unet.config.class_embed_type == 'projection'
268
+ assert self.unet.config.projection_class_embeddings_input_dim == 6 or self.unet.config.projection_class_embeddings_input_dim == 10
269
+ else:
270
+ raise NotImplementedError
271
+
272
+ # Note: repeat differently from official pipelines
273
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
274
+ camera_embedding = camera_embedding.repeat(num_images_per_prompt, 1)
275
+
276
+ if do_classifier_free_guidance:
277
+ camera_embedding = torch.cat([
278
+ camera_embedding,
279
+ camera_embedding
280
+ ], dim=0)
281
+
282
+ return camera_embedding
283
+
284
+ @torch.no_grad()
285
+ def __call__(
286
+ self,
287
+ image: Union[List[PIL.Image.Image], torch.FloatTensor],
288
+ # elevation_cond: torch.FloatTensor,
289
+ # elevation: torch.FloatTensor,
290
+ # azimuth: torch.FloatTensor,
291
+ camera_embedding: torch.FloatTensor,
292
+ height: Optional[int] = None,
293
+ width: Optional[int] = None,
294
+ num_inference_steps: int = 50,
295
+ guidance_scale: float = 7.5,
296
+ num_images_per_prompt: Optional[int] = 1,
297
+ eta: float = 0.0,
298
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
299
+ latents: Optional[torch.FloatTensor] = None,
300
+ output_type: Optional[str] = "pil",
301
+ return_dict: bool = True,
302
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
303
+ callback_steps: int = 1,
304
+ normal_cond: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
305
+ ):
306
+ r"""
307
+ The call function to the pipeline for generation.
308
+
309
+ Args:
310
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
311
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
312
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
313
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
314
+ The height in pixels of the generated image.
315
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
316
+ The width in pixels of the generated image.
317
+ num_inference_steps (`int`, *optional*, defaults to 50):
318
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
319
+ expense of slower inference. This parameter is modulated by `strength`.
320
+ guidance_scale (`float`, *optional*, defaults to 7.5):
321
+ A higher guidance scale value encourages the model to generate images closely linked to the text
322
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
323
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
324
+ The number of images to generate per prompt.
325
+ eta (`float`, *optional*, defaults to 0.0):
326
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
327
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
328
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
329
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
330
+ generation deterministic.
331
+ latents (`torch.FloatTensor`, *optional*):
332
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
333
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
334
+ tensor is generated by sampling using the supplied random `generator`.
335
+ output_type (`str`, *optional*, defaults to `"pil"`):
336
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
337
+ return_dict (`bool`, *optional*, defaults to `True`):
338
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
339
+ plain tuple.
340
+ callback (`Callable`, *optional*):
341
+ A function that calls every `callback_steps` steps during inference. The function is called with the
342
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
343
+ callback_steps (`int`, *optional*, defaults to 1):
344
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
345
+ every step.
346
+
347
+ Returns:
348
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
349
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
350
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
351
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
352
+ "not-safe-for-work" (nsfw) content.
353
+
354
+ Examples:
355
+
356
+ ```py
357
+ from diffusers import StableDiffusionImageVariationPipeline
358
+ from PIL import Image
359
+ from io import BytesIO
360
+ import requests
361
+
362
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
363
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
364
+ )
365
+ pipe = pipe.to("cuda")
366
+
367
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
368
+
369
+ response = requests.get(url)
370
+ image = Image.open(BytesIO(response.content)).convert("RGB")
371
+
372
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
373
+ out["images"][0].save("result.jpg")
374
+ ```
375
+ """
376
+ # 0. Default height and width to unet
377
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
378
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
379
+
380
+ # 1. Check inputs. Raise error if not correct
381
+ self.check_inputs(image, height, width, callback_steps)
382
+
383
+
384
+ # 2. Define call parameters
385
+ if isinstance(image, list):
386
+ batch_size = len(image)
387
+ else:
388
+ batch_size = image.shape[0]
389
+ assert batch_size >= self.num_views and batch_size % self.num_views == 0
390
+ device = self._execution_device
391
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
392
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
393
+ # corresponds to doing no classifier free guidance.
394
+ do_classifier_free_guidance = guidance_scale != 1.0
395
+
396
+ # 3. Encode input image
397
+ if isinstance(image, list):
398
+ image_pil = image
399
+ elif isinstance(image, torch.Tensor):
400
+ image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
401
+ image_embeddings, image_latents = self._encode_image(image_pil, device, num_images_per_prompt, do_classifier_free_guidance)
402
+
403
+ if normal_cond is not None:
404
+ if isinstance(normal_cond, list):
405
+ normal_cond_pil = normal_cond
406
+ elif isinstance(normal_cond, torch.Tensor):
407
+ normal_cond_pil = [TF.to_pil_image(normal_cond[i]) for i in range(normal_cond.shape[0])]
408
+ _, image_latents = self._encode_image(normal_cond_pil, device, num_images_per_prompt, do_classifier_free_guidance)
409
+
410
+
411
+ # assert len(elevation_cond) == batch_size and len(elevation) == batch_size and len(azimuth) == batch_size
412
+ # camera_embeddings = self.prepare_camera_condition(elevation_cond, elevation, azimuth, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt)
413
+ assert len(camera_embedding) == batch_size
414
+ camera_embeddings = self.prepare_camera_embedding(camera_embedding, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt)
415
+
416
+ # 4. Prepare timesteps
417
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
418
+ timesteps = self.scheduler.timesteps
419
+
420
+ # 5. Prepare latent variables
421
+ num_channels_latents = self.unet.config.out_channels
422
+ latents = self.prepare_latents(
423
+ batch_size * num_images_per_prompt,
424
+ num_channels_latents,
425
+ height,
426
+ width,
427
+ image_embeddings.dtype,
428
+ device,
429
+ generator,
430
+ latents,
431
+ )
432
+
433
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
434
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
435
+
436
+ # 7. Denoising loop
437
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
438
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
439
+ for i, t in enumerate(timesteps):
440
+ # expand the latents if we are doing classifier free guidance
441
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
442
+ latent_model_input = torch.cat([
443
+ latent_model_input, image_latents
444
+ ], dim=1)
445
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
446
+
447
+ # predict the noise residual
448
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, class_labels=camera_embeddings).sample
449
+
450
+ # perform guidance
451
+ if do_classifier_free_guidance:
452
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
453
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
454
+
455
+ # compute the previous noisy sample x_t -> x_t-1
456
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
457
+
458
+ # call the callback, if provided
459
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
460
+ progress_bar.update()
461
+ if callback is not None and i % callback_steps == 0:
462
+ callback(i, t, latents)
463
+
464
+ if not output_type == "latent":
465
+ if num_channels_latents == 8:
466
+ latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0)
467
+
468
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
469
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
470
+ else:
471
+ image = latents
472
+ has_nsfw_concept = None
473
+
474
+ if has_nsfw_concept is None:
475
+ do_denormalize = [True] * image.shape[0]
476
+ else:
477
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
478
+
479
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
480
+
481
+ if not return_dict:
482
+ return (image, has_nsfw_concept)
483
+
484
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
485
+
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu113
2
+ torch==1.12.1
3
+ torchvision==0.13.1
4
+ diffusers[torch]==0.11.1
5
+ transformers>=4.25.1
6
+ bitsandbytes==0.35.4
7
+ decord==0.6.0
8
+ pytorch-lightning<2
9
+ omegaconf==2.2.3
10
+ nerfacc==0.3.3
11
+ trimesh==3.9.8
12
+ pyhocon==0.3.57
13
+ icecream==2.1.0
14
+ PyMCubes==0.1.2
15
+ xformers
16
+ accelerate
17
+ modelcards
18
+ einops
19
+ ftfy
20
+ piq
21
+ matplotlib
22
+ opencv-python
23
+ imageio
24
+ imageio-ffmpeg
25
+ scipy
26
+ pyransac3d
27
+ torch_efficient_distloss
28
+ tensorboard
29
+ rembg
30
+ segment_anything
run_test.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ accelerate launch --config_file 1gpu.yaml test_mvdiffusion_seq.py --config configs/mvdiffusion-joint-ortho-6views.yaml
utils/misc.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from omegaconf import OmegaConf
3
+ from packaging import version
4
+
5
+
6
+ # ============ Register OmegaConf Recolvers ============= #
7
+ OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n))
8
+ OmegaConf.register_new_resolver('add', lambda a, b: a + b)
9
+ OmegaConf.register_new_resolver('sub', lambda a, b: a - b)
10
+ OmegaConf.register_new_resolver('mul', lambda a, b: a * b)
11
+ OmegaConf.register_new_resolver('div', lambda a, b: a / b)
12
+ OmegaConf.register_new_resolver('idiv', lambda a, b: a // b)
13
+ OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p))
14
+ # ======================================================= #
15
+
16
+
17
+ def prompt(question):
18
+ inp = input(f"{question} (y/n)").lower().strip()
19
+ if inp and inp == 'y':
20
+ return True
21
+ if inp and inp == 'n':
22
+ return False
23
+ return prompt(question)
24
+
25
+
26
+ def load_config(*yaml_files, cli_args=[]):
27
+ yaml_confs = [OmegaConf.load(f) for f in yaml_files]
28
+ cli_conf = OmegaConf.from_cli(cli_args)
29
+ conf = OmegaConf.merge(*yaml_confs, cli_conf)
30
+ OmegaConf.resolve(conf)
31
+ return conf
32
+
33
+
34
+ def config_to_primitive(config, resolve=True):
35
+ return OmegaConf.to_container(config, resolve=resolve)
36
+
37
+
38
+ def dump_config(path, config):
39
+ with open(path, 'w') as fp:
40
+ OmegaConf.save(config=config, f=fp)
41
+
42
+ def get_rank():
43
+ # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
44
+ # therefore LOCAL_RANK needs to be checked first
45
+ rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
46
+ for key in rank_keys:
47
+ rank = os.environ.get(key)
48
+ if rank is not None:
49
+ return int(rank)
50
+ return 0
51
+
52
+
53
+ def parse_version(ver):
54
+ return version.parse(ver)