Yuxiao319 commited on
Commit
4bbe787
1 Parent(s): fc21bba
mv_diffusion_30/data/depth_utils.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+
5
+ def colorize_depth_maps(
6
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
7
+ ):
8
+ """
9
+ Colorize depth maps.
10
+ """
11
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
12
+
13
+ if isinstance(depth_map, torch.Tensor):
14
+ depth = depth_map.detach().squeeze().numpy()
15
+ elif isinstance(depth_map, np.ndarray):
16
+ depth = depth_map.copy().squeeze()
17
+ # reshape to [ (B,) H, W ]
18
+ if depth.ndim < 3:
19
+ depth = depth[np.newaxis, :, :]
20
+
21
+ # colorize
22
+ cm = matplotlib.colormaps[cmap]
23
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
24
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
25
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
26
+
27
+ if valid_mask is not None:
28
+ if isinstance(depth_map, torch.Tensor):
29
+ valid_mask = valid_mask.detach().numpy()
30
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
31
+ if valid_mask.ndim < 3:
32
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
33
+ else:
34
+ valid_mask = valid_mask[:, np.newaxis, :, :]
35
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
36
+ img_colored_np[~valid_mask] = 0
37
+
38
+ if isinstance(depth_map, torch.Tensor):
39
+ img_colored = torch.from_numpy(img_colored_np).float()
40
+ elif isinstance(depth_map, np.ndarray):
41
+ img_colored = img_colored_np
42
+
43
+ return img_colored
44
+
45
+
46
+ def scale_depth_to_model(depth, camera_type='ortho'):
47
+ """
48
+ Scale depth from the original range.
49
+ """
50
+ assert camera_type == 'ortho' or camera_type == 'persp'
51
+ w, h = depth.shape
52
+
53
+ if camera_type == 'ortho':
54
+ original_min = 9000
55
+ original_max = 17000
56
+ target_min = 2000
57
+ target_max = 62000
58
+
59
+ mask = depth != 0
60
+ # Scale depth to [0, 1]
61
+ depth_normalized = np.zeros([w, h])
62
+ depth_normalized[mask] = (depth[mask] - original_min) / (original_max - original_min)
63
+
64
+ # Scale depth to [2000, 60000]
65
+ scaled_depth = np.zeros([w, h])
66
+ scaled_depth[mask] = depth_normalized[mask] * (target_max - target_min) + target_min
67
+
68
+ else:
69
+ original_min = 4000
70
+ original_max = 13000
71
+ target_min = 2000
72
+ target_max = 62000
73
+
74
+ mask = depth != 0
75
+ # Scale depth to [0, 1]
76
+ depth_normalized = np.zeros([w, h])
77
+ depth_normalized[mask] = (depth[mask] - original_min) / (original_max - original_min)
78
+
79
+ # Scale depth to [2000, 60000]
80
+ scaled_depth = np.zeros([w, h])
81
+ scaled_depth[mask] = depth_normalized[mask] * (target_max - target_min) + target_min
82
+
83
+ scaled_depth[scaled_depth > 62000] = 0
84
+ scaled_depth = scaled_depth / 65535. # [0, 1]
85
+
86
+ return scaled_depth
87
+
88
+ def rescale_depth_to_world(scaled_depth, camera_type='ortho'):
89
+ """
90
+ Rescale depth from the scaled range back to the original range.
91
+ """
92
+ assert camera_type == 'ortho' or camera_type == 'persp'
93
+ scaled_depth = scaled_depth * 65535.
94
+ w, h = scaled_depth.shape
95
+
96
+ if camera_type == 'ortho':
97
+ original_min = 9000
98
+ original_max = 17000
99
+ target_min = 2000
100
+ target_max = 62000
101
+
102
+ mask = scaled_depth != 0
103
+ rescaled_depth_norm = np.zeros([w, h])
104
+ # Rescale depth to [0, 1]
105
+ rescaled_depth_norm[mask] = (scaled_depth[mask] - target_min) / (target_max - target_min)
106
+
107
+ # Rescale depth to [9000, 17000]
108
+ rescaled_depth = np.zeros([w, h])
109
+ rescaled_depth[mask] = rescaled_depth_norm[mask] * (original_max - original_min) + original_min
110
+
111
+ else:
112
+ original_min = 4000
113
+ original_max = 13000
114
+ target_min = 2000
115
+ target_max = 62000
116
+
117
+ mask = scaled_depth != 0
118
+ rescaled_depth_norm = np.zeros([w, h])
119
+ # Rescale depth to [0, 1]
120
+ rescaled_depth_norm[mask] = (scaled_depth[mask] - target_min) / (target_max - target_min)
121
+
122
+ # Rescale depth to [9000, 17000]
123
+ rescaled_depth = np.zeros([w, h])
124
+ rescaled_depth[mask] = rescaled_depth_norm[mask] * (original_max - original_min) + original_min
125
+
126
+ return rescaled_depth
mv_diffusion_30/data/fixed_poses/nine_views/000_back_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -5.266582965850830078e-01 7.410295009613037109e-01 -4.165407419204711914e-01 -5.960464477539062500e-08
2
+ 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 -9.462351613365171943e-08
3
+ 8.500770330429077148e-01 4.590988159179687500e-01 -2.580644786357879639e-01 -1.300000071525573730e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_back_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -9.734988808631896973e-01 1.993551850318908691e-01 -1.120596975088119507e-01 -1.713633537292480469e-07
2
+ 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 1.772203575001185527e-07
3
+ 2.286916375160217285e-01 8.486189246177673340e-01 -4.770178496837615967e-01 -1.838477611541748047e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_back_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 2.286914736032485962e-01 8.486190438270568848e-01 -4.770178198814392090e-01 1.564621925354003906e-07
2
+ -3.417914484771245043e-08 4.900034070014953613e-01 8.717205524444580078e-01 -7.293811421504869941e-08
3
+ 9.734990000724792480e-01 -1.993550658226013184e-01 1.120596155524253845e-01 -1.838477969169616699e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_front_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 5.266583561897277832e-01 -7.410295009613037109e-01 4.165407419204711914e-01 0.000000000000000000e+00
2
+ 5.865638996738198330e-08 4.900035560131072998e-01 8.717204332351684570e-01 9.462351613365171943e-08
3
+ -8.500770330429077148e-01 -4.590988159179687500e-01 2.580645382404327393e-01 -1.300000071525573730e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_front_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -2.286916971206665039e-01 -8.486189842224121094e-01 4.770179092884063721e-01 -2.458691596984863281e-07
2
+ 9.085837859856837895e-09 4.900034666061401367e-01 8.717205524444580078e-01 1.205695667749751010e-07
3
+ -9.734990000724792480e-01 1.993551701307296753e-01 -1.120597645640373230e-01 -1.838477969169616699e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_front_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 9.734989404678344727e-01 -1.993551850318908691e-01 1.120596975088119507e-01 -1.415610313415527344e-07
2
+ 3.790224578636980368e-09 4.900034964084625244e-01 8.717204928398132324e-01 -1.772203575001185527e-07
3
+ -2.286916375160217285e-01 -8.486189246177673340e-01 4.770178794860839844e-01 -1.838477611541748047e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_left_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ -8.500771522521972656e-01 -4.590989053249359131e-01 2.580644488334655762e-01 0.000000000000000000e+00
2
+ -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 9.006067358541258727e-08
3
+ -5.266583561897277832e-01 7.410295605659484863e-01 -4.165408313274383545e-01 -1.300000071525573730e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_right_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 8.500770330429077148e-01 4.590989053249359131e-01 -2.580644488334655762e-01 5.960464477539062500e-08
2
+ -4.257411134744870651e-08 4.900034964084625244e-01 8.717204928398132324e-01 -9.006067358541258727e-08
3
+ 5.266583561897277832e-01 -7.410295605659484863e-01 4.165407419204711914e-01 -1.300000071525573730e+00
mv_diffusion_30/data/fixed_poses/nine_views/000_top_RT.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 9.958608150482177734e-01 7.923202216625213623e-02 -4.453715682029724121e-02 -3.098167056236889039e-09
2
+ -9.089154005050659180e-02 8.681122064590454102e-01 -4.879753291606903076e-01 5.784738377201392723e-08
3
+ -2.028124157504862524e-08 4.900035560131072998e-01 8.717204332351684570e-01 -1.300000071525573730e+00
mv_diffusion_30/data/multiview_image_dataset.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ from glob import glob
20
+
21
+ import PIL.Image
22
+ from .normal_utils import trans_normal, normal2img, img2normal
23
+ import pdb
24
+
25
+
26
+ import cv2
27
+ import numpy as np
28
+
29
+ def add_margin(pil_img, color=0, size=256):
30
+ width, height = pil_img.size
31
+ result = Image.new(pil_img.mode, (size, size), color)
32
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
33
+ return result
34
+
35
+ def scale_and_place_object(image, scale_factor):
36
+ assert np.shape(image)[-1]==4 # RGBA
37
+
38
+ # Extract the alpha channel (transparency) and the object (RGB channels)
39
+ alpha_channel = image[:, :, 3]
40
+
41
+ # Find the bounding box coordinates of the object
42
+ coords = cv2.findNonZero(alpha_channel)
43
+ x, y, width, height = cv2.boundingRect(coords)
44
+
45
+ # Calculate the scale factor for resizing
46
+ original_height, original_width = image.shape[:2]
47
+
48
+ if width > height:
49
+ size = width
50
+ original_size = original_width
51
+ else:
52
+ size = height
53
+ original_size = original_height
54
+
55
+ scale_factor = min(scale_factor, size / (original_size+0.0))
56
+
57
+ new_size = scale_factor * original_size
58
+ scale_factor = new_size / size
59
+
60
+ # Calculate the new size based on the scale factor
61
+ new_width = int(width * scale_factor)
62
+ new_height = int(height * scale_factor)
63
+
64
+ center_x = original_width // 2
65
+ center_y = original_height // 2
66
+
67
+ paste_x = center_x - (new_width // 2)
68
+ paste_y = center_y - (new_height // 2)
69
+
70
+ # Resize the object (RGB channels) to the new size
71
+ rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height))
72
+
73
+ # Create a new RGBA image with the resized image
74
+ new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8)
75
+
76
+ new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object
77
+
78
+ return new_image
79
+
80
+ class InferenceImageDataset(Dataset):
81
+ def __init__(self,
82
+ root_dir: str,
83
+ num_views: int,
84
+ img_wh: Tuple[int, int],
85
+ bg_color: str,
86
+ crop_size: int = 224,
87
+ single_image: Optional[PIL.Image.Image] = None,
88
+ num_validation_samples: Optional[int] = None,
89
+ filepaths: Optional[list] = None,
90
+ cam_types: Optional[list] = None,
91
+ cond_type: Optional[str] = None,
92
+ load_cam_type: Optional[bool] = True
93
+ ) -> None:
94
+ """Create a dataset from a folder of images.
95
+ If you pass in a root directory it will be searched for images
96
+ ending in ext (ext can be a list)
97
+ """
98
+ self.root_dir = root_dir
99
+ self.num_views = num_views
100
+ self.img_wh = img_wh
101
+ self.crop_size = crop_size
102
+ self.bg_color = bg_color
103
+ self.cond_type = cond_type
104
+ self.load_cam_type = load_cam_type
105
+ self.cam_types = cam_types
106
+
107
+ if self.num_views == 4:
108
+ self.view_types = ['front', 'right', 'back', 'left']
109
+ elif self.num_views == 5:
110
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left']
111
+ elif self.num_views == 6:
112
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
113
+
114
+ self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views"
115
+
116
+ self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
117
+
118
+
119
+
120
+ if filepaths is None:
121
+ # Get a list of all files in the directory
122
+ file_list = os.listdir(self.root_dir)
123
+ self.cam_types = ['ortho'] * len(file_list) + ['persp']* len(file_list)
124
+ file_list = file_list * 2
125
+ else:
126
+ file_list = filepaths
127
+ print(filepaths, root_dir)
128
+ # Filter the files that end with .png or .jpg
129
+ self.file_list = [file for file in file_list]
130
+
131
+ self.bg_color = self.get_bg_color()
132
+
133
+
134
+
135
+
136
+ def __len__(self):
137
+ return len(self.file_list)
138
+
139
+ def load_fixed_poses(self):
140
+ poses = {}
141
+ for face in self.view_types:
142
+ RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
143
+ poses[face] = RT
144
+
145
+ return poses
146
+
147
+ def cartesian_to_spherical(self, xyz):
148
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
149
+ xy = xyz[:,0]**2 + xyz[:,1]**2
150
+ z = np.sqrt(xy + xyz[:,2]**2)
151
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
152
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
153
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
154
+ return np.array([theta, azimuth, z])
155
+
156
+ def get_T(self, target_RT, cond_RT):
157
+ R, T = target_RT[:3, :3], target_RT[:, -1]
158
+ T_target = -R.T @ T # change to cam2world
159
+
160
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
161
+ T_cond = -R.T @ T
162
+
163
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
164
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
165
+
166
+ d_theta = theta_target - theta_cond
167
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
168
+ d_z = z_target - z_cond
169
+
170
+ # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
171
+ return d_theta, d_azimuth
172
+
173
+ def get_bg_color(self):
174
+ if self.bg_color == 'white':
175
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
176
+ elif self.bg_color == 'black':
177
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
178
+ elif self.bg_color == 'gray':
179
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
180
+ elif self.bg_color == 'random':
181
+ bg_color = np.random.rand(3)
182
+ elif isinstance(self.bg_color, float):
183
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
184
+ else:
185
+ raise NotImplementedError
186
+ return bg_color
187
+
188
+
189
+ def load_image(self, img_path, bg_color, return_type='pt', Imagefile=None):
190
+ # pil always returns uint8
191
+ if Imagefile is None:
192
+ image_input = Image.open(img_path)
193
+ else:
194
+ image_input = Imagefile
195
+ image_size = self.img_wh[0]
196
+
197
+ # if self.crop_size!=-1:
198
+ # alpha_np = np.asarray(image_input)[:, :, 3]
199
+ # coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
200
+ # min_x, min_y = np.min(coords, 0)
201
+ # max_x, max_y = np.max(coords, 0)
202
+ # ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
203
+ # h, w = ref_img_.height, ref_img_.width
204
+ # scale = self.crop_size / max(h, w)
205
+ # h_, w_ = int(scale * h), int(scale * w)
206
+ # ref_img_ = ref_img_.resize((w_, h_))
207
+ # image_input = add_margin(ref_img_, size=image_size)
208
+ # else:
209
+ # image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
210
+ # image_input = image_input.resize((image_size, image_size))
211
+
212
+ # img = scale_and_place_object(img, self.scale_ratio)
213
+ img = np.array(image_input)
214
+ img = img.astype(np.float32) / 255. # [0, 1]
215
+ assert img.shape[-1] == 4 # RGBA
216
+
217
+ alpha = img[...,3:4]
218
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
219
+
220
+ if return_type == "np":
221
+ pass
222
+ elif return_type == "pt":
223
+ img = torch.from_numpy(img)
224
+ alpha = torch.from_numpy(alpha)
225
+ else:
226
+ raise NotImplementedError
227
+
228
+ return img, alpha
229
+
230
+
231
+ def __len__(self):
232
+ return len(self.file_list)
233
+
234
+ def __getitem__(self, index):
235
+
236
+ # image = self.all_images[index%len(self.all_images)]
237
+ # alpha = self.all_alphas[index%len(self.all_images)]
238
+ cam_type = self.cam_types[index%len(self.file_list)]
239
+ if self.file_list is not None:
240
+ filename = self.file_list[index%len(self.file_list)].replace(".png", "")
241
+ else:
242
+ filename = 'null'
243
+
244
+ cond_w2c = self.fix_cam_poses['front']
245
+
246
+ tgt_w2cs = [self.fix_cam_poses[view] for view in self.view_types]
247
+
248
+ elevations = []
249
+ azimuths = []
250
+
251
+ img_tensors_in = []
252
+ for view in self.view_types:
253
+ img_path = os.path.join(self.root_dir, filename, cam_type,"color_000_%s.png" % (view))
254
+ img_tensor, alpha = self.load_image(img_path, self.bg_color, return_type="pt")
255
+ img_tensor = img_tensor.permute(2, 0, 1)
256
+ img_tensors_in.append(img_tensor)
257
+
258
+ alpha_tensors_in = [
259
+ alpha.permute(2, 0, 1)
260
+ ] * self.num_views
261
+
262
+ for view, tgt_w2c in zip(self.view_types, tgt_w2cs):
263
+ # evelations, azimuths
264
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
265
+ elevations.append(elevation)
266
+ azimuths.append(azimuth)
267
+
268
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
269
+ # alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W)
270
+
271
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
272
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
273
+ elevations_cond = torch.as_tensor([0] * self.num_views).float()
274
+
275
+ normal_class = torch.tensor([1, 0]).float()
276
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
277
+ color_class = torch.tensor([0, 1]).float()
278
+ depth_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
279
+
280
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
281
+
282
+ if cam_type == 'ortho':
283
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
284
+ else:
285
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
286
+
287
+ if self.load_cam_type:
288
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
289
+
290
+ out = {
291
+ 'elevations_cond': elevations_cond,
292
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
293
+ 'elevations': elevations,
294
+ 'azimuths': azimuths,
295
+ 'elevations_deg': torch.rad2deg(elevations),
296
+ 'azimuths_deg': torch.rad2deg(azimuths),
297
+ 'imgs_in': img_tensors_in,
298
+ 'alphas': alpha_tensors_in,
299
+ 'camera_embeddings': camera_embeddings,
300
+ 'normal_task_embeddings': normal_task_embeddings,
301
+ 'depth_task_embeddings': depth_task_embeddings,
302
+ 'filename': filename,
303
+ 'cam_type': cam_type
304
+ }
305
+
306
+ return out
307
+
308
+
mv_diffusion_30/data/normal_utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ def camNormal2worldNormal(rot_c2w, camNormal):
4
+ H,W,_ = camNormal.shape
5
+ normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
6
+
7
+ return normal_img
8
+
9
+ def worldNormal2camNormal(rot_w2c, normal_map_world):
10
+ H,W,_ = normal_map_world.shape
11
+ # normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
12
+
13
+ # faster version
14
+ # Reshape the normal map into a 2D array where each row represents a normal vector
15
+ normal_map_flat = normal_map_world.reshape(-1, 3)
16
+
17
+ # Transform the normal vectors using the transformation matrix
18
+ normal_map_camera_flat = np.dot(normal_map_flat, rot_w2c.T)
19
+
20
+ # Reshape the transformed normal map back to its original shape
21
+ normal_map_camera = normal_map_camera_flat.reshape(normal_map_world.shape)
22
+
23
+ return normal_map_camera
24
+
25
+ def trans_normal(normal, RT_w2c, RT_w2c_target):
26
+
27
+ # normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
28
+ # normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
29
+
30
+ relative_RT = np.matmul(RT_w2c_target[:3,:3], np.linalg.inv(RT_w2c[:3,:3]))
31
+ normal_target_cam = worldNormal2camNormal(relative_RT[:3,:3], normal)
32
+
33
+ return normal_target_cam
34
+
35
+ def img2normal(img):
36
+ return (img/255.)*2-1
37
+
38
+ def normal2img(normal):
39
+ return np.uint8((normal*0.5+0.5)*255)
40
+
41
+ def norm_normalize(normal, dim=-1):
42
+
43
+ normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
44
+
45
+ return normal
mv_diffusion_30/data/objaverse_dataset.py ADDED
@@ -0,0 +1,1359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ import PIL.Image
20
+ from .normal_utils import trans_normal, normal2img, img2normal
21
+ import pdb
22
+ from .depth_utils import scale_depth_to_model
23
+ import traceback
24
+
25
+
26
+ class ObjaverseDataset(Dataset):
27
+ def __init__(self,
28
+ root_dir_ortho: str,
29
+ root_dir_persp: str,
30
+ pred_ortho: bool,
31
+ pred_persp: bool,
32
+ num_views: int,
33
+ bg_color: Any,
34
+ img_wh: Tuple[int, int],
35
+ object_list: str,
36
+ groups_num: int=1,
37
+ validation: bool = False,
38
+ data_view_num: int = 6,
39
+ num_validation_samples: int = 64,
40
+ num_samples: Optional[int] = None,
41
+ invalid_list: Optional[str] = None,
42
+ trans_norm_system: bool = True, # if True, transform all normals map into the cam system of front view
43
+ augment_data: bool = False,
44
+ read_normal: bool = True,
45
+ read_color: bool = False,
46
+ read_depth: bool = False,
47
+ read_mask: bool = False,
48
+ pred_type: str = 'color',
49
+ suffix: str = 'png',
50
+ subscene_tag: int = 2,
51
+ load_cam_type: bool = False,
52
+ backup_scene: str = "0306b42594fb447ca574f597352d4b56",
53
+ ortho_crop_size: int = 360,
54
+ persp_crop_size: int = 440,
55
+ load_switcher: bool = True
56
+ ) -> None:
57
+ """Create a dataset from a folder of images.
58
+ If you pass in a root directory it will be searched for images
59
+ ending in ext (ext can be a list)
60
+ """
61
+ self.load_cam_type = load_cam_type
62
+ self.root_dir_ortho = Path(root_dir_ortho)
63
+ self.root_dir_persp = Path(root_dir_persp)
64
+ self.pred_ortho = pred_ortho
65
+ self.pred_persp = pred_persp
66
+ self.num_views = num_views
67
+ self.bg_color = bg_color
68
+ self.validation = validation
69
+ self.num_samples = num_samples
70
+ self.trans_norm_system = trans_norm_system
71
+ self.augment_data = augment_data
72
+ self.invalid_list = invalid_list
73
+ self.groups_num = groups_num
74
+ print("augment data: ", self.augment_data)
75
+ self.img_wh = img_wh
76
+ self.read_normal = read_normal
77
+ self.read_color = read_color
78
+ self.read_depth = read_depth
79
+ self.read_mask = read_mask
80
+ self.pred_type = pred_type # load type
81
+ self.suffix = suffix
82
+ self.subscene_tag = subscene_tag
83
+
84
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
85
+ self.fix_cam_pose_dir = "./mvdiffusion/data/fixed_poses/nine_views"
86
+
87
+ self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
88
+ self.ortho_crop_size = ortho_crop_size
89
+ self.persp_crop_size = persp_crop_size
90
+ self.load_switcher = load_switcher
91
+
92
+ if object_list is not None:
93
+ with open(object_list) as f:
94
+ self.objects = json.load(f)
95
+ self.objects = [os.path.basename(o).replace(".glb", "") for o in self.objects]
96
+ else:
97
+ self.objects = os.listdir(self.root_dir)
98
+ self.objects = sorted(self.objects)
99
+
100
+ if self.invalid_list is not None:
101
+ with open(self.invalid_list) as f:
102
+ self.invalid_objects = json.load(f)
103
+ self.invalid_objects = [os.path.basename(o).replace(".glb", "") for o in self.invalid_objects]
104
+ else:
105
+ self.invalid_objects = []
106
+
107
+
108
+ self.all_objects = set(self.objects) - (set(self.invalid_objects) & set(self.objects))
109
+ self.all_objects = list(self.all_objects)
110
+
111
+ if not validation:
112
+ self.all_objects = self.all_objects[:-num_validation_samples]
113
+ else:
114
+ self.all_objects = self.all_objects[-num_validation_samples:]
115
+ if num_samples is not None:
116
+ self.all_objects = self.all_objects[:num_samples]
117
+
118
+ print("loading ", len(self.all_objects), " objects in the dataset")
119
+
120
+ if self.pred_type == 'color':
121
+ self.backup_data = self.__getitem_color__(0, backup_scene)
122
+ elif self.pred_type == 'normal_depth':
123
+ self.backup_data = self.__getitem_normal_depth__(0, backup_scene)
124
+ elif self.pred_type == 'mixed_rgb_normal_depth':
125
+ self.backup_data = self.__getitem_mixed__(0, backup_scene)
126
+ elif self.pred_type == 'mixed_color_normal':
127
+ self.backup_data = self.__getitem_image_normal_mixed__(0, backup_scene)
128
+ elif self.pred_type == 'mixed_rgb_noraml_mask':
129
+ self.backup_data = self.__getitem_mixed_rgb_noraml_mask__(0, backup_scene)
130
+ elif self.pred_type == 'joint_color_normal':
131
+ self.backup_data = self.__getitem_joint_rgb_noraml__(0, backup_scene)
132
+
133
+
134
+ def __len__(self):
135
+ return len(self.objects)*self.total_view
136
+
137
+ def load_fixed_poses(self):
138
+ poses = {}
139
+ for face in self.view_types:
140
+ RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
141
+ poses[face] = RT
142
+
143
+ return poses
144
+
145
+ def cartesian_to_spherical(self, xyz):
146
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
147
+ xy = xyz[:,0]**2 + xyz[:,1]**2
148
+ z = np.sqrt(xy + xyz[:,2]**2)
149
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
150
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
151
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
152
+ return np.array([theta, azimuth, z])
153
+
154
+ def get_T(self, target_RT, cond_RT):
155
+ R, T = target_RT[:3, :3], target_RT[:, -1]
156
+ T_target = -R.T @ T # change to cam2world
157
+
158
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
159
+ T_cond = -R.T @ T
160
+
161
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
162
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
163
+
164
+ d_theta = theta_target - theta_cond
165
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
166
+ d_z = z_target - z_cond
167
+
168
+ # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
169
+ return d_theta, d_azimuth
170
+
171
+ def get_bg_color(self):
172
+ if self.bg_color == 'white':
173
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
174
+ elif self.bg_color == 'black':
175
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
176
+ elif self.bg_color == 'gray':
177
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
178
+ elif self.bg_color == 'random':
179
+ bg_color = np.random.rand(3)
180
+ elif self.bg_color == 'three_choices':
181
+ white = np.array([1., 1., 1.], dtype=np.float32)
182
+ black = np.array([0., 0., 0.], dtype=np.float32)
183
+ gray = np.array([0.5, 0.5, 0.5], dtype=np.float32)
184
+ bg_color = random.choice([white, black, gray])
185
+ elif isinstance(self.bg_color, float):
186
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
187
+ else:
188
+ raise NotImplementedError
189
+ return bg_color
190
+
191
+
192
+
193
+ def load_mask(self, img_path, return_type='np'):
194
+ # not using cv2 as may load in uint16 format
195
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
196
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
197
+ # pil always returns uint8
198
+ img = np.array(Image.open(img_path).resize(self.img_wh))
199
+ img = np.float32(img > 0)
200
+
201
+ assert len(np.shape(img)) == 2
202
+
203
+ if return_type == "np":
204
+ pass
205
+ elif return_type == "pt":
206
+ img = torch.from_numpy(img)
207
+ else:
208
+ raise NotImplementedError
209
+
210
+ return img
211
+
212
+ def load_mask_from_rgba(self, img_path, camera_type):
213
+ img = Image.open(img_path)
214
+
215
+ if camera_type == 'ortho':
216
+ left = (img.width - self.ortho_crop_size) // 2
217
+ right = (img.width + self.ortho_crop_size) // 2
218
+ top = (img.height - self.ortho_crop_size) // 2
219
+ bottom = (img.height + self.ortho_crop_size) // 2
220
+ img = img.crop((left, top, right, bottom))
221
+ if camera_type == 'persp':
222
+ left = (img.width - self.persp_crop_size) // 2
223
+ right = (img.width + self.persp_crop_size) // 2
224
+ top = (img.height - self.persp_crop_size) // 2
225
+ bottom = (img.height + self.persp_crop_size) // 2
226
+ img = img.crop((left, top, right, bottom))
227
+
228
+ img = img.resize(self.img_wh)
229
+ img = np.array(img).astype(np.float32) / 255. # [0, 1]
230
+ assert img.shape[-1] == 4 # must RGBA
231
+
232
+ alpha = img[:, :, 3:]
233
+
234
+ if alpha.shape[-1] != 1:
235
+ alpha = alpha[:, :, None]
236
+
237
+ return alpha
238
+
239
+ def load_image(self, img_path, bg_color, alpha, return_type='np', camera_type=None, read_depth=False, center_crop_size=None):
240
+ # not using cv2 as may load in uint16 format
241
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
242
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
243
+ # pil always returns uint8
244
+ img = Image.open(img_path)
245
+ if center_crop_size == None:
246
+ if camera_type == 'ortho':
247
+ left = (img.width - self.ortho_crop_size) // 2
248
+ right = (img.width + self.ortho_crop_size) // 2
249
+ top = (img.height - self.ortho_crop_size) // 2
250
+ bottom = (img.height + self.ortho_crop_size) // 2
251
+ img = img.crop((left, top, right, bottom))
252
+ if camera_type == 'persp':
253
+ left = (img.width - self.persp_crop_size) // 2
254
+ right = (img.width + self.persp_crop_size) // 2
255
+ top = (img.height - self.persp_crop_size) // 2
256
+ bottom = (img.height + self.persp_crop_size) // 2
257
+ img = img.crop((left, top, right, bottom))
258
+ else:
259
+ center_crop_size = min(center_crop_size, 512)
260
+ left = (img.width - center_crop_size) // 2
261
+ right = (img.width + center_crop_size) // 2
262
+ top = (img.height - center_crop_size) // 2
263
+ bottom = (img.height + center_crop_size) // 2
264
+ img = img.crop((left, top, right, bottom))
265
+
266
+ img = img.resize(self.img_wh)
267
+ img = np.array(img).astype(np.float32) / 255. # [0, 1]
268
+ assert img.shape[-1] == 3 or img.shape[-1] == 4 # RGB or RGBA
269
+
270
+ if alpha is None and img.shape[-1] == 4:
271
+ alpha = img[:, :, 3:]
272
+ img = img[:, :, :3]
273
+
274
+ if alpha.shape[-1] != 1:
275
+ alpha = alpha[:, :, None]
276
+
277
+ if read_depth:
278
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
279
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
280
+
281
+ if return_type == "np":
282
+ pass
283
+ elif return_type == "pt":
284
+ img = torch.from_numpy(img)
285
+ else:
286
+ raise NotImplementedError
287
+
288
+ return img
289
+
290
+ def load_depth(self, img_path, bg_color, alpha, return_type='np', camera_type=None):
291
+ # not using cv2 as may load in uint16 format
292
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
293
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
294
+ # pil always returns uint8
295
+ depth_bg_color = np.array([1., 1., 1.], dtype=np.float32) # white for depth
296
+ depth_map = Image.open(img_path)
297
+
298
+ if camera_type == 'ortho':
299
+ left = (depth_map.width - self.ortho_crop_size) // 2
300
+ right = (depth_map.width + self.ortho_crop_size) // 2
301
+ top = (depth_map.height - self.ortho_crop_size) // 2
302
+ bottom = (depth_map.height + self.ortho_crop_size) // 2
303
+ depth_map = depth_map.crop((left, top, right, bottom))
304
+ if camera_type == 'persp':
305
+ left = (depth_map.width - self.persp_crop_size) // 2
306
+ right = (depth_map.width + self.persp_crop_size) // 2
307
+ top = (depth_map.height - self.persp_crop_size) // 2
308
+ bottom = (depth_map.height + self.persp_crop_size) // 2
309
+ depth_map = depth_map.crop((left, top, right, bottom))
310
+
311
+ depth_map = depth_map.resize(self.img_wh)
312
+ depth_map = np.array(depth_map)
313
+
314
+ # scale the depth map:
315
+ depth_map = scale_depth_to_model(depth_map.astype(np.float32))
316
+ # depth_map = depth_map / 65535. # [0, 1]
317
+ # depth_map[depth_map > 0.4] = 0
318
+ # depth_map = depth_map / 0.4
319
+
320
+ assert depth_map.ndim == 2 # depth
321
+ img = np.stack([depth_map]*3, axis=-1)
322
+
323
+ if alpha.shape[-1] != 1:
324
+ alpha = alpha[:, :, None]
325
+
326
+
327
+ # print(np.max(img[:, :, 0]))
328
+ # print(np.min(img[...,:3]), np.max(img[...,:3]))
329
+ img = img[...,:3] * alpha + depth_bg_color * (1 - alpha)
330
+
331
+ if return_type == "np":
332
+ pass
333
+ elif return_type == "pt":
334
+ img = torch.from_numpy(img)
335
+ else:
336
+ raise NotImplementedError
337
+
338
+ return img
339
+
340
+ def transform_mask_as_input(self, mask, return_type='np'):
341
+
342
+ # mask = mask * 255
343
+ # print(np.max(mask))
344
+
345
+ # mask = mask.resize(self.img_wh)
346
+ mask = np.squeeze(mask, axis=-1)
347
+ assert mask.ndim == 2 #
348
+ mask = np.stack([mask]*3, axis=-1)
349
+ if return_type == "np":
350
+ pass
351
+ elif return_type == "pt":
352
+ mask = torch.from_numpy(mask)
353
+ else:
354
+ raise NotImplementedError
355
+ return mask
356
+
357
+
358
+
359
+ def load_normal(self, img_path, bg_color, alpha, RT_w2c=None, RT_w2c_cond=None, return_type='np', camera_type=None, center_crop_size=None):
360
+ # not using cv2 as may load in uint16 format
361
+ # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
362
+ # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
363
+ # pil always returns uint8
364
+ # normal = Image.open(img_path)
365
+
366
+ img = Image.open(img_path)
367
+ if center_crop_size == None:
368
+ if camera_type == 'ortho':
369
+ left = (img.width - self.ortho_crop_size) // 2
370
+ right = (img.width + self.ortho_crop_size) // 2
371
+ top = (img.height - self.ortho_crop_size) // 2
372
+ bottom = (img.height + self.ortho_crop_size) // 2
373
+ img = img.crop((left, top, right, bottom))
374
+ if camera_type == 'persp':
375
+ left = (img.width - self.persp_crop_size) // 2
376
+ right = (img.width + self.persp_crop_size) // 2
377
+ top = (img.height - self.persp_crop_size) // 2
378
+ bottom = (img.height + self.persp_crop_size) // 2
379
+ img = img.crop((left, top, right, bottom))
380
+ else:
381
+ center_crop_size = min(center_crop_size, 512)
382
+ left = (img.width - center_crop_size) // 2
383
+ right = (img.width + center_crop_size) // 2
384
+ top = (img.height - center_crop_size) // 2
385
+ bottom = (img.height + center_crop_size) // 2
386
+ img = img.crop((left, top, right, bottom))
387
+
388
+ normal = np.array(img.resize(self.img_wh))
389
+
390
+ assert normal.shape[-1] == 3 or normal.shape[-1] == 4 # RGB or RGBA
391
+
392
+ if alpha is None and normal.shape[-1] == 4:
393
+ alpha = normal[:, :, 3:] / 255.
394
+ normal = normal[:, :, :3]
395
+
396
+ normal = trans_normal(img2normal(normal), RT_w2c, RT_w2c_cond)
397
+
398
+ img = (normal*0.5 + 0.5).astype(np.float32) # [0, 1]
399
+
400
+ if alpha.shape[-1] != 1:
401
+ alpha = alpha[:, :, None]
402
+
403
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
404
+
405
+ if return_type == "np":
406
+ pass
407
+ elif return_type == "pt":
408
+ img = torch.from_numpy(img)
409
+ else:
410
+ raise NotImplementedError
411
+
412
+ return img
413
+
414
+ def __len__(self):
415
+ return len(self.all_objects)
416
+
417
+ def __getitem_color__(self, index, debug_object=None):
418
+ if debug_object is not None:
419
+ object_name = debug_object #
420
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
421
+ else:
422
+ object_name = self.all_objects[index % len(self.all_objects)]
423
+ set_idx = 0
424
+
425
+ if self.augment_data:
426
+ cond_view = random.sample(self.view_types, k=1)[0]
427
+ else:
428
+ cond_view = 'front'
429
+
430
+ assert self.pred_ortho or self.pred_persp
431
+ if self.pred_ortho and self.pred_persp:
432
+ if random.random() < 0.5:
433
+ load_dir = self.root_dir_ortho
434
+ load_cam_type = 'ortho'
435
+ else:
436
+ load_dir = self.root_dir_persp
437
+ load_cam_type = 'persp'
438
+ elif self.pred_ortho and not self.pred_persp:
439
+ load_dir = self.root_dir_ortho
440
+ load_cam_type = 'ortho'
441
+ elif self.pred_persp and not self.pred_ortho:
442
+ load_dir = self.root_dir_persp
443
+ load_cam_type = 'persp'
444
+
445
+ # ! if you would like predict depth; modify here
446
+
447
+ read_color, read_normal, read_depth = True, False, False
448
+
449
+
450
+ assert (read_color and (read_normal or read_depth)) is False
451
+
452
+ view_types = self.view_types
453
+
454
+ cond_w2c = self.fix_cam_poses[cond_view]
455
+
456
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
457
+
458
+ elevations = []
459
+ azimuths = []
460
+
461
+ # get the bg color
462
+ bg_color = self.get_bg_color()
463
+
464
+ if self.read_mask:
465
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
466
+ "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
467
+ return_type='np')
468
+ else:
469
+ cond_alpha = None
470
+ img_tensors_in = [
471
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
472
+ "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
473
+ bg_color, cond_alpha, return_type='pt', camera_type=load_cam_type).permute(2, 0, 1)
474
+ ] * self.num_views
475
+ img_tensors_out = []
476
+
477
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
478
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
479
+ "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
480
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
481
+ "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
482
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
483
+ "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
484
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
485
+ "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
486
+ if self.read_mask:
487
+ alpha = self.load_mask(mask_path, return_type='np')
488
+ else:
489
+ alpha = None
490
+
491
+ if read_color:
492
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type)
493
+ img_tensor = img_tensor.permute(2, 0, 1)
494
+ img_tensors_out.append(img_tensor)
495
+
496
+ if read_normal:
497
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c,
498
+ return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
499
+ img_tensors_out.append(normal_tensor)
500
+ if read_depth:
501
+ depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
502
+ img_tensors_out.append(depth_tensor)
503
+
504
+ # evelations, azimuths
505
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
506
+ elevations.append(elevation)
507
+ azimuths.append(azimuth)
508
+
509
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
510
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
511
+
512
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
513
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
514
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
515
+
516
+ if load_cam_type == 'ortho':
517
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
518
+ else:
519
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
520
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
521
+ # if self.pred_ortho and self.pred_persp:
522
+ if self.load_cam_type:
523
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
524
+
525
+ normal_class = torch.tensor([1, 0]).float()
526
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
527
+ color_class = torch.tensor([0, 1]).float()
528
+ color_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
529
+ if read_normal or read_depth:
530
+ task_embeddings = normal_task_embeddings
531
+ if read_color:
532
+ task_embeddings = color_task_embeddings
533
+ # print(elevations)
534
+ # print(azimuths)
535
+ return {
536
+ 'elevations_cond': elevations_cond,
537
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
538
+ 'elevations': elevations,
539
+ 'azimuths': azimuths,
540
+ 'elevations_deg': torch.rad2deg(elevations),
541
+ 'azimuths_deg': torch.rad2deg(azimuths),
542
+ 'imgs_in': img_tensors_in,
543
+ 'imgs_out': img_tensors_out,
544
+ 'camera_embeddings': camera_embeddings,
545
+ 'task_embeddings': task_embeddings
546
+ }
547
+
548
+ def __getitem_normal_depth__(self, index, debug_object=None):
549
+ if debug_object is not None:
550
+ object_name = debug_object #
551
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
552
+ else:
553
+ object_name = self.all_objects[index%len(self.all_objects)]
554
+ set_idx = 0
555
+
556
+ if self.augment_data:
557
+ cond_view = random.sample(self.view_types, k=1)[0]
558
+ else:
559
+ cond_view = 'front'
560
+
561
+ assert self.pred_ortho or self.pred_persp
562
+ if self.pred_ortho and self.pred_persp:
563
+ if random.random() < 0.5:
564
+ load_dir = self.root_dir_ortho
565
+ load_cam_type = 'ortho'
566
+ else:
567
+ load_dir = self.root_dir_persp
568
+ load_cam_type = 'persp'
569
+ elif self.pred_ortho and not self.pred_persp:
570
+ load_dir = self.root_dir_ortho
571
+ load_cam_type = 'ortho'
572
+ elif self.pred_persp and not self.pred_ortho:
573
+ load_dir = self.root_dir_persp
574
+ load_cam_type = 'persp'
575
+
576
+ view_types = self.view_types
577
+
578
+ cond_w2c = self.fix_cam_poses[cond_view]
579
+
580
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
581
+
582
+ elevations = []
583
+ azimuths = []
584
+
585
+ # get the bg color
586
+ bg_color = self.get_bg_color()
587
+
588
+ if self.read_mask:
589
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
590
+ else:
591
+ cond_alpha = None
592
+ # img_tensors_in = [
593
+ # self.load_image(os.path.join(self.root_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
594
+ # ] * self.num_views
595
+ img_tensors_out = []
596
+ normal_tensors_out = []
597
+ depth_tensors_out = []
598
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
599
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
600
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
601
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
602
+
603
+ if self.read_mask:
604
+ alpha = self.load_mask(mask_path, return_type='np')
605
+ else:
606
+ alpha = None
607
+
608
+ if self.read_color:
609
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type)
610
+ img_tensor = img_tensor.permute(2, 0, 1)
611
+ img_tensors_out.append(img_tensor)
612
+
613
+ if self.read_normal:
614
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
615
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
616
+ normal_tensors_out.append(normal_tensor)
617
+
618
+ if self.read_depth:
619
+ if alpha is None:
620
+ alpha = self.load_mask_from_rgba(img_path, camera_type=load_cam_type)
621
+ depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
622
+ depth_tensors_out.append(depth_tensor)
623
+
624
+
625
+ # evelations, azimuths
626
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
627
+ elevations.append(elevation)
628
+ azimuths.append(azimuth)
629
+
630
+ img_tensors_in = img_tensors_out
631
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
632
+ if self.read_color:
633
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
634
+ if self.read_normal:
635
+ normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
636
+ if self.read_depth:
637
+ depth_tensors_out = torch.stack(depth_tensors_out, dim=0).float() # (Nv, 3, H, W)
638
+
639
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
640
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
641
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
642
+
643
+ if load_cam_type == 'ortho':
644
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
645
+ else:
646
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
647
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
648
+ # if self.pred_ortho and self.pred_persp:
649
+ if self.load_cam_type:
650
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
651
+
652
+ normal_class = torch.tensor([1, 0]).float()
653
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
654
+ color_class = torch.tensor([0, 1]).float()
655
+ depth_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
656
+
657
+ return {
658
+ 'elevations_cond': elevations_cond,
659
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
660
+ 'elevations': elevations,
661
+ 'azimuths': azimuths,
662
+ 'elevations_deg': torch.rad2deg(elevations),
663
+ 'azimuths_deg': torch.rad2deg(azimuths),
664
+ 'imgs_in': img_tensors_in,
665
+ 'imgs_out': img_tensors_out,
666
+ 'normals_out': normal_tensors_out,
667
+ 'depth_out': depth_tensors_out,
668
+ 'camera_embeddings': camera_embeddings,
669
+ 'normal_task_embeddings': normal_task_embeddings,
670
+ 'depth_task_embeddings': depth_task_embeddings
671
+ }
672
+
673
+ def __getitem_mixed_rgb_noraml_mask__(self, index, debug_object=None):
674
+ if debug_object is not None:
675
+ object_name = debug_object #
676
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
677
+ else:
678
+ object_name = self.all_objects[index%len(self.all_objects)]
679
+ set_idx = 0
680
+
681
+ if self.augment_data:
682
+ cond_view = random.sample(self.view_types, k=1)[0]
683
+ else:
684
+ cond_view = 'front'
685
+
686
+ assert self.pred_ortho or self.pred_persp
687
+ if self.pred_ortho and self.pred_persp:
688
+ if random.random() < 0.5:
689
+ load_dir = self.root_dir_ortho
690
+ load_cam_type = 'ortho'
691
+ else:
692
+ load_dir = self.root_dir_persp
693
+ load_cam_type = 'persp'
694
+ elif self.pred_ortho and not self.pred_persp:
695
+ load_dir = self.root_dir_ortho
696
+ load_cam_type = 'ortho'
697
+ elif self.pred_persp and not self.pred_ortho:
698
+ load_dir = self.root_dir_persp
699
+ load_cam_type = 'persp'
700
+
701
+ view_types = self.view_types
702
+
703
+ cond_w2c = self.fix_cam_poses[cond_view]
704
+
705
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
706
+
707
+ elevations = []
708
+ azimuths = []
709
+
710
+ # get the bg color
711
+ bg_color = self.get_bg_color()
712
+
713
+ if self.read_mask:
714
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
715
+ else:
716
+ cond_alpha = None
717
+
718
+ img_tensors_out = []
719
+ normal_tensors_out = []
720
+ depth_tensors_out = []
721
+
722
+ random_select = random.random()
723
+ read_color, read_normal, read_mask = [random_select < 1 / 3, 1 / 3 <= random_select <= 2 / 3,
724
+ random_select > 2 / 3]
725
+ # print(read_color, read_normal, read_depth)
726
+
727
+ assert sum([read_color, read_normal, read_mask]) == 1, "Only one variable should be True"
728
+
729
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
730
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
731
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
732
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
733
+
734
+ if self.read_mask:
735
+ alpha = self.load_mask(mask_path, return_type='np')
736
+ else:
737
+ alpha = None
738
+
739
+ if read_color:
740
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type, read_depth=False)
741
+ img_tensor = img_tensor.permute(2, 0, 1)
742
+ img_tensors_out.append(img_tensor)
743
+
744
+ if read_normal:
745
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
746
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
747
+ img_tensors_out.append(normal_tensor)
748
+
749
+ if read_mask:
750
+ if alpha is None:
751
+ alpha = self.load_mask_from_rgba(img_path, camera_type=load_cam_type)
752
+ mask_tensor = self.transform_mask_as_input(alpha, return_type='pt').permute(2, 0, 1)
753
+ img_tensors_out.append(mask_tensor)
754
+
755
+ # evelations, azimuths
756
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
757
+ elevations.append(elevation)
758
+ azimuths.append(azimuth)
759
+
760
+ if self.load_switcher: # rgb input, use domain switcher to control the output type
761
+ img_tensors_in = [
762
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
763
+ "normals_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
764
+ bg_color, cond_alpha, RT_w2c=cond_w2c, RT_w2c_cond=cond_w2c, return_type='pt', camera_type=load_cam_type).permute(
765
+ 2, 0, 1)
766
+ ] * self.num_views
767
+ color_class = torch.tensor([0, 1]).float()
768
+ color_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
769
+
770
+ normal_class = torch.tensor([1, 0]).float()
771
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
772
+
773
+ mask_class = torch.tensor([1, 1]).float()
774
+ mask_task_embeddings = torch.stack([mask_class] * self.num_views, dim=0)
775
+
776
+ if read_color:
777
+ task_embeddings = color_task_embeddings
778
+ # img_tensors_out = depth_tensors_out
779
+ elif read_normal:
780
+ task_embeddings = normal_task_embeddings
781
+ # img_tensors_out = normal_tensors_out
782
+ elif read_mask:
783
+ task_embeddings = mask_task_embeddings
784
+ # img_tensors_out = depth_tensors_out
785
+
786
+ else: # for stage 1 training, the input and the output are in the same domain
787
+ img_tensors_in = [img_tensors_out[0]] * self.num_views
788
+
789
+ empty_class = torch.tensor([0, 0]).float() # empty task
790
+ empty_task_embeddings = torch.stack([empty_class] * self.num_views, dim=0)
791
+ task_embeddings = empty_task_embeddings
792
+
793
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
794
+
795
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
796
+
797
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
798
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
799
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
800
+
801
+ if load_cam_type == 'ortho':
802
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
803
+ else:
804
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
805
+
806
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
807
+
808
+ if self.load_cam_type:
809
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
810
+
811
+ return {
812
+ 'elevations_cond': elevations_cond,
813
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
814
+ 'elevations': elevations,
815
+ 'azimuths': azimuths,
816
+ 'elevations_deg': torch.rad2deg(elevations),
817
+ 'azimuths_deg': torch.rad2deg(azimuths),
818
+ 'imgs_in': img_tensors_in,
819
+ 'imgs_out': img_tensors_out,
820
+ 'normals_out': normal_tensors_out,
821
+ 'depth_out': depth_tensors_out,
822
+ 'camera_embeddings': camera_embeddings,
823
+ 'task_embeddings': task_embeddings,
824
+ }
825
+
826
+ def __getitem_mixed__(self, index, debug_object=None):
827
+ if debug_object is not None:
828
+ object_name = debug_object #
829
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
830
+ else:
831
+ object_name = self.all_objects[index%len(self.all_objects)]
832
+ set_idx = 0
833
+
834
+ if self.augment_data:
835
+ cond_view = random.sample(self.view_types, k=1)[0]
836
+ else:
837
+ cond_view = 'front'
838
+
839
+ assert self.pred_ortho or self.pred_persp
840
+ if self.pred_ortho and self.pred_persp:
841
+ if random.random() < 0.5:
842
+ load_dir = self.root_dir_ortho
843
+ load_cam_type = 'ortho'
844
+ else:
845
+ load_dir = self.root_dir_persp
846
+ load_cam_type = 'persp'
847
+ elif self.pred_ortho and not self.pred_persp:
848
+ load_dir = self.root_dir_ortho
849
+ load_cam_type = 'ortho'
850
+ elif self.pred_persp and not self.pred_ortho:
851
+ load_dir = self.root_dir_persp
852
+ load_cam_type = 'persp'
853
+
854
+ view_types = self.view_types
855
+
856
+ cond_w2c = self.fix_cam_poses[cond_view]
857
+
858
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
859
+
860
+ elevations = []
861
+ azimuths = []
862
+
863
+ # get the bg color
864
+ bg_color = self.get_bg_color()
865
+
866
+ if self.read_mask:
867
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
868
+ else:
869
+ cond_alpha = None
870
+ # img_tensors_in = [
871
+ # self.load_image(os.path.join(self.root_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
872
+ # ] * self.num_views
873
+ img_tensors_out = []
874
+ normal_tensors_out = []
875
+ depth_tensors_out = []
876
+
877
+ random_select = random.random()
878
+ read_color, read_normal, read_depth = [random_select < 1 / 3, 1 / 3 <= random_select <= 2 / 3,
879
+ random_select > 2 / 3]
880
+ # print(read_color, read_normal, read_depth)
881
+
882
+ assert sum([read_color, read_normal, read_depth]) == 1, "Only one variable should be True"
883
+
884
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
885
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
886
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
887
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
888
+
889
+ if self.read_mask:
890
+ alpha = self.load_mask(mask_path, return_type='np')
891
+ else:
892
+ alpha = None
893
+
894
+ if read_color:
895
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type, read_depth=read_depth)
896
+ img_tensor = img_tensor.permute(2, 0, 1)
897
+ img_tensors_out.append(img_tensor)
898
+
899
+ if read_normal:
900
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
901
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
902
+ img_tensors_out.append(normal_tensor)
903
+
904
+ if read_depth:
905
+ if alpha is None:
906
+ alpha = self.load_mask_from_rgba(img_path, camera_type=load_cam_type)
907
+ depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
908
+ img_tensors_out.append(depth_tensor)
909
+
910
+
911
+ # evelations, azimuths
912
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
913
+ elevations.append(elevation)
914
+ azimuths.append(azimuth)
915
+
916
+ img_tensors_in = [
917
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
918
+ "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
919
+ bg_color, cond_alpha, return_type='pt', camera_type=load_cam_type, read_depth=read_depth).permute(
920
+ 2, 0, 1)
921
+ ] * self.num_views
922
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
923
+ # if self.read_color:
924
+ # img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
925
+ # if self.read_normal:
926
+ # normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
927
+ # if self.read_depth:
928
+ # depth_tensors_out = torch.stack(depth_tensors_out, dim=0).float() # (Nv, 3, H, W)
929
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
930
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
931
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
932
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
933
+
934
+ if load_cam_type == 'ortho':
935
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
936
+ else:
937
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
938
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
939
+ # if self.pred_ortho and self.pred_persp:
940
+ if self.load_cam_type:
941
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
942
+
943
+ color_class = torch.tensor([0, 1]).float()
944
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
945
+
946
+ normal_class = torch.tensor([1, 0]).float()
947
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
948
+
949
+ depth_class = torch.tensor([1, 1]).float()
950
+ depth_task_embeddings = torch.stack([depth_class]*self.num_views, dim=0)
951
+
952
+ if read_color:
953
+ task_embeddings = color_task_embeddings
954
+ # img_tensors_out = depth_tensors_out
955
+ elif read_normal:
956
+ task_embeddings = normal_task_embeddings
957
+ # img_tensors_out = normal_tensors_out
958
+ elif read_depth:
959
+ task_embeddings = depth_task_embeddings
960
+ # img_tensors_out = depth_tensors_out
961
+
962
+ return {
963
+ 'elevations_cond': elevations_cond,
964
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
965
+ 'elevations': elevations,
966
+ 'azimuths': azimuths,
967
+ 'elevations_deg': torch.rad2deg(elevations),
968
+ 'azimuths_deg': torch.rad2deg(azimuths),
969
+ 'imgs_in': img_tensors_in,
970
+ 'imgs_out': img_tensors_out,
971
+ 'normals_out': normal_tensors_out,
972
+ 'depth_out': depth_tensors_out,
973
+ 'camera_embeddings': camera_embeddings,
974
+ 'task_embeddings': task_embeddings,
975
+ }
976
+
977
+ def __getitem_image_normal_mixed__(self, index, debug_object=None):
978
+ if debug_object is not None:
979
+ object_name = debug_object #
980
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
981
+ else:
982
+ object_name = self.all_objects[index%len(self.all_objects)]
983
+ set_idx = 0
984
+
985
+ if self.augment_data:
986
+ cond_view = random.sample(self.view_types, k=1)[0]
987
+ else:
988
+ cond_view = 'front'
989
+
990
+ assert self.pred_ortho or self.pred_persp
991
+ if self.pred_ortho and self.pred_persp:
992
+ if random.random() < 0.5:
993
+ load_dir = self.root_dir_ortho
994
+ load_cam_type = 'ortho'
995
+ else:
996
+ load_dir = self.root_dir_persp
997
+ load_cam_type = 'persp'
998
+ elif self.pred_ortho and not self.pred_persp:
999
+ load_dir = self.root_dir_ortho
1000
+ load_cam_type = 'ortho'
1001
+ elif self.pred_persp and not self.pred_ortho:
1002
+ load_dir = self.root_dir_persp
1003
+ load_cam_type = 'persp'
1004
+
1005
+ view_types = self.view_types
1006
+
1007
+ cond_w2c = self.fix_cam_poses[cond_view]
1008
+
1009
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
1010
+
1011
+ elevations = []
1012
+ azimuths = []
1013
+
1014
+ # get the bg color
1015
+ bg_color = self.get_bg_color()
1016
+
1017
+ # get crop size for each mv instance:
1018
+ center_crop_size = 0
1019
+ for view in view_types:
1020
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
1021
+
1022
+ img = Image.open(img_path)
1023
+ img = img.resize([512,512])
1024
+ img = np.array(img).astype(np.float32) / 255. # [0, 1]
1025
+
1026
+ max_w_h = self.cal_single_view_crop(img)
1027
+ center_crop_size = max(center_crop_size, max_w_h)
1028
+
1029
+ center_crop_size = center_crop_size * 4. / 3.
1030
+ center_crop_size = center_crop_size + (random.random()-0.5) * 10.
1031
+
1032
+ if self.read_mask:
1033
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
1034
+ else:
1035
+ cond_alpha = None
1036
+ # img_tensors_in = [
1037
+ # self.load_image(os.path.join(self.root_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), bg_color, cond_alpha, return_type='pt').permute(2, 0, 1)
1038
+ # ] * self.num_views
1039
+ img_tensors_out = []
1040
+ normal_tensors_out = []
1041
+ depth_tensors_out = []
1042
+
1043
+ random_select = random.random()
1044
+ read_color, read_normal = [random_select < 1 / 2, 1 / 2 <= random_select <= 1]
1045
+ # print(read_color, read_normal, read_depth)
1046
+
1047
+ assert sum([read_color, read_normal]) == 1, "Only one variable should be True"
1048
+
1049
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
1050
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
1051
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
1052
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
1053
+
1054
+ if self.read_mask:
1055
+ alpha = self.load_mask(mask_path, return_type='np')
1056
+ else:
1057
+ alpha = None
1058
+
1059
+ if read_color:
1060
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type, read_depth=False, center_crop_size=center_crop_size)
1061
+ img_tensor = img_tensor.permute(2, 0, 1)
1062
+ img_tensors_out.append(img_tensor)
1063
+
1064
+ if read_normal:
1065
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
1066
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type, center_crop_size=center_crop_size).permute(2, 0, 1)
1067
+ img_tensors_out.append(normal_tensor)
1068
+
1069
+ # if read_depth:
1070
+ # if alpha is None:
1071
+ # alpha = self.load_mask_from_rgba(img_path, camera_type=load_cam_type)
1072
+ # depth_tensor = self.load_depth(depth_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type).permute(2, 0, 1)
1073
+ # img_tensors_out.append(depth_tensor)
1074
+
1075
+
1076
+ # evelations, azimuths
1077
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
1078
+ elevations.append(elevation)
1079
+ azimuths.append(azimuth)
1080
+
1081
+ img_tensors_in = [
1082
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
1083
+ "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
1084
+ bg_color, cond_alpha, return_type='pt', camera_type=load_cam_type, read_depth=False, center_crop_size=center_crop_size).permute(
1085
+ 2, 0, 1)
1086
+ ] * self.num_views
1087
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
1088
+ # if self.read_color:
1089
+ # img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
1090
+ # if self.read_normal:
1091
+ # normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
1092
+ # if self.read_depth:
1093
+ # depth_tensors_out = torch.stack(depth_tensors_out, dim=0).float() # (Nv, 3, H, W)
1094
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
1095
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
1096
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
1097
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
1098
+
1099
+ if load_cam_type == 'ortho':
1100
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
1101
+ else:
1102
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
1103
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
1104
+ # if self.pred_ortho and self.pred_persp:
1105
+ if self.load_cam_type:
1106
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
1107
+
1108
+ color_class = torch.tensor([0, 1]).float()
1109
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
1110
+
1111
+ normal_class = torch.tensor([1, 0]).float()
1112
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
1113
+
1114
+ # depth_class = torch.tensor([1, 1]).float()
1115
+ # depth_task_embeddings = torch.stack([depth_class]*self.num_views, dim=0)
1116
+
1117
+ if read_color:
1118
+ task_embeddings = color_task_embeddings
1119
+ # img_tensors_out = depth_tensors_out
1120
+ elif read_normal:
1121
+ task_embeddings = normal_task_embeddings
1122
+ # img_tensors_out = normal_tensors_out
1123
+ # elif read_depth:
1124
+ # task_embeddings = depth_task_embeddings
1125
+ # img_tensors_out = depth_tensors_out
1126
+
1127
+ return {
1128
+ 'elevations_cond': elevations_cond,
1129
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
1130
+ 'elevations': elevations,
1131
+ 'azimuths': azimuths,
1132
+ 'elevations_deg': torch.rad2deg(elevations),
1133
+ 'azimuths_deg': torch.rad2deg(azimuths),
1134
+ 'imgs_in': img_tensors_in,
1135
+ 'imgs_out': img_tensors_out,
1136
+ 'normals_out': normal_tensors_out,
1137
+ 'depth_out': depth_tensors_out,
1138
+ 'camera_embeddings': camera_embeddings,
1139
+ 'task_embeddings': task_embeddings,
1140
+ }
1141
+
1142
+ def cal_single_view_crop(self, image):
1143
+ assert np.shape(image)[-1] == 4 # RGBA
1144
+
1145
+ # Extract the alpha channel (transparency) and the object (RGB channels)
1146
+ alpha_channel = image[:, :, 3]
1147
+
1148
+ # Find the bounding box coordinates of the object
1149
+ coords = cv2.findNonZero(alpha_channel)
1150
+ x, y, width, height = cv2.boundingRect(coords)
1151
+
1152
+ return max(width, height)
1153
+
1154
+ def __getitem_joint_rgb_noraml__(self, index, debug_object=None):
1155
+ if debug_object is not None:
1156
+ object_name = debug_object #
1157
+ set_idx = random.sample(range(0, self.groups_num), 1)[0] # without replacement
1158
+ else:
1159
+ object_name = self.all_objects[index%len(self.all_objects)]
1160
+ set_idx = 0
1161
+
1162
+ if self.augment_data:
1163
+ cond_view = random.sample(self.view_types, k=1)[0]
1164
+ else:
1165
+ cond_view = 'front'
1166
+
1167
+ assert self.pred_ortho or self.pred_persp
1168
+ if self.pred_ortho and self.pred_persp:
1169
+ if random.random() < 0.5:
1170
+ load_dir = self.root_dir_ortho
1171
+ load_cam_type = 'ortho'
1172
+ else:
1173
+ load_dir = self.root_dir_persp
1174
+ load_cam_type = 'persp'
1175
+ elif self.pred_ortho and not self.pred_persp:
1176
+ load_dir = self.root_dir_ortho
1177
+ load_cam_type = 'ortho'
1178
+ elif self.pred_persp and not self.pred_ortho:
1179
+ load_dir = self.root_dir_persp
1180
+ load_cam_type = 'persp'
1181
+
1182
+ view_types = self.view_types
1183
+
1184
+ cond_w2c = self.fix_cam_poses[cond_view]
1185
+
1186
+ tgt_w2cs = [self.fix_cam_poses[view] for view in view_types]
1187
+
1188
+ elevations = []
1189
+ azimuths = []
1190
+
1191
+ # get the bg color
1192
+ bg_color = self.get_bg_color()
1193
+
1194
+ if self.read_mask:
1195
+ cond_alpha = self.load_mask(os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, cond_view, self.suffix)), return_type='np')
1196
+ else:
1197
+ cond_alpha = None
1198
+
1199
+ img_tensors_out = []
1200
+ normal_tensors_out = []
1201
+
1202
+
1203
+ read_color, read_normal = True, True
1204
+ # print(read_color, read_normal, read_depth)
1205
+
1206
+ # get crop size for each mv instance:
1207
+ center_crop_size = 0
1208
+ for view in view_types:
1209
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
1210
+
1211
+ img = Image.open(img_path)
1212
+ img = img.resize([512,512])
1213
+ img = np.array(img).astype(np.float32) / 255. # [0, 1]
1214
+
1215
+ max_w_h = self.cal_single_view_crop(img)
1216
+ center_crop_size = max(center_crop_size, max_w_h)
1217
+
1218
+ center_crop_size = center_crop_size * 4. / 3.
1219
+ center_crop_size = center_crop_size + (random.random()-0.5) * 10.
1220
+
1221
+
1222
+
1223
+ for view, tgt_w2c in zip(view_types, tgt_w2cs):
1224
+ img_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "rgb_%03d_%s.%s" % (set_idx, view, self.suffix))
1225
+ mask_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "mask_%03d_%s.%s" % (set_idx, view, self.suffix))
1226
+ depth_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "depth_%03d_%s.%s" % (set_idx, view, self.suffix))
1227
+
1228
+ if self.read_mask:
1229
+ alpha = self.load_mask(mask_path, return_type='np')
1230
+ else:
1231
+ alpha = None
1232
+
1233
+ if read_color:
1234
+ img_tensor = self.load_image(img_path, bg_color, alpha, return_type="pt", camera_type=load_cam_type, read_depth=False, center_crop_size=center_crop_size)
1235
+ img_tensor = img_tensor.permute(2, 0, 1)
1236
+ img_tensors_out.append(img_tensor)
1237
+
1238
+ if read_normal:
1239
+ normal_path = os.path.join(load_dir, object_name[:self.subscene_tag], object_name, "normals_%03d_%s.%s" % (set_idx, view, self.suffix))
1240
+ normal_tensor = self.load_normal(normal_path, bg_color, alpha, RT_w2c=tgt_w2c, RT_w2c_cond=cond_w2c, return_type="pt", camera_type=load_cam_type, center_crop_size=center_crop_size).permute(2, 0, 1)
1241
+ normal_tensors_out.append(normal_tensor)
1242
+
1243
+ # evelations, azimuths
1244
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
1245
+ elevations.append(elevation)
1246
+ azimuths.append(azimuth)
1247
+
1248
+ if self.load_switcher: # rgb input, use domain switcher to control the output type
1249
+ img_tensors_in = [
1250
+ self.load_image(os.path.join(load_dir, object_name[:self.subscene_tag], object_name,
1251
+ "rgb_%03d_%s.%s" % (set_idx, cond_view, self.suffix)),
1252
+ bg_color, cond_alpha, return_type='pt', camera_type=load_cam_type,
1253
+ read_depth=False, center_crop_size=center_crop_size).permute(
1254
+ 2, 0, 1)
1255
+ ] * self.num_views
1256
+
1257
+ color_class = torch.tensor([0, 1]).float()
1258
+ color_task_embeddings = torch.stack([color_class] * self.num_views, dim=0) # (Nv, 2)
1259
+
1260
+ normal_class = torch.tensor([1, 0]).float()
1261
+ normal_task_embeddings = torch.stack([normal_class] * self.num_views, dim=0) # (Nv, 2)
1262
+
1263
+
1264
+ if read_color:
1265
+ task_embeddings = color_task_embeddings
1266
+ # img_tensors_out = depth_tensors_out
1267
+ elif read_normal:
1268
+ task_embeddings = normal_task_embeddings
1269
+ # img_tensors_out = normal_tensors_out
1270
+
1271
+ else: # for stage 1 training, the input and the output are in the same domain
1272
+ img_tensors_in = [img_tensors_out[0]] * self.num_views
1273
+
1274
+ empty_class = torch.tensor([0, 0]).float() # empty task
1275
+ empty_task_embeddings = torch.stack([empty_class] * self.num_views, dim=0)
1276
+ task_embeddings = empty_task_embeddings
1277
+
1278
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
1279
+
1280
+ img_tensors_out = torch.stack(img_tensors_out, dim=0).float() # (Nv, 3, H, W)
1281
+ normal_tensors_out = torch.stack(normal_tensors_out, dim=0).float() # (Nv, 3, H, W)
1282
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
1283
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
1284
+ elevations_cond = torch.as_tensor([0] * self.num_views).float() # fixed only use 4 views to train
1285
+
1286
+ if load_cam_type == 'ortho':
1287
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
1288
+ else:
1289
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
1290
+
1291
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1)
1292
+
1293
+ if self.load_cam_type:
1294
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
1295
+
1296
+ return {
1297
+ 'elevations_cond': elevations_cond,
1298
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
1299
+ 'elevations': elevations,
1300
+ 'azimuths': azimuths,
1301
+ 'elevations_deg': torch.rad2deg(elevations),
1302
+ 'azimuths_deg': torch.rad2deg(azimuths),
1303
+ 'imgs_in': img_tensors_in,
1304
+ 'imgs_out': img_tensors_out,
1305
+ 'normals_out': normal_tensors_out,
1306
+ 'camera_embeddings': camera_embeddings,
1307
+ 'color_task_embeddings': color_task_embeddings,
1308
+ 'normal_task_embeddings': normal_task_embeddings
1309
+ }
1310
+
1311
+ def __getitem__(self, index):
1312
+ try:
1313
+ if self.pred_type == 'color':
1314
+ data = self.backup_data = self.__getitem_color__(index)
1315
+ elif self.pred_type == 'normal_depth':
1316
+ data = self.backup_data = self.__getitem_normal_depth__(index)
1317
+ elif self.pred_type == 'mixed_rgb_normal_depth':
1318
+ data = self.backup_data = self.__getitem_mixed__(index)
1319
+ elif self.pred_type == 'mixed_color_normal':
1320
+ data = self.backup_data = self.__getitem_image_normal_mixed__(index)
1321
+ elif self.pred_type == 'mixed_rgb_noraml_mask':
1322
+ data = self.backup_data = self.__getitem_mixed_rgb_noraml_mask__(index)
1323
+ elif self.pred_type == 'joint_color_normal':
1324
+ data = self.backup_data = self.__getitem_joint_rgb_noraml__(index)
1325
+ return data
1326
+
1327
+ except:
1328
+ print("load error ", self.all_objects[index%len(self.all_objects)])
1329
+ return self.backup_data
1330
+
1331
+ class ConcatDataset(torch.utils.data.Dataset):
1332
+ def __init__(self, datasets, weights):
1333
+ self.datasets = datasets
1334
+ self.weights = weights
1335
+ self.num_datasets = len(datasets)
1336
+
1337
+ def __getitem__(self, i):
1338
+
1339
+ chosen = random.choices(self.datasets, self.weights, k=1)[0]
1340
+ return chosen[i]
1341
+
1342
+ def __len__(self):
1343
+ return max(len(d) for d in self.datasets)
1344
+
1345
+ if __name__ == "__main__":
1346
+ train_dataset = ObjaverseDataset(
1347
+ root_dir="/ghome/l5/xxlong/.objaverse/hf-objaverse-v1/renderings",
1348
+ size=(128, 128),
1349
+ ext="hdf5",
1350
+ default_trans=torch.zeros(3),
1351
+ return_paths=False,
1352
+ total_view=8,
1353
+ validation=False,
1354
+ object_list=None,
1355
+ views_mode='fourviews'
1356
+ )
1357
+ data0 = train_dataset[0]
1358
+ data1 = train_dataset[50]
1359
+ # print(data)
mv_diffusion_30/data/single_image_dataset.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import numpy as np
3
+ from omegaconf import DictConfig, ListConfig
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from pathlib import Path
7
+ import json
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+ from einops import rearrange
11
+ from typing import Literal, Tuple, Optional, Any
12
+ import cv2
13
+ import random
14
+
15
+ import json
16
+ import os, sys
17
+ import math
18
+
19
+ from glob import glob
20
+
21
+ import PIL.Image
22
+ from .normal_utils import trans_normal, normal2img, img2normal
23
+ import pdb
24
+ from rembg import remove
25
+
26
+
27
+ import cv2
28
+ import numpy as np
29
+
30
+ def add_margin(pil_img, color=0, size=256):
31
+ width, height = pil_img.size
32
+ result = Image.new(pil_img.mode, (size, size), color)
33
+ result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
34
+ return result
35
+
36
+ def scale_and_place_object(image, scale_factor):
37
+ assert np.shape(image)[-1]==4 # RGBA
38
+
39
+ # Extract the alpha channel (transparency) and the object (RGB channels)
40
+ alpha_channel = image[:, :, 3]
41
+
42
+ # Find the bounding box coordinates of the object
43
+ coords = cv2.findNonZero(alpha_channel)
44
+ x, y, width, height = cv2.boundingRect(coords)
45
+
46
+ # Calculate the scale factor for resizing
47
+ original_height, original_width = image.shape[:2]
48
+
49
+ if width > height:
50
+ size = width
51
+ original_size = original_width
52
+ else:
53
+ size = height
54
+ original_size = original_height
55
+
56
+ scale_factor = min(scale_factor, size / (original_size+0.0))
57
+
58
+ new_size = scale_factor * original_size
59
+ scale_factor = new_size / size
60
+
61
+ # Calculate the new size based on the scale factor
62
+ new_width = int(width * scale_factor)
63
+ new_height = int(height * scale_factor)
64
+
65
+ center_x = original_width // 2
66
+ center_y = original_height // 2
67
+
68
+ paste_x = center_x - (new_width // 2)
69
+ paste_y = center_y - (new_height // 2)
70
+
71
+ # Resize the object (RGB channels) to the new size
72
+ rescaled_object = cv2.resize(image[y:y+height, x:x+width], (new_width, new_height))
73
+
74
+ # Create a new RGBA image with the resized image
75
+ new_image = np.zeros((original_height, original_width, 4), dtype=np.uint8)
76
+
77
+ new_image[paste_y:paste_y + new_height, paste_x:paste_x + new_width] = rescaled_object
78
+
79
+ return new_image
80
+
81
+ class SingleImageDataset(Dataset):
82
+ def __init__(self,
83
+ root_dir: str,
84
+ num_views: int,
85
+ img_wh: Tuple[int, int],
86
+ bg_color: str,
87
+ crop_size: int = 224,
88
+ single_image: Optional[PIL.Image.Image] = None,
89
+ num_validation_samples: Optional[int] = None,
90
+ filepaths: Optional[list] = None,
91
+ cam_types: Optional[list] = None,
92
+ cond_type: Optional[str] = None,
93
+ load_cam_type: Optional[bool] = True
94
+ ) -> None:
95
+ """Create a dataset from a folder of images.
96
+ If you pass in a root directory it will be searched for images
97
+ ending in ext (ext can be a list)
98
+ """
99
+ self.root_dir = root_dir
100
+ self.num_views = num_views
101
+ self.img_wh = img_wh
102
+ self.crop_size = crop_size
103
+ self.bg_color = bg_color
104
+ self.cond_type = cond_type
105
+ self.load_cam_type = load_cam_type
106
+ self.cam_types = cam_types
107
+
108
+ if self.num_views == 4:
109
+ self.view_types = ['front', 'right', 'back', 'left']
110
+ elif self.num_views == 5:
111
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left']
112
+ elif self.num_views == 6:
113
+ self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
114
+
115
+ self.fix_cam_pose_dir = "./mv_diffusion_30/data/fixed_poses/nine_views"
116
+
117
+ self.fix_cam_poses = self.load_fixed_poses() # world2cam matrix
118
+
119
+ if single_image is None:
120
+ if filepaths is None:
121
+ # Get a list of all files in the directory
122
+ file_list = os.listdir(self.root_dir)
123
+ self.cam_types = ['ortho'] * len(file_list) + ['persp']* len(file_list)
124
+ file_list = file_list * 2
125
+ else:
126
+ file_list = filepaths
127
+
128
+ # Filter the files that end with .png or .jpg
129
+ self.file_list = [file for file in file_list if file.endswith(('.png', '.jpg'))]
130
+ else:
131
+ self.file_list = None
132
+
133
+ # load all images
134
+ self.all_images = []
135
+ self.all_alphas = []
136
+ bg_color = self.get_bg_color()
137
+
138
+ if single_image is not None:
139
+ image, alpha = self.load_image(None, bg_color, return_type='pt', Imagefile=single_image)
140
+ self.all_images.append(image)
141
+ self.all_alphas.append(alpha)
142
+ else:
143
+ for file in self.file_list:
144
+ print(os.path.join(self.root_dir, file))
145
+ image, alpha = self.load_image(os.path.join(self.root_dir, file), bg_color, return_type='pt')
146
+ self.all_images.append(image)
147
+ self.all_alphas.append(alpha)
148
+ #
149
+ # assert len(self.file_list) == len(self.cam_types)
150
+ self.all_images = self.all_images[:num_validation_samples]
151
+ self.all_alphas = self.all_alphas[:num_validation_samples]
152
+
153
+ def __len__(self):
154
+ return len(self.all_images)
155
+
156
+ def load_fixed_poses(self):
157
+ poses = {}
158
+ for face in self.view_types:
159
+ RT = np.loadtxt(os.path.join(self.fix_cam_pose_dir,'%03d_%s_RT.txt'%(0, face)))
160
+ poses[face] = RT
161
+
162
+ return poses
163
+
164
+ def cartesian_to_spherical(self, xyz):
165
+ ptsnew = np.hstack((xyz, np.zeros(xyz.shape)))
166
+ xy = xyz[:,0]**2 + xyz[:,1]**2
167
+ z = np.sqrt(xy + xyz[:,2]**2)
168
+ theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down
169
+ #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up
170
+ azimuth = np.arctan2(xyz[:,1], xyz[:,0])
171
+ return np.array([theta, azimuth, z])
172
+
173
+ def get_T(self, target_RT, cond_RT):
174
+ R, T = target_RT[:3, :3], target_RT[:, -1]
175
+ T_target = -R.T @ T # change to cam2world
176
+
177
+ R, T = cond_RT[:3, :3], cond_RT[:, -1]
178
+ T_cond = -R.T @ T
179
+
180
+ theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :])
181
+ theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :])
182
+
183
+ d_theta = theta_target - theta_cond
184
+ d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi)
185
+ d_z = z_target - z_cond
186
+
187
+ # d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()])
188
+ return d_theta, d_azimuth
189
+
190
+ def get_bg_color(self):
191
+ if self.bg_color == 'white':
192
+ bg_color = np.array([1., 1., 1.], dtype=np.float32)
193
+ elif self.bg_color == 'black':
194
+ bg_color = np.array([0., 0., 0.], dtype=np.float32)
195
+ elif self.bg_color == 'gray':
196
+ bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
197
+ elif self.bg_color == 'random':
198
+ bg_color = np.random.rand(3)
199
+ elif isinstance(self.bg_color, float):
200
+ bg_color = np.array([self.bg_color] * 3, dtype=np.float32)
201
+ else:
202
+ raise NotImplementedError
203
+ return bg_color
204
+
205
+
206
+ def load_image(self, img_path, bg_color, return_type='np', Imagefile=None):
207
+ # pil always returns uint8
208
+ if Imagefile is None:
209
+ image_input = Image.open(img_path)
210
+ else:
211
+ image_input = Imagefile
212
+ image_size = self.img_wh[0]
213
+
214
+
215
+ # if np.asarray(image_input).shape[-1] != 4:
216
+ # print('move background for:', image_input)
217
+ # image_input = remove(image_input)
218
+
219
+ if self.crop_size!=-1:
220
+ alpha_np = np.asarray(image_input)[:, :, 3]
221
+ coords = np.stack(np.nonzero(alpha_np), 1)[:, (1, 0)]
222
+ min_x, min_y = np.min(coords, 0)
223
+ max_x, max_y = np.max(coords, 0)
224
+ ref_img_ = image_input.crop((min_x, min_y, max_x, max_y))
225
+ h, w = ref_img_.height, ref_img_.width
226
+ scale = self.crop_size / max(h, w)
227
+ h_, w_ = int(scale * h), int(scale * w)
228
+ ref_img_ = ref_img_.resize((w_, h_))
229
+ image_input = add_margin(ref_img_, size=image_size)
230
+ else:
231
+ image_input = add_margin(image_input, size=max(image_input.height, image_input.width))
232
+ image_input = image_input.resize((image_size, image_size))
233
+
234
+ # img = scale_and_place_object(img, self.scale_ratio)
235
+ img = np.array(image_input)
236
+ img = img.astype(np.float32) / 255. # [0, 1]
237
+ assert img.shape[-1] == 4 # RGBA
238
+
239
+ alpha = img[...,3:4]
240
+ img = img[...,:3] * alpha + bg_color * (1 - alpha)
241
+
242
+ if return_type == "np":
243
+ pass
244
+ elif return_type == "pt":
245
+ img = torch.from_numpy(img)
246
+ alpha = torch.from_numpy(alpha)
247
+ else:
248
+ raise NotImplementedError
249
+
250
+ return img, alpha
251
+
252
+
253
+ def __len__(self):
254
+ return len(self.all_images)
255
+
256
+ def __getitem__(self, index):
257
+
258
+ image = self.all_images[index%len(self.all_images)]
259
+ alpha = self.all_alphas[index%len(self.all_images)]
260
+ if self.load_cam_type:
261
+ cam_type = self.cam_types[index%len(self.all_images)]
262
+ else:
263
+ cam_type = 'ortho'
264
+ if self.file_list is not None:
265
+ filename = self.file_list[index%len(self.all_images)].replace(".png", "")
266
+ else:
267
+ filename = 'null'
268
+
269
+ cond_w2c = self.fix_cam_poses['front']
270
+
271
+ tgt_w2cs = [self.fix_cam_poses[view] for view in self.view_types]
272
+
273
+ elevations = []
274
+ azimuths = []
275
+
276
+ img_tensors_in = [
277
+ image.permute(2, 0, 1)
278
+ ] * self.num_views
279
+
280
+ alpha_tensors_in = [
281
+ alpha.permute(2, 0, 1)
282
+ ] * self.num_views
283
+
284
+ for view, tgt_w2c in zip(self.view_types, tgt_w2cs):
285
+ # evelations, azimuths
286
+ elevation, azimuth = self.get_T(tgt_w2c, cond_w2c)
287
+ elevations.append(elevation)
288
+ azimuths.append(azimuth)
289
+
290
+ img_tensors_in = torch.stack(img_tensors_in, dim=0).float() # (Nv, 3, H, W)
291
+ alpha_tensors_in = torch.stack(alpha_tensors_in, dim=0).float() # (Nv, 3, H, W)
292
+
293
+ elevations = torch.as_tensor(elevations).float().squeeze(1)
294
+ azimuths = torch.as_tensor(azimuths).float().squeeze(1)
295
+ elevations_cond = torch.as_tensor([0] * self.num_views).float()
296
+
297
+ normal_class = torch.tensor([1, 0]).float()
298
+ normal_task_embeddings = torch.stack([normal_class]*self.num_views, dim=0) # (Nv, 2)
299
+ color_class = torch.tensor([0, 1]).float()
300
+ color_task_embeddings = torch.stack([color_class]*self.num_views, dim=0) # (Nv, 2)
301
+ depth_class = torch.tensor([1, 1]).float()
302
+ depth_task_embeddings = torch.stack([depth_class]*self.num_views, dim=0)
303
+
304
+ camera_embeddings = torch.stack([elevations_cond, elevations, azimuths], dim=-1) # (Nv, 3)
305
+
306
+ print("camera type:", cam_type)
307
+ if cam_type == 'ortho':
308
+ cam_type_emb = torch.tensor([0, 1]).expand(self.num_views, -1)
309
+ else:
310
+ cam_type_emb = torch.tensor([1, 0]).expand(self.num_views, -1)
311
+
312
+ if self.load_cam_type:
313
+ camera_embeddings = torch.cat((camera_embeddings, cam_type_emb), dim=-1) # (Nv, 5)
314
+
315
+ out = {
316
+ 'elevations_cond': elevations_cond,
317
+ 'elevations_cond_deg': torch.rad2deg(elevations_cond),
318
+ 'elevations': elevations,
319
+ 'azimuths': azimuths,
320
+ 'elevations_deg': torch.rad2deg(elevations),
321
+ 'azimuths_deg': torch.rad2deg(azimuths),
322
+ 'imgs_in': img_tensors_in,
323
+ 'alphas': alpha_tensors_in,
324
+ 'camera_embeddings': camera_embeddings,
325
+ 'normal_task_embeddings': normal_task_embeddings,
326
+ 'color_task_embeddings': color_task_embeddings,
327
+ 'depth_task_embeddings': depth_task_embeddings,
328
+ 'filename': filename,
329
+ 'cam_type': cam_type
330
+ }
331
+
332
+ return out
333
+
334
+
mv_diffusion_30/models/transformer_mv2d.py ADDED
@@ -0,0 +1,1093 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ # from torch.nn.attention import SDPBackend, sdpa_kernel
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.models.embeddings import ImagePositionalEmbeddings
24
+ from diffusers.utils import BaseOutput, deprecate
25
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
26
+ from diffusers.models.attention import FeedForward, AdaLayerNorm, AdaLayerNormZero, Attention
27
+ from diffusers.models.embeddings import PatchEmbed
28
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils.import_utils import is_xformers_available
31
+
32
+ from einops import rearrange, repeat
33
+ import pdb
34
+ import random
35
+
36
+
37
+ # if is_xformers_available():
38
+ # import xformers
39
+ # import xformers.ops
40
+ # else:
41
+ # xformers = None
42
+
43
+ def my_repeat(tensor, num_repeats):
44
+ """
45
+ Repeat a tensor along a given dimension
46
+ """
47
+ if len(tensor.shape) == 3:
48
+ return repeat(tensor, "b d c -> (b v) d c", v=num_repeats)
49
+ elif len(tensor.shape) == 4:
50
+ return repeat(tensor, "a b d c -> (a v) b d c", v=num_repeats)
51
+
52
+
53
+ @dataclass
54
+ class TransformerMV2DModelOutput(BaseOutput):
55
+ """
56
+ The output of [`Transformer2DModel`].
57
+
58
+ Args:
59
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
60
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
61
+ distributions for the unnoised latent pixels.
62
+ """
63
+
64
+ sample: torch.FloatTensor
65
+
66
+
67
+ class TransformerMV2DModel(ModelMixin, ConfigMixin):
68
+ """
69
+ A 2D Transformer model for image-like data.
70
+
71
+ Parameters:
72
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
73
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
74
+ in_channels (`int`, *optional*):
75
+ The number of channels in the input and output (specify if the input is **continuous**).
76
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
77
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
78
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
79
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
80
+ This is fixed during training since it is used to learn a number of position embeddings.
81
+ num_vector_embeds (`int`, *optional*):
82
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
83
+ Includes the class for the masked latent pixel.
84
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
85
+ num_embeds_ada_norm ( `int`, *optional*):
86
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
87
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
88
+ added to the hidden states.
89
+
90
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
91
+ attention_bias (`bool`, *optional*):
92
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
93
+ """
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ num_attention_heads: int = 16,
99
+ attention_head_dim: int = 88,
100
+ in_channels: Optional[int] = None,
101
+ out_channels: Optional[int] = None,
102
+ num_layers: int = 1,
103
+ dropout: float = 0.0,
104
+ norm_num_groups: int = 32,
105
+ cross_attention_dim: Optional[int] = None,
106
+ attention_bias: bool = False,
107
+ sample_size: Optional[int] = None,
108
+ num_vector_embeds: Optional[int] = None,
109
+ patch_size: Optional[int] = None,
110
+ activation_fn: str = "geglu",
111
+ num_embeds_ada_norm: Optional[int] = None,
112
+ use_linear_projection: bool = False,
113
+ only_cross_attention: bool = False,
114
+ upcast_attention: bool = False,
115
+ norm_type: str = "layer_norm",
116
+ norm_elementwise_affine: bool = True,
117
+ num_views: int = 1,
118
+ cd_attention_last: bool=False,
119
+ cd_attention_mid: bool=False,
120
+ multiview_attention: bool=True,
121
+ sparse_mv_attention: bool = False,
122
+ mvcd_attention: bool=False
123
+ ):
124
+ super().__init__()
125
+ self.use_linear_projection = use_linear_projection
126
+ self.num_attention_heads = num_attention_heads
127
+ self.attention_head_dim = attention_head_dim
128
+ inner_dim = num_attention_heads * attention_head_dim
129
+
130
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
131
+ # Define whether input is continuous or discrete depending on configuration
132
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
133
+ self.is_input_vectorized = num_vector_embeds is not None
134
+ self.is_input_patches = in_channels is not None and patch_size is not None
135
+
136
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
137
+ deprecation_message = (
138
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
139
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
140
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
141
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
142
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
143
+ )
144
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
145
+ norm_type = "ada_norm"
146
+
147
+ if self.is_input_continuous and self.is_input_vectorized:
148
+ raise ValueError(
149
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
150
+ " sure that either `in_channels` or `num_vector_embeds` is None."
151
+ )
152
+ elif self.is_input_vectorized and self.is_input_patches:
153
+ raise ValueError(
154
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
155
+ " sure that either `num_vector_embeds` or `num_patches` is None."
156
+ )
157
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
158
+ raise ValueError(
159
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
160
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
161
+ )
162
+
163
+ # 2. Define input layers
164
+ if self.is_input_continuous:
165
+ self.in_channels = in_channels
166
+
167
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
168
+ if use_linear_projection:
169
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
170
+ else:
171
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
172
+ elif self.is_input_vectorized:
173
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
174
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
175
+
176
+ self.height = sample_size
177
+ self.width = sample_size
178
+ self.num_vector_embeds = num_vector_embeds
179
+ self.num_latent_pixels = self.height * self.width
180
+
181
+ self.latent_image_embedding = ImagePositionalEmbeddings(
182
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
183
+ )
184
+ elif self.is_input_patches:
185
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
186
+
187
+ self.height = sample_size
188
+ self.width = sample_size
189
+
190
+ self.patch_size = patch_size
191
+ self.pos_embed = PatchEmbed(
192
+ height=sample_size,
193
+ width=sample_size,
194
+ patch_size=patch_size,
195
+ in_channels=in_channels,
196
+ embed_dim=inner_dim,
197
+ )
198
+
199
+ # 3. Define transformers blocks
200
+ self.transformer_blocks = nn.ModuleList(
201
+ [
202
+ BasicMVTransformerBlock(
203
+ inner_dim,
204
+ num_attention_heads,
205
+ attention_head_dim,
206
+ dropout=dropout,
207
+ cross_attention_dim=cross_attention_dim,
208
+ activation_fn=activation_fn,
209
+ num_embeds_ada_norm=num_embeds_ada_norm,
210
+ attention_bias=attention_bias,
211
+ only_cross_attention=only_cross_attention,
212
+ upcast_attention=upcast_attention,
213
+ norm_type=norm_type,
214
+ norm_elementwise_affine=norm_elementwise_affine,
215
+ num_views=num_views,
216
+ cd_attention_last=cd_attention_last,
217
+ cd_attention_mid=cd_attention_mid,
218
+ multiview_attention=multiview_attention,
219
+ sparse_mv_attention=sparse_mv_attention,
220
+ mvcd_attention=mvcd_attention
221
+ )
222
+ for d in range(num_layers)
223
+ ]
224
+ )
225
+
226
+ # 4. Define output layers
227
+ self.out_channels = in_channels if out_channels is None else out_channels
228
+ if self.is_input_continuous:
229
+ # TODO: should use out_channels for continuous projections
230
+ if use_linear_projection:
231
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
232
+ else:
233
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
234
+ elif self.is_input_vectorized:
235
+ self.norm_out = nn.LayerNorm(inner_dim)
236
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
237
+ elif self.is_input_patches:
238
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
239
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
240
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
241
+
242
+ def forward(
243
+ self,
244
+ hidden_states: torch.Tensor,
245
+ encoder_hidden_states: Optional[torch.Tensor] = None,
246
+ timestep: Optional[torch.LongTensor] = None,
247
+ class_labels: Optional[torch.LongTensor] = None,
248
+ cross_attention_kwargs: Dict[str, Any] = None,
249
+ attention_mask: Optional[torch.Tensor] = None,
250
+ encoder_attention_mask: Optional[torch.Tensor] = None,
251
+ return_dict: bool = True,
252
+ ):
253
+ """
254
+ The [`Transformer2DModel`] forward method.
255
+
256
+ Args:
257
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
258
+ Input `hidden_states`.
259
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
260
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
261
+ self-attention.
262
+ timestep ( `torch.LongTensor`, *optional*):
263
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
264
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
265
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
266
+ `AdaLayerZeroNorm`.
267
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
268
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
269
+
270
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
271
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
272
+
273
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
274
+ above. This bias will be added to the cross-attention scores.
275
+ return_dict (`bool`, *optional*, defaults to `True`):
276
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
277
+ tuple.
278
+
279
+ Returns:
280
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
281
+ `tuple` where the first element is the sample tensor.
282
+ """
283
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
284
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
285
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
286
+ # expects mask of shape:
287
+ # [batch, key_tokens]
288
+ # adds singleton query_tokens dimension:
289
+ # [batch, 1, key_tokens]
290
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
291
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
292
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
293
+ if attention_mask is not None and attention_mask.ndim == 2:
294
+ # assume that mask is expressed as:
295
+ # (1 = keep, 0 = discard)
296
+ # convert mask into a bias that can be added to attention scores:
297
+ # (keep = +0, discard = -10000.0)
298
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
299
+ attention_mask = attention_mask.unsqueeze(1)
300
+
301
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
302
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
303
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
304
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
305
+
306
+ # 1. Input
307
+ if self.is_input_continuous:
308
+ batch, _, height, width = hidden_states.shape
309
+ residual = hidden_states
310
+
311
+ hidden_states = self.norm(hidden_states)
312
+ if not self.use_linear_projection:
313
+ hidden_states = self.proj_in(hidden_states)
314
+ inner_dim = hidden_states.shape[1]
315
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
316
+ else:
317
+ inner_dim = hidden_states.shape[1]
318
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
319
+ hidden_states = self.proj_in(hidden_states)
320
+ elif self.is_input_vectorized:
321
+ hidden_states = self.latent_image_embedding(hidden_states)
322
+ elif self.is_input_patches:
323
+ hidden_states = self.pos_embed(hidden_states)
324
+
325
+ # 2. Blocks
326
+ for block in self.transformer_blocks:
327
+ hidden_states = block(
328
+ hidden_states,
329
+ attention_mask=attention_mask,
330
+ encoder_hidden_states=encoder_hidden_states,
331
+ encoder_attention_mask=encoder_attention_mask,
332
+ timestep=timestep,
333
+ cross_attention_kwargs=cross_attention_kwargs,
334
+ class_labels=class_labels,
335
+ )
336
+
337
+ # 3. Output
338
+ if self.is_input_continuous:
339
+ if not self.use_linear_projection:
340
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
341
+ hidden_states = self.proj_out(hidden_states)
342
+ else:
343
+ hidden_states = self.proj_out(hidden_states)
344
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
345
+
346
+ output = hidden_states + residual
347
+ elif self.is_input_vectorized:
348
+ hidden_states = self.norm_out(hidden_states)
349
+ logits = self.out(hidden_states)
350
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
351
+ logits = logits.permute(0, 2, 1)
352
+
353
+ # log(p(x_0))
354
+ output = F.log_softmax(logits.double(), dim=1).float()
355
+ elif self.is_input_patches:
356
+ # TODO: cleanup!
357
+ conditioning = self.transformer_blocks[0].norm1.emb(
358
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
359
+ )
360
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
361
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
362
+ hidden_states = self.proj_out_2(hidden_states)
363
+
364
+ # unpatchify
365
+ height = width = int(hidden_states.shape[1] ** 0.5)
366
+ hidden_states = hidden_states.reshape(
367
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
368
+ )
369
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
370
+ output = hidden_states.reshape(
371
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
372
+ )
373
+
374
+ if not return_dict:
375
+ return (output,)
376
+
377
+ return TransformerMV2DModelOutput(sample=output)
378
+
379
+
380
+ @maybe_allow_in_graph
381
+ class BasicMVTransformerBlock(nn.Module):
382
+ r"""
383
+ A basic Transformer block.
384
+
385
+ Parameters:
386
+ dim (`int`): The number of channels in the input and output.
387
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
388
+ attention_head_dim (`int`): The number of channels in each head.
389
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
390
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
391
+ only_cross_attention (`bool`, *optional*):
392
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
393
+ double_self_attention (`bool`, *optional*):
394
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
395
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
396
+ num_embeds_ada_norm (:
397
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
398
+ attention_bias (:
399
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ dim: int,
405
+ num_attention_heads: int,
406
+ attention_head_dim: int,
407
+ dropout=0.0,
408
+ cross_attention_dim: Optional[int] = None,
409
+ activation_fn: str = "geglu",
410
+ num_embeds_ada_norm: Optional[int] = None,
411
+ attention_bias: bool = False,
412
+ only_cross_attention: bool = False,
413
+ double_self_attention: bool = False,
414
+ upcast_attention: bool = False,
415
+ norm_elementwise_affine: bool = True,
416
+ norm_type: str = "layer_norm",
417
+ final_dropout: bool = False,
418
+ num_views: int = 1,
419
+ cd_attention_last: bool = False,
420
+ cd_attention_mid: bool = False,
421
+ multiview_attention: bool = True,
422
+ sparse_mv_attention: bool = False,
423
+ mvcd_attention: bool = False
424
+ ):
425
+ super().__init__()
426
+ self.only_cross_attention = only_cross_attention
427
+
428
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
429
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
430
+
431
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
432
+ raise ValueError(
433
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
434
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
435
+ )
436
+
437
+ # Define 3 blocks. Each block has its own normalization layer.
438
+ # 1. Self-Attn
439
+ if self.use_ada_layer_norm:
440
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
441
+ elif self.use_ada_layer_norm_zero:
442
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
443
+ else:
444
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
445
+
446
+ self.multiview_attention = multiview_attention
447
+ self.sparse_mv_attention = sparse_mv_attention
448
+ self.mvcd_attention = mvcd_attention
449
+
450
+ self.attn1 = CustomAttention(
451
+ query_dim=dim,
452
+ heads=num_attention_heads,
453
+ dim_head=attention_head_dim,
454
+ dropout=dropout,
455
+ bias=attention_bias,
456
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
457
+ upcast_attention=upcast_attention,
458
+ processor=MVAttnProcessor()
459
+ )
460
+
461
+ # 2. Cross-Attn
462
+ if cross_attention_dim is not None or double_self_attention:
463
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
464
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
465
+ # the second cross attention block.
466
+ self.norm2 = (
467
+ AdaLayerNorm(dim, num_embeds_ada_norm)
468
+ if self.use_ada_layer_norm
469
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
470
+ )
471
+ self.attn2 = Attention(
472
+ query_dim=dim,
473
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
474
+ heads=num_attention_heads,
475
+ dim_head=attention_head_dim,
476
+ dropout=dropout,
477
+ bias=attention_bias,
478
+ upcast_attention=upcast_attention,
479
+ # processor=CrossAttnProcessor()
480
+ ) # is self-attn if encoder_hidden_states is none
481
+ else:
482
+ self.norm2 = None
483
+ self.attn2 = None
484
+
485
+ # 3. Feed-forward
486
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
487
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
488
+
489
+ # let chunk size default to None
490
+ self._chunk_size = None
491
+ self._chunk_dim = 0
492
+
493
+ self.num_views = num_views
494
+
495
+ self.cd_attention_last = cd_attention_last
496
+
497
+ if self.cd_attention_last:
498
+ # Joint task -Attn
499
+ self.attn_joint_last = Attention(
500
+ query_dim=dim,
501
+ heads=num_attention_heads,
502
+ dim_head=attention_head_dim,
503
+ dropout=dropout,
504
+ bias=attention_bias,
505
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
506
+ upcast_attention=upcast_attention,
507
+ processor=JointAttnProcessor()
508
+ )
509
+ nn.init.zeros_(self.attn_joint_last.to_out[0].weight.data)
510
+ self.norm_joint_last = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
511
+
512
+
513
+ self.cd_attention_mid = cd_attention_mid
514
+
515
+ if self.cd_attention_mid:
516
+ # print("cross-domain attn in the middle")
517
+ # Joint task -Attn
518
+ self.attn_joint_mid = Attention(
519
+ query_dim=dim,
520
+ heads=num_attention_heads,
521
+ dim_head=attention_head_dim,
522
+ dropout=dropout,
523
+ bias=attention_bias,
524
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
525
+ upcast_attention=upcast_attention,
526
+ processor=JointAttnProcessor()
527
+ )
528
+ nn.init.zeros_(self.attn_joint_mid.to_out[0].weight.data)
529
+ self.norm_joint_mid = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
530
+
531
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
532
+ # Sets chunk feed-forward
533
+ self._chunk_size = chunk_size
534
+ self._chunk_dim = dim
535
+
536
+ def forward(
537
+ self,
538
+ hidden_states: torch.FloatTensor,
539
+ attention_mask: Optional[torch.FloatTensor] = None,
540
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
541
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
542
+ timestep: Optional[torch.LongTensor] = None,
543
+ cross_attention_kwargs: Dict[str, Any] = None,
544
+ class_labels: Optional[torch.LongTensor] = None,
545
+ ):
546
+ """
547
+
548
+ :type attention_mask: object
549
+ """
550
+ assert attention_mask is None # not supported yet
551
+ # Notice that normalization is always applied before the real computation in the following blocks.
552
+ # 1. Self-Attention
553
+ if self.use_ada_layer_norm:
554
+ norm_hidden_states = self.norm1(hidden_states, timestep)
555
+ elif self.use_ada_layer_norm_zero:
556
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
557
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
558
+ )
559
+ else:
560
+ norm_hidden_states = self.norm1(hidden_states)
561
+
562
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
563
+
564
+ attn_output = self.attn1(norm_hidden_states,
565
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
566
+ attention_mask=attention_mask,
567
+ num_views=self.num_views,
568
+ multiview_attention=self.multiview_attention,
569
+ sparse_mv_attention=self.sparse_mv_attention,
570
+ mvcd_attention=self.mvcd_attention,
571
+ **cross_attention_kwargs,
572
+ )
573
+
574
+
575
+ if self.use_ada_layer_norm_zero:
576
+ attn_output = gate_msa.unsqueeze(1) * attn_output
577
+ hidden_states = attn_output + hidden_states
578
+
579
+ # joint attention twice
580
+ if self.cd_attention_mid:
581
+ norm_hidden_states = (
582
+ self.norm_joint_mid(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_mid(hidden_states)
583
+ )
584
+ hidden_states = self.attn_joint_mid(norm_hidden_states) + hidden_states
585
+
586
+ # 2. Cross-Attention
587
+ if self.attn2 is not None:
588
+ norm_hidden_states = (
589
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
590
+ )
591
+
592
+ attn_output = self.attn2(
593
+ norm_hidden_states,
594
+ encoder_hidden_states=encoder_hidden_states,
595
+ attention_mask=encoder_attention_mask,
596
+ **cross_attention_kwargs,
597
+ )
598
+ hidden_states = attn_output + hidden_states
599
+
600
+ # 3. Feed-forward
601
+ norm_hidden_states = self.norm3(hidden_states)
602
+
603
+ if self.use_ada_layer_norm_zero:
604
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
605
+
606
+ if self._chunk_size is not None:
607
+ # "feed_forward_chunk_size" can be used to save memory
608
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
609
+ raise ValueError(
610
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
611
+ )
612
+
613
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
614
+ ff_output = torch.cat(
615
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
616
+ dim=self._chunk_dim,
617
+ )
618
+ else:
619
+ ff_output = self.ff(norm_hidden_states)
620
+
621
+ if self.use_ada_layer_norm_zero:
622
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
623
+
624
+ hidden_states = ff_output + hidden_states
625
+
626
+ if self.cd_attention_last:
627
+ norm_hidden_states = (
628
+ self.norm_joint_last(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_joint_last(hidden_states)
629
+ )
630
+ hidden_states = self.attn_joint_last(norm_hidden_states) + hidden_states
631
+
632
+ return hidden_states
633
+
634
+
635
+ class CustomAttention(Attention):
636
+ def set_use_memory_efficient_attention_xformers(
637
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
638
+ ):
639
+ processor = XFormersMVAttnProcessor()
640
+ self.set_processor(processor)
641
+ # print("using xformers attention processor")
642
+
643
+
644
+ class CustomJointAttention(Attention):
645
+ def set_use_memory_efficient_attention_xformers(
646
+ self, use_memory_efficient_attention_xformers: bool, *args, **kwargs
647
+ ):
648
+ processor = XFormersJointAttnProcessor()
649
+ self.set_processor(processor)
650
+ # print("using xformers attention processor")
651
+
652
+ class MVAttnProcessor:
653
+ r"""
654
+ Default processor for performing attention-related computations.
655
+ """
656
+
657
+ def __call__(
658
+ self,
659
+ attn: Attention,
660
+ hidden_states,
661
+ encoder_hidden_states=None,
662
+ attention_mask=None,
663
+ temb=None,
664
+ num_views=1,
665
+ multiview_attention=True,
666
+ sparse_mv_attention=False,
667
+ mvcd_attention=False,
668
+ ):
669
+ residual = hidden_states
670
+
671
+ if attn.spatial_norm is not None:
672
+ hidden_states = attn.spatial_norm(hidden_states, temb)
673
+
674
+ input_ndim = hidden_states.ndim
675
+
676
+ if input_ndim == 4:
677
+ batch_size, channel, height, width = hidden_states.shape
678
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
679
+
680
+ batch_size, sequence_length, input_dim = (
681
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
682
+ )
683
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
684
+
685
+ if attn.group_norm is not None:
686
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
687
+
688
+ query = attn.to_q(hidden_states)
689
+
690
+ if encoder_hidden_states is None:
691
+ encoder_hidden_states = hidden_states
692
+ elif attn.norm_cross:
693
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
694
+
695
+ key = attn.to_k(encoder_hidden_states)
696
+ value = attn.to_v(encoder_hidden_states)
697
+
698
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
699
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
700
+ # pdb.set_trace()
701
+ # multi-view self-attention
702
+ if multiview_attention:
703
+ key = rearrange(key, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
704
+ value = rearrange(value, "(b t) d c -> b (t d) c", t=num_views).repeat_interleave(num_views, dim=0)
705
+
706
+ # batch, n_heads, n_tokens, channel
707
+ query = attn.head_to_batch_dim(query, out_dim=4).contiguous()
708
+ key = attn.head_to_batch_dim(key, out_dim=4).contiguous()
709
+ value = attn.head_to_batch_dim(value, out_dim=4).contiguous()
710
+
711
+ with torch.backends.cuda.sdp_kernel(
712
+ enable_flash=True,
713
+ enable_math=False,
714
+ enable_mem_efficient=True
715
+ ):
716
+ hidden_states = F.scaled_dot_product_attention(query, key, value)
717
+
718
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, input_dim)
719
+
720
+ # linear proj
721
+ hidden_states = attn.to_out[0](hidden_states)
722
+ # dropout
723
+ hidden_states = attn.to_out[1](hidden_states)
724
+
725
+ if input_ndim == 4:
726
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
727
+
728
+ if attn.residual_connection:
729
+ hidden_states = hidden_states + residual
730
+
731
+ hidden_states = hidden_states / attn.rescale_output_factor
732
+
733
+ return hidden_states
734
+
735
+
736
+ class XFormersMVAttnProcessor:
737
+ r"""
738
+ Default processor for performing attention-related computations.
739
+ """
740
+
741
+ def __call__(
742
+ self,
743
+ attn: Attention,
744
+ hidden_states,
745
+ encoder_hidden_states=None,
746
+ attention_mask=None,
747
+ temb=None,
748
+ num_views=1.,
749
+ multiview_attention=True,
750
+ sparse_mv_attention=False,
751
+ mvcd_attention=False,
752
+ ):
753
+ residual = hidden_states
754
+
755
+ if attn.spatial_norm is not None:
756
+ hidden_states = attn.spatial_norm(hidden_states, temb)
757
+
758
+ input_ndim = hidden_states.ndim
759
+
760
+ if input_ndim == 4:
761
+ batch_size, channel, height, width = hidden_states.shape
762
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
763
+
764
+ batch_size, sequence_length, _ = (
765
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
766
+ )
767
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
768
+
769
+ # from yuancheng; here attention_mask is None
770
+ if attention_mask is not None:
771
+ # expand our mask's singleton query_tokens dimension:
772
+ # [batch*heads, 1, key_tokens] ->
773
+ # [batch*heads, query_tokens, key_tokens]
774
+ # so that it can be added as a bias onto the attention scores that xformers computes:
775
+ # [batch*heads, query_tokens, key_tokens]
776
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
777
+ _, query_tokens, _ = hidden_states.shape
778
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
779
+
780
+ if attn.group_norm is not None:
781
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
782
+
783
+ query = attn.to_q(hidden_states)
784
+
785
+ if encoder_hidden_states is None:
786
+ encoder_hidden_states = hidden_states
787
+ elif attn.norm_cross:
788
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
789
+
790
+ key_raw = attn.to_k(encoder_hidden_states)
791
+ value_raw = attn.to_v(encoder_hidden_states)
792
+
793
+ # print('query', query.shape, 'key', key.shape, 'value', value.shape)
794
+ #([bx4, 1024, 320]) key torch.Size([bx4, 1024, 320]) value torch.Size([bx4, 1024, 320])
795
+ # pdb.set_trace()
796
+ # multi-view self-attention
797
+ if multiview_attention:
798
+ if not sparse_mv_attention:
799
+ key = my_repeat(rearrange(key_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
800
+ value = my_repeat(rearrange(value_raw, "(b t) d c -> b (t d) c", t=num_views), num_views)
801
+ else:
802
+ key_front = my_repeat(rearrange(key_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views) # [(b t), d, c]
803
+ value_front = my_repeat(rearrange(value_raw, "(b t) d c -> b t d c", t=num_views)[:, 0, :, :], num_views)
804
+ key = torch.cat([key_front, key_raw], dim=1) # shape (b t) (2 d) c
805
+ value = torch.cat([value_front, value_raw], dim=1)
806
+
807
+ else:
808
+ # print("don't use multiview attention.")
809
+ key = key_raw
810
+ value = value_raw
811
+
812
+ query = attn.head_to_batch_dim(query)
813
+ key = attn.head_to_batch_dim(key)
814
+ value = attn.head_to_batch_dim(value)
815
+
816
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
817
+ # for flash attention implementation
818
+ # with torch.backends.cuda.sdp_kernel(enable_math=False):
819
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_bias=attention_mask)
820
+ # hidden_states = attn.batch_to_head_dim(hidden_states)
821
+
822
+ # linear proj
823
+ hidden_states = attn.to_out[0](hidden_states)
824
+ # dropout
825
+ hidden_states = attn.to_out[1](hidden_states)
826
+
827
+ if input_ndim == 4:
828
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
829
+
830
+ if attn.residual_connection:
831
+ hidden_states = hidden_states + residual
832
+
833
+ hidden_states = hidden_states / attn.rescale_output_factor
834
+
835
+ return hidden_states
836
+
837
+
838
+
839
+ class XFormersJointAttnProcessor:
840
+ r"""
841
+ Default processor for performing attention-related computations.
842
+ """
843
+
844
+ def __call__(
845
+ self,
846
+ attn: Attention,
847
+ hidden_states,
848
+ encoder_hidden_states=None,
849
+ attention_mask=None,
850
+ temb=None,
851
+ num_tasks=2
852
+ ):
853
+
854
+ residual = hidden_states
855
+
856
+ if attn.spatial_norm is not None:
857
+ hidden_states = attn.spatial_norm(hidden_states, temb)
858
+
859
+ input_ndim = hidden_states.ndim
860
+
861
+ if input_ndim == 4:
862
+ batch_size, channel, height, width = hidden_states.shape
863
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
864
+
865
+ batch_size, sequence_length, _ = (
866
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
867
+ )
868
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
869
+
870
+ # from yuancheng; here attention_mask is None
871
+ if attention_mask is not None:
872
+ # expand our mask's singleton query_tokens dimension:
873
+ # [batch*heads, 1, key_tokens] ->
874
+ # [batch*heads, query_tokens, key_tokens]
875
+ # so that it can be added as a bias onto the attention scores that xformers computes:
876
+ # [batch*heads, query_tokens, key_tokens]
877
+ # we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
878
+ _, query_tokens, _ = hidden_states.shape
879
+ attention_mask = attention_mask.expand(-1, query_tokens, -1)
880
+
881
+ if attn.group_norm is not None:
882
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
883
+
884
+ query = attn.to_q(hidden_states)
885
+
886
+ if encoder_hidden_states is None:
887
+ encoder_hidden_states = hidden_states
888
+ elif attn.norm_cross:
889
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
890
+
891
+ key = attn.to_k(encoder_hidden_states)
892
+ value = attn.to_v(encoder_hidden_states)
893
+
894
+ assert num_tasks == 2 # only support two tasks now
895
+
896
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
897
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
898
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
899
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
900
+ key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
901
+ value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
902
+
903
+
904
+ query = attn.head_to_batch_dim(query).contiguous()
905
+ key = attn.head_to_batch_dim(key).contiguous()
906
+ value = attn.head_to_batch_dim(value).contiguous()
907
+
908
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
909
+ # for flash attention implementation
910
+ # with torch.backends.cuda.sdp_kernel(enable_math=False):
911
+ # hidden_states = F.scaled_dot_product_attention(query, key, value, attn_bias=attention_mask)
912
+ # hidden_states = attn.batch_to_head_dim(hidden_states)
913
+
914
+ # linear proj
915
+ hidden_states = attn.to_out[0](hidden_states)
916
+ # dropout
917
+ hidden_states = attn.to_out[1](hidden_states)
918
+
919
+ if input_ndim == 4:
920
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
921
+
922
+ if attn.residual_connection:
923
+ hidden_states = hidden_states + residual
924
+
925
+ hidden_states = hidden_states / attn.rescale_output_factor
926
+
927
+ return hidden_states
928
+
929
+
930
+ # class JointAttnProcessor:
931
+ # r"""
932
+ # Default processor for performing attention-related computations.
933
+ # """
934
+ #
935
+ # def __call__(
936
+ # self,
937
+ # attn: Attention,
938
+ # hidden_states,
939
+ # encoder_hidden_states=None,
940
+ # attention_mask=None,
941
+ # temb=None,
942
+ # num_tasks=2
943
+ # ):
944
+ #
945
+ # residual = hidden_states
946
+ #
947
+ # if attn.spatial_norm is not None:
948
+ # hidden_states = attn.spatial_norm(hidden_states, temb)
949
+ #
950
+ # input_ndim = hidden_states.ndim
951
+ #
952
+ # if input_ndim == 4:
953
+ # batch_size, channel, height, width = hidden_states.shape
954
+ # hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
955
+ #
956
+ # batch_size, sequence_length, input_dim = (
957
+ # hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
958
+ # )
959
+ # attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
960
+ #
961
+ #
962
+ # if attn.group_norm is not None:
963
+ # hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
964
+ #
965
+ # query = attn.to_q(hidden_states)
966
+ #
967
+ # if encoder_hidden_states is None:
968
+ # encoder_hidden_states = hidden_states
969
+ # elif attn.norm_cross:
970
+ # encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
971
+ #
972
+ # key = attn.to_k(encoder_hidden_states)
973
+ # value = attn.to_v(encoder_hidden_states)
974
+ #
975
+ # assert num_tasks == 2 # only support two tasks now
976
+ #
977
+ # key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
978
+ # value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
979
+ # key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
980
+ # value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
981
+ # key = torch.cat([key]*2, dim=0) # ( 2 b t) 2d c
982
+ # value = torch.cat([value]*2, dim=0) # (2 b t) 2d c
983
+ #
984
+ #
985
+ # # batch, n_heads, n_tokens, channel
986
+ # query = attn.head_to_batch_dim(query, out_dim=4).contiguous()
987
+ # key = attn.head_to_batch_dim(key, out_dim=4).contiguous()
988
+ # value = attn.head_to_batch_dim(value, out_dim=4).contiguous()
989
+ #
990
+ # # attention_probs = attn.get_attention_scores(query, key, attention_mask)
991
+ # # hidden_states = torch.bmm(attention_probs, value)
992
+ # # hidden_states = attn.batch_to_head_dim(hidden_states)
993
+ #
994
+ # # for flash attention implementation
995
+ # with torch.backends.cuda.sdp_kernel(
996
+ # enable_flash=True,
997
+ # enable_math=False,
998
+ # enable_mem_efficient=True
999
+ # ):
1000
+ # hidden_states = F.scaled_dot_product_attention(query, key, value)
1001
+ #
1002
+ # hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, sequence_length, input_dim)
1003
+ #
1004
+ # # linear proj
1005
+ # hidden_states = attn.to_out[0](hidden_states)
1006
+ # # dropout
1007
+ # hidden_states = attn.to_out[1](hidden_states)
1008
+ #
1009
+ # if input_ndim == 4:
1010
+ # hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1011
+ #
1012
+ # if attn.residual_connection:
1013
+ # hidden_states = hidden_states + residual
1014
+ #
1015
+ # hidden_states = hidden_states / attn.rescale_output_factor
1016
+ #
1017
+ # return hidden_states
1018
+
1019
+ class JointAttnProcessor:
1020
+ r"""
1021
+ Default processor for performing attention-related computations.
1022
+ """
1023
+
1024
+ def __call__(
1025
+ self,
1026
+ attn: Attention,
1027
+ hidden_states,
1028
+ encoder_hidden_states=None,
1029
+ attention_mask=None,
1030
+ temb=None,
1031
+ num_tasks=2
1032
+ ):
1033
+
1034
+ residual = hidden_states
1035
+
1036
+ if attn.spatial_norm is not None:
1037
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1038
+
1039
+ input_ndim = hidden_states.ndim
1040
+
1041
+ if input_ndim == 4:
1042
+ batch_size, channel, height, width = hidden_states.shape
1043
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1044
+
1045
+ batch_size, sequence_length, _ = (
1046
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1047
+ )
1048
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1049
+
1050
+ if attn.group_norm is not None:
1051
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1052
+
1053
+ query = attn.to_q(hidden_states)
1054
+
1055
+ if encoder_hidden_states is None:
1056
+ encoder_hidden_states = hidden_states
1057
+ elif attn.norm_cross:
1058
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1059
+
1060
+ key = attn.to_k(encoder_hidden_states)
1061
+ value = attn.to_v(encoder_hidden_states)
1062
+
1063
+ assert num_tasks == 2 # only support two tasks now
1064
+
1065
+ key_0, key_1 = torch.chunk(key, dim=0, chunks=2) # keys shape (b t) d c
1066
+ value_0, value_1 = torch.chunk(value, dim=0, chunks=2)
1067
+ key = torch.cat([key_0, key_1], dim=1) # (b t) 2d c
1068
+ value = torch.cat([value_0, value_1], dim=1) # (b t) 2d c
1069
+ key = torch.cat([key] * 2, dim=0) # ( 2 b t) 2d c
1070
+ value = torch.cat([value] * 2, dim=0) # (2 b t) 2d c
1071
+
1072
+ query = attn.head_to_batch_dim(query).contiguous()
1073
+ key = attn.head_to_batch_dim(key).contiguous()
1074
+ value = attn.head_to_batch_dim(value).contiguous()
1075
+
1076
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1077
+ hidden_states = torch.bmm(attention_probs, value)
1078
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1079
+
1080
+ # linear proj
1081
+ hidden_states = attn.to_out[0](hidden_states)
1082
+ # dropout
1083
+ hidden_states = attn.to_out[1](hidden_states)
1084
+
1085
+ if input_ndim == 4:
1086
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1087
+
1088
+ if attn.residual_connection:
1089
+ hidden_states = hidden_states + residual
1090
+
1091
+ hidden_states = hidden_states / attn.rescale_output_factor
1092
+
1093
+ return hidden_states
mv_diffusion_30/models/unet_mv2d_blocks.py ADDED
@@ -0,0 +1,922 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+
21
+ from diffusers.utils import is_torch_version, logging
22
+ # from diffusers.models.normalization import AdaGroupNorm
23
+ # from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
24
+ # from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
25
+ from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
26
+ from mv_diffusion_30.models.transformer_mv2d import TransformerMV2DModel
27
+
28
+ from diffusers.models.unets.unet_2d_blocks import DownBlock2D, ResnetDownsampleBlock2D, AttnDownBlock2D, CrossAttnDownBlock2D, SimpleCrossAttnDownBlock2D, SkipDownBlock2D, AttnSkipDownBlock2D, DownEncoderBlock2D, AttnDownEncoderBlock2D, KDownBlock2D, KCrossAttnDownBlock2D
29
+ from diffusers.models.unets.unet_2d_blocks import UpBlock2D, ResnetUpsampleBlock2D, CrossAttnUpBlock2D, SimpleCrossAttnUpBlock2D, AttnUpBlock2D, SkipUpBlock2D, AttnSkipUpBlock2D, UpDecoderBlock2D, AttnUpDecoderBlock2D, KUpBlock2D, KCrossAttnUpBlock2D
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ def get_down_block(
36
+ down_block_type,
37
+ num_layers,
38
+ in_channels,
39
+ out_channels,
40
+ temb_channels,
41
+ add_downsample,
42
+ resnet_eps,
43
+ resnet_act_fn,
44
+ transformer_layers_per_block=1,
45
+ num_attention_heads=None,
46
+ resnet_groups=None,
47
+ cross_attention_dim=None,
48
+ downsample_padding=None,
49
+ dual_cross_attention=False,
50
+ use_linear_projection=False,
51
+ only_cross_attention=False,
52
+ upcast_attention=False,
53
+ resnet_time_scale_shift="default",
54
+ resnet_skip_time_act=False,
55
+ resnet_out_scale_factor=1.0,
56
+ cross_attention_norm=None,
57
+ attention_head_dim=None,
58
+ downsample_type=None,
59
+ num_views=1,
60
+ cd_attention_last: bool = False,
61
+ cd_attention_mid: bool = False,
62
+ multiview_attention: bool = True,
63
+ sparse_mv_attention: bool = False,
64
+ mvcd_attention: bool=False
65
+ ):
66
+ # If attn head dim is not defined, we default it to the number of heads
67
+ if attention_head_dim is None:
68
+ logger.warn(
69
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
70
+ )
71
+ attention_head_dim = num_attention_heads
72
+
73
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
74
+ if down_block_type == "DownBlock2D":
75
+ return DownBlock2D(
76
+ num_layers=num_layers,
77
+ in_channels=in_channels,
78
+ out_channels=out_channels,
79
+ temb_channels=temb_channels,
80
+ add_downsample=add_downsample,
81
+ resnet_eps=resnet_eps,
82
+ resnet_act_fn=resnet_act_fn,
83
+ resnet_groups=resnet_groups,
84
+ downsample_padding=downsample_padding,
85
+ resnet_time_scale_shift=resnet_time_scale_shift,
86
+ )
87
+ elif down_block_type == "ResnetDownsampleBlock2D":
88
+ return ResnetDownsampleBlock2D(
89
+ num_layers=num_layers,
90
+ in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ temb_channels=temb_channels,
93
+ add_downsample=add_downsample,
94
+ resnet_eps=resnet_eps,
95
+ resnet_act_fn=resnet_act_fn,
96
+ resnet_groups=resnet_groups,
97
+ resnet_time_scale_shift=resnet_time_scale_shift,
98
+ skip_time_act=resnet_skip_time_act,
99
+ output_scale_factor=resnet_out_scale_factor,
100
+ )
101
+ elif down_block_type == "AttnDownBlock2D":
102
+ if add_downsample is False:
103
+ downsample_type = None
104
+ else:
105
+ downsample_type = downsample_type or "conv" # default to 'conv'
106
+ return AttnDownBlock2D(
107
+ num_layers=num_layers,
108
+ in_channels=in_channels,
109
+ out_channels=out_channels,
110
+ temb_channels=temb_channels,
111
+ resnet_eps=resnet_eps,
112
+ resnet_act_fn=resnet_act_fn,
113
+ resnet_groups=resnet_groups,
114
+ downsample_padding=downsample_padding,
115
+ attention_head_dim=attention_head_dim,
116
+ resnet_time_scale_shift=resnet_time_scale_shift,
117
+ downsample_type=downsample_type,
118
+ )
119
+ elif down_block_type == "CrossAttnDownBlock2D":
120
+ if cross_attention_dim is None:
121
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
122
+ return CrossAttnDownBlock2D(
123
+ num_layers=num_layers,
124
+ transformer_layers_per_block=transformer_layers_per_block,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ temb_channels=temb_channels,
128
+ add_downsample=add_downsample,
129
+ resnet_eps=resnet_eps,
130
+ resnet_act_fn=resnet_act_fn,
131
+ resnet_groups=resnet_groups,
132
+ downsample_padding=downsample_padding,
133
+ cross_attention_dim=cross_attention_dim,
134
+ num_attention_heads=num_attention_heads,
135
+ dual_cross_attention=dual_cross_attention,
136
+ use_linear_projection=use_linear_projection,
137
+ only_cross_attention=only_cross_attention,
138
+ upcast_attention=upcast_attention,
139
+ resnet_time_scale_shift=resnet_time_scale_shift,
140
+ )
141
+ # custom MV2D attention block
142
+ elif down_block_type == "CrossAttnDownBlockMV2D":
143
+ if cross_attention_dim is None:
144
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMV2D")
145
+ return CrossAttnDownBlockMV2D(
146
+ num_layers=num_layers,
147
+ transformer_layers_per_block=transformer_layers_per_block,
148
+ in_channels=in_channels,
149
+ out_channels=out_channels,
150
+ temb_channels=temb_channels,
151
+ add_downsample=add_downsample,
152
+ resnet_eps=resnet_eps,
153
+ resnet_act_fn=resnet_act_fn,
154
+ resnet_groups=resnet_groups,
155
+ downsample_padding=downsample_padding,
156
+ cross_attention_dim=cross_attention_dim,
157
+ num_attention_heads=num_attention_heads,
158
+ dual_cross_attention=dual_cross_attention,
159
+ use_linear_projection=use_linear_projection,
160
+ only_cross_attention=only_cross_attention,
161
+ upcast_attention=upcast_attention,
162
+ resnet_time_scale_shift=resnet_time_scale_shift,
163
+ num_views=num_views,
164
+ cd_attention_last=cd_attention_last,
165
+ cd_attention_mid=cd_attention_mid,
166
+ multiview_attention=multiview_attention,
167
+ sparse_mv_attention=sparse_mv_attention,
168
+ mvcd_attention=mvcd_attention
169
+ )
170
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
171
+ if cross_attention_dim is None:
172
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
173
+ return SimpleCrossAttnDownBlock2D(
174
+ num_layers=num_layers,
175
+ in_channels=in_channels,
176
+ out_channels=out_channels,
177
+ temb_channels=temb_channels,
178
+ add_downsample=add_downsample,
179
+ resnet_eps=resnet_eps,
180
+ resnet_act_fn=resnet_act_fn,
181
+ resnet_groups=resnet_groups,
182
+ cross_attention_dim=cross_attention_dim,
183
+ attention_head_dim=attention_head_dim,
184
+ resnet_time_scale_shift=resnet_time_scale_shift,
185
+ skip_time_act=resnet_skip_time_act,
186
+ output_scale_factor=resnet_out_scale_factor,
187
+ only_cross_attention=only_cross_attention,
188
+ cross_attention_norm=cross_attention_norm,
189
+ )
190
+ elif down_block_type == "SkipDownBlock2D":
191
+ return SkipDownBlock2D(
192
+ num_layers=num_layers,
193
+ in_channels=in_channels,
194
+ out_channels=out_channels,
195
+ temb_channels=temb_channels,
196
+ add_downsample=add_downsample,
197
+ resnet_eps=resnet_eps,
198
+ resnet_act_fn=resnet_act_fn,
199
+ downsample_padding=downsample_padding,
200
+ resnet_time_scale_shift=resnet_time_scale_shift,
201
+ )
202
+ elif down_block_type == "AttnSkipDownBlock2D":
203
+ return AttnSkipDownBlock2D(
204
+ num_layers=num_layers,
205
+ in_channels=in_channels,
206
+ out_channels=out_channels,
207
+ temb_channels=temb_channels,
208
+ add_downsample=add_downsample,
209
+ resnet_eps=resnet_eps,
210
+ resnet_act_fn=resnet_act_fn,
211
+ attention_head_dim=attention_head_dim,
212
+ resnet_time_scale_shift=resnet_time_scale_shift,
213
+ )
214
+ elif down_block_type == "DownEncoderBlock2D":
215
+ return DownEncoderBlock2D(
216
+ num_layers=num_layers,
217
+ in_channels=in_channels,
218
+ out_channels=out_channels,
219
+ add_downsample=add_downsample,
220
+ resnet_eps=resnet_eps,
221
+ resnet_act_fn=resnet_act_fn,
222
+ resnet_groups=resnet_groups,
223
+ downsample_padding=downsample_padding,
224
+ resnet_time_scale_shift=resnet_time_scale_shift,
225
+ )
226
+ elif down_block_type == "AttnDownEncoderBlock2D":
227
+ return AttnDownEncoderBlock2D(
228
+ num_layers=num_layers,
229
+ in_channels=in_channels,
230
+ out_channels=out_channels,
231
+ add_downsample=add_downsample,
232
+ resnet_eps=resnet_eps,
233
+ resnet_act_fn=resnet_act_fn,
234
+ resnet_groups=resnet_groups,
235
+ downsample_padding=downsample_padding,
236
+ attention_head_dim=attention_head_dim,
237
+ resnet_time_scale_shift=resnet_time_scale_shift,
238
+ )
239
+ elif down_block_type == "KDownBlock2D":
240
+ return KDownBlock2D(
241
+ num_layers=num_layers,
242
+ in_channels=in_channels,
243
+ out_channels=out_channels,
244
+ temb_channels=temb_channels,
245
+ add_downsample=add_downsample,
246
+ resnet_eps=resnet_eps,
247
+ resnet_act_fn=resnet_act_fn,
248
+ )
249
+ elif down_block_type == "KCrossAttnDownBlock2D":
250
+ return KCrossAttnDownBlock2D(
251
+ num_layers=num_layers,
252
+ in_channels=in_channels,
253
+ out_channels=out_channels,
254
+ temb_channels=temb_channels,
255
+ add_downsample=add_downsample,
256
+ resnet_eps=resnet_eps,
257
+ resnet_act_fn=resnet_act_fn,
258
+ cross_attention_dim=cross_attention_dim,
259
+ attention_head_dim=attention_head_dim,
260
+ add_self_attention=True if not add_downsample else False,
261
+ )
262
+ raise ValueError(f"{down_block_type} does not exist.")
263
+
264
+
265
+ def get_up_block(
266
+ up_block_type,
267
+ num_layers,
268
+ in_channels,
269
+ out_channels,
270
+ prev_output_channel,
271
+ temb_channels,
272
+ add_upsample,
273
+ resnet_eps,
274
+ resnet_act_fn,
275
+ transformer_layers_per_block=1,
276
+ num_attention_heads=None,
277
+ resnet_groups=None,
278
+ cross_attention_dim=None,
279
+ dual_cross_attention=False,
280
+ use_linear_projection=False,
281
+ only_cross_attention=False,
282
+ upcast_attention=False,
283
+ resnet_time_scale_shift="default",
284
+ resnet_skip_time_act=False,
285
+ resnet_out_scale_factor=1.0,
286
+ cross_attention_norm=None,
287
+ attention_head_dim=None,
288
+ upsample_type=None,
289
+ num_views=1,
290
+ cd_attention_last: bool = False,
291
+ cd_attention_mid: bool = False,
292
+ multiview_attention: bool = True,
293
+ sparse_mv_attention: bool = False,
294
+ mvcd_attention: bool=False
295
+ ):
296
+ # If attn head dim is not defined, we default it to the number of heads
297
+ if attention_head_dim is None:
298
+ logger.warn(
299
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
300
+ )
301
+ attention_head_dim = num_attention_heads
302
+
303
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
304
+ if up_block_type == "UpBlock2D":
305
+ return UpBlock2D(
306
+ num_layers=num_layers,
307
+ in_channels=in_channels,
308
+ out_channels=out_channels,
309
+ prev_output_channel=prev_output_channel,
310
+ temb_channels=temb_channels,
311
+ add_upsample=add_upsample,
312
+ resnet_eps=resnet_eps,
313
+ resnet_act_fn=resnet_act_fn,
314
+ resnet_groups=resnet_groups,
315
+ resnet_time_scale_shift=resnet_time_scale_shift,
316
+ )
317
+ elif up_block_type == "ResnetUpsampleBlock2D":
318
+ return ResnetUpsampleBlock2D(
319
+ num_layers=num_layers,
320
+ in_channels=in_channels,
321
+ out_channels=out_channels,
322
+ prev_output_channel=prev_output_channel,
323
+ temb_channels=temb_channels,
324
+ add_upsample=add_upsample,
325
+ resnet_eps=resnet_eps,
326
+ resnet_act_fn=resnet_act_fn,
327
+ resnet_groups=resnet_groups,
328
+ resnet_time_scale_shift=resnet_time_scale_shift,
329
+ skip_time_act=resnet_skip_time_act,
330
+ output_scale_factor=resnet_out_scale_factor,
331
+ )
332
+ elif up_block_type == "CrossAttnUpBlock2D":
333
+ if cross_attention_dim is None:
334
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
335
+ return CrossAttnUpBlock2D(
336
+ num_layers=num_layers,
337
+ transformer_layers_per_block=transformer_layers_per_block,
338
+ in_channels=in_channels,
339
+ out_channels=out_channels,
340
+ prev_output_channel=prev_output_channel,
341
+ temb_channels=temb_channels,
342
+ add_upsample=add_upsample,
343
+ resnet_eps=resnet_eps,
344
+ resnet_act_fn=resnet_act_fn,
345
+ resnet_groups=resnet_groups,
346
+ cross_attention_dim=cross_attention_dim,
347
+ num_attention_heads=num_attention_heads,
348
+ dual_cross_attention=dual_cross_attention,
349
+ use_linear_projection=use_linear_projection,
350
+ only_cross_attention=only_cross_attention,
351
+ upcast_attention=upcast_attention,
352
+ resnet_time_scale_shift=resnet_time_scale_shift,
353
+ )
354
+ # custom MV2D attention block
355
+ elif up_block_type == "CrossAttnUpBlockMV2D":
356
+ if cross_attention_dim is None:
357
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMV2D")
358
+ return CrossAttnUpBlockMV2D(
359
+ num_layers=num_layers,
360
+ transformer_layers_per_block=transformer_layers_per_block,
361
+ in_channels=in_channels,
362
+ out_channels=out_channels,
363
+ prev_output_channel=prev_output_channel,
364
+ temb_channels=temb_channels,
365
+ add_upsample=add_upsample,
366
+ resnet_eps=resnet_eps,
367
+ resnet_act_fn=resnet_act_fn,
368
+ resnet_groups=resnet_groups,
369
+ cross_attention_dim=cross_attention_dim,
370
+ num_attention_heads=num_attention_heads,
371
+ dual_cross_attention=dual_cross_attention,
372
+ use_linear_projection=use_linear_projection,
373
+ only_cross_attention=only_cross_attention,
374
+ upcast_attention=upcast_attention,
375
+ resnet_time_scale_shift=resnet_time_scale_shift,
376
+ num_views=num_views,
377
+ cd_attention_last=cd_attention_last,
378
+ cd_attention_mid=cd_attention_mid,
379
+ multiview_attention=multiview_attention,
380
+ sparse_mv_attention=sparse_mv_attention,
381
+ mvcd_attention=mvcd_attention
382
+ )
383
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
384
+ if cross_attention_dim is None:
385
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
386
+ return SimpleCrossAttnUpBlock2D(
387
+ num_layers=num_layers,
388
+ in_channels=in_channels,
389
+ out_channels=out_channels,
390
+ prev_output_channel=prev_output_channel,
391
+ temb_channels=temb_channels,
392
+ add_upsample=add_upsample,
393
+ resnet_eps=resnet_eps,
394
+ resnet_act_fn=resnet_act_fn,
395
+ resnet_groups=resnet_groups,
396
+ cross_attention_dim=cross_attention_dim,
397
+ attention_head_dim=attention_head_dim,
398
+ resnet_time_scale_shift=resnet_time_scale_shift,
399
+ skip_time_act=resnet_skip_time_act,
400
+ output_scale_factor=resnet_out_scale_factor,
401
+ only_cross_attention=only_cross_attention,
402
+ cross_attention_norm=cross_attention_norm,
403
+ )
404
+ elif up_block_type == "AttnUpBlock2D":
405
+ if add_upsample is False:
406
+ upsample_type = None
407
+ else:
408
+ upsample_type = upsample_type or "conv" # default to 'conv'
409
+
410
+ return AttnUpBlock2D(
411
+ num_layers=num_layers,
412
+ in_channels=in_channels,
413
+ out_channels=out_channels,
414
+ prev_output_channel=prev_output_channel,
415
+ temb_channels=temb_channels,
416
+ resnet_eps=resnet_eps,
417
+ resnet_act_fn=resnet_act_fn,
418
+ resnet_groups=resnet_groups,
419
+ attention_head_dim=attention_head_dim,
420
+ resnet_time_scale_shift=resnet_time_scale_shift,
421
+ upsample_type=upsample_type,
422
+ )
423
+ elif up_block_type == "SkipUpBlock2D":
424
+ return SkipUpBlock2D(
425
+ num_layers=num_layers,
426
+ in_channels=in_channels,
427
+ out_channels=out_channels,
428
+ prev_output_channel=prev_output_channel,
429
+ temb_channels=temb_channels,
430
+ add_upsample=add_upsample,
431
+ resnet_eps=resnet_eps,
432
+ resnet_act_fn=resnet_act_fn,
433
+ resnet_time_scale_shift=resnet_time_scale_shift,
434
+ )
435
+ elif up_block_type == "AttnSkipUpBlock2D":
436
+ return AttnSkipUpBlock2D(
437
+ num_layers=num_layers,
438
+ in_channels=in_channels,
439
+ out_channels=out_channels,
440
+ prev_output_channel=prev_output_channel,
441
+ temb_channels=temb_channels,
442
+ add_upsample=add_upsample,
443
+ resnet_eps=resnet_eps,
444
+ resnet_act_fn=resnet_act_fn,
445
+ attention_head_dim=attention_head_dim,
446
+ resnet_time_scale_shift=resnet_time_scale_shift,
447
+ )
448
+ elif up_block_type == "UpDecoderBlock2D":
449
+ return UpDecoderBlock2D(
450
+ num_layers=num_layers,
451
+ in_channels=in_channels,
452
+ out_channels=out_channels,
453
+ add_upsample=add_upsample,
454
+ resnet_eps=resnet_eps,
455
+ resnet_act_fn=resnet_act_fn,
456
+ resnet_groups=resnet_groups,
457
+ resnet_time_scale_shift=resnet_time_scale_shift,
458
+ temb_channels=temb_channels,
459
+ )
460
+ elif up_block_type == "AttnUpDecoderBlock2D":
461
+ return AttnUpDecoderBlock2D(
462
+ num_layers=num_layers,
463
+ in_channels=in_channels,
464
+ out_channels=out_channels,
465
+ add_upsample=add_upsample,
466
+ resnet_eps=resnet_eps,
467
+ resnet_act_fn=resnet_act_fn,
468
+ resnet_groups=resnet_groups,
469
+ attention_head_dim=attention_head_dim,
470
+ resnet_time_scale_shift=resnet_time_scale_shift,
471
+ temb_channels=temb_channels,
472
+ )
473
+ elif up_block_type == "KUpBlock2D":
474
+ return KUpBlock2D(
475
+ num_layers=num_layers,
476
+ in_channels=in_channels,
477
+ out_channels=out_channels,
478
+ temb_channels=temb_channels,
479
+ add_upsample=add_upsample,
480
+ resnet_eps=resnet_eps,
481
+ resnet_act_fn=resnet_act_fn,
482
+ )
483
+ elif up_block_type == "KCrossAttnUpBlock2D":
484
+ return KCrossAttnUpBlock2D(
485
+ num_layers=num_layers,
486
+ in_channels=in_channels,
487
+ out_channels=out_channels,
488
+ temb_channels=temb_channels,
489
+ add_upsample=add_upsample,
490
+ resnet_eps=resnet_eps,
491
+ resnet_act_fn=resnet_act_fn,
492
+ cross_attention_dim=cross_attention_dim,
493
+ attention_head_dim=attention_head_dim,
494
+ )
495
+
496
+ raise ValueError(f"{up_block_type} does not exist.")
497
+
498
+
499
+ class UNetMidBlockMV2DCrossAttn(nn.Module):
500
+ def __init__(
501
+ self,
502
+ in_channels: int,
503
+ temb_channels: int,
504
+ dropout: float = 0.0,
505
+ num_layers: int = 1,
506
+ transformer_layers_per_block: int = 1,
507
+ resnet_eps: float = 1e-6,
508
+ resnet_time_scale_shift: str = "default",
509
+ resnet_act_fn: str = "swish",
510
+ resnet_groups: int = 32,
511
+ resnet_pre_norm: bool = True,
512
+ num_attention_heads=1,
513
+ output_scale_factor=1.0,
514
+ cross_attention_dim=1280,
515
+ dual_cross_attention=False,
516
+ use_linear_projection=False,
517
+ upcast_attention=False,
518
+ num_views: int = 1,
519
+ cd_attention_last: bool = False,
520
+ cd_attention_mid: bool = False,
521
+ multiview_attention: bool = True,
522
+ sparse_mv_attention: bool = False,
523
+ mvcd_attention: bool=False
524
+ ):
525
+ super().__init__()
526
+
527
+ self.has_cross_attention = True
528
+ self.num_attention_heads = num_attention_heads
529
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
530
+
531
+ # there is always at least one resnet
532
+ resnets = [
533
+ ResnetBlock2D(
534
+ in_channels=in_channels,
535
+ out_channels=in_channels,
536
+ temb_channels=temb_channels,
537
+ eps=resnet_eps,
538
+ groups=resnet_groups,
539
+ dropout=dropout,
540
+ time_embedding_norm=resnet_time_scale_shift,
541
+ non_linearity=resnet_act_fn,
542
+ output_scale_factor=output_scale_factor,
543
+ pre_norm=resnet_pre_norm,
544
+ )
545
+ ]
546
+ attentions = []
547
+
548
+ for _ in range(num_layers):
549
+ if not dual_cross_attention:
550
+ attentions.append(
551
+ TransformerMV2DModel(
552
+ num_attention_heads,
553
+ in_channels // num_attention_heads,
554
+ in_channels=in_channels,
555
+ num_layers=transformer_layers_per_block,
556
+ cross_attention_dim=cross_attention_dim,
557
+ norm_num_groups=resnet_groups,
558
+ use_linear_projection=use_linear_projection,
559
+ upcast_attention=upcast_attention,
560
+ num_views=num_views,
561
+ cd_attention_last=cd_attention_last,
562
+ cd_attention_mid=cd_attention_mid,
563
+ multiview_attention=multiview_attention,
564
+ sparse_mv_attention=sparse_mv_attention,
565
+ mvcd_attention=mvcd_attention
566
+ )
567
+ )
568
+ else:
569
+ raise NotImplementedError
570
+ resnets.append(
571
+ ResnetBlock2D(
572
+ in_channels=in_channels,
573
+ out_channels=in_channels,
574
+ temb_channels=temb_channels,
575
+ eps=resnet_eps,
576
+ groups=resnet_groups,
577
+ dropout=dropout,
578
+ time_embedding_norm=resnet_time_scale_shift,
579
+ non_linearity=resnet_act_fn,
580
+ output_scale_factor=output_scale_factor,
581
+ pre_norm=resnet_pre_norm,
582
+ )
583
+ )
584
+
585
+ self.attentions = nn.ModuleList(attentions)
586
+ self.resnets = nn.ModuleList(resnets)
587
+
588
+ def forward(
589
+ self,
590
+ hidden_states: torch.FloatTensor,
591
+ temb: Optional[torch.FloatTensor] = None,
592
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
593
+ attention_mask: Optional[torch.FloatTensor] = None,
594
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
595
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
596
+ ) -> torch.FloatTensor:
597
+ hidden_states = self.resnets[0](hidden_states, temb)
598
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
599
+ hidden_states = attn(
600
+ hidden_states,
601
+ encoder_hidden_states=encoder_hidden_states,
602
+ cross_attention_kwargs=cross_attention_kwargs,
603
+ attention_mask=attention_mask,
604
+ encoder_attention_mask=encoder_attention_mask,
605
+ return_dict=False,
606
+ )[0]
607
+ hidden_states = resnet(hidden_states, temb)
608
+
609
+ return hidden_states
610
+
611
+
612
+ class CrossAttnUpBlockMV2D(nn.Module):
613
+ def __init__(
614
+ self,
615
+ in_channels: int,
616
+ out_channels: int,
617
+ prev_output_channel: int,
618
+ temb_channels: int,
619
+ dropout: float = 0.0,
620
+ num_layers: int = 1,
621
+ transformer_layers_per_block: int = 1,
622
+ resnet_eps: float = 1e-6,
623
+ resnet_time_scale_shift: str = "default",
624
+ resnet_act_fn: str = "swish",
625
+ resnet_groups: int = 32,
626
+ resnet_pre_norm: bool = True,
627
+ num_attention_heads=1,
628
+ cross_attention_dim=1280,
629
+ output_scale_factor=1.0,
630
+ add_upsample=True,
631
+ dual_cross_attention=False,
632
+ use_linear_projection=False,
633
+ only_cross_attention=False,
634
+ upcast_attention=False,
635
+ num_views: int = 1,
636
+ cd_attention_last: bool = False,
637
+ cd_attention_mid: bool = False,
638
+ multiview_attention: bool = True,
639
+ sparse_mv_attention: bool = False,
640
+ mvcd_attention: bool=False
641
+ ):
642
+ super().__init__()
643
+ resnets = []
644
+ attentions = []
645
+
646
+ self.has_cross_attention = True
647
+ self.num_attention_heads = num_attention_heads
648
+
649
+ for i in range(num_layers):
650
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
651
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
652
+
653
+ resnets.append(
654
+ ResnetBlock2D(
655
+ in_channels=resnet_in_channels + res_skip_channels,
656
+ out_channels=out_channels,
657
+ temb_channels=temb_channels,
658
+ eps=resnet_eps,
659
+ groups=resnet_groups,
660
+ dropout=dropout,
661
+ time_embedding_norm=resnet_time_scale_shift,
662
+ non_linearity=resnet_act_fn,
663
+ output_scale_factor=output_scale_factor,
664
+ pre_norm=resnet_pre_norm,
665
+ )
666
+ )
667
+ if not dual_cross_attention:
668
+ attentions.append(
669
+ TransformerMV2DModel(
670
+ num_attention_heads,
671
+ out_channels // num_attention_heads,
672
+ in_channels=out_channels,
673
+ num_layers=transformer_layers_per_block,
674
+ cross_attention_dim=cross_attention_dim,
675
+ norm_num_groups=resnet_groups,
676
+ use_linear_projection=use_linear_projection,
677
+ only_cross_attention=only_cross_attention,
678
+ upcast_attention=upcast_attention,
679
+ num_views=num_views,
680
+ cd_attention_last=cd_attention_last,
681
+ cd_attention_mid=cd_attention_mid,
682
+ multiview_attention=multiview_attention,
683
+ sparse_mv_attention=sparse_mv_attention,
684
+ mvcd_attention=mvcd_attention
685
+ )
686
+ )
687
+ else:
688
+ raise NotImplementedError
689
+ self.attentions = nn.ModuleList(attentions)
690
+ self.resnets = nn.ModuleList(resnets)
691
+
692
+ if add_upsample:
693
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
694
+ else:
695
+ self.upsamplers = None
696
+
697
+ self.gradient_checkpointing = False
698
+
699
+ def forward(
700
+ self,
701
+ hidden_states: torch.FloatTensor,
702
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
703
+ temb: Optional[torch.FloatTensor] = None,
704
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
705
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
706
+ upsample_size: Optional[int] = None,
707
+ attention_mask: Optional[torch.FloatTensor] = None,
708
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
709
+ ):
710
+ for resnet, attn in zip(self.resnets, self.attentions):
711
+ # pop res hidden states
712
+ res_hidden_states = res_hidden_states_tuple[-1]
713
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
714
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
715
+
716
+ if self.training and self.gradient_checkpointing:
717
+
718
+ def create_custom_forward(module, return_dict=None):
719
+ def custom_forward(*inputs):
720
+ if return_dict is not None:
721
+ return module(*inputs, return_dict=return_dict)
722
+ else:
723
+ return module(*inputs)
724
+
725
+ return custom_forward
726
+
727
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
728
+ hidden_states = torch.utils.checkpoint.checkpoint(
729
+ create_custom_forward(resnet),
730
+ hidden_states,
731
+ temb,
732
+ **ckpt_kwargs,
733
+ )
734
+ hidden_states = torch.utils.checkpoint.checkpoint(
735
+ create_custom_forward(attn, return_dict=False),
736
+ hidden_states,
737
+ encoder_hidden_states,
738
+ None, # timestep
739
+ None, # class_labels
740
+ cross_attention_kwargs,
741
+ attention_mask,
742
+ encoder_attention_mask,
743
+ **ckpt_kwargs,
744
+ )[0]
745
+ else:
746
+ hidden_states = resnet(hidden_states, temb)
747
+ hidden_states = attn(
748
+ hidden_states,
749
+ encoder_hidden_states=encoder_hidden_states,
750
+ cross_attention_kwargs=cross_attention_kwargs,
751
+ attention_mask=attention_mask,
752
+ encoder_attention_mask=encoder_attention_mask,
753
+ return_dict=False,
754
+ )[0]
755
+
756
+ if self.upsamplers is not None:
757
+ for upsampler in self.upsamplers:
758
+ hidden_states = upsampler(hidden_states, upsample_size)
759
+
760
+ return hidden_states
761
+
762
+
763
+ class CrossAttnDownBlockMV2D(nn.Module):
764
+ def __init__(
765
+ self,
766
+ in_channels: int,
767
+ out_channels: int,
768
+ temb_channels: int,
769
+ dropout: float = 0.0,
770
+ num_layers: int = 1,
771
+ transformer_layers_per_block: int = 1,
772
+ resnet_eps: float = 1e-6,
773
+ resnet_time_scale_shift: str = "default",
774
+ resnet_act_fn: str = "swish",
775
+ resnet_groups: int = 32,
776
+ resnet_pre_norm: bool = True,
777
+ num_attention_heads=1,
778
+ cross_attention_dim=1280,
779
+ output_scale_factor=1.0,
780
+ downsample_padding=1,
781
+ add_downsample=True,
782
+ dual_cross_attention=False,
783
+ use_linear_projection=False,
784
+ only_cross_attention=False,
785
+ upcast_attention=False,
786
+ num_views: int = 1,
787
+ cd_attention_last: bool = False,
788
+ cd_attention_mid: bool = False,
789
+ multiview_attention: bool = True,
790
+ sparse_mv_attention: bool = False,
791
+ mvcd_attention: bool=False
792
+ ):
793
+ super().__init__()
794
+ resnets = []
795
+ attentions = []
796
+
797
+ self.has_cross_attention = True
798
+ self.num_attention_heads = num_attention_heads
799
+
800
+ for i in range(num_layers):
801
+ in_channels = in_channels if i == 0 else out_channels
802
+ resnets.append(
803
+ ResnetBlock2D(
804
+ in_channels=in_channels,
805
+ out_channels=out_channels,
806
+ temb_channels=temb_channels,
807
+ eps=resnet_eps,
808
+ groups=resnet_groups,
809
+ dropout=dropout,
810
+ time_embedding_norm=resnet_time_scale_shift,
811
+ non_linearity=resnet_act_fn,
812
+ output_scale_factor=output_scale_factor,
813
+ pre_norm=resnet_pre_norm,
814
+ )
815
+ )
816
+ if not dual_cross_attention:
817
+ attentions.append(
818
+ TransformerMV2DModel(
819
+ num_attention_heads,
820
+ out_channels // num_attention_heads,
821
+ in_channels=out_channels,
822
+ num_layers=transformer_layers_per_block,
823
+ cross_attention_dim=cross_attention_dim,
824
+ norm_num_groups=resnet_groups,
825
+ use_linear_projection=use_linear_projection,
826
+ only_cross_attention=only_cross_attention,
827
+ upcast_attention=upcast_attention,
828
+ num_views=num_views,
829
+ cd_attention_last=cd_attention_last,
830
+ cd_attention_mid=cd_attention_mid,
831
+ multiview_attention=multiview_attention,
832
+ sparse_mv_attention=sparse_mv_attention,
833
+ mvcd_attention=mvcd_attention
834
+ )
835
+ )
836
+ else:
837
+ raise NotImplementedError
838
+ self.attentions = nn.ModuleList(attentions)
839
+ self.resnets = nn.ModuleList(resnets)
840
+
841
+ if add_downsample:
842
+ self.downsamplers = nn.ModuleList(
843
+ [
844
+ Downsample2D(
845
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
846
+ )
847
+ ]
848
+ )
849
+ else:
850
+ self.downsamplers = None
851
+
852
+ self.gradient_checkpointing = False
853
+
854
+ def forward(
855
+ self,
856
+ hidden_states: torch.FloatTensor,
857
+ temb: Optional[torch.FloatTensor] = None,
858
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
859
+ attention_mask: Optional[torch.FloatTensor] = None,
860
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
861
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
862
+ additional_residuals=None,
863
+ ):
864
+ output_states = ()
865
+
866
+ blocks = list(zip(self.resnets, self.attentions))
867
+
868
+ for i, (resnet, attn) in enumerate(blocks):
869
+ if self.training and self.gradient_checkpointing:
870
+
871
+ def create_custom_forward(module, return_dict=None):
872
+ def custom_forward(*inputs):
873
+ if return_dict is not None:
874
+ return module(*inputs, return_dict=return_dict)
875
+ else:
876
+ return module(*inputs)
877
+
878
+ return custom_forward
879
+
880
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
881
+ hidden_states = torch.utils.checkpoint.checkpoint(
882
+ create_custom_forward(resnet),
883
+ hidden_states,
884
+ temb,
885
+ **ckpt_kwargs,
886
+ )
887
+ hidden_states = torch.utils.checkpoint.checkpoint(
888
+ create_custom_forward(attn, return_dict=False),
889
+ hidden_states,
890
+ encoder_hidden_states,
891
+ None, # timestep
892
+ None, # class_labels
893
+ cross_attention_kwargs,
894
+ attention_mask,
895
+ encoder_attention_mask,
896
+ **ckpt_kwargs,
897
+ )[0]
898
+ else:
899
+ hidden_states = resnet(hidden_states, temb)
900
+ hidden_states = attn(
901
+ hidden_states,
902
+ encoder_hidden_states=encoder_hidden_states,
903
+ cross_attention_kwargs=cross_attention_kwargs,
904
+ attention_mask=attention_mask,
905
+ encoder_attention_mask=encoder_attention_mask,
906
+ return_dict=False,
907
+ )[0]
908
+
909
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
910
+ if i == len(blocks) - 1 and additional_residuals is not None:
911
+ hidden_states = hidden_states + additional_residuals
912
+
913
+ output_states = output_states + (hidden_states,)
914
+
915
+ if self.downsamplers is not None:
916
+ for downsampler in self.downsamplers:
917
+ hidden_states = downsampler(hidden_states)
918
+
919
+ output_states = output_states + (hidden_states,)
920
+
921
+ return hidden_states, output_states
922
+
mv_diffusion_30/models/unet_mv2d_condition.py ADDED
@@ -0,0 +1,1501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.utils import BaseOutput, logging
25
+ from diffusers.models.activations import get_activation
26
+ from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor
27
+ from diffusers.models.embeddings import (
28
+ GaussianFourierProjection,
29
+ ImageHintTimeEmbedding,
30
+ ImageProjection,
31
+ ImageTimeEmbedding,
32
+ TextImageProjection,
33
+ TextImageTimeEmbedding,
34
+ TextTimeEmbedding,
35
+ TimestepEmbedding,
36
+ Timesteps,
37
+ )
38
+ from diffusers.models.modeling_utils import ModelMixin, load_state_dict, _load_state_dict_into_model
39
+ from diffusers.models.unets.unet_2d_blocks import (
40
+ CrossAttnDownBlock2D,
41
+ CrossAttnUpBlock2D,
42
+ DownBlock2D,
43
+ UNetMidBlock2DCrossAttn,
44
+ UNetMidBlock2DSimpleCrossAttn,
45
+ UpBlock2D,
46
+ )
47
+ from diffusers.utils import (
48
+ CONFIG_NAME,
49
+ HF_MODULES_CACHE,
50
+ FLAX_WEIGHTS_NAME,
51
+ SAFETENSORS_WEIGHTS_NAME,
52
+ WEIGHTS_NAME,
53
+ _add_variant,
54
+ _get_model_file,
55
+ deprecate,
56
+ is_accelerate_available,
57
+ is_safetensors_available,
58
+ is_torch_version,
59
+ logging,
60
+ )
61
+ from diffusers import __version__
62
+ from mv_diffusion_30.models.unet_mv2d_blocks import (
63
+ CrossAttnDownBlockMV2D,
64
+ CrossAttnUpBlockMV2D,
65
+ UNetMidBlockMV2DCrossAttn,
66
+ get_down_block,
67
+ get_up_block,
68
+ )
69
+ from huggingface_hub.constants import HF_HUB_OFFLINE
70
+
71
+
72
+ hf_cache_home = os.path.expanduser(
73
+ os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
74
+ )
75
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
76
+
77
+
78
+ DIFFUSERS_CACHE = HUGGINGFACE_HUB_CACHE
79
+
80
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
81
+
82
+
83
+ @dataclass
84
+ class UNetMV2DConditionOutput(BaseOutput):
85
+ """
86
+ The output of [`UNet2DConditionModel`].
87
+
88
+ Args:
89
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
90
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
91
+ """
92
+
93
+ sample: torch.FloatTensor = None
94
+
95
+
96
+ class UNetMV2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
97
+ r"""
98
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
99
+ shaped output.
100
+
101
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
102
+ for all models (such as downloading or saving).
103
+
104
+ Parameters:
105
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
106
+ Height and width of input/output sample.
107
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
108
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
109
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
110
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
111
+ Whether to flip the sin to cos in the time embedding.
112
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
113
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
114
+ The tuple of downsample blocks to use.
115
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
116
+ Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
117
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
118
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
119
+ The tuple of upsample blocks to use.
120
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
121
+ Whether to include self-attention in the basic transformer blocks, see
122
+ [`~models.attention.BasicTransformerBlock`].
123
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
124
+ The tuple of output channels for each block.
125
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
126
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
127
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
128
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
129
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
130
+ If `None`, normalization and activation layers is skipped in post-processing.
131
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
132
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
133
+ The dimension of the cross attention features.
134
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
135
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
136
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
137
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
138
+ encoder_hid_dim (`int`, *optional*, defaults to None):
139
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
140
+ dimension to `cross_attention_dim`.
141
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
142
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
143
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
144
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
145
+ num_attention_heads (`int`, *optional*):
146
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
147
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
148
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
149
+ class_embed_type (`str`, *optional*, defaults to `None`):
150
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
151
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
152
+ addition_embed_type (`str`, *optional*, defaults to `None`):
153
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
154
+ "text". "text" will use the `TextTimeEmbedding` layer.
155
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
156
+ Dimension for the timestep embeddings.
157
+ num_class_embeds (`int`, *optional*, defaults to `None`):
158
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
159
+ class conditioning with `class_embed_type` equal to `None`.
160
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
161
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
162
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
163
+ An optional override for the dimension of the projected time embedding.
164
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
165
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
166
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
167
+ timestep_post_act (`str`, *optional*, defaults to `None`):
168
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
169
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
170
+ The dimension of `cond_proj` layer in the timestep embedding.
171
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
172
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
173
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
174
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
175
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
176
+ embeddings with the class embeddings.
177
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
178
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
179
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
180
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
181
+ otherwise.
182
+ """
183
+
184
+ _supports_gradient_checkpointing = True
185
+
186
+ @register_to_config
187
+ def __init__(
188
+ self,
189
+ sample_size: Optional[int] = None,
190
+ in_channels: int = 4,
191
+ out_channels: int = 4,
192
+ center_input_sample: bool = False,
193
+ flip_sin_to_cos: bool = True,
194
+ freq_shift: int = 0,
195
+ down_block_types: Tuple[str] = (
196
+ "CrossAttnDownBlockMV2D",
197
+ "CrossAttnDownBlockMV2D",
198
+ "CrossAttnDownBlockMV2D",
199
+ "DownBlock2D",
200
+ ),
201
+ mid_block_type: Optional[str] = "UNetMidBlockMV2DCrossAttn",
202
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D", "CrossAttnUpBlockMV2D"),
203
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
204
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
205
+ layers_per_block: Union[int, Tuple[int]] = 2,
206
+ downsample_padding: int = 1,
207
+ mid_block_scale_factor: float = 1,
208
+ act_fn: str = "silu",
209
+ norm_num_groups: Optional[int] = 32,
210
+ norm_eps: float = 1e-5,
211
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
212
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
213
+ encoder_hid_dim: Optional[int] = None,
214
+ encoder_hid_dim_type: Optional[str] = None,
215
+ attention_head_dim: Union[int, Tuple[int]] = 8,
216
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
217
+ dual_cross_attention: bool = False,
218
+ use_linear_projection: bool = False,
219
+ class_embed_type: Optional[str] = None,
220
+ addition_embed_type: Optional[str] = None,
221
+ addition_time_embed_dim: Optional[int] = None,
222
+ num_class_embeds: Optional[int] = None,
223
+ upcast_attention: bool = False,
224
+ resnet_time_scale_shift: str = "default",
225
+ resnet_skip_time_act: bool = False,
226
+ resnet_out_scale_factor: int = 1.0,
227
+ time_embedding_type: str = "positional",
228
+ time_embedding_dim: Optional[int] = None,
229
+ time_embedding_act_fn: Optional[str] = None,
230
+ timestep_post_act: Optional[str] = None,
231
+ time_cond_proj_dim: Optional[int] = None,
232
+ conv_in_kernel: int = 3,
233
+ conv_out_kernel: int = 3,
234
+ projection_class_embeddings_input_dim: Optional[int] = None,
235
+ class_embeddings_concat: bool = False,
236
+ mid_block_only_cross_attention: Optional[bool] = None,
237
+ cross_attention_norm: Optional[str] = None,
238
+ addition_embed_type_num_heads=64,
239
+ num_views: int = 1,
240
+ cd_attention_last: bool = False,
241
+ cd_attention_mid: bool = False,
242
+ multiview_attention: bool = True,
243
+ sparse_mv_attention: bool = False,
244
+ mvcd_attention: bool = False
245
+ ):
246
+ super().__init__()
247
+
248
+ self.sample_size = sample_size
249
+
250
+ if num_attention_heads is not None:
251
+ raise ValueError(
252
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
253
+ )
254
+
255
+ # If `num_attention_heads` is not defined (which is the case for most models)
256
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
257
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
258
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
259
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
260
+ # which is why we correct for the naming here.
261
+ num_attention_heads = num_attention_heads or attention_head_dim
262
+
263
+ # Check inputs
264
+ if len(down_block_types) != len(up_block_types):
265
+ raise ValueError(
266
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
267
+ )
268
+
269
+ if len(block_out_channels) != len(down_block_types):
270
+ raise ValueError(
271
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
272
+ )
273
+
274
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
275
+ raise ValueError(
276
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
277
+ )
278
+
279
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
280
+ raise ValueError(
281
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
282
+ )
283
+
284
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
285
+ raise ValueError(
286
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
287
+ )
288
+
289
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
290
+ raise ValueError(
291
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
292
+ )
293
+
294
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
295
+ raise ValueError(
296
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
297
+ )
298
+
299
+ # input
300
+ conv_in_padding = (conv_in_kernel - 1) // 2
301
+ self.conv_in = nn.Conv2d(
302
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
303
+ )
304
+
305
+ # time
306
+ if time_embedding_type == "fourier":
307
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
308
+ if time_embed_dim % 2 != 0:
309
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
310
+ self.time_proj = GaussianFourierProjection(
311
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
312
+ )
313
+ timestep_input_dim = time_embed_dim
314
+ elif time_embedding_type == "positional":
315
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
316
+
317
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
318
+ timestep_input_dim = block_out_channels[0]
319
+ else:
320
+ raise ValueError(
321
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
322
+ )
323
+
324
+ self.time_embedding = TimestepEmbedding(
325
+ timestep_input_dim,
326
+ time_embed_dim,
327
+ act_fn=act_fn,
328
+ post_act_fn=timestep_post_act,
329
+ cond_proj_dim=time_cond_proj_dim,
330
+ )
331
+
332
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
333
+ encoder_hid_dim_type = "text_proj"
334
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
335
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
336
+
337
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
338
+ raise ValueError(
339
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
340
+ )
341
+
342
+ if encoder_hid_dim_type == "text_proj":
343
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
344
+ elif encoder_hid_dim_type == "text_image_proj":
345
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
346
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
347
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
348
+ self.encoder_hid_proj = TextImageProjection(
349
+ text_embed_dim=encoder_hid_dim,
350
+ image_embed_dim=cross_attention_dim,
351
+ cross_attention_dim=cross_attention_dim,
352
+ )
353
+ elif encoder_hid_dim_type == "image_proj":
354
+ # Kandinsky 2.2
355
+ self.encoder_hid_proj = ImageProjection(
356
+ image_embed_dim=encoder_hid_dim,
357
+ cross_attention_dim=cross_attention_dim,
358
+ )
359
+ elif encoder_hid_dim_type is not None:
360
+ raise ValueError(
361
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
362
+ )
363
+ else:
364
+ self.encoder_hid_proj = None
365
+
366
+ # class embedding
367
+ if class_embed_type is None and num_class_embeds is not None:
368
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
369
+ elif class_embed_type == "timestep":
370
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
371
+ elif class_embed_type == "identity":
372
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
373
+ elif class_embed_type == "projection":
374
+ if projection_class_embeddings_input_dim is None:
375
+ raise ValueError(
376
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
377
+ )
378
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
379
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
380
+ # 2. it projects from an arbitrary input dimension.
381
+ #
382
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
383
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
384
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
385
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
386
+ elif class_embed_type == "simple_projection":
387
+ if projection_class_embeddings_input_dim is None:
388
+ raise ValueError(
389
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
390
+ )
391
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
392
+ else:
393
+ self.class_embedding = None
394
+
395
+ if addition_embed_type == "text":
396
+ if encoder_hid_dim is not None:
397
+ text_time_embedding_from_dim = encoder_hid_dim
398
+ else:
399
+ text_time_embedding_from_dim = cross_attention_dim
400
+
401
+ self.add_embedding = TextTimeEmbedding(
402
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
403
+ )
404
+ elif addition_embed_type == "text_image":
405
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
406
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
407
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
408
+ self.add_embedding = TextImageTimeEmbedding(
409
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
410
+ )
411
+ elif addition_embed_type == "text_time":
412
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
413
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
414
+ elif addition_embed_type == "image":
415
+ # Kandinsky 2.2
416
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
417
+ elif addition_embed_type == "image_hint":
418
+ # Kandinsky 2.2 ControlNet
419
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
420
+ elif addition_embed_type is not None:
421
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
422
+
423
+ if time_embedding_act_fn is None:
424
+ self.time_embed_act = None
425
+ else:
426
+ self.time_embed_act = get_activation(time_embedding_act_fn)
427
+
428
+ self.down_blocks = nn.ModuleList([])
429
+ self.up_blocks = nn.ModuleList([])
430
+
431
+ if isinstance(only_cross_attention, bool):
432
+ if mid_block_only_cross_attention is None:
433
+ mid_block_only_cross_attention = only_cross_attention
434
+
435
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
436
+
437
+ if mid_block_only_cross_attention is None:
438
+ mid_block_only_cross_attention = False
439
+
440
+ if isinstance(num_attention_heads, int):
441
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
442
+
443
+ if isinstance(attention_head_dim, int):
444
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
445
+
446
+ if isinstance(cross_attention_dim, int):
447
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
448
+
449
+ if isinstance(layers_per_block, int):
450
+ layers_per_block = [layers_per_block] * len(down_block_types)
451
+
452
+ if isinstance(transformer_layers_per_block, int):
453
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
454
+
455
+ if class_embeddings_concat:
456
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
457
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
458
+ # regular time embeddings
459
+ blocks_time_embed_dim = time_embed_dim * 2
460
+ else:
461
+ blocks_time_embed_dim = time_embed_dim
462
+
463
+ # down
464
+ output_channel = block_out_channels[0]
465
+ for i, down_block_type in enumerate(down_block_types):
466
+ input_channel = output_channel
467
+ output_channel = block_out_channels[i]
468
+ is_final_block = i == len(block_out_channels) - 1
469
+
470
+ down_block = get_down_block(
471
+ down_block_type,
472
+ num_layers=layers_per_block[i],
473
+ transformer_layers_per_block=transformer_layers_per_block[i],
474
+ in_channels=input_channel,
475
+ out_channels=output_channel,
476
+ temb_channels=blocks_time_embed_dim,
477
+ add_downsample=not is_final_block,
478
+ resnet_eps=norm_eps,
479
+ resnet_act_fn=act_fn,
480
+ resnet_groups=norm_num_groups,
481
+ cross_attention_dim=cross_attention_dim[i],
482
+ num_attention_heads=num_attention_heads[i],
483
+ downsample_padding=downsample_padding,
484
+ dual_cross_attention=dual_cross_attention,
485
+ use_linear_projection=use_linear_projection,
486
+ only_cross_attention=only_cross_attention[i],
487
+ upcast_attention=upcast_attention,
488
+ resnet_time_scale_shift=resnet_time_scale_shift,
489
+ resnet_skip_time_act=resnet_skip_time_act,
490
+ resnet_out_scale_factor=resnet_out_scale_factor,
491
+ cross_attention_norm=cross_attention_norm,
492
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
493
+ num_views=num_views,
494
+ cd_attention_last=cd_attention_last,
495
+ cd_attention_mid=cd_attention_mid,
496
+ multiview_attention=multiview_attention,
497
+ sparse_mv_attention=sparse_mv_attention,
498
+ mvcd_attention=mvcd_attention
499
+ )
500
+ self.down_blocks.append(down_block)
501
+
502
+ # mid
503
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
504
+ self.mid_block = UNetMidBlock2DCrossAttn(
505
+ transformer_layers_per_block=transformer_layers_per_block[-1],
506
+ in_channels=block_out_channels[-1],
507
+ temb_channels=blocks_time_embed_dim,
508
+ resnet_eps=norm_eps,
509
+ resnet_act_fn=act_fn,
510
+ output_scale_factor=mid_block_scale_factor,
511
+ resnet_time_scale_shift=resnet_time_scale_shift,
512
+ cross_attention_dim=cross_attention_dim[-1],
513
+ num_attention_heads=num_attention_heads[-1],
514
+ resnet_groups=norm_num_groups,
515
+ dual_cross_attention=dual_cross_attention,
516
+ use_linear_projection=use_linear_projection,
517
+ upcast_attention=upcast_attention,
518
+ )
519
+ # custom MV2D attention block
520
+ elif mid_block_type == "UNetMidBlockMV2DCrossAttn":
521
+ self.mid_block = UNetMidBlockMV2DCrossAttn(
522
+ transformer_layers_per_block=transformer_layers_per_block[-1],
523
+ in_channels=block_out_channels[-1],
524
+ temb_channels=blocks_time_embed_dim,
525
+ resnet_eps=norm_eps,
526
+ resnet_act_fn=act_fn,
527
+ output_scale_factor=mid_block_scale_factor,
528
+ resnet_time_scale_shift=resnet_time_scale_shift,
529
+ cross_attention_dim=cross_attention_dim[-1],
530
+ num_attention_heads=num_attention_heads[-1],
531
+ resnet_groups=norm_num_groups,
532
+ dual_cross_attention=dual_cross_attention,
533
+ use_linear_projection=use_linear_projection,
534
+ upcast_attention=upcast_attention,
535
+ num_views=num_views,
536
+ cd_attention_last=cd_attention_last,
537
+ cd_attention_mid=cd_attention_mid,
538
+ multiview_attention=multiview_attention,
539
+ sparse_mv_attention=sparse_mv_attention,
540
+ mvcd_attention=mvcd_attention
541
+ )
542
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
543
+ self.mid_block = UNetMidBlock2DSimpleCrossAttn(
544
+ in_channels=block_out_channels[-1],
545
+ temb_channels=blocks_time_embed_dim,
546
+ resnet_eps=norm_eps,
547
+ resnet_act_fn=act_fn,
548
+ output_scale_factor=mid_block_scale_factor,
549
+ cross_attention_dim=cross_attention_dim[-1],
550
+ attention_head_dim=attention_head_dim[-1],
551
+ resnet_groups=norm_num_groups,
552
+ resnet_time_scale_shift=resnet_time_scale_shift,
553
+ skip_time_act=resnet_skip_time_act,
554
+ only_cross_attention=mid_block_only_cross_attention,
555
+ cross_attention_norm=cross_attention_norm,
556
+ )
557
+ elif mid_block_type is None:
558
+ self.mid_block = None
559
+ else:
560
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
561
+
562
+ # count how many layers upsample the images
563
+ self.num_upsamplers = 0
564
+
565
+ # up
566
+ reversed_block_out_channels = list(reversed(block_out_channels))
567
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
568
+ reversed_layers_per_block = list(reversed(layers_per_block))
569
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
570
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
571
+ only_cross_attention = list(reversed(only_cross_attention))
572
+
573
+ output_channel = reversed_block_out_channels[0]
574
+ for i, up_block_type in enumerate(up_block_types):
575
+ is_final_block = i == len(block_out_channels) - 1
576
+
577
+ prev_output_channel = output_channel
578
+ output_channel = reversed_block_out_channels[i]
579
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
580
+
581
+ # add upsample block for all BUT final layer
582
+ if not is_final_block:
583
+ add_upsample = True
584
+ self.num_upsamplers += 1
585
+ else:
586
+ add_upsample = False
587
+
588
+ up_block = get_up_block(
589
+ up_block_type,
590
+ num_layers=reversed_layers_per_block[i] + 1,
591
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
592
+ in_channels=input_channel,
593
+ out_channels=output_channel,
594
+ prev_output_channel=prev_output_channel,
595
+ temb_channels=blocks_time_embed_dim,
596
+ add_upsample=add_upsample,
597
+ resnet_eps=norm_eps,
598
+ resnet_act_fn=act_fn,
599
+ resnet_groups=norm_num_groups,
600
+ cross_attention_dim=reversed_cross_attention_dim[i],
601
+ num_attention_heads=reversed_num_attention_heads[i],
602
+ dual_cross_attention=dual_cross_attention,
603
+ use_linear_projection=use_linear_projection,
604
+ only_cross_attention=only_cross_attention[i],
605
+ upcast_attention=upcast_attention,
606
+ resnet_time_scale_shift=resnet_time_scale_shift,
607
+ resnet_skip_time_act=resnet_skip_time_act,
608
+ resnet_out_scale_factor=resnet_out_scale_factor,
609
+ cross_attention_norm=cross_attention_norm,
610
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
611
+ num_views=num_views,
612
+ cd_attention_last=cd_attention_last,
613
+ cd_attention_mid=cd_attention_mid,
614
+ multiview_attention=multiview_attention,
615
+ sparse_mv_attention=sparse_mv_attention,
616
+ mvcd_attention=mvcd_attention
617
+ )
618
+ self.up_blocks.append(up_block)
619
+ prev_output_channel = output_channel
620
+
621
+ # out
622
+ if norm_num_groups is not None:
623
+ self.conv_norm_out = nn.GroupNorm(
624
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
625
+ )
626
+
627
+ self.conv_act = get_activation(act_fn)
628
+
629
+ else:
630
+ self.conv_norm_out = None
631
+ self.conv_act = None
632
+
633
+ conv_out_padding = (conv_out_kernel - 1) // 2
634
+ self.conv_out = nn.Conv2d(
635
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
636
+ )
637
+
638
+ @property
639
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
640
+ r"""
641
+ Returns:
642
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
643
+ indexed by its weight name.
644
+ """
645
+ # set recursively
646
+ processors = {}
647
+
648
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
649
+ if hasattr(module, "set_processor"):
650
+ processors[f"{name}.processor"] = module.processor
651
+
652
+ for sub_name, child in module.named_children():
653
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
654
+
655
+ return processors
656
+
657
+ for name, module in self.named_children():
658
+ fn_recursive_add_processors(name, module, processors)
659
+
660
+ return processors
661
+
662
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
663
+ r"""
664
+ Sets the attention processor to use to compute attention.
665
+
666
+ Parameters:
667
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
668
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
669
+ for **all** `Attention` layers.
670
+
671
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
672
+ processor. This is strongly recommended when setting trainable attention processors.
673
+
674
+ """
675
+ count = len(self.attn_processors.keys())
676
+
677
+ if isinstance(processor, dict) and len(processor) != count:
678
+ raise ValueError(
679
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
680
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
681
+ )
682
+
683
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
684
+ if hasattr(module, "set_processor"):
685
+ if not isinstance(processor, dict):
686
+ module.set_processor(processor)
687
+ else:
688
+ module.set_processor(processor.pop(f"{name}.processor"))
689
+
690
+ for sub_name, child in module.named_children():
691
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
692
+
693
+ for name, module in self.named_children():
694
+ fn_recursive_attn_processor(name, module, processor)
695
+
696
+ def set_default_attn_processor(self):
697
+ """
698
+ Disables custom attention processors and sets the default attention implementation.
699
+ """
700
+ self.set_attn_processor(AttnProcessor())
701
+
702
+ def set_attention_slice(self, slice_size):
703
+ r"""
704
+ Enable sliced attention computation.
705
+
706
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
707
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
708
+
709
+ Args:
710
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
711
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
712
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
713
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
714
+ must be a multiple of `slice_size`.
715
+ """
716
+ sliceable_head_dims = []
717
+
718
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
719
+ if hasattr(module, "set_attention_slice"):
720
+ sliceable_head_dims.append(module.sliceable_head_dim)
721
+
722
+ for child in module.children():
723
+ fn_recursive_retrieve_sliceable_dims(child)
724
+
725
+ # retrieve number of attention layers
726
+ for module in self.children():
727
+ fn_recursive_retrieve_sliceable_dims(module)
728
+
729
+ num_sliceable_layers = len(sliceable_head_dims)
730
+
731
+ if slice_size == "auto":
732
+ # half the attention head size is usually a good trade-off between
733
+ # speed and memory
734
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
735
+ elif slice_size == "max":
736
+ # make smallest slice possible
737
+ slice_size = num_sliceable_layers * [1]
738
+
739
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
740
+
741
+ if len(slice_size) != len(sliceable_head_dims):
742
+ raise ValueError(
743
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
744
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
745
+ )
746
+
747
+ for i in range(len(slice_size)):
748
+ size = slice_size[i]
749
+ dim = sliceable_head_dims[i]
750
+ if size is not None and size > dim:
751
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
752
+
753
+ # Recursively walk through all the children.
754
+ # Any children which exposes the set_attention_slice method
755
+ # gets the message
756
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
757
+ if hasattr(module, "set_attention_slice"):
758
+ module.set_attention_slice(slice_size.pop())
759
+
760
+ for child in module.children():
761
+ fn_recursive_set_attention_slice(child, slice_size)
762
+
763
+ reversed_slice_size = list(reversed(slice_size))
764
+ for module in self.children():
765
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
766
+
767
+ def _set_gradient_checkpointing(self, module, value=False):
768
+ if isinstance(module, (CrossAttnDownBlock2D, CrossAttnDownBlockMV2D, DownBlock2D, CrossAttnUpBlock2D, CrossAttnUpBlockMV2D, UpBlock2D)):
769
+ module.gradient_checkpointing = value
770
+
771
+ def forward(
772
+ self,
773
+ sample: torch.FloatTensor,
774
+ timestep: Union[torch.Tensor, float, int],
775
+ encoder_hidden_states: torch.Tensor,
776
+ class_labels: Optional[torch.Tensor] = None,
777
+ timestep_cond: Optional[torch.Tensor] = None,
778
+ attention_mask: Optional[torch.Tensor] = None,
779
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
780
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
781
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
782
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
783
+ encoder_attention_mask: Optional[torch.Tensor] = None,
784
+ return_dict: bool = True,
785
+ ) -> Union[UNetMV2DConditionOutput, Tuple]:
786
+ r"""
787
+ The [`UNet2DConditionModel`] forward method.
788
+
789
+ Args:
790
+ sample (`torch.FloatTensor`):
791
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
792
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
793
+ encoder_hidden_states (`torch.FloatTensor`):
794
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
795
+ encoder_attention_mask (`torch.Tensor`):
796
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
797
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
798
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
799
+ return_dict (`bool`, *optional*, defaults to `True`):
800
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
801
+ tuple.
802
+ cross_attention_kwargs (`dict`, *optional*):
803
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
804
+ added_cond_kwargs: (`dict`, *optional*):
805
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
806
+ are passed along to the UNet blocks.
807
+
808
+ Returns:
809
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
810
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
811
+ a `tuple` is returned where the first element is the sample tensor.
812
+ """
813
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
814
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
815
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
816
+ # on the fly if necessary.
817
+ default_overall_up_factor = 2**self.num_upsamplers
818
+
819
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
820
+ forward_upsample_size = False
821
+ upsample_size = None
822
+
823
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
824
+ logger.info("Forward upsample size to force interpolation output size.")
825
+ forward_upsample_size = True
826
+
827
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
828
+ # expects mask of shape:
829
+ # [batch, key_tokens]
830
+ # adds singleton query_tokens dimension:
831
+ # [batch, 1, key_tokens]
832
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
833
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
834
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
835
+ if attention_mask is not None:
836
+ # assume that mask is expressed as:
837
+ # (1 = keep, 0 = discard)
838
+ # convert mask into a bias that can be added to attention scores:
839
+ # (keep = +0, discard = -10000.0)
840
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
841
+ attention_mask = attention_mask.unsqueeze(1)
842
+
843
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
844
+ if encoder_attention_mask is not None:
845
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
846
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
847
+
848
+ # 0. center input if necessary
849
+ if self.config.center_input_sample:
850
+ sample = 2 * sample - 1.0
851
+
852
+ # 1. time
853
+ timesteps = timestep
854
+ if not torch.is_tensor(timesteps):
855
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
856
+ # This would be a good case for the `match` statement (Python 3.10+)
857
+ is_mps = sample.device.type == "mps"
858
+ if isinstance(timestep, float):
859
+ dtype = torch.float32 if is_mps else torch.float64
860
+ else:
861
+ dtype = torch.int32 if is_mps else torch.int64
862
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
863
+ elif len(timesteps.shape) == 0:
864
+ timesteps = timesteps[None].to(sample.device)
865
+
866
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
867
+ timesteps = timesteps.expand(sample.shape[0])
868
+
869
+ t_emb = self.time_proj(timesteps)
870
+
871
+ # `Timesteps` does not contain any weights and will always return f32 tensors
872
+ # but time_embedding might actually be running in fp16. so we need to cast here.
873
+ # there might be better ways to encapsulate this.
874
+ t_emb = t_emb.to(dtype=sample.dtype)
875
+
876
+ # self.time_embedding.to(dtype=t_emb.dtype)
877
+ emb = self.time_embedding(t_emb, timestep_cond)
878
+ aug_emb = None
879
+
880
+ if self.class_embedding is not None:
881
+ if class_labels is None:
882
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
883
+
884
+ if self.config.class_embed_type == "timestep":
885
+ class_labels = self.time_proj(class_labels)
886
+
887
+ # `Timesteps` does not contain any weights and will always return f32 tensors
888
+ # there might be better ways to encapsulate this.
889
+ class_labels = class_labels.to(dtype=sample.dtype)
890
+
891
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
892
+
893
+ if self.config.class_embeddings_concat:
894
+ emb = torch.cat([emb, class_emb], dim=-1)
895
+ else:
896
+ emb = emb + class_emb
897
+
898
+ if self.config.addition_embed_type == "text":
899
+ aug_emb = self.add_embedding(encoder_hidden_states)
900
+ elif self.config.addition_embed_type == "text_image":
901
+ # Kandinsky 2.1 - style
902
+ if "image_embeds" not in added_cond_kwargs:
903
+ raise ValueError(
904
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
905
+ )
906
+
907
+ image_embs = added_cond_kwargs.get("image_embeds")
908
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
909
+ aug_emb = self.add_embedding(text_embs, image_embs)
910
+ elif self.config.addition_embed_type == "text_time":
911
+ # SDXL - style
912
+ if "text_embeds" not in added_cond_kwargs:
913
+ raise ValueError(
914
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
915
+ )
916
+ text_embeds = added_cond_kwargs.get("text_embeds")
917
+ if "time_ids" not in added_cond_kwargs:
918
+ raise ValueError(
919
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
920
+ )
921
+ time_ids = added_cond_kwargs.get("time_ids")
922
+ time_embeds = self.add_time_proj(time_ids.flatten())
923
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
924
+
925
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
926
+ add_embeds = add_embeds.to(emb.dtype)
927
+ aug_emb = self.add_embedding(add_embeds)
928
+ elif self.config.addition_embed_type == "image":
929
+ # Kandinsky 2.2 - style
930
+ if "image_embeds" not in added_cond_kwargs:
931
+ raise ValueError(
932
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
933
+ )
934
+ image_embs = added_cond_kwargs.get("image_embeds")
935
+ aug_emb = self.add_embedding(image_embs)
936
+ elif self.config.addition_embed_type == "image_hint":
937
+ # Kandinsky 2.2 - style
938
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
939
+ raise ValueError(
940
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
941
+ )
942
+ image_embs = added_cond_kwargs.get("image_embeds")
943
+ hint = added_cond_kwargs.get("hint")
944
+ aug_emb, hint = self.add_embedding(image_embs, hint)
945
+ sample = torch.cat([sample, hint], dim=1)
946
+
947
+ emb = emb + aug_emb if aug_emb is not None else emb
948
+
949
+ if self.time_embed_act is not None:
950
+ emb = self.time_embed_act(emb)
951
+
952
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
953
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
954
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
955
+ # Kadinsky 2.1 - style
956
+ if "image_embeds" not in added_cond_kwargs:
957
+ raise ValueError(
958
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
959
+ )
960
+
961
+ image_embeds = added_cond_kwargs.get("image_embeds")
962
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
963
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
964
+ # Kandinsky 2.2 - style
965
+ if "image_embeds" not in added_cond_kwargs:
966
+ raise ValueError(
967
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
968
+ )
969
+ image_embeds = added_cond_kwargs.get("image_embeds")
970
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
971
+ # 2. pre-process
972
+ sample = self.conv_in(sample)
973
+
974
+ # 3. down
975
+
976
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
977
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
978
+
979
+ down_block_res_samples = (sample,)
980
+ for downsample_block in self.down_blocks:
981
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
982
+ # For t2i-adapter CrossAttnDownBlock2D
983
+ additional_residuals = {}
984
+ if is_adapter and len(down_block_additional_residuals) > 0:
985
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
986
+
987
+ sample, res_samples = downsample_block(
988
+ hidden_states=sample,
989
+ temb=emb,
990
+ encoder_hidden_states=encoder_hidden_states,
991
+ attention_mask=attention_mask,
992
+ cross_attention_kwargs=cross_attention_kwargs,
993
+ encoder_attention_mask=encoder_attention_mask,
994
+ **additional_residuals,
995
+ )
996
+ else:
997
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
998
+
999
+ if is_adapter and len(down_block_additional_residuals) > 0:
1000
+ sample += down_block_additional_residuals.pop(0)
1001
+
1002
+ down_block_res_samples += res_samples
1003
+
1004
+ if is_controlnet:
1005
+ new_down_block_res_samples = ()
1006
+
1007
+ for down_block_res_sample, down_block_additional_residual in zip(
1008
+ down_block_res_samples, down_block_additional_residuals
1009
+ ):
1010
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1011
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1012
+
1013
+ down_block_res_samples = new_down_block_res_samples
1014
+
1015
+ # 4. mid
1016
+ if self.mid_block is not None:
1017
+ sample = self.mid_block(
1018
+ sample,
1019
+ emb,
1020
+ encoder_hidden_states=encoder_hidden_states,
1021
+ attention_mask=attention_mask,
1022
+ cross_attention_kwargs=cross_attention_kwargs,
1023
+ encoder_attention_mask=encoder_attention_mask,
1024
+ )
1025
+
1026
+ if is_controlnet:
1027
+ sample = sample + mid_block_additional_residual
1028
+
1029
+ # 5. up
1030
+ for i, upsample_block in enumerate(self.up_blocks):
1031
+ is_final_block = i == len(self.up_blocks) - 1
1032
+
1033
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1034
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1035
+
1036
+ # if we have not reached the final block and need to forward the
1037
+ # upsample size, we do it here
1038
+ if not is_final_block and forward_upsample_size:
1039
+ upsample_size = down_block_res_samples[-1].shape[2:]
1040
+
1041
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1042
+ sample = upsample_block(
1043
+ hidden_states=sample,
1044
+ temb=emb,
1045
+ res_hidden_states_tuple=res_samples,
1046
+ encoder_hidden_states=encoder_hidden_states,
1047
+ cross_attention_kwargs=cross_attention_kwargs,
1048
+ upsample_size=upsample_size,
1049
+ attention_mask=attention_mask,
1050
+ encoder_attention_mask=encoder_attention_mask,
1051
+ )
1052
+ else:
1053
+ sample = upsample_block(
1054
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
1055
+ )
1056
+
1057
+ # 6. post-process
1058
+ if self.conv_norm_out:
1059
+ sample = self.conv_norm_out(sample)
1060
+ sample = self.conv_act(sample)
1061
+ sample = self.conv_out(sample)
1062
+
1063
+ if not return_dict:
1064
+ return (sample,)
1065
+
1066
+ return UNetMV2DConditionOutput(sample=sample)
1067
+
1068
+ @classmethod
1069
+ def from_pretrained_2d(
1070
+ cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
1071
+ camera_embedding_type: str, num_views: int, sample_size: int,
1072
+ zero_init_conv_in: bool = True, zero_init_camera_projection: bool = False,
1073
+ projection_class_embeddings_input_dim: int=6, cd_attention_last: bool = False,
1074
+ cd_attention_mid: bool = False, multiview_attention: bool = True,
1075
+ sparse_mv_attention: bool = False, mvcd_attention: bool = False,
1076
+ in_channels: int = 8, out_channels: int = 4,
1077
+ **kwargs
1078
+ ):
1079
+ r"""
1080
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
1081
+
1082
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
1083
+ train the model, set it back in training mode with `model.train()`.
1084
+
1085
+ Parameters:
1086
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
1087
+ Can be either:
1088
+
1089
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
1090
+ the Hub.
1091
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
1092
+ with [`~ModelMixin.save_pretrained`].
1093
+
1094
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
1095
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
1096
+ is not used.
1097
+ torch_dtype (`str` or `torch.dtype`, *optional*):
1098
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
1099
+ dtype is automatically derived from the model's weights.
1100
+ force_download (`bool`, *optional*, defaults to `False`):
1101
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
1102
+ cached versions if they exist.
1103
+ resume_download (`bool`, *optional*, defaults to `False`):
1104
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
1105
+ incompletely downloaded files are deleted.
1106
+ proxies (`Dict[str, str]`, *optional*):
1107
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
1108
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
1109
+ output_loading_info (`bool`, *optional*, defaults to `False`):
1110
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
1111
+ local_files_only(`bool`, *optional*, defaults to `False`):
1112
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
1113
+ won't be downloaded from the Hub.
1114
+ use_auth_token (`str` or *bool*, *optional*):
1115
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
1116
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
1117
+ revision (`str`, *optional*, defaults to `"main"`):
1118
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
1119
+ allowed by Git.
1120
+ from_flax (`bool`, *optional*, defaults to `False`):
1121
+ Load the model weights from a Flax checkpoint save file.
1122
+ subfolder (`str`, *optional*, defaults to `""`):
1123
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
1124
+ mirror (`str`, *optional*):
1125
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
1126
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
1127
+ information.
1128
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
1129
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
1130
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
1131
+ same device.
1132
+
1133
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
1134
+ more information about each option see [designing a device
1135
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
1136
+ max_memory (`Dict`, *optional*):
1137
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
1138
+ each GPU and the available CPU RAM if unset.
1139
+ offload_folder (`str` or `os.PathLike`, *optional*):
1140
+ The path to offload weights if `device_map` contains the value `"disk"`.
1141
+ offload_state_dict (`bool`, *optional*):
1142
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
1143
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
1144
+ when there is some disk offload.
1145
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
1146
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
1147
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
1148
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
1149
+ argument to `True` will raise an error.
1150
+ variant (`str`, *optional*):
1151
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
1152
+ loading `from_flax`.
1153
+ use_safetensors (`bool`, *optional*, defaults to `None`):
1154
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
1155
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
1156
+ weights. If set to `False`, `safetensors` weights are not loaded.
1157
+
1158
+ <Tip>
1159
+
1160
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
1161
+ `huggingface-cli login`. You can also activate the special
1162
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
1163
+ firewalled environment.
1164
+
1165
+ </Tip>
1166
+
1167
+ Example:
1168
+
1169
+ ```py
1170
+ from diffusers import UNet2DConditionModel
1171
+
1172
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
1173
+ ```
1174
+
1175
+ If you get the error message below, you need to finetune the weights for your downstream task:
1176
+
1177
+ ```bash
1178
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
1179
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
1180
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
1181
+ ```
1182
+ """
1183
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
1184
+ ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
1185
+ force_download = kwargs.pop("force_download", False)
1186
+ from_flax = kwargs.pop("from_flax", False)
1187
+ resume_download = kwargs.pop("resume_download", False)
1188
+ proxies = kwargs.pop("proxies", None)
1189
+ output_loading_info = kwargs.pop("output_loading_info", False)
1190
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
1191
+ use_auth_token = kwargs.pop("use_auth_token", None)
1192
+ revision = kwargs.pop("revision", None)
1193
+ torch_dtype = kwargs.pop("torch_dtype", None)
1194
+ subfolder = kwargs.pop("subfolder", None)
1195
+ device_map = kwargs.pop("device_map", None)
1196
+ max_memory = kwargs.pop("max_memory", None)
1197
+ offload_folder = kwargs.pop("offload_folder", None)
1198
+ offload_state_dict = kwargs.pop("offload_state_dict", False)
1199
+ variant = kwargs.pop("variant", None)
1200
+ use_safetensors = kwargs.pop("use_safetensors", None)
1201
+
1202
+ if use_safetensors and not is_safetensors_available():
1203
+ raise ValueError(
1204
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
1205
+ )
1206
+
1207
+ allow_pickle = False
1208
+ if use_safetensors is None:
1209
+ use_safetensors = is_safetensors_available()
1210
+ allow_pickle = True
1211
+
1212
+ if device_map is not None and not is_accelerate_available():
1213
+ raise NotImplementedError(
1214
+ "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
1215
+ " `device_map=None`. You can install accelerate with `pip install accelerate`."
1216
+ )
1217
+
1218
+ # Check if we can handle device_map and dispatching the weights
1219
+ if device_map is not None and not is_torch_version(">=", "1.9.0"):
1220
+ raise NotImplementedError(
1221
+ "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
1222
+ " `device_map=None`."
1223
+ )
1224
+
1225
+ # Load config if we don't provide a configuration
1226
+ config_path = pretrained_model_name_or_path
1227
+
1228
+ user_agent = {
1229
+ "diffusers": __version__,
1230
+ "file_type": "model",
1231
+ "framework": "pytorch",
1232
+ }
1233
+
1234
+ # load config
1235
+ config, unused_kwargs, commit_hash = cls.load_config(
1236
+ config_path,
1237
+ cache_dir=cache_dir,
1238
+ return_unused_kwargs=True,
1239
+ return_commit_hash=True,
1240
+ force_download=force_download,
1241
+ resume_download=resume_download,
1242
+ proxies=proxies,
1243
+ local_files_only=local_files_only,
1244
+ use_auth_token=use_auth_token,
1245
+ revision=revision,
1246
+ subfolder=subfolder,
1247
+ device_map=device_map,
1248
+ max_memory=max_memory,
1249
+ offload_folder=offload_folder,
1250
+ offload_state_dict=offload_state_dict,
1251
+ user_agent=user_agent,
1252
+ **kwargs,
1253
+ )
1254
+
1255
+ # modify config
1256
+ config["_class_name"] = cls.__name__
1257
+ config['in_channels'] = in_channels
1258
+ config['out_channels'] = out_channels
1259
+ config['sample_size'] = sample_size # training resolution
1260
+ config['num_views'] = num_views
1261
+ config['cd_attention_last'] = cd_attention_last
1262
+ config['cd_attention_mid'] = cd_attention_mid
1263
+ config['multiview_attention'] = multiview_attention
1264
+ config['sparse_mv_attention'] = sparse_mv_attention
1265
+ config['mvcd_attention'] = mvcd_attention
1266
+ config["down_block_types"] = [
1267
+ "CrossAttnDownBlockMV2D",
1268
+ "CrossAttnDownBlockMV2D",
1269
+ "CrossAttnDownBlockMV2D",
1270
+ "DownBlock2D"
1271
+ ]
1272
+ config['mid_block_type'] = "UNetMidBlockMV2DCrossAttn"
1273
+ config["up_block_types"] = [
1274
+ "UpBlock2D",
1275
+ "CrossAttnUpBlockMV2D",
1276
+ "CrossAttnUpBlockMV2D",
1277
+ "CrossAttnUpBlockMV2D"
1278
+ ]
1279
+ config['class_embed_type'] = 'projection'
1280
+ if camera_embedding_type == 'e_de_da_sincos':
1281
+ config['projection_class_embeddings_input_dim'] = projection_class_embeddings_input_dim # default 6
1282
+ else:
1283
+ raise NotImplementedError
1284
+
1285
+ # load model
1286
+ model_file = None
1287
+ if from_flax:
1288
+ raise NotImplementedError
1289
+ else:
1290
+ if use_safetensors:
1291
+ try:
1292
+ model_file = _get_model_file(
1293
+ pretrained_model_name_or_path,
1294
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
1295
+ cache_dir=cache_dir,
1296
+ force_download=force_download,
1297
+ # resume_download=resume_download,
1298
+ proxies=proxies,
1299
+ local_files_only=local_files_only,
1300
+ use_auth_token=use_auth_token,
1301
+ revision=revision,
1302
+ subfolder=subfolder,
1303
+ user_agent=user_agent,
1304
+ commit_hash=commit_hash,
1305
+ )
1306
+ except IOError as e:
1307
+ if not allow_pickle:
1308
+ raise e
1309
+ pass
1310
+ if model_file is None:
1311
+ model_file = _get_model_file(
1312
+ pretrained_model_name_or_path,
1313
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
1314
+ cache_dir=cache_dir,
1315
+ force_download=force_download,
1316
+ # resume_download=resume_download,
1317
+ proxies=proxies,
1318
+ local_files_only=local_files_only,
1319
+ use_auth_token=use_auth_token,
1320
+ revision=revision,
1321
+ subfolder=subfolder,
1322
+ user_agent=user_agent,
1323
+ commit_hash=commit_hash,
1324
+ )
1325
+
1326
+ model = cls.from_config(config, **unused_kwargs)
1327
+ import copy
1328
+ state_dict_v0 = load_state_dict(model_file, variant=variant)
1329
+ state_dict = copy.deepcopy(state_dict_v0)
1330
+ # attn_joint -> attn_joint_last; norm_joint -> norm_joint_last
1331
+ # attn_joint_twice -> attn_joint_mid; norm_joint_twice -> norm_joint_mid
1332
+ for key in state_dict_v0:
1333
+ if 'attn_joint.' in key:
1334
+ tmp = copy.deepcopy(key)
1335
+ state_dict[key.replace("attn_joint.", "attn_joint_last.")] = state_dict.pop(tmp)
1336
+ if 'norm_joint.' in key:
1337
+ tmp = copy.deepcopy(key)
1338
+ state_dict[key.replace("norm_joint.", "norm_joint_last.")] = state_dict.pop(tmp)
1339
+ if 'attn_joint_twice.' in key:
1340
+ tmp = copy.deepcopy(key)
1341
+ state_dict[key.replace("attn_joint_twice.", "attn_joint_mid.")] = state_dict.pop(tmp)
1342
+ if 'norm_joint_twice.' in key:
1343
+ tmp = copy.deepcopy(key)
1344
+ state_dict[key.replace("norm_joint_twice.", "norm_joint_mid.")] = state_dict.pop(tmp)
1345
+
1346
+ model._convert_deprecated_attention_blocks(state_dict)
1347
+
1348
+ conv_in_weight = state_dict['conv_in.weight']
1349
+ conv_out_weight = state_dict['conv_out.weight']
1350
+ model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model_2d(
1351
+ model,
1352
+ state_dict,
1353
+ model_file,
1354
+ pretrained_model_name_or_path,
1355
+ ignore_mismatched_sizes=True,
1356
+ )
1357
+ if any([key == 'conv_in.weight' for key, _, _ in mismatched_keys]):
1358
+ # initialize from the original SD structure
1359
+ model.conv_in.weight.data[:,:4] = conv_in_weight
1360
+
1361
+ # whether to place all zero to new layers?
1362
+ if zero_init_conv_in:
1363
+ model.conv_in.weight.data[:,4:] = 0.
1364
+
1365
+ if any([key == 'conv_out.weight' for key, _, _ in mismatched_keys]):
1366
+ # initialize from the original SD structure
1367
+ model.conv_out.weight.data[:,:4] = conv_out_weight
1368
+ if out_channels == 8: # copy for the last 4 channels
1369
+ model.conv_out.weight.data[:, 4:] = conv_out_weight
1370
+
1371
+ # if zero_init_camera_projection:
1372
+ # for p in model.class_embedding.parameters():
1373
+ # torch.nn.init.zeros_(p)
1374
+
1375
+ loading_info = {
1376
+ "missing_keys": missing_keys,
1377
+ "unexpected_keys": unexpected_keys,
1378
+ "mismatched_keys": mismatched_keys,
1379
+ "error_msgs": error_msgs,
1380
+ }
1381
+
1382
+ if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
1383
+ raise ValueError(
1384
+ f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
1385
+ )
1386
+ elif torch_dtype is not None:
1387
+ model = model.to(torch_dtype)
1388
+
1389
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
1390
+
1391
+ # Set model in evaluation mode to deactivate DropOut modules by default
1392
+ model.eval()
1393
+ if output_loading_info:
1394
+ return model, loading_info
1395
+
1396
+ return model
1397
+
1398
+ @classmethod
1399
+ def _load_pretrained_model_2d(
1400
+ cls,
1401
+ model,
1402
+ state_dict,
1403
+ resolved_archive_file,
1404
+ pretrained_model_name_or_path,
1405
+ ignore_mismatched_sizes=False,
1406
+ ):
1407
+ # Retrieve missing & unexpected_keys
1408
+ model_state_dict = model.state_dict()
1409
+ loaded_keys = list(state_dict.keys())
1410
+
1411
+ expected_keys = list(model_state_dict.keys())
1412
+
1413
+ original_loaded_keys = loaded_keys
1414
+
1415
+ missing_keys = list(set(expected_keys) - set(loaded_keys))
1416
+ unexpected_keys = list(set(loaded_keys) - set(expected_keys))
1417
+
1418
+ # Make sure we are able to load base models as well as derived models (with heads)
1419
+ model_to_load = model
1420
+
1421
+ def _find_mismatched_keys(
1422
+ state_dict,
1423
+ model_state_dict,
1424
+ loaded_keys,
1425
+ ignore_mismatched_sizes,
1426
+ ):
1427
+ mismatched_keys = []
1428
+ if ignore_mismatched_sizes:
1429
+ for checkpoint_key in loaded_keys:
1430
+ model_key = checkpoint_key
1431
+
1432
+ if (
1433
+ model_key in model_state_dict
1434
+ and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1435
+ ):
1436
+ mismatched_keys.append(
1437
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1438
+ )
1439
+ del state_dict[checkpoint_key]
1440
+ return mismatched_keys
1441
+
1442
+ if state_dict is not None:
1443
+ # Whole checkpoint
1444
+ mismatched_keys = _find_mismatched_keys(
1445
+ state_dict,
1446
+ model_state_dict,
1447
+ original_loaded_keys,
1448
+ ignore_mismatched_sizes,
1449
+ )
1450
+ error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
1451
+
1452
+ if len(error_msgs) > 0:
1453
+ error_msg = "\n\t".join(error_msgs)
1454
+ if "size mismatch" in error_msg:
1455
+ error_msg += (
1456
+ "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
1457
+ )
1458
+ raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
1459
+
1460
+ if len(unexpected_keys) > 0:
1461
+ logger.warning(
1462
+ f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
1463
+ f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
1464
+ f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
1465
+ " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
1466
+ " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
1467
+ f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
1468
+ " identical (initializing a BertForSequenceClassification model from a"
1469
+ " BertForSequenceClassification model)."
1470
+ )
1471
+ else:
1472
+ logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
1473
+ if len(missing_keys) > 0:
1474
+ logger.warning(
1475
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1476
+ f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
1477
+ " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
1478
+ )
1479
+ elif len(mismatched_keys) == 0:
1480
+ logger.info(
1481
+ f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
1482
+ f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
1483
+ f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
1484
+ " without further training."
1485
+ )
1486
+ if len(mismatched_keys) > 0:
1487
+ mismatched_warning = "\n".join(
1488
+ [
1489
+ f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
1490
+ for key, shape1, shape2 in mismatched_keys
1491
+ ]
1492
+ )
1493
+ logger.warning(
1494
+ f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
1495
+ f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
1496
+ f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
1497
+ " able to use it for predictions and inference."
1498
+ )
1499
+
1500
+ return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
1501
+
mv_diffusion_30/pipelines/pipeline_mvdiffusion_image.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import warnings
17
+ from typing import Callable, List, Optional, Union
18
+
19
+ import PIL
20
+ import torch
21
+ import torchvision.transforms.functional as TF
22
+ from packaging import version
23
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
24
+
25
+ from diffusers.configuration_utils import FrozenDict
26
+ from diffusers.image_processor import VaeImageProcessor
27
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
28
+ from diffusers.schedulers import KarrasDiffusionSchedulers
29
+ from diffusers.utils.torch_utils import logging, randn_tensor
30
+ from diffusers.utils.deprecation_utils import deprecate
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
34
+ from einops import rearrange, repeat
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ class MVDiffusionImagePipeline(DiffusionPipeline):
40
+ r"""
41
+ Pipeline to generate image variations from an input image using Stable Diffusion.
42
+
43
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
44
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
45
+
46
+ Args:
47
+ vae ([`AutoencoderKL`]):
48
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
49
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
50
+ Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
51
+ text_encoder ([`~transformers.CLIPTextModel`]):
52
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
53
+ tokenizer ([`~transformers.CLIPTokenizer`]):
54
+ A `CLIPTokenizer` to tokenize text.
55
+ unet ([`UNet2DConditionModel`]):
56
+ A `UNet2DConditionModel` to denoise the encoded image latents.
57
+ scheduler ([`SchedulerMixin`]):
58
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
59
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
60
+ safety_checker ([`StableDiffusionSafetyChecker`]):
61
+ Classification module that estimates whether generated images could be considered offensive or harmful.
62
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
63
+ about a model's potential harms.
64
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
65
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
66
+ """
67
+ # TODO: feature_extractor is required to encode images (if they are in PIL format),
68
+ # we should give a descriptive message if the pipeline doesn't have one.
69
+ _optional_components = ["safety_checker"]
70
+
71
+ def __init__(
72
+ self,
73
+ vae: AutoencoderKL,
74
+ image_encoder: CLIPVisionModelWithProjection,
75
+ unet: UNet2DConditionModel,
76
+ scheduler: KarrasDiffusionSchedulers,
77
+ safety_checker: StableDiffusionSafetyChecker,
78
+ feature_extractor: CLIPImageProcessor,
79
+ requires_safety_checker: bool = True,
80
+ camera_embedding_type: str = 'e_de_da_sincos',
81
+ num_views: int = 6,
82
+ pred_type: str = 'color',
83
+ ):
84
+ super().__init__()
85
+
86
+ if safety_checker is None and requires_safety_checker:
87
+ logger.warn(
88
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
89
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
90
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
91
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
92
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
93
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
94
+ )
95
+
96
+ if safety_checker is not None and feature_extractor is None:
97
+ raise ValueError(
98
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
99
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
100
+ )
101
+
102
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
103
+ version.parse(unet.config._diffusers_version).base_version
104
+ ) < version.parse("0.9.0.dev0")
105
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
106
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
107
+ deprecation_message = (
108
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
109
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
110
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
111
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
112
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
113
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
114
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
115
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
116
+ " the `unet/config.json` file"
117
+ )
118
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
119
+ new_config = dict(unet.config)
120
+ new_config["sample_size"] = 64
121
+ unet._internal_dict = FrozenDict(new_config)
122
+
123
+ self.register_modules(
124
+ vae=vae,
125
+ image_encoder=image_encoder,
126
+ unet=unet,
127
+ scheduler=scheduler,
128
+ safety_checker=safety_checker,
129
+ feature_extractor=feature_extractor,
130
+ )
131
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
132
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
133
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
134
+
135
+ self.camera_embedding_type: str = camera_embedding_type
136
+ self.num_views: int = num_views
137
+ self.pred_type = pred_type
138
+
139
+ self.camera_embedding = torch.tensor(
140
+ [[ 0.0000, 0.0000, 0.0000, 1.0000, 0.0000],
141
+ [ 0.0000, -0.2362, 0.8125, 1.0000, 0.0000],
142
+ [ 0.0000, -0.1686, 1.6934, 1.0000, 0.0000],
143
+ [ 0.0000, 0.5220, 3.1406, 1.0000, 0.0000],
144
+ [ 0.0000, 0.6904, 4.8359, 1.0000, 0.0000],
145
+ [ 0.0000, 0.3733, 5.5859, 1.0000, 0.0000],
146
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
147
+ [ 0.0000, -0.2362, 0.8125, 0.0000, 1.0000],
148
+ [ 0.0000, -0.1686, 1.6934, 0.0000, 1.0000],
149
+ [ 0.0000, 0.5220, 3.1406, 0.0000, 1.0000],
150
+ [ 0.0000, 0.6904, 4.8359, 0.0000, 1.0000],
151
+ [ 0.0000, 0.3733, 5.5859, 0.0000, 1.0000]], dtype=torch.float16)
152
+
153
+ def _encode_image(self, image_pil, device, num_images_per_prompt, do_classifier_free_guidance):
154
+ dtype = next(self.image_encoder.parameters()).dtype
155
+
156
+ image_pt = self.feature_extractor(images=image_pil, return_tensors="pt").pixel_values
157
+ image_pt = image_pt.to(device=device, dtype=dtype)
158
+ image_embeddings = self.image_encoder(image_pt).image_embeds
159
+ image_embeddings = image_embeddings.unsqueeze(1)
160
+
161
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
162
+ # Note: repeat differently from official pipelines
163
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
164
+ bs_embed, seq_len, _ = image_embeddings.shape
165
+ image_embeddings = image_embeddings.repeat(num_images_per_prompt, 1, 1)
166
+
167
+ if do_classifier_free_guidance:
168
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
169
+
170
+ # For classifier free guidance, we need to do two forward passes.
171
+ # Here we concatenate the unconditional and text embeddings into a single batch
172
+ # to avoid doing two forward passes
173
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
174
+
175
+ image_pt = torch.stack([TF.to_tensor(img) for img in image_pil], dim=0).to(device).to(dtype)
176
+ image_pt = image_pt * 2.0 - 1.0
177
+ image_latents = self.vae.encode(image_pt).latent_dist.mode() * self.vae.config.scaling_factor
178
+ # Note: repeat differently from official pipelines
179
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
180
+ image_latents = image_latents.repeat(num_images_per_prompt, 1, 1, 1)
181
+
182
+ if do_classifier_free_guidance:
183
+ image_latents = torch.cat([torch.zeros_like(image_latents), image_latents])
184
+
185
+ return image_embeddings, image_latents
186
+
187
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
188
+ def run_safety_checker(self, image, device, dtype):
189
+ if self.safety_checker is None:
190
+ has_nsfw_concept = None
191
+ else:
192
+ if torch.is_tensor(image):
193
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
194
+ else:
195
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
196
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
197
+ image, has_nsfw_concept = self.safety_checker(
198
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
199
+ )
200
+ return image, has_nsfw_concept
201
+
202
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
203
+ def decode_latents(self, latents):
204
+ warnings.warn(
205
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
206
+ " use VaeImageProcessor instead",
207
+ FutureWarning,
208
+ )
209
+ latents = 1 / self.vae.config.scaling_factor * latents
210
+ image = self.vae.decode(latents, return_dict=False)[0]
211
+ image = (image / 2 + 0.5).clamp(0, 1)
212
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
213
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
214
+ return image
215
+
216
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
217
+ def prepare_extra_step_kwargs(self, generator, eta):
218
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
219
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
220
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
221
+ # and should be between [0, 1]
222
+
223
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
224
+ extra_step_kwargs = {}
225
+ if accepts_eta:
226
+ extra_step_kwargs["eta"] = eta
227
+
228
+ # check if the scheduler accepts generator
229
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
230
+ if accepts_generator:
231
+ extra_step_kwargs["generator"] = generator
232
+ return extra_step_kwargs
233
+
234
+ def check_inputs(self, image, height, width, callback_steps):
235
+ if (
236
+ not isinstance(image, torch.Tensor)
237
+ and not isinstance(image, PIL.Image.Image)
238
+ and not isinstance(image, list)
239
+ ):
240
+ raise ValueError(
241
+ "`image` has to be of type `torch.FloatTensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
242
+ f" {type(image)}"
243
+ )
244
+
245
+ if height % 8 != 0 or width % 8 != 0:
246
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
247
+
248
+ if (callback_steps is None) or (
249
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
250
+ ):
251
+ raise ValueError(
252
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
253
+ f" {type(callback_steps)}."
254
+ )
255
+
256
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
257
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, cross_domain_latnte=True):
258
+ if cross_domain_latnte:
259
+ # generate cross-domain initial latents
260
+ # for cross-domain task, make sure the two domain are start from a same initial latents
261
+ assert batch_size % 2 == 0
262
+ batch_size = batch_size // 2
263
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
264
+ if isinstance(generator, list) and len(generator) != batch_size:
265
+ raise ValueError(
266
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
267
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
268
+ )
269
+
270
+ if latents is None:
271
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
272
+ else:
273
+ latents = latents.to(device)
274
+
275
+ # scale the initial noise by the standard deviation required by the scheduler
276
+ latents = latents * self.scheduler.init_noise_sigma
277
+ if cross_domain_latnte:
278
+ latents = torch.cat([latents] * 2)
279
+ return latents
280
+
281
+ def prepare_camera_embedding(self, camera_embedding: Union[float, torch.Tensor], do_classifier_free_guidance, num_images_per_prompt=1):
282
+ # (B, 3)
283
+ camera_embedding = camera_embedding.to(dtype=self.unet.dtype, device=self.unet.device)
284
+
285
+ if self.camera_embedding_type == 'e_de_da_sincos':
286
+ # (B, 6)
287
+ camera_embedding = torch.cat([
288
+ torch.sin(camera_embedding),
289
+ torch.cos(camera_embedding)
290
+ ], dim=-1)
291
+ assert self.unet.config.class_embed_type == 'projection'
292
+ assert self.unet.config.projection_class_embeddings_input_dim == 14 or self.unet.config.projection_class_embeddings_input_dim == 10
293
+ else:
294
+ raise NotImplementedError
295
+
296
+ # Note: repeat differently from official pipelines
297
+ # B1B2B3B4 -> B1B2B3B4B1B2B3B4
298
+ camera_embedding = camera_embedding.repeat(num_images_per_prompt, 1)
299
+
300
+ if do_classifier_free_guidance:
301
+ camera_embedding = torch.cat([
302
+ camera_embedding,
303
+ camera_embedding
304
+ ], dim=0)
305
+
306
+ return camera_embedding
307
+
308
+ def reshape_to_cd_input(self, input):
309
+ # reshape input for cross-domain attention
310
+ input_norm_uc, input_rgb_uc, input_norm_cond, input_rgb_cond = torch.chunk(
311
+ input, dim=0, chunks=4)
312
+ input = torch.cat(
313
+ [input_norm_uc, input_norm_cond, input_rgb_uc, input_rgb_cond], dim=0)
314
+ return input
315
+
316
+ def reshape_to_cfg_output(self, output):
317
+ # reshape input for cfg
318
+ output_norm_uc, output_norm_cond, output_rgb_uc, output_rgb_cond = torch.chunk(
319
+ output, dim=0, chunks=4)
320
+ output = torch.cat(
321
+ [output_norm_uc, output_rgb_uc, output_norm_cond, output_rgb_cond],
322
+ dim=0)
323
+ return output
324
+
325
+ @torch.no_grad()
326
+ def __call__(
327
+ self,
328
+ image: Union[List[PIL.Image.Image], torch.FloatTensor],
329
+ # elevation_cond: torch.FloatTensor,
330
+ # elevation: torch.FloatTensor,
331
+ # azimuth: torch.FloatTensor,
332
+ camera_embedding: Optional[torch.FloatTensor]=None,
333
+ height: Optional[int] = None,
334
+ width: Optional[int] = None,
335
+ num_inference_steps: int = 50,
336
+ guidance_scale: float = 7.5,
337
+ num_images_per_prompt: Optional[int] = 1,
338
+ eta: float = 0.0,
339
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
340
+ latents: Optional[torch.FloatTensor] = None,
341
+ output_type: Optional[str] = "pil",
342
+ return_dict: bool = True,
343
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
344
+ callback_steps: int = 1,
345
+ normal_cond: Optional[Union[List[PIL.Image.Image], torch.FloatTensor]] = None,
346
+ ):
347
+ r"""
348
+ The call function to the pipeline for generation.
349
+
350
+ Args:
351
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
352
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
353
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
354
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
355
+ The height in pixels of the generated image.
356
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
357
+ The width in pixels of the generated image.
358
+ num_inference_steps (`int`, *optional*, defaults to 50):
359
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
360
+ expense of slower inference. This parameter is modulated by `strength`.
361
+ guidance_scale (`float`, *optional*, defaults to 7.5):
362
+ A higher guidance scale value encourages the model to generate images closely linked to the text
363
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
364
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
365
+ The number of images to generate per prompt.
366
+ eta (`float`, *optional*, defaults to 0.0):
367
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
368
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
369
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
370
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
371
+ generation deterministic.
372
+ latents (`torch.FloatTensor`, *optional*):
373
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
374
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
375
+ tensor is generated by sampling using the supplied random `generator`.
376
+ output_type (`str`, *optional*, defaults to `"pil"`):
377
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
378
+ return_dict (`bool`, *optional*, defaults to `True`):
379
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
380
+ plain tuple.
381
+ callback (`Callable`, *optional*):
382
+ A function that calls every `callback_steps` steps during inference. The function is called with the
383
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
384
+ callback_steps (`int`, *optional*, defaults to 1):
385
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
386
+ every step.
387
+
388
+ Returns:
389
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
390
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
391
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
392
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
393
+ "not-safe-for-work" (nsfw) content.
394
+
395
+ Examples:
396
+
397
+ ```py
398
+ from diffusers import StableDiffusionImageVariationPipeline
399
+ from PIL import Image
400
+ from io import BytesIO
401
+ import requests
402
+
403
+ pipe = StableDiffusionImageVariationPipeline.from_pretrained(
404
+ "lambdalabs/sd-image-variations-diffusers", revision="v2.0"
405
+ )
406
+ pipe = pipe.to("cuda")
407
+
408
+ url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200"
409
+
410
+ response = requests.get(url)
411
+ image = Image.open(BytesIO(response.content)).convert("RGB")
412
+
413
+ out = pipe(image, num_images_per_prompt=3, guidance_scale=15)
414
+ out["images"][0].save("result.jpg")
415
+ ```
416
+ """
417
+ # 0. Default height and width to unet
418
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
419
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
420
+
421
+ # 1. Check inputs. Raise error if not correct
422
+ self.check_inputs(image, height, width, callback_steps)
423
+
424
+
425
+ # 2. Define call parameters
426
+ if isinstance(image, list):
427
+ batch_size = len(image)
428
+ elif isinstance(image, torch.Tensor):
429
+ batch_size = image.shape[0]
430
+ assert batch_size >= self.num_views and batch_size % self.num_views == 0
431
+ elif isinstance(image, PIL.Image.Image):
432
+ image = [image]*self.num_views*2
433
+ batch_size = self.num_views*2
434
+
435
+ device = self._execution_device
436
+ dtype = self.vae.dtype
437
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
438
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
439
+ # corresponds to doing no classifier free guidance.
440
+ do_classifier_free_guidance = guidance_scale != 1.0
441
+
442
+ # 3. Encode input image
443
+ if isinstance(image, list):
444
+ image_pil = image
445
+ elif isinstance(image, torch.Tensor):
446
+ image_pil = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
447
+ image_embeddings, image_latents = self._encode_image(image_pil, device, num_images_per_prompt, do_classifier_free_guidance)
448
+
449
+ if normal_cond is not None:
450
+ if isinstance(normal_cond, list):
451
+ normal_cond_pil = normal_cond
452
+ elif isinstance(normal_cond, torch.Tensor):
453
+ normal_cond_pil = [TF.to_pil_image(normal_cond[i]) for i in range(normal_cond.shape[0])]
454
+ _, image_latents = self._encode_image(normal_cond_pil, device, num_images_per_prompt, do_classifier_free_guidance)
455
+
456
+
457
+ # assert len(elevation_cond) == batch_size and len(elevation) == batch_size and len(azimuth) == batch_size
458
+ # camera_embeddings = self.prepare_camera_condition(elevation_cond, elevation, azimuth, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt)
459
+
460
+ if camera_embedding is not None:
461
+ assert len(camera_embedding) == batch_size
462
+ else:
463
+ camera_embedding = self.camera_embedding.to(dtype)
464
+ camera_embedding = repeat(camera_embedding, "Nv Nce -> (B Nv) Nce", B=batch_size//len(camera_embedding))
465
+ camera_embeddings = self.prepare_camera_embedding(camera_embedding, do_classifier_free_guidance=do_classifier_free_guidance, num_images_per_prompt=num_images_per_prompt)
466
+
467
+ # 4. Prepare timesteps
468
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
469
+ timesteps = self.scheduler.timesteps
470
+
471
+ # 5. Prepare latent variables
472
+ num_channels_latents = self.unet.config.out_channels
473
+ latents = self.prepare_latents(
474
+ batch_size * num_images_per_prompt,
475
+ num_channels_latents,
476
+ height,
477
+ width,
478
+ image_embeddings.dtype,
479
+ device,
480
+ generator,
481
+ latents,
482
+ cross_domain_latnte=True
483
+ )
484
+
485
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
486
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
487
+
488
+ # 7. Denoising loop
489
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
490
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
491
+ if do_classifier_free_guidance and self.pred_type == 'joint_color_normal':
492
+ print("reshape the input to cross-domain format")
493
+ image_embeddings = self.reshape_to_cd_input(image_embeddings)
494
+ camera_embeddings = self.reshape_to_cd_input(camera_embeddings)
495
+ image_latents = self.reshape_to_cd_input(image_latents)
496
+ for i, t in enumerate(timesteps):
497
+ # expand the latents if we are doing classifier free guidance
498
+ # latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
499
+ if do_classifier_free_guidance and self.pred_type == 'joint_color_normal':
500
+ latent_model_input = torch.cat([latents] * 2)
501
+ latent_model_input = self.reshape_to_cd_input(latent_model_input)
502
+ elif do_classifier_free_guidance and self.pred_type != 'joint_color_normal':
503
+ latent_model_input = torch.cat([latents] * 2)
504
+ else:
505
+ latent_model_input = latents
506
+
507
+ latent_model_input = torch.cat([
508
+ latent_model_input, image_latents
509
+ ], dim=1)
510
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
511
+
512
+ # predict the noise residual
513
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings,
514
+ class_labels=camera_embeddings).sample
515
+
516
+ # perform guidance
517
+ if do_classifier_free_guidance and self.pred_type != 'joint_color_normal':
518
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
519
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
520
+ elif do_classifier_free_guidance and self.pred_type == 'joint_color_normal':
521
+ noise_pred = self.reshape_to_cfg_output(noise_pred)
522
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
523
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
524
+
525
+ # compute the previous noisy sample x_t -> x_t-1
526
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
527
+
528
+ # call the callback, if provided
529
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
530
+ progress_bar.update()
531
+ if callback is not None and i % callback_steps == 0:
532
+ callback(i, t, latents)
533
+
534
+ if not output_type == "latent":
535
+ if num_channels_latents == 8:
536
+ latents = torch.cat([latents[:, :4], latents[:, 4:]], dim=0)
537
+
538
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
539
+ image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype)
540
+ else:
541
+ image = latents
542
+ has_nsfw_concept = None
543
+
544
+ if has_nsfw_concept is None:
545
+ do_denormalize = [True] * image.shape[0]
546
+ else:
547
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
548
+
549
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
550
+
551
+ if not return_dict:
552
+ return (image, has_nsfw_concept)
553
+
554
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
555
+