diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..dd1285d1cba924e32b1580a148943fb416073f7b --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +*.DS_Store +*.ipynb +*.egg-info/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..84930f96a1348853fb54f55a528b9651f230a347 --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +--- +license: mit +--- + + +# One-2-3-45's Inference Model + +
+ + + +
+ +This inference model supports the demo for [One-2-3-45](http://One-2-3-45.com). + +Try out our 🤗 Hugging Face Demo: + + Open in HuggingFace + + +Please refer to our [GitHub repo](https://github.com/One-2-3-45/One-2-3-45) for full code release and local deployment. + +## Citation + +```bibtex +@misc{liu2023one2345, + title={One-2-3-45: Any Single Image to 3D Mesh in 45 Seconds without Per-Shape Optimization}, + author={Minghua Liu and Chao Xu and Haian Jin and Linghao Chen and Mukund Varma T and Zexiang Xu and Hao Su}, + year={2023}, + eprint={2306.16928}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` \ No newline at end of file diff --git a/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf b/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf new file mode 100644 index 0000000000000000000000000000000000000000..e591ac038b854140efc81cdad3c8dc7838f03a83 --- /dev/null +++ b/SparseNeuS_demo_v1/confs/one2345_lod0_val_demo.conf @@ -0,0 +1,135 @@ +# - for the lod1 geometry network, using adaptive cost for sparse cost regularization network +#- for lod1 rendering network, using depth-adaptive render + +general { + + base_exp_dir = exp/lod0 # !!! where you store the results and checkpoints to be used + recording = [ + ./, + ./data + ./ops + ./models + ./loss + ] +} + +dataset { + trainpath = ../ + valpath = ../ # !!! where you store the validation data + testpath = ../ + + imgScale_train = 1.0 + imgScale_test = 1.0 + nviews = 5 + clean_image = True + importance_sample = True + test_ref_views = [23] + + # test dataset + test_n_views = 2 + test_img_wh = [256, 256] + test_clip_wh = [0, 0] + test_scan_id = scan110 + train_img_idx = [49, 50, 52, 53, 54, 56, 58] #[21, 22, 23, 24, 25] # + test_img_idx = [51, 55, 57] #[32, 33, 34] # + + test_dir_comment = train +} + +train { + learning_rate = 2e-4 + learning_rate_milestone = [100000, 150000, 200000] + learning_rate_factor = 0.5 + end_iter = 200000 + save_freq = 5000 + val_freq = 1 + val_mesh_freq = 1 + report_freq = 100 + + N_rays = 512 + + validate_resolution_level = 4 + anneal_start = 0 + anneal_end = 25000 + anneal_start_lod1 = 0 + anneal_end_lod1 = 15000 + + use_white_bkgd = True + + # Loss + # ! for training the lod1 network, don't use this regularization in first 10k steps; then use the regularization + sdf_igr_weight = 0.1 + sdf_sparse_weight = 0.02 # 0.002 for lod1 network; 0.02 for lod0 network + sdf_decay_param = 100 # cannot be too large, which decide the tsdf range + fg_bg_weight = 0.01 # first 0.01 + bg_ratio = 0.3 + + if_fix_lod0_networks = False +} + +model { + num_lods = 1 + + sdf_network_lod0 { + lod = 0, + ch_in = 56, # the channel num of fused pyramid features + voxel_size = 0.02105263, # 0.02083333, should be 2/95 + vol_dims = [96, 96, 96], + hidden_dim = 128, + cost_type = variance_mean + d_pyramid_feature_compress = 16, + regnet_d_out = 16, + num_sdf_layers = 4, + # position embedding + multires = 6 + } + + + sdf_network_lod1 { + lod = 1, + ch_in = 56, # the channel num of fused pyramid features + voxel_size = 0.0104712, #0.01041667, should be 2/191 + vol_dims = [192, 192, 192], + hidden_dim = 128, + cost_type = variance_mean + d_pyramid_feature_compress = 8, + regnet_d_out = 16, + num_sdf_layers = 4, + + # position embedding + multires = 6 + } + + + variance_network { + init_val = 0.2 + } + + variance_network_lod1 { + init_val = 0.2 + } + + rendering_network { + in_geometry_feat_ch = 16 + in_rendering_feat_ch = 56 + anti_alias_pooling = True + } + + rendering_network_lod1 { + in_geometry_feat_ch = 16 # default 8 + in_rendering_feat_ch = 56 + anti_alias_pooling = True + + } + + + trainer { + n_samples_lod0 = 64 + n_importance_lod0 = 64 + n_samples_lod1 = 64 + n_importance_lod1 = 64 + n_outside = 0 # 128 if render_outside_uniform_sampling + perturb = 1.0 + alpha_type = div + } +} diff --git a/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py new file mode 100644 index 0000000000000000000000000000000000000000..530a434828d4fdb1c4d2439ea9fbdcc40d449ef6 --- /dev/null +++ b/SparseNeuS_demo_v1/data/blender_general_narrow_all_eval_new_data.py @@ -0,0 +1,394 @@ +from torch.utils.data import Dataset +import os +import json +import numpy as np +import cv2 +from PIL import Image +import torch +from torchvision import transforms as T +from data.scene import get_boundingbox + +from models.rays import gen_rays_from_single_image, gen_random_rays_from_single_image +from kornia import create_meshgrid + +def get_ray_directions(H, W, focal, center=None): + """ + Get ray directions for all pixels in camera coordinate. + Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/ + ray-tracing-generating-camera-rays/standard-coordinate-systems + Inputs: + H, W, focal: image height, width and focal length + Outputs: + directions: (H, W, 3), the direction of the rays in camera coordinate + """ + grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5 # 1xHxWx2 + + i, j = grid.unbind(-1) + # the direction here is without +0.5 pixel centering as calibration is not so accurate + # see https://github.com/bmild/nerf/issues/24 + cent = center if center is not None else [W / 2, H / 2] + directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3) + + return directions + +def load_K_Rt_from_P(filename, P=None): + if P is None: + lines = open(filename).read().splitlines() + if len(lines) == 4: + lines = lines[1:] + lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] + P = np.asarray(lines).astype(np.float32).squeeze() + + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() # ? why need transpose here + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose # ! return cam2world matrix here + + +# ! load one ref-image with multiple src-images in camera coordinate system +class BlenderPerView(Dataset): + def __init__(self, root_dir, split, n_views=3, img_wh=(256, 256), downSample=1.0, + split_filepath=None, pair_filepath=None, + N_rays=512, + vol_dims=[128, 128, 128], batch_size=1, + clean_image=False, importance_sample=False, test_ref_views=[], + specific_dataset_name = 'GSO' + ): + + # print("root_dir: ", root_dir) + self.root_dir = root_dir + self.split = split + + self.specific_dataset_name = specific_dataset_name + self.n_views = n_views + self.N_rays = N_rays + self.batch_size = batch_size # - used for construct new metas for gru fusion training + + self.clean_image = clean_image + self.importance_sample = importance_sample + self.test_ref_views = test_ref_views # used for testing + self.scale_factor = 1.0 + self.scale_mat = np.float32(np.diag([1, 1, 1, 1.0])) + assert self.split == 'val' or 'export_mesh', 'only support val or export_mesh' + # find all subfolders + main_folder = os.path.join(root_dir, self.specific_dataset_name) + self.shape_list = [""] # os.listdir(main_folder) # MODIFIED + self.shape_list.sort() + + # self.shape_list = ['barrel_render'] + # self.shape_list = ["barrel", "bag", "mailbox", "shoe", "chair", "car", "dog", "teddy"] # TO BE DELETED + + + self.lvis_paths = [] + for shape_name in self.shape_list: + self.lvis_paths.append(os.path.join(main_folder, shape_name)) + + if img_wh is not None: + assert img_wh[0] % 32 == 0 and img_wh[1] % 32 == 0, \ + 'img_wh must both be multiples of 32!' + + + # * bounding box for rendering + self.bbox_min = np.array([-1.0, -1.0, -1.0]) + self.bbox_max = np.array([1.0, 1.0, 1.0]) + + # - used for cost volume regularization + self.voxel_dims = torch.tensor(vol_dims, dtype=torch.float32) + self.partial_vol_origin = torch.tensor([-1., -1., -1.], dtype=torch.float32) + + + def define_transforms(self): + self.transform = T.Compose([T.ToTensor()]) + + + + def load_cam_info(self): + for vid, img_id in enumerate(self.img_ids): + intrinsic, extrinsic, near_far = self.intrinsic, np.linalg.inv(self.c2ws[vid]), self.near_far + self.all_intrinsics.append(intrinsic) + self.all_extrinsics.append(extrinsic) + self.all_near_fars.append(near_far) + + def read_mask(self, filename): + mask_h = cv2.imread(filename, 0) + mask_h = cv2.resize(mask_h, None, fx=self.downSample, fy=self.downSample, + interpolation=cv2.INTER_NEAREST) + mask = cv2.resize(mask_h, None, fx=0.25, fy=0.25, + interpolation=cv2.INTER_NEAREST) + + mask[mask > 0] = 1 # the masks stored in png are not binary + mask_h[mask_h > 0] = 1 + + return mask, mask_h + + def cal_scale_mat(self, img_hw, intrinsics, extrinsics, near_fars, factor=1.): + + center, radius, bounds = get_boundingbox(img_hw, intrinsics, extrinsics, near_fars) + + radius = radius * factor + scale_mat = np.diag([radius, radius, radius, 1.0]) + scale_mat[:3, 3] = center.cpu().numpy() + scale_mat = scale_mat.astype(np.float32) + + return scale_mat, 1. / radius.cpu().numpy() + + def __len__(self): + # return 8*len(self.lvis_paths) + return len(self.lvis_paths) + + def __getitem__(self, idx): + sample = {} + idx = idx * 8 # to be deleted + origin_idx = idx + imgs, depths_h, masks_h = [], [], [] # full size (256, 256) + intrinsics, w2cs, c2ws, near_fars = [], [], [], [] # record proj-mats between views + + folder_path = self.lvis_paths[idx//8] + idx = idx % 8 # [0, 7] + + # last subdir name + shape_name = os.path.split(folder_path)[-1] + + pose_json_path = os.path.join(folder_path, "pose.json") + with open(pose_json_path, 'r') as f: + meta = json.load(f) + + self.img_ids = list(meta["c2ws"].keys()) # e.g. "view_0", "view_7", "view_0_2_10" + self.img_wh = (256, 256) + self.input_poses = np.array(list(meta["c2ws"].values())) + intrinsic = np.eye(4) + intrinsic[:3, :3] = np.array(meta["intrinsics"]) + self.intrinsic = intrinsic + self.near_far = np.array(meta["near_far"]) + self.near_far[1] = 1.8 + self.define_transforms() + self.blender2opencv = np.array( + [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] + ) + + self.c2ws = [] + self.w2cs = [] + self.near_fars = [] + for image_idx, img_id in enumerate(self.img_ids): + pose = self.input_poses[image_idx] + c2w = pose @ self.blender2opencv + self.c2ws.append(c2w) + self.w2cs.append(np.linalg.inv(c2w)) + self.near_fars.append(self.near_far) + self.c2ws = np.stack(self.c2ws, axis=0) + self.w2cs = np.stack(self.w2cs, axis=0) + + + self.all_intrinsics = [] # the cam info of the whole scene + self.all_extrinsics = [] + self.all_near_fars = [] + self.load_cam_info() + + + # target view + c2w = self.c2ws[idx] + w2c = np.linalg.inv(c2w) + w2c_ref = w2c + w2c_ref_inv = np.linalg.inv(w2c_ref) + + w2cs.append(w2c @ w2c_ref_inv) + c2ws.append(np.linalg.inv(w2c @ w2c_ref_inv)) + + img_filename = os.path.join(folder_path, 'stage1_8', f'{self.img_ids[idx]}') + + img = Image.open(img_filename) + img = self.transform(img) # (4, h, w) + + + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + imgs += [img] + + + depth_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.float32) + depth_h = depth_h.fill_(-1.0) + mask_h = torch.ones((img.shape[1], img.shape[2]), dtype=torch.int32) + + + depths_h.append(depth_h) + masks_h.append(mask_h) + + intrinsic = self.intrinsic + intrinsics.append(intrinsic) + + + near_fars.append(self.near_fars[idx]) + image_perm = 0 # only supervised on reference view + + mask_dilated = None + + + src_views = range(8, 8 + 8 * 4) + + for vid in src_views: + + img_filename = os.path.join(folder_path, 'stage2_8', f'{self.img_ids[vid]}') + img = Image.open(img_filename) + img_wh = self.img_wh + + img = self.transform(img) + if img.shape[0] == 4: + img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB + + imgs += [img] + depth_h = np.ones(img.shape[1:], dtype=np.float32) + depths_h.append(depth_h) + masks_h.append(np.ones(img.shape[1:], dtype=np.int32)) + + near_fars.append(self.all_near_fars[vid]) + intrinsics.append(self.all_intrinsics[vid]) + + w2cs.append(self.all_extrinsics[vid] @ w2c_ref_inv) + + + # ! estimate scale_mat + scale_mat, scale_factor = self.cal_scale_mat( + img_hw=[img_wh[1], img_wh[0]], + intrinsics=intrinsics, extrinsics=w2cs, + near_fars=near_fars, factor=1.1 + ) + + + new_near_fars = [] + new_w2cs = [] + new_c2ws = [] + new_affine_mats = [] + new_depths_h = [] + for intrinsic, extrinsic, near_far, depth in zip(intrinsics, w2cs, near_fars, depths_h): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_w2cs.append(w2c) + new_c2ws.append(c2w) + affine_mat = np.eye(4) + affine_mat[:3, :4] = intrinsic[:3, :3] @ w2c[:3, :4] + new_affine_mats.append(affine_mat) + + camera_o = c2w[:3, 3] + dist = np.sqrt(np.sum(camera_o ** 2)) + near = dist - 1 + far = dist + 1 + + new_near_fars.append([0.95 * near, 1.05 * far]) + new_depths_h.append(depth * scale_factor) + + imgs = torch.stack(imgs).float() + depths_h = np.stack(new_depths_h) + masks_h = np.stack(masks_h) + + affine_mats = np.stack(new_affine_mats) + intrinsics, w2cs, c2ws, near_fars = np.stack(intrinsics), np.stack(new_w2cs), np.stack(new_c2ws), np.stack( + new_near_fars) + + if self.split == 'train': + start_idx = 0 + else: + start_idx = 1 + + + + target_w2cs = [] + target_intrinsics = [] + new_target_w2cs = [] + for i_idx in range(8): + target_w2cs.append(self.all_extrinsics[i_idx] @ w2c_ref_inv) + target_intrinsics.append(self.all_intrinsics[i_idx]) + + for intrinsic, extrinsic in zip(target_intrinsics, target_w2cs): + + P = intrinsic @ extrinsic @ scale_mat + P = P[:3, :4] + # - should use load_K_Rt_from_P() to obtain c2w + c2w = load_K_Rt_from_P(None, P)[1] + w2c = np.linalg.inv(c2w) + new_target_w2cs.append(w2c) + target_w2cs = np.stack(new_target_w2cs) + + + + view_ids = [idx] + list(src_views) + sample['origin_idx'] = origin_idx + sample['images'] = imgs # (V, 3, H, W) + sample['depths_h'] = torch.from_numpy(depths_h.astype(np.float32)) # (V, H, W) + sample['masks_h'] = torch.from_numpy(masks_h.astype(np.float32)) # (V, H, W) + sample['w2cs'] = torch.from_numpy(w2cs.astype(np.float32)) # (V, 4, 4) + sample['c2ws'] = torch.from_numpy(c2ws.astype(np.float32)) # (V, 4, 4) + sample['target_candidate_w2cs'] = torch.from_numpy(target_w2cs.astype(np.float32)) # (8, 4, 4) + sample['near_fars'] = torch.from_numpy(near_fars.astype(np.float32)) # (V, 2) + sample['intrinsics'] = torch.from_numpy(intrinsics.astype(np.float32))[:, :3, :3] # (V, 3, 3) + sample['view_ids'] = torch.from_numpy(np.array(view_ids)) + sample['affine_mats'] = torch.from_numpy(affine_mats.astype(np.float32)) # ! in world space + + sample['scan'] = shape_name + + sample['scale_factor'] = torch.tensor(scale_factor) + sample['img_wh'] = torch.from_numpy(np.array(img_wh)) + sample['render_img_idx'] = torch.tensor(image_perm) + sample['partial_vol_origin'] = self.partial_vol_origin + sample['meta'] = str(self.specific_dataset_name) + '_' + str(shape_name) + "_refview" + str(view_ids[0]) + # print("meta: ", sample['meta']) + + # - image to render + sample['query_image'] = sample['images'][0] + sample['query_c2w'] = sample['c2ws'][0] + sample['query_w2c'] = sample['w2cs'][0] + sample['query_intrinsic'] = sample['intrinsics'][0] + sample['query_depth'] = sample['depths_h'][0] + sample['query_mask'] = sample['masks_h'][0] + sample['query_near_far'] = sample['near_fars'][0] + + sample['images'] = sample['images'][start_idx:] # (V, 3, H, W) + sample['depths_h'] = sample['depths_h'][start_idx:] # (V, H, W) + sample['masks_h'] = sample['masks_h'][start_idx:] # (V, H, W) + sample['w2cs'] = sample['w2cs'][start_idx:] # (V, 4, 4) + sample['c2ws'] = sample['c2ws'][start_idx:] # (V, 4, 4) + sample['intrinsics'] = sample['intrinsics'][start_idx:] # (V, 3, 3) + sample['view_ids'] = sample['view_ids'][start_idx:] + sample['affine_mats'] = sample['affine_mats'][start_idx:] # ! in world space + + sample['scale_mat'] = torch.from_numpy(scale_mat) + sample['trans_mat'] = torch.from_numpy(w2c_ref_inv) + + # - generate rays + if ('val' in self.split) or ('test' in self.split): + sample_rays = gen_rays_from_single_image( + img_wh[1], img_wh[0], + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None) + else: + sample_rays = gen_random_rays_from_single_image( + img_wh[1], img_wh[0], + self.N_rays, + sample['query_image'], + sample['query_intrinsic'], + sample['query_c2w'], + depth=sample['query_depth'], + mask=sample['query_mask'] if self.clean_image else None, + dilated_mask=mask_dilated, + importance_sample=self.importance_sample) + + + sample['rays'] = sample_rays + + return sample diff --git a/SparseNeuS_demo_v1/data/scene.py b/SparseNeuS_demo_v1/data/scene.py new file mode 100644 index 0000000000000000000000000000000000000000..5f34f4abf9977fba8a3f8785ef4f0c95dbd9fa1b --- /dev/null +++ b/SparseNeuS_demo_v1/data/scene.py @@ -0,0 +1,101 @@ +import numpy as np +import torch + + +def rigid_transform(xyz, transform): + """Applies a rigid transform (c2w) to an (N, 3) pointcloud. + """ + device = xyz.device + xyz_h = torch.cat([xyz, torch.ones((len(xyz), 1)).to(device)], dim=1) # (N, 4) + xyz_t_h = (transform @ xyz_h.T).T # * checked: the same with the below + + return xyz_t_h[:, :3] + + +def get_view_frustum(min_depth, max_depth, size, cam_intr, c2w): + """Get corners of 3D camera view frustum of depth image + """ + device = cam_intr.device + im_h, im_w = size + im_h = int(im_h) + im_w = int(im_w) + view_frust_pts = torch.stack([ + (torch.tensor([0, 0, im_w, im_w, 0, 0, im_w, im_w]).to(device) - cam_intr[0, 2]) * torch.tensor( + [min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) / + cam_intr[0, 0], + (torch.tensor([0, im_h, 0, im_h, 0, im_h, 0, im_h]).to(device) - cam_intr[1, 2]) * torch.tensor( + [min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to(device) / + cam_intr[1, 1], + torch.tensor([min_depth, min_depth, min_depth, min_depth, max_depth, max_depth, max_depth, max_depth]).to( + device) + ]) + view_frust_pts = view_frust_pts.type(torch.float32) + c2w = c2w.type(torch.float32) + view_frust_pts = rigid_transform(view_frust_pts.T, c2w).T + return view_frust_pts + + +def set_pixel_coords(h, w): + i_range = torch.arange(0, h).view(1, h, 1).expand(1, h, w).type(torch.float32) # [1, H, W] + j_range = torch.arange(0, w).view(1, 1, w).expand(1, h, w).type(torch.float32) # [1, H, W] + ones = torch.ones(1, h, w).type(torch.float32) + + pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W] + + return pixel_coords + + +def get_boundingbox(img_hw, intrinsics, extrinsics, near_fars): + """ + # get the minimum bounding box of all visual hulls + :param img_hw: + :param intrinsics: + :param extrinsics: + :param near_fars: + :return: + """ + + bnds = torch.zeros((3, 2)) + bnds[:, 0] = np.inf + bnds[:, 1] = -np.inf + + if isinstance(intrinsics, list): + num = len(intrinsics) + else: + num = intrinsics.shape[0] + # print("num: ", num) + view_frust_pts_list = [] + for i in range(num): + if not isinstance(intrinsics[i], torch.Tensor): + cam_intr = torch.tensor(intrinsics[i]) + w2c = torch.tensor(extrinsics[i]) + c2w = torch.inverse(w2c) + else: + cam_intr = intrinsics[i] + w2c = extrinsics[i] + c2w = torch.inverse(w2c) + min_depth, max_depth = near_fars[i][0], near_fars[i][1] + # todo: check the coresponding points are matched + + view_frust_pts = get_view_frustum(min_depth, max_depth, img_hw, cam_intr, c2w) + bnds[:, 0] = torch.min(bnds[:, 0], torch.min(view_frust_pts, dim=1)[0]) + bnds[:, 1] = torch.max(bnds[:, 1], torch.max(view_frust_pts, dim=1)[0]) + view_frust_pts_list.append(view_frust_pts) + all_view_frust_pts = torch.cat(view_frust_pts_list, dim=1) + + # print("all_view_frust_pts: ", all_view_frust_pts.shape) + # distance = torch.norm(all_view_frust_pts, dim=0) + # print("distance: ", distance) + + # print("all_view_frust_pts_z: ", all_view_frust_pts[2, :]) + + center = torch.tensor(((bnds[0, 1] + bnds[0, 0]) / 2, (bnds[1, 1] + bnds[1, 0]) / 2, + (bnds[2, 1] + bnds[2, 0]) / 2)) + + lengths = bnds[:, 1] - bnds[:, 0] + + max_length, _ = torch.max(lengths, dim=0) + radius = max_length / 2 + + # print("radius: ", radius) + return center, radius, bnds diff --git a/SparseNeuS_demo_v1/exp/lod0/.gitignore b/SparseNeuS_demo_v1/exp/lod0/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..35c54109136367b098bb5112c0b87cee09444c0b --- /dev/null +++ b/SparseNeuS_demo_v1/exp/lod0/.gitignore @@ -0,0 +1 @@ +checkpoints_*/ \ No newline at end of file diff --git a/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_215000.pth b/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_215000.pth new file mode 100644 index 0000000000000000000000000000000000000000..e293399afc6fce93faab470c22e88de0e78e841d --- /dev/null +++ b/SparseNeuS_demo_v1/exp/lod0/checkpoints/ckpt_215000.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:888aaa8abde948358c26e4ef63df99f666438345c1dee301059967c5ce77b6ea +size 5312111 diff --git a/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py b/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py new file mode 100644 index 0000000000000000000000000000000000000000..7d09a56f3d66935ca26b2690ed637dfb6f51049c --- /dev/null +++ b/SparseNeuS_demo_v1/exp_runner_generic_blender_val.py @@ -0,0 +1,629 @@ +import os +import logging +import argparse +import numpy as np +from shutil import copyfile +import torch +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +from rich import print +from tqdm import tqdm +from pyhocon import ConfigFactory + +import sys +sys.path.append(os.path.dirname(__file__)) + +from models.fields import SingleVarianceNetwork +from models.featurenet import FeatureNet +from models.trainer_generic import GenericTrainer +from models.sparse_sdf_network import SparseSdfNetwork +from models.rendering_network import GeneralRenderingNetwork +from data.blender_general_narrow_all_eval_new_data import BlenderPerView + + +from datetime import datetime + +class Runner: + def __init__(self, conf_path, mode='train', is_continue=False, + is_restore=False, restore_lod0=False, local_rank=0): + + # Initial setting + self.device = torch.device('cuda:%d' % local_rank) + # self.device = torch.device('cuda') + self.num_devices = torch.cuda.device_count() + self.is_continue = is_continue or (mode == "export_mesh") + self.is_restore = is_restore + self.restore_lod0 = restore_lod0 + self.mode = mode + self.model_list = [] + self.logger = logging.getLogger('exp_logger') + + print("detected %d GPUs" % self.num_devices) + + self.conf_path = conf_path + self.conf = ConfigFactory.parse_file(conf_path) + self.timestamp = None + if not self.is_continue: + self.timestamp = '_{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) + self.base_exp_dir = self.conf['general.base_exp_dir'] + self.timestamp # jha comment this when testing and use this when training + else: + self.base_exp_dir = self.conf['general.base_exp_dir'] + self.conf['general.base_exp_dir'] = self.base_exp_dir # jha use this when testing + print("base_exp_dir: " + self.base_exp_dir) + os.makedirs(self.base_exp_dir, exist_ok=True) + self.iter_step = 0 + self.val_step = 0 + + # trainning parameters + self.end_iter = self.conf.get_int('train.end_iter') + self.save_freq = self.conf.get_int('train.save_freq') + self.report_freq = self.conf.get_int('train.report_freq') + self.val_freq = self.conf.get_int('train.val_freq') + self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') + self.batch_size = self.num_devices # use DataParallel to warp + self.validate_resolution_level = self.conf.get_int('train.validate_resolution_level') + self.learning_rate = self.conf.get_float('train.learning_rate') + self.learning_rate_milestone = self.conf.get_list('train.learning_rate_milestone') + self.learning_rate_factor = self.conf.get_float('train.learning_rate_factor') + self.use_white_bkgd = self.conf.get_bool('train.use_white_bkgd') + self.N_rays = self.conf.get_int('train.N_rays') + + # warmup params for sdf gradient + self.anneal_start_lod0 = self.conf.get_float('train.anneal_start', default=0) + self.anneal_end_lod0 = self.conf.get_float('train.anneal_end', default=0) + self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0) + self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0) + + self.writer = None + + # Networks + self.num_lods = self.conf.get_int('model.num_lods') + + self.rendering_network_outside = None + self.sdf_network_lod0 = None + self.sdf_network_lod1 = None + self.variance_network_lod0 = None + self.variance_network_lod1 = None + self.rendering_network_lod0 = None + self.rendering_network_lod1 = None + self.pyramid_feature_network = None # extract 2d pyramid feature maps from images, used for geometry + self.pyramid_feature_network_lod1 = None # may use different feature network for different lod + + # * pyramid_feature_network + self.pyramid_feature_network = FeatureNet().to(self.device) + self.sdf_network_lod0 = SparseSdfNetwork(**self.conf['model.sdf_network_lod0']).to(self.device) + self.variance_network_lod0 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) + + if self.num_lods > 1: + self.sdf_network_lod1 = SparseSdfNetwork(**self.conf['model.sdf_network_lod1']).to(self.device) + self.variance_network_lod1 = SingleVarianceNetwork(**self.conf['model.variance_network']).to(self.device) + + self.rendering_network_lod0 = GeneralRenderingNetwork(**self.conf['model.rendering_network']).to( + self.device) + + if self.num_lods > 1: + self.pyramid_feature_network_lod1 = FeatureNet().to(self.device) + self.rendering_network_lod1 = GeneralRenderingNetwork( + **self.conf['model.rendering_network_lod1']).to(self.device) + if self.mode == 'export_mesh' or self.mode == 'val': + # base_exp_dir_to_store = os.path.join(self.base_exp_dir, '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())) + base_exp_dir_to_store = os.path.join("../", args.specific_dataset_name) #"../gradio_tmp" # MODIFIED + else: + base_exp_dir_to_store = self.base_exp_dir + + print(f"Store in: {base_exp_dir_to_store}") + # Renderer model + self.trainer = GenericTrainer( + self.rendering_network_outside, + self.pyramid_feature_network, + self.pyramid_feature_network_lod1, + self.sdf_network_lod0, + self.sdf_network_lod1, + self.variance_network_lod0, + self.variance_network_lod1, + self.rendering_network_lod0, + self.rendering_network_lod1, + **self.conf['model.trainer'], + timestamp=self.timestamp, + base_exp_dir=base_exp_dir_to_store, + conf=self.conf) + + self.data_setup() # * data setup + + self.optimizer_setup() + + # Load checkpoint + latest_model_name = None + if self.is_continue: + model_list_raw = os.listdir(os.path.join(self.base_exp_dir, 'checkpoints')) + model_list = [] + for model_name in model_list_raw: + if model_name.startswith('ckpt'): + if model_name[-3:] == 'pth': # and int(model_name[5:-4]) <= self.end_iter: + model_list.append(model_name) + model_list.sort() + latest_model_name = model_list[-1] + + if latest_model_name is not None: + self.logger.info('Find checkpoint: {}'.format(latest_model_name)) + self.load_checkpoint(latest_model_name) + + self.trainer = torch.nn.DataParallel(self.trainer).to(self.device) + + if self.mode[:5] == 'train': + self.file_backup() + + def optimizer_setup(self): + self.params_to_train = self.trainer.get_trainable_params() + self.optimizer = torch.optim.Adam(self.params_to_train, lr=self.learning_rate) + + def data_setup(self): + """ + if use ddp, use setup() not prepare_data(), + prepare_data() only called on 1 GPU/TPU in distributed + :return: + """ + + self.train_dataset = BlenderPerView( + root_dir=self.conf['dataset.trainpath'], + split=self.conf.get_string('dataset.train_split', default='train'), + split_filepath=self.conf.get_string('dataset.train_split_filepath', default=None), + n_views=self.conf['dataset.nviews'], + downSample=self.conf['dataset.imgScale_train'], + N_rays=self.N_rays, + batch_size=self.batch_size, + clean_image=True, # True for training + importance_sample=self.conf.get_bool('dataset.importance_sample', default=False), + specific_dataset_name = args.specific_dataset_name + ) + + self.val_dataset = BlenderPerView( + root_dir=self.conf['dataset.valpath'], + split=self.conf.get_string('dataset.test_split', default='test'), + split_filepath=self.conf.get_string('dataset.val_split_filepath', default=None), + n_views=3, + downSample=self.conf['dataset.imgScale_test'], + N_rays=self.N_rays, + batch_size=self.batch_size, + clean_image=self.conf.get_bool('dataset.mask_out_image', + default=False) if self.mode != 'train' else False, + importance_sample=self.conf.get_bool('dataset.importance_sample', default=False), + test_ref_views=self.conf.get_list('dataset.test_ref_views', default=[]), + specific_dataset_name = args.specific_dataset_name + ) + + # item = self.train_dataset.__getitem__(0) + self.train_dataloader = DataLoader(self.train_dataset, + shuffle=True, + num_workers=4 * self.batch_size, + # num_workers=1, + batch_size=self.batch_size, + pin_memory=True, + drop_last=True + ) + + self.val_dataloader = DataLoader(self.val_dataset, + # shuffle=False if self.mode == 'train' else True, + shuffle=False, + num_workers=4 * self.batch_size, + # num_workers=1, + batch_size=self.batch_size, + pin_memory=True, + drop_last=False + ) + + self.val_dataloader_iterator = iter(self.val_dataloader) # - should be after "reconstruct_metas_for_gru_fusion" + + def train(self): + self.writer = SummaryWriter(log_dir=os.path.join(self.base_exp_dir, 'logs')) + res_step = self.end_iter - self.iter_step + + dataloader = self.train_dataloader + + epochs = int(1 + res_step // len(dataloader)) + + self.adjust_learning_rate() + print("starting training learning rate: {:.5f}".format(self.optimizer.param_groups[0]['lr'])) + + background_rgb = None + if self.use_white_bkgd: + # background_rgb = torch.ones([1, 3]).to(self.device) + background_rgb = 1.0 + + for epoch_i in range(epochs): + + print("current epoch %d" % epoch_i) + dataloader = tqdm(dataloader) + + for batch in dataloader: + # print("Checker1:, fetch data") + batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) # used to get meta + + # - warmup params + if self.num_lods == 1: + alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) + else: + alpha_inter_ratio_lod0 = 1. + alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) + + losses = self.trainer( + batch, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=self.iter_step, + mode='train', + ) + + loss_types = ['loss_lod0', 'loss_lod1'] + # print("[TEST]: weights_sum in trainer return", losses['losses_lod0']['weights_sum'].mean()) + + losses_lod0 = losses['losses_lod0'] + losses_lod1 = losses['losses_lod1'] + # import ipdb; ipdb.set_trace() + loss = 0 + for loss_type in loss_types: + if losses[loss_type] is not None: + loss = loss + losses[loss_type].mean() + # print("Checker4:, begin BP") + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.params_to_train, 1.0) + self.optimizer.step() + # print("Checker5:, end BP") + self.iter_step += 1 + + if self.iter_step % self.report_freq == 0: + self.writer.add_scalar('Loss/loss', loss, self.iter_step) + + if losses_lod0 is not None: + self.writer.add_scalar('Loss/d_loss_lod0', + losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0, + self.iter_step) + self.writer.add_scalar('Loss/sparse_loss_lod0', + losses_lod0[ + 'sparse_loss'].mean() if losses_lod0 is not None else 0, + self.iter_step) + self.writer.add_scalar('Loss/color_loss_lod0', + losses_lod0['color_fine_loss'].mean() + if losses_lod0['color_fine_loss'] is not None else 0, + self.iter_step) + + self.writer.add_scalar('statis/psnr_lod0', + losses_lod0['psnr'].mean() + if losses_lod0['psnr'] is not None else 0, + self.iter_step) + + self.writer.add_scalar('param/variance_lod0', + 1. / torch.exp(self.variance_network_lod0.variance * 10), + self.iter_step) + self.writer.add_scalar('param/eikonal_loss', losses_lod0['gradient_error_loss'].mean() if losses_lod0 is not None else 0, + self.iter_step) + + ######## - lod 1 + if self.num_lods > 1: + self.writer.add_scalar('Loss/d_loss_lod1', + losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0, + self.iter_step) + self.writer.add_scalar('Loss/sparse_loss_lod1', + losses_lod1[ + 'sparse_loss'].mean() if losses_lod1 is not None else 0, + self.iter_step) + self.writer.add_scalar('Loss/color_loss_lod1', + losses_lod1['color_fine_loss'].mean() + if losses_lod1['color_fine_loss'] is not None else 0, + self.iter_step) + self.writer.add_scalar('statis/sdf_mean_lod1', + losses_lod1['sdf_mean'].mean() if losses_lod1 is not None else 0, + self.iter_step) + self.writer.add_scalar('statis/psnr_lod1', + losses_lod1['psnr'].mean() + if losses_lod1['psnr'] is not None else 0, + self.iter_step) + self.writer.add_scalar('statis/sparseness_0.01_lod1', + losses_lod1['sparseness_1'].mean() + if losses_lod1['sparseness_1'] is not None else 0, + self.iter_step) + self.writer.add_scalar('statis/sparseness_0.02_lod1', + losses_lod1['sparseness_2'].mean() + if losses_lod1['sparseness_2'] is not None else 0, + self.iter_step) + self.writer.add_scalar('param/variance_lod1', + 1. / torch.exp(self.variance_network_lod1.variance * 10), + self.iter_step) + + print(self.base_exp_dir) + print( + 'iter:{:8>d} ' + 'loss = {:.4f} ' + 'd_loss_lod0 = {:.4f} ' + 'color_loss_lod0 = {:.4f} ' + 'sparse_loss_lod0= {:.4f} ' + 'd_loss_lod1 = {:.4f} ' + 'color_loss_lod1 = {:.4f} ' + ' lr = {:.5f}'.format( + self.iter_step, loss, + losses_lod0['depth_loss'].mean() if losses_lod0 is not None else 0, + losses_lod0['color_fine_loss'].mean() if losses_lod0 is not None else 0, + losses_lod0['sparse_loss'].mean() if losses_lod0 is not None else 0, + losses_lod1['depth_loss'].mean() if losses_lod1 is not None else 0, + losses_lod1['color_fine_loss'].mean() if losses_lod1 is not None else 0, + self.optimizer.param_groups[0]['lr'])) + + print('alpha_inter_ratio_lod0 = {:.4f} alpha_inter_ratio_lod1 = {:.4f}\n'.format( + alpha_inter_ratio_lod0, alpha_inter_ratio_lod1)) + + if losses_lod0 is not None: + # print("[TEST]: weights_sum in print", losses_lod0['weights_sum'].mean()) + # import ipdb; ipdb.set_trace() + print( + 'iter:{:8>d} ' + 'variance = {:.5f} ' + 'weights_sum = {:.4f} ' + 'weights_sum_fg = {:.4f} ' + 'alpha_sum = {:.4f} ' + 'sparse_weight= {:.4f} ' + 'background_loss = {:.4f} ' + 'background_weight = {:.4f} ' + .format( + self.iter_step, + losses_lod0['variance'].mean(), + losses_lod0['weights_sum'].mean(), + losses_lod0['weights_sum_fg'].mean(), + losses_lod0['alpha_sum'].mean(), + losses_lod0['sparse_weight'].mean(), + losses_lod0['fg_bg_loss'].mean(), + losses_lod0['fg_bg_weight'].mean(), + )) + + if losses_lod1 is not None: + print( + 'iter:{:8>d} ' + 'variance = {:.5f} ' + ' weights_sum = {:.4f} ' + 'alpha_sum = {:.4f} ' + 'fg_bg_loss = {:.4f} ' + 'fg_bg_weight = {:.4f} ' + 'sparse_weight= {:.4f} ' + 'fg_bg_loss = {:.4f} ' + 'fg_bg_weight = {:.4f} ' + .format( + self.iter_step, + losses_lod1['variance'].mean(), + losses_lod1['weights_sum'].mean(), + losses_lod1['alpha_sum'].mean(), + losses_lod1['fg_bg_loss'].mean(), + losses_lod1['fg_bg_weight'].mean(), + losses_lod1['sparse_weight'].mean(), + losses_lod1['fg_bg_loss'].mean(), + losses_lod1['fg_bg_weight'].mean(), + )) + + if self.iter_step % self.save_freq == 0: + self.save_checkpoint() + + if self.iter_step % self.val_freq == 0: + self.validate() + + # - ajust learning rate + self.adjust_learning_rate() + + def adjust_learning_rate(self): + # - ajust learning rate, cosine learning schedule + learning_rate = (np.cos(np.pi * self.iter_step / self.end_iter) + 1.0) * 0.5 * 0.9 + 0.1 + learning_rate = self.learning_rate * learning_rate + for g in self.optimizer.param_groups: + g['lr'] = learning_rate + + def get_alpha_inter_ratio(self, start, end): + if end == 0.0: + return 1.0 + elif self.iter_step < start: + return 0.0 + else: + return np.min([1.0, (self.iter_step - start) / (end - start)]) + + def file_backup(self): + # copy python file + dir_lis = self.conf['general.recording'] + os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) + for dir_name in dir_lis: + cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) + os.makedirs(cur_dir, exist_ok=True) + files = os.listdir(dir_name) + for f_name in files: + if f_name[-3:] == '.py': + copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) + + # copy configs + copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) + + def load_checkpoint(self, checkpoint_name): + + def load_state_dict(network, checkpoint, comment): + if network is not None: + try: + pretrained_dict = checkpoint[comment] + + model_dict = network.state_dict() + + # 1. filter out unnecessary keys + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + # 2. overwrite entries in the existing state dict + model_dict.update(pretrained_dict) + # 3. load the new state dict + network.load_state_dict(pretrained_dict) + except: + print(comment + " load fails") + + checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), + map_location=self.device) + + load_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside') + + load_state_dict(self.sdf_network_lod0, checkpoint, 'sdf_network_lod0') + load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod1') + + load_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network') + load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1') + + load_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0') + load_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1') + + load_state_dict(self.rendering_network_lod0, checkpoint, 'rendering_network_lod0') + load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod1') + + if self.restore_lod0: # use the trained lod0 networks to initialize lod1 networks + load_state_dict(self.sdf_network_lod1, checkpoint, 'sdf_network_lod0') + load_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network') + load_state_dict(self.rendering_network_lod1, checkpoint, 'rendering_network_lod0') + + if self.is_continue and (not self.restore_lod0): + try: + self.optimizer.load_state_dict(checkpoint['optimizer']) + except: + print("load optimizer fails") + self.iter_step = checkpoint['iter_step'] + self.val_step = checkpoint['val_step'] if 'val_step' in checkpoint.keys() else 0 + + self.logger.info('End') + + def save_checkpoint(self): + + def save_state_dict(network, checkpoint, comment): + if network is not None: + checkpoint[comment] = network.state_dict() + + checkpoint = { + 'optimizer': self.optimizer.state_dict(), + 'iter_step': self.iter_step, + 'val_step': self.val_step, + } + + save_state_dict(self.sdf_network_lod0, checkpoint, "sdf_network_lod0") + save_state_dict(self.sdf_network_lod1, checkpoint, "sdf_network_lod1") + + save_state_dict(self.rendering_network_outside, checkpoint, 'rendering_network_outside') + save_state_dict(self.rendering_network_lod0, checkpoint, "rendering_network_lod0") + save_state_dict(self.rendering_network_lod1, checkpoint, "rendering_network_lod1") + + save_state_dict(self.variance_network_lod0, checkpoint, 'variance_network_lod0') + save_state_dict(self.variance_network_lod1, checkpoint, 'variance_network_lod1') + + save_state_dict(self.pyramid_feature_network, checkpoint, 'pyramid_feature_network') + save_state_dict(self.pyramid_feature_network_lod1, checkpoint, 'pyramid_feature_network_lod1') + + os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) + torch.save(checkpoint, + os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) + + def validate(self, resolution_level=-1): + # validate image + print("iter_step: ", self.iter_step) + self.logger.info('Validate begin') + self.val_step += 1 + + try: + batch = next(self.val_dataloader_iterator) + except: + self.val_dataloader_iterator = iter(self.val_dataloader) # reset + + batch = next(self.val_dataloader_iterator) + + + background_rgb = None + if self.use_white_bkgd: + # background_rgb = torch.ones([1, 3]).to(self.device) + background_rgb = 1.0 + + batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) + + # - warmup params + if self.num_lods == 1: + alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) + else: + alpha_inter_ratio_lod0 = 1. + alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) + + self.trainer( + batch, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=self.iter_step, + save_vis=True, + mode='val', + ) + + + def export_mesh(self, resolution_level=-1): + print("iter_step: ", self.iter_step) + self.logger.info('Validate begin') + self.val_step += 1 + + try: + batch = next(self.val_dataloader_iterator) + except: + self.val_dataloader_iterator = iter(self.val_dataloader) # reset + + batch = next(self.val_dataloader_iterator) + + + background_rgb = None + if self.use_white_bkgd: + background_rgb = 1.0 + + batch['batch_idx'] = torch.tensor([x for x in range(self.batch_size)]) + + # - warmup params + if self.num_lods == 1: + alpha_inter_ratio_lod0 = self.get_alpha_inter_ratio(self.anneal_start_lod0, self.anneal_end_lod0) + else: + alpha_inter_ratio_lod0 = 1. + alpha_inter_ratio_lod1 = self.get_alpha_inter_ratio(self.anneal_start_lod1, self.anneal_end_lod1) + self.trainer( + batch, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=self.iter_step, + save_vis=True, + mode='export_mesh', + ) + + +if __name__ == '__main__': + # torch.set_default_tensor_type('torch.cuda.FloatTensor') + torch.set_default_dtype(torch.float32) + FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" + logging.basicConfig(level=logging.INFO, format=FORMAT) + + parser = argparse.ArgumentParser() + parser.add_argument('--conf', type=str, default='./confs/base.conf') + parser.add_argument('--mode', type=str, default='train') + parser.add_argument('--threshold', type=float, default=0.0) + parser.add_argument('--is_continue', default=False, action="store_true") + parser.add_argument('--is_restore', default=False, action="store_true") + parser.add_argument('--is_finetune', default=False, action="store_true") + parser.add_argument('--train_from_scratch', default=False, action="store_true") + parser.add_argument('--restore_lod0', default=False, action="store_true") + parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--specific_dataset_name', type=str, default='GSO') + + + args = parser.parse_args() + + torch.cuda.set_device(args.local_rank) + torch.backends.cudnn.benchmark = True # ! make training 2x faster + + runner = Runner(args.conf, args.mode, args.is_continue, args.is_restore, args.restore_lod0, + args.local_rank) + + if args.mode == 'train': + runner.train() + elif args.mode == 'val': + for i in range(len(runner.val_dataset)): + runner.validate() + elif args.mode == 'export_mesh': + for i in range(len(runner.val_dataset)): + runner.export_mesh() diff --git a/SparseNeuS_demo_v1/loss/__init__.py b/SparseNeuS_demo_v1/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/loss/color_loss.py b/SparseNeuS_demo_v1/loss/color_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..abf3f0eb51c6ed29799a870d5833b23c4c41dde8 --- /dev/null +++ b/SparseNeuS_demo_v1/loss/color_loss.py @@ -0,0 +1,152 @@ +import torch +import torch.nn as nn +from loss.ncc import NCC + + +class Normalize(nn.Module): + def __init__(self): + super(Normalize, self).__init__() + + def forward(self, bottom): + qn = torch.norm(bottom, p=2, dim=1).unsqueeze(dim=1) + 1e-12 + top = bottom.div(qn) + + return top + + +class OcclusionColorLoss(nn.Module): + def __init__(self, alpha=1, beta=0.025, gama=0.01, occlusion_aware=True, weight_thred=[0.6]): + super(OcclusionColorLoss, self).__init__() + self.alpha = alpha + self.beta = beta + self.gama = gama + self.occlusion_aware = occlusion_aware + self.eps = 1e-4 + + self.weight_thred = weight_thred + self.adjuster = ParamAdjuster(self.weight_thred, self.beta) + + def forward(self, pred, gt, weight, mask, detach=False, occlusion_aware=True): + """ + + :param pred: [N_pts, 3] + :param gt: [N_pts, 3] + :param weight: [N_pts] + :param mask: [N_pts] + :return: + """ + if detach: + weight = weight.detach() + + error = torch.abs(pred - gt).sum(dim=-1, keepdim=False) # [N_pts] + error = error[mask] + + if not (self.occlusion_aware and occlusion_aware): + return torch.mean(error), torch.mean(error) + + beta = self.adjuster(weight.mean()) + + # weight = weight[mask] + weight = weight.clamp(0.0, 1.0) + term1 = self.alpha * torch.mean(weight[mask] * error) + term2 = beta * torch.log(1 - weight + self.eps).mean() + term3 = self.gama * torch.log(weight + self.eps).mean() + + return term1 + term2 + term3, term1 + + +class OcclusionColorPatchLoss(nn.Module): + def __init__(self, alpha=1, beta=0.025, gama=0.015, + occlusion_aware=True, type='l1', h_patch_size=3, weight_thred=[0.6]): + super(OcclusionColorPatchLoss, self).__init__() + self.alpha = alpha + self.beta = beta + self.gama = gama + self.occlusion_aware = occlusion_aware + self.type = type # 'l1' or 'ncc' loss + self.ncc = NCC(h_patch_size=h_patch_size) + self.eps = 1e-4 + self.weight_thred = weight_thred + + self.adjuster = ParamAdjuster(self.weight_thred, self.beta) + + print("type {} patch_size {} beta {} gama {} weight_thred {}".format(type, h_patch_size, beta, gama, + weight_thred)) + + def forward(self, pred, gt, weight, mask, penalize_ratio=0.9, detach=False, occlusion_aware=True): + """ + + :param pred: [N_pts, Npx, 3] + :param gt: [N_pts, Npx, 3] + :param weight: [N_pts] + :param mask: [N_pts] + :return: + """ + + if detach: + weight = weight.detach() + + if self.type == 'l1': + error = torch.abs(pred - gt).mean(dim=-1, keepdim=False).sum(dim=-1, keepdim=False) # [N_pts] + elif self.type == 'ncc': + error = 1 - self.ncc(pred[:, None, :, :], gt)[:, 0] # ncc 1 positive, -1 negative + error, indices = torch.sort(error) + mask = torch.index_select(mask, 0, index=indices) + mask[int(penalize_ratio * mask.shape[0]):] = False # can help boundaries + elif self.type == 'ssd': + error = ((pred - gt) ** 2).mean(dim=-1, keepdim=False).sum(dim=-1, keepdims=False) + + error = error[mask] + if not (self.occlusion_aware and occlusion_aware): + return torch.mean(error), torch.mean(error), 0. + + # * weight adjuster + beta = self.adjuster(weight.mean()) + + # weight = weight[mask] + weight = weight.clamp(0.0, 1.0) + + term1 = self.alpha * torch.mean(weight[mask] * error) + term2 = beta * torch.log(1 - weight + self.eps).mean() + term3 = self.gama * torch.log(weight + self.eps).mean() + + return term1 + term2 + term3, term1, beta + + +class ParamAdjuster(nn.Module): + def __init__(self, weight_thred, param): + super(ParamAdjuster, self).__init__() + self.weight_thred = weight_thred + self.thred_num = len(weight_thred) + self.param = param + self.global_step = 0 + self.statis_window = 100 + self.counter = 0 + self.adjusted = False + self.adjusted_step = 0 + self.thred_idx = 0 + + def reset(self): + self.counter = 0 + self.adjusted = False + + def adjust(self): + if (self.counter / self.statis_window) > 0.3: + self.param = self.param + 0.005 + self.adjusted = True + self.adjusted_step = self.global_step + self.thred_idx += 1 + print("adjusted param, now {}".format(self.param)) + + def forward(self, weight_mean): + self.global_step += 1 + + if (self.global_step % self.statis_window == 0) and self.adjusted is False: + self.adjust() + self.reset() + + if self.thred_idx < self.thred_num: + if weight_mean < self.weight_thred[self.thred_idx] and (not self.adjusted): + self.counter += 1 + + return self.param diff --git a/SparseNeuS_demo_v1/loss/depth_loss.py b/SparseNeuS_demo_v1/loss/depth_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..cba92851a79857ff6edd5c2f2eb12a2972b85bdc --- /dev/null +++ b/SparseNeuS_demo_v1/loss/depth_loss.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DepthLoss(nn.Module): + def __init__(self, type='l1'): + super(DepthLoss, self).__init__() + self.type = type + + + def forward(self, depth_pred, depth_gt, mask=None): + if (depth_gt < 0).sum() > 0: + # print("no depth loss") + return torch.tensor(0.0).to(depth_pred.device) + if mask is not None: + mask_d = (depth_gt > 0).float() + + mask = mask * mask_d + + mask_sum = mask.sum() + 1e-5 + depth_error = (depth_pred - depth_gt) * mask + depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), + reduction='sum') / mask_sum + else: + depth_error = depth_pred - depth_gt + depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), + reduction='mean') + return depth_loss + +def forward(self, depth_pred, depth_gt, mask=None): + if mask is not None: + mask_d = (depth_gt > 0).float() + + mask = mask * mask_d + + mask_sum = mask.sum() + 1e-5 + depth_error = (depth_pred - depth_gt) * mask + depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), + reduction='sum') / mask_sum + else: + depth_error = depth_pred - depth_gt + depth_loss = F.l1_loss(depth_error, torch.zeros_like(depth_error).to(depth_error.device), + reduction='mean') + return depth_loss + +class DepthSmoothLoss(nn.Module): + def __init__(self): + super(DepthSmoothLoss, self).__init__() + + def forward(self, disp, img, mask): + """ + Computes the smoothness loss for a disparity image + The color image is used for edge-aware smoothness + :param disp: [B, 1, H, W] + :param img: [B, 1, H, W] + :param mask: [B, 1, H, W] + :return: + """ + grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) + grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) + + grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) + grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) + + grad_disp_x *= torch.exp(-grad_img_x) + grad_disp_y *= torch.exp(-grad_img_y) + + grad_disp = (grad_disp_x * mask[:, :, :, :-1]).mean() + (grad_disp_y * mask[:, :, :-1, :]).mean() + + return grad_disp diff --git a/SparseNeuS_demo_v1/loss/depth_metric.py b/SparseNeuS_demo_v1/loss/depth_metric.py new file mode 100644 index 0000000000000000000000000000000000000000..e8b6249ac6a06906e20a344f468fc1c6e4b992ae --- /dev/null +++ b/SparseNeuS_demo_v1/loss/depth_metric.py @@ -0,0 +1,240 @@ +import numpy as np + + +def l1(depth1, depth2): + """ + Computes the l1 errors between the two depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + L1(log) + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + diff = depth1 - depth2 + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.absolute(diff)) / num_pixels + + +def l1_inverse(depth1, depth2): + """ + Computes the l1 errors between inverses of two depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + L1(log) + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + diff = np.reciprocal(depth1) - np.reciprocal(depth2) + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.absolute(diff)) / num_pixels + + +def rmse_log(depth1, depth2): + """ + Computes the root min square errors between the logs of two depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + RMSE(log) + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + log_diff = np.log(depth1) - np.log(depth2) + num_pixels = float(log_diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sqrt(np.sum(np.square(log_diff)) / num_pixels) + + +def rmse(depth1, depth2): + """ + Computes the root min square errors between the two depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + RMSE(log) + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + diff = depth1 - depth2 + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sqrt(np.sum(np.square(diff)) / num_pixels) + + +def scale_invariant(depth1, depth2): + """ + Computes the scale invariant loss based on differences of logs of depth maps. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + scale_invariant_distance + + """ + # sqrt(Eq. 3) + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + log_diff = np.log(depth1) - np.log(depth2) + num_pixels = float(log_diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sqrt(np.sum(np.square(log_diff)) / num_pixels - np.square(np.sum(log_diff)) / np.square(num_pixels)) + + +def abs_relative(depth_pred, depth_gt): + """ + Computes relative absolute distance. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth_pred: depth map prediction + depth_gt: depth map ground truth + + Returns: + abs_relative_distance + + """ + assert (np.all(np.isfinite(depth_pred) & np.isfinite(depth_gt) & (depth_pred >= 0) & (depth_gt >= 0))) + diff = depth_pred - depth_gt + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.absolute(diff) / depth_gt) / num_pixels + + +def avg_log10(depth1, depth2): + """ + Computes average log_10 error (Liu, Neural Fields, 2015). + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + abs_relative_distance + + """ + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + log_diff = np.log10(depth1) - np.log10(depth2) + num_pixels = float(log_diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.absolute(log_diff)) / num_pixels + + +def sq_relative(depth_pred, depth_gt): + """ + Computes relative squared distance. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth_pred: depth map prediction + depth_gt: depth map ground truth + + Returns: + squared_relative_distance + + """ + assert (np.all(np.isfinite(depth_pred) & np.isfinite(depth_gt) & (depth_pred >= 0) & (depth_gt >= 0))) + diff = depth_pred - depth_gt + num_pixels = float(diff.size) + + if num_pixels == 0: + return np.nan + else: + return np.sum(np.square(diff) / depth_gt) / num_pixels + + +def ratio_threshold(depth1, depth2, threshold): + """ + Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold. + Takes preprocessed depths (no nans, infs and non-positive values) + + depth1: one depth map + depth2: another depth map + + Returns: + percentage of pixels with ratio less than the threshold + + """ + assert (threshold > 0.) + assert (np.all(np.isfinite(depth1) & np.isfinite(depth2) & (depth1 >= 0) & (depth2 >= 0))) + log_diff = np.log(depth1) - np.log(depth2) + num_pixels = float(log_diff.size) + + if num_pixels == 0: + return np.nan + else: + return float(np.sum(np.absolute(log_diff) < np.log(threshold))) / num_pixels + + +def compute_depth_errors(depth_pred, depth_gt, valid_mask): + """ + Computes different distance measures between two depth maps. + + depth_pred: depth map prediction + depth_gt: depth map ground truth + distances_to_compute: which distances to compute + + Returns: + a dictionary with computed distances, and the number of valid pixels + + """ + depth_pred = depth_pred[valid_mask] + depth_gt = depth_gt[valid_mask] + num_valid = np.sum(valid_mask) + + distances_to_compute = ['l1', + 'l1_inverse', + 'scale_invariant', + 'abs_relative', + 'sq_relative', + 'avg_log10', + 'rmse_log', + 'rmse', + 'ratio_threshold_1.25', + 'ratio_threshold_1.5625', + 'ratio_threshold_1.953125'] + + results = {'num_valid': num_valid} + for dist in distances_to_compute: + if dist.startswith('ratio_threshold'): + threshold = float(dist.split('_')[-1]) + results[dist] = ratio_threshold(depth_pred, depth_gt, threshold) + else: + results[dist] = globals()[dist](depth_pred, depth_gt) + + return results diff --git a/SparseNeuS_demo_v1/loss/ncc.py b/SparseNeuS_demo_v1/loss/ncc.py new file mode 100644 index 0000000000000000000000000000000000000000..768fcefc3aab55d8e3fed49f23ffb4a974eec4ec --- /dev/null +++ b/SparseNeuS_demo_v1/loss/ncc.py @@ -0,0 +1,65 @@ +import torch +import torch.nn.functional as F +import numpy as np +from math import exp, sqrt + + +class NCC(torch.nn.Module): + def __init__(self, h_patch_size, mode='rgb'): + super(NCC, self).__init__() + self.window_size = 2 * h_patch_size + 1 + self.mode = mode # 'rgb' or 'gray' + self.channel = 3 + self.register_buffer("window", create_window(self.window_size, self.channel)) + + def forward(self, img_pred, img_gt): + """ + :param img_pred: [Npx, nviews, npatch, c] + :param img_gt: [Npx, npatch, c] + :return: + """ + ntotpx, nviews, npatch, channels = img_pred.shape + + patch_size = int(sqrt(npatch)) + patch_img_pred = img_pred.reshape(ntotpx, nviews, patch_size, patch_size, channels).permute(0, 1, 4, 2, + 3).contiguous() + patch_img_gt = img_gt.reshape(ntotpx, patch_size, patch_size, channels).permute(0, 3, 1, 2) + + return _ncc(patch_img_pred, patch_img_gt, self.window, self.channel) + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel, std=1.5): + _1D_window = gaussian(window_size, std).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0) + window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() + return window + + +def _ncc(pred, gt, window, channel): + ntotpx, nviews, nc, h, w = pred.shape + flat_pred = pred.view(-1, nc, h, w) + mu1 = F.conv2d(flat_pred, window, padding=0, groups=channel).view(ntotpx, nviews, nc) + mu2 = F.conv2d(gt, window, padding=0, groups=channel).view(ntotpx, nc) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2).unsqueeze(1) # (ntotpx, 1, nc) + + sigma1_sq = F.conv2d(flat_pred * flat_pred, window, padding=0, groups=channel).view(ntotpx, nviews, nc) - mu1_sq + sigma2_sq = F.conv2d(gt * gt, window, padding=0, groups=channel).view(ntotpx, 1, 3) - mu2_sq + + sigma1 = torch.sqrt(sigma1_sq + 1e-4) + sigma2 = torch.sqrt(sigma2_sq + 1e-4) + + pred_norm = (pred - mu1[:, :, :, None, None]) / (sigma1[:, :, :, None, None] + 1e-8) # [ntotpx, nviews, nc, h, w] + gt_norm = (gt[:, None, :, :, :] - mu2[:, None, :, None, None]) / ( + sigma2[:, :, :, None, None] + 1e-8) # ntotpx, nc, h, w + + ncc = F.conv2d((pred_norm * gt_norm).view(-1, nc, h, w), window, padding=0, groups=channel).view( + ntotpx, nviews, nc) + + return torch.mean(ncc, dim=2) diff --git a/SparseNeuS_demo_v1/models/__init__.py b/SparseNeuS_demo_v1/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/models/embedder.py b/SparseNeuS_demo_v1/models/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..d327d92d9f64c0b32908dbee864160b65daa450e --- /dev/null +++ b/SparseNeuS_demo_v1/models/embedder.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn + +""" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """ + + +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) + else: + freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + if self.kwargs['normalize']: + embed_fns.append(lambda x, p_fn=p_fn, + freq=freq: p_fn(x * freq) / freq) + else: + embed_fns.append(lambda x, p_fn=p_fn, + freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +def get_embedder(multires, normalize=False, input_dims=3): + embed_kwargs = { + 'include_input': True, + 'input_dims': input_dims, + 'max_freq_log2': multires - 1, + 'num_freqs': multires, + 'normalize': normalize, + 'log_sampling': True, + 'periodic_fns': [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + + def embed(x, eo=embedder_obj): return eo.embed(x) + + return embed, embedder_obj.out_dim + + +class Embedding(nn.Module): + def __init__(self, in_channels, N_freqs, logscale=True, normalize=False): + """ + Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) + in_channels: number of input channels (3 for both xyz and direction) + """ + super(Embedding, self).__init__() + self.N_freqs = N_freqs + self.in_channels = in_channels + self.funcs = [torch.sin, torch.cos] + self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1) + self.normalize = normalize + + if logscale: + self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs) + else: + self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs) + + def forward(self, x): + """ + Embeds x to (x, sin(2^k x), cos(2^k x), ...) + Different from the paper, "x" is also in the output + See https://github.com/bmild/nerf/issues/12 + + Inputs: + x: (B, self.in_channels) + + Outputs: + out: (B, self.out_channels) + """ + out = [x] + for freq in self.freq_bands: + for func in self.funcs: + if self.normalize: + out += [func(freq * x) / freq] + else: + out += [func(freq * x)] + + return torch.cat(out, -1) diff --git a/SparseNeuS_demo_v1/models/fast_renderer.py b/SparseNeuS_demo_v1/models/fast_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..1faeba85e5b156d0de12e430287d90f4a803aa92 --- /dev/null +++ b/SparseNeuS_demo_v1/models/fast_renderer.py @@ -0,0 +1,316 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from icecream import ic + + +# - neus: use sphere-tracing to speed up depth maps extraction +# This code snippet is heavily borrowed from IDR. +class FastRenderer(nn.Module): + def __init__(self): + super(FastRenderer, self).__init__() + + self.sdf_threshold = 5e-5 + self.line_search_step = 0.5 + self.line_step_iters = 1 + self.sphere_tracing_iters = 10 + self.n_steps = 100 + self.n_secant_steps = 8 + + # - use sdf_network to inference sdf value or directly interpolate sdf value from precomputed sdf_volume + self.network_inference = False + + def extract_depth_maps(self, rays_o, rays_d, near, far, sdf_network, conditional_volume): + with torch.no_grad(): + curr_start_points, network_object_mask, acc_start_dis = self.get_intersection( + rays_o, rays_d, near, far, + sdf_network, conditional_volume) + + network_object_mask = network_object_mask.reshape(-1) + + return network_object_mask, acc_start_dis + + def get_intersection(self, rays_o, rays_d, near, far, sdf_network, conditional_volume): + device = rays_o.device + num_pixels, _ = rays_d.shape + + curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis = \ + self.sphere_tracing(rays_o, rays_d, near, far, sdf_network, conditional_volume) + + network_object_mask = (acc_start_dis < acc_end_dis) + + # The non convergent rays should be handled by the sampler + sampler_mask = unfinished_mask_start + sampler_net_obj_mask = torch.zeros_like(sampler_mask).bool().to(device) + if sampler_mask.sum() > 0: + # sampler_min_max = torch.zeros((num_pixels, 2)).to(device) + # sampler_min_max[sampler_mask, 0] = acc_start_dis[sampler_mask] + # sampler_min_max[sampler_mask, 1] = acc_end_dis[sampler_mask] + + # ray_sampler(self, rays_o, rays_d, near, far, sampler_mask): + sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(rays_o, + rays_d, + acc_start_dis, + acc_end_dis, + sampler_mask, + sdf_network, + conditional_volume + ) + + curr_start_points[sampler_mask] = sampler_pts[sampler_mask] + acc_start_dis[sampler_mask] = sampler_dists[sampler_mask][:, None] + network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask][:, None] + + # print('----------------------------------------------------------------') + # print('RayTracing: object = {0}/{1}, secant on {2}/{3}.' + # .format(network_object_mask.sum(), len(network_object_mask), sampler_net_obj_mask.sum(), + # sampler_mask.sum())) + # print('----------------------------------------------------------------') + + return curr_start_points, network_object_mask, acc_start_dis + + def sphere_tracing(self, rays_o, rays_d, near, far, sdf_network, conditional_volume): + ''' Run sphere tracing algorithm for max iterations from both sides of unit sphere intersection ''' + + device = rays_o.device + + unfinished_mask_start = (near < far).reshape(-1).clone() + unfinished_mask_end = (near < far).reshape(-1).clone() + + # Initialize start current points + curr_start_points = rays_o + rays_d * near + acc_start_dis = near.clone() + + # Initialize end current points + curr_end_points = rays_o + rays_d * far + acc_end_dis = far.clone() + + # Initizlize min and max depth + min_dis = acc_start_dis.clone() + max_dis = acc_end_dis.clone() + + # Iterate on the rays (from both sides) till finding a surface + iters = 0 + + next_sdf_start = torch.zeros_like(acc_start_dis).to(device) + + if self.network_inference: + sdf_func = sdf_network.sdf + else: + sdf_func = sdf_network.sdf_from_sdfvolume + + next_sdf_start[unfinished_mask_start] = sdf_func( + curr_start_points[unfinished_mask_start], + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0] + + next_sdf_end = torch.zeros_like(acc_end_dis).to(device) + next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end], + conditional_volume, lod=0, gru_fusion=False)[ + 'sdf_pts_scale%d' % 0] + + while True: + # Update sdf + curr_sdf_start = torch.zeros_like(acc_start_dis).to(device) + curr_sdf_start[unfinished_mask_start] = next_sdf_start[unfinished_mask_start] + curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0 + + curr_sdf_end = torch.zeros_like(acc_end_dis).to(device) + curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end] + curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0 + + # Update masks + unfinished_mask_start = unfinished_mask_start & (curr_sdf_start > self.sdf_threshold).reshape(-1) + unfinished_mask_end = unfinished_mask_end & (curr_sdf_end > self.sdf_threshold).reshape(-1) + + if ( + unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0) or iters == self.sphere_tracing_iters: + break + iters += 1 + + # Make step + # Update distance + acc_start_dis = acc_start_dis + curr_sdf_start + acc_end_dis = acc_end_dis - curr_sdf_end + + # Update points + curr_start_points = rays_o + acc_start_dis * rays_d + curr_end_points = rays_o + acc_end_dis * rays_d + + # Fix points which wrongly crossed the surface + next_sdf_start = torch.zeros_like(acc_start_dis).to(device) + if unfinished_mask_start.sum() > 0: + next_sdf_start[unfinished_mask_start] = sdf_func(curr_start_points[unfinished_mask_start], + conditional_volume, lod=0, gru_fusion=False)[ + 'sdf_pts_scale%d' % 0] + + next_sdf_end = torch.zeros_like(acc_end_dis).to(device) + if unfinished_mask_end.sum() > 0: + next_sdf_end[unfinished_mask_end] = sdf_func(curr_end_points[unfinished_mask_end], + conditional_volume, lod=0, gru_fusion=False)[ + 'sdf_pts_scale%d' % 0] + + not_projected_start = (next_sdf_start < 0).reshape(-1) + not_projected_end = (next_sdf_end < 0).reshape(-1) + not_proj_iters = 0 + + while ( + not_projected_start.sum() > 0 or not_projected_end.sum() > 0) and not_proj_iters < self.line_step_iters: + # Step backwards + if not_projected_start.sum() > 0: + acc_start_dis[not_projected_start] -= ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \ + curr_sdf_start[not_projected_start] + curr_start_points[not_projected_start] = (rays_o + acc_start_dis * rays_d)[not_projected_start] + + next_sdf_start[not_projected_start] = sdf_func( + curr_start_points[not_projected_start], + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0] + + if not_projected_end.sum() > 0: + acc_end_dis[not_projected_end] += ((1 - self.line_search_step) / (2 ** not_proj_iters)) * \ + curr_sdf_end[ + not_projected_end] + curr_end_points[not_projected_end] = (rays_o + acc_end_dis * rays_d)[not_projected_end] + + # Calc sdf + + next_sdf_end[not_projected_end] = sdf_func( + curr_end_points[not_projected_end], + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0] + + # Update mask + not_projected_start = (next_sdf_start < 0).reshape(-1) + not_projected_end = (next_sdf_end < 0).reshape(-1) + not_proj_iters += 1 + + unfinished_mask_start = unfinished_mask_start & (acc_start_dis < acc_end_dis).reshape(-1) + unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis).reshape(-1) + + return curr_start_points, unfinished_mask_start, acc_start_dis, acc_end_dis, min_dis, max_dis + + def ray_sampler(self, rays_o, rays_d, near, far, sampler_mask, sdf_network, conditional_volume): + ''' Sample the ray in a given range and run secant on rays which have sign transition ''' + device = rays_o.device + num_pixels, _ = rays_d.shape + sampler_pts = torch.zeros(num_pixels, 3).to(device).float() + sampler_dists = torch.zeros(num_pixels).to(device).float() + + intervals_dist = torch.linspace(0, 1, steps=self.n_steps).to(device).view(1, -1) + + pts_intervals = near + intervals_dist * (far - near) + points = rays_o[:, None, :] + pts_intervals[:, :, None] * rays_d[:, None, :] + + # Get the non convergent rays + mask_intersect_idx = torch.nonzero(sampler_mask).flatten() + points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :] + pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask] + + if self.network_inference: + sdf_func = sdf_network.sdf + else: + sdf_func = sdf_network.sdf_from_sdfvolume + + sdf_val_all = [] + for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0): + sdf_val_all.append(sdf_func(pnts, + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0]) + sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps) + + tmp = torch.sign(sdf_val) * torch.arange(self.n_steps, 0, -1).to(device).float().reshape( + (1, self.n_steps)) # Force argmin to return the first min value + sampler_pts_ind = torch.argmin(tmp, -1) + sampler_pts[mask_intersect_idx] = points[torch.arange(points.shape[0]), sampler_pts_ind, :] + sampler_dists[mask_intersect_idx] = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind] + + net_surface_pts = (sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0) + + # take points with minimal SDF value for P_out pixels + p_out_mask = ~net_surface_pts + n_p_out = p_out_mask.sum() + if n_p_out > 0: + out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1) + sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][torch.arange(n_p_out), out_pts_idx, + :] + sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[p_out_mask, :][ + torch.arange(n_p_out), out_pts_idx] + + # Get Network object mask + sampler_net_obj_mask = sampler_mask.clone() + sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False + + # Run Secant method + secant_pts = net_surface_pts + n_secant_pts = secant_pts.sum() + if n_secant_pts > 0: + # Get secant z predictions + z_high = pts_intervals[torch.arange(pts_intervals.shape[0]), sampler_pts_ind][secant_pts] + sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][secant_pts] + z_low = pts_intervals[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1] + sdf_low = sdf_val[secant_pts][torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1] + + cam_loc_secant = rays_o[mask_intersect_idx[secant_pts]] + ray_directions_secant = rays_d[mask_intersect_idx[secant_pts]] + z_pred_secant = self.secant(sdf_low, sdf_high, z_low, z_high, cam_loc_secant, ray_directions_secant, + sdf_network, conditional_volume) + + # Get points + sampler_pts[mask_intersect_idx[secant_pts]] = cam_loc_secant + z_pred_secant[:, + None] * ray_directions_secant + sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant + + return sampler_pts, sampler_net_obj_mask, sampler_dists + + def secant(self, sdf_low, sdf_high, z_low, z_high, rays_o, rays_d, sdf_network, conditional_volume): + ''' Runs the secant method for interval [z_low, z_high] for n_secant_steps ''' + + if self.network_inference: + sdf_func = sdf_network.sdf + else: + sdf_func = sdf_network.sdf_from_sdfvolume + + z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low + for i in range(self.n_secant_steps): + p_mid = rays_o + z_pred[:, None] * rays_d + sdf_mid = sdf_func(p_mid, + conditional_volume, lod=0, gru_fusion=False)['sdf_pts_scale%d' % 0].reshape(-1) + ind_low = (sdf_mid > 0).reshape(-1) + if ind_low.sum() > 0: + z_low[ind_low] = z_pred[ind_low] + sdf_low[ind_low] = sdf_mid[ind_low] + ind_high = sdf_mid < 0 + if ind_high.sum() > 0: + z_high[ind_high] = z_pred[ind_high] + sdf_high[ind_high] = sdf_mid[ind_high] + + z_pred = - sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low + + return z_pred # 1D tensor + + def minimal_sdf_points(self, num_pixels, sdf, cam_loc, ray_directions, mask, min_dis, max_dis): + ''' Find points with minimal SDF value on rays for P_out pixels ''' + device = sdf.device + n_mask_points = mask.sum() + + n = self.n_steps + # steps = torch.linspace(0.0, 1.0,n).to(device) + steps = torch.empty(n).uniform_(0.0, 1.0).to(device) + mask_max_dis = max_dis[mask].unsqueeze(-1) + mask_min_dis = min_dis[mask].unsqueeze(-1) + steps = steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis) + mask_min_dis + + mask_points = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)[mask] + mask_rays = ray_directions[mask, :] + + mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze(-1) * mask_rays.unsqueeze( + 1).repeat(1, n, 1) + points = mask_points_all.reshape(-1, 3) + + mask_sdf_all = [] + for pnts in torch.split(points, 100000, dim=0): + mask_sdf_all.append(sdf(pnts)) + + mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n) + min_vals, min_idx = mask_sdf_all.min(-1) + min_mask_points = mask_points_all.reshape(-1, n, 3)[torch.arange(0, n_mask_points), min_idx] + min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx] + + return min_mask_points, min_mask_dist diff --git a/SparseNeuS_demo_v1/models/featurenet.py b/SparseNeuS_demo_v1/models/featurenet.py new file mode 100644 index 0000000000000000000000000000000000000000..652e65967708f57a1722c5951d53e72f05ddf1d3 --- /dev/null +++ b/SparseNeuS_demo_v1/models/featurenet.py @@ -0,0 +1,91 @@ +import torch + +# ! amazing!!!! autograd.grad with set_detect_anomaly(True) will cause memory leak +# ! https://github.com/pytorch/pytorch/issues/51349 +# torch.autograd.set_detect_anomaly(True) +import torch.nn as nn +import torch.nn.functional as F +from inplace_abn import InPlaceABN + + +############################################# MVS Net models ################################################ +class ConvBnReLU(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size=3, stride=1, pad=1, + norm_act=InPlaceABN): + super(ConvBnReLU, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, + kernel_size, stride=stride, padding=pad, bias=False) + self.bn = norm_act(out_channels) + + def forward(self, x): + return self.bn(self.conv(x)) + + +class ConvBnReLU3D(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size=3, stride=1, pad=1, + norm_act=InPlaceABN): + super(ConvBnReLU3D, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, + kernel_size, stride=stride, padding=pad, bias=False) + self.bn = norm_act(out_channels) + # self.bn = nn.ReLU() + + def forward(self, x): + return self.bn(self.conv(x)) + + +################################### feature net ###################################### +class FeatureNet(nn.Module): + """ + output 3 levels of features using a FPN structure + """ + + def __init__(self, norm_act=InPlaceABN): + super(FeatureNet, self).__init__() + + self.conv0 = nn.Sequential( + ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act), + ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act)) + + self.conv1 = nn.Sequential( + ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act), + ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act), + ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act)) + + self.conv2 = nn.Sequential( + ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), + ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act), + ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act)) + + self.toplayer = nn.Conv2d(32, 32, 1) + self.lat1 = nn.Conv2d(16, 32, 1) + self.lat0 = nn.Conv2d(8, 32, 1) + + # to reduce channel size of the outputs from FPN + self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) + self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) + + def _upsample_add(self, x, y): + return F.interpolate(x, scale_factor=2, + mode="bilinear", align_corners=True) + y + + def forward(self, x): + # x: (B, 3, H, W) + conv0 = self.conv0(x) # (B, 8, H, W) + conv1 = self.conv1(conv0) # (B, 16, H//2, W//2) + conv2 = self.conv2(conv1) # (B, 32, H//4, W//4) + feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4) + feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2) + feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W) + + # reduce output channels + feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) + feat0 = self.smooth0(feat0) # (B, 8, H, W) + + # feats = {"level_0": feat0, + # "level_1": feat1, + # "level_2": feat2} + + return [feat2, feat1, feat0] # coarser to finer features diff --git a/SparseNeuS_demo_v1/models/fields.py b/SparseNeuS_demo_v1/models/fields.py new file mode 100644 index 0000000000000000000000000000000000000000..184e4a55399f56f8f505379ce4a14add8821c4c4 --- /dev/null +++ b/SparseNeuS_demo_v1/models/fields.py @@ -0,0 +1,333 @@ +# The codes are from NeuS + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from models.embedder import get_embedder + + +class SDFNetwork(nn.Module): + def __init__(self, + d_in, + d_out, + d_hidden, + n_layers, + skip_in=(4,), + multires=0, + bias=0.5, + scale=1, + geometric_init=True, + weight_norm=True, + activation='softplus', + conditional_type='multiply'): + super(SDFNetwork, self).__init__() + + dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] + + self.embed_fn_fine = None + + if multires > 0: + embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) + self.embed_fn_fine = embed_fn + dims[0] = input_ch + + self.num_layers = len(dims) + self.skip_in = skip_in + self.scale = scale + + for l in range(0, self.num_layers - 1): + if l + 1 in self.skip_in: + out_dim = dims[l + 1] - dims[0] + else: + out_dim = dims[l + 1] + + lin = nn.Linear(dims[l], out_dim) + + if geometric_init: + if l == self.num_layers - 2: + torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) + torch.nn.init.constant_(lin.bias, -bias) + elif multires > 0 and l == 0: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.constant_(lin.weight[:, 3:], 0.0) + torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) + elif multires > 0 and l in self.skip_in: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) # ? why dims[0] - 3 + else: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + if activation == 'softplus': + self.activation = nn.Softplus(beta=100) + else: + assert activation == 'relu' + self.activation = nn.ReLU() + + def forward(self, inputs): + inputs = inputs * self.scale + if self.embed_fn_fine is not None: + inputs = self.embed_fn_fine(inputs) + + x = inputs + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + if l in self.skip_in: + x = torch.cat([x, inputs], 1) / np.sqrt(2) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.activation(x) + return torch.cat([x[:, :1] / self.scale, x[:, 1:]], dim=-1) + + def sdf(self, x): + return self.forward(x)[:, :1] + + def sdf_hidden_appearance(self, x): + return self.forward(x) + + def gradient(self, x): + x.requires_grad_(True) + y = self.sdf(x) + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return gradients.unsqueeze(1) + + +class VarianceNetwork(nn.Module): + def __init__(self, d_in, d_out, d_hidden, n_layers, skip_in=(4,), multires=0): + super(VarianceNetwork, self).__init__() + + dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] + + self.embed_fn_fine = None + + if multires > 0: + embed_fn, input_ch = get_embedder(multires, normalize=False) + self.embed_fn_fine = embed_fn + dims[0] = input_ch + + self.num_layers = len(dims) + self.skip_in = skip_in + + for l in range(0, self.num_layers - 1): + if l + 1 in self.skip_in: + out_dim = dims[l + 1] - dims[0] + else: + out_dim = dims[l + 1] + + lin = nn.Linear(dims[l], out_dim) + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + self.softplus = nn.Softplus(beta=100) + + def forward(self, inputs): + if self.embed_fn_fine is not None: + inputs = self.embed_fn_fine(inputs) + + x = inputs + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + if l in self.skip_in: + x = torch.cat([x, inputs], 1) / np.sqrt(2) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) + + # return torch.exp(x) + return 1.0 / (self.softplus(x + 0.5) + 1e-3) + + def coarse(self, inputs): + return self.forward(inputs)[:, :1] + + def fine(self, inputs): + return self.forward(inputs)[:, 1:] + + +class FixVarianceNetwork(nn.Module): + def __init__(self, base): + super(FixVarianceNetwork, self).__init__() + self.base = base + self.iter_step = 0 + + def set_iter_step(self, iter_step): + self.iter_step = iter_step + + def forward(self, x): + return torch.ones([len(x), 1]) * np.exp(-self.iter_step / self.base) + + +class SingleVarianceNetwork(nn.Module): + def __init__(self, init_val=1.0): + super(SingleVarianceNetwork, self).__init__() + self.register_parameter('variance', nn.Parameter(torch.tensor(init_val))) + + def forward(self, x): + return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0) + + + +class RenderingNetwork(nn.Module): + def __init__( + self, + d_feature, + mode, + d_in, + d_out, + d_hidden, + n_layers, + weight_norm=True, + multires_view=0, + squeeze_out=True, + d_conditional_colors=0 + ): + super().__init__() + + self.mode = mode + self.squeeze_out = squeeze_out + dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] + + self.embedview_fn = None + if multires_view > 0: + embedview_fn, input_ch = get_embedder(multires_view) + self.embedview_fn = embedview_fn + dims[0] += (input_ch - 3) + + self.num_layers = len(dims) + + for l in range(0, self.num_layers - 1): + out_dim = dims[l + 1] + lin = nn.Linear(dims[l], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + + def forward(self, points, normals, view_dirs, feature_vectors): + if self.embedview_fn is not None: + view_dirs = self.embedview_fn(view_dirs) + + rendering_input = None + + if self.mode == 'idr': + rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1) + elif self.mode == 'no_view_dir': + rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) + elif self.mode == 'no_normal': + rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) + elif self.mode == 'no_points': + rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) + elif self.mode == 'no_points_no_view_dir': + rendering_input = torch.cat([normals, feature_vectors], dim=-1) + + x = rendering_input + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) + + if self.squeeze_out: + x = torch.sigmoid(x) + return x + + +# Code from nerf-pytorch +class NeRF(nn.Module): + def __init__(self, D=8, W=256, d_in=3, d_in_view=3, multires=0, multires_view=0, output_ch=4, skips=[4], + use_viewdirs=False): + """ + """ + super(NeRF, self).__init__() + self.D = D + self.W = W + self.d_in = d_in + self.d_in_view = d_in_view + self.input_ch = 3 + self.input_ch_view = 3 + self.embed_fn = None + self.embed_fn_view = None + + if multires > 0: + embed_fn, input_ch = get_embedder(multires, input_dims=d_in, normalize=False) + self.embed_fn = embed_fn + self.input_ch = input_ch + + if multires_view > 0: + embed_fn_view, input_ch_view = get_embedder(multires_view, input_dims=d_in_view, normalize=False) + self.embed_fn_view = embed_fn_view + self.input_ch_view = input_ch_view + + self.skips = skips + self.use_viewdirs = use_viewdirs + + self.pts_linears = nn.ModuleList( + [nn.Linear(self.input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + self.input_ch, W) + for i in + range(D - 1)]) + + ### Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) + self.views_linears = nn.ModuleList([nn.Linear(self.input_ch_view + W, W // 2)]) + + ### Implementation according to the paper + # self.views_linears = nn.ModuleList( + # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) + + if use_viewdirs: + self.feature_linear = nn.Linear(W, W) + self.alpha_linear = nn.Linear(W, 1) + self.rgb_linear = nn.Linear(W // 2, 3) + else: + self.output_linear = nn.Linear(W, output_ch) + + def forward(self, input_pts, input_views): + if self.embed_fn is not None: + input_pts = self.embed_fn(input_pts) + if self.embed_fn_view is not None: + input_views = self.embed_fn_view(input_views) + + h = input_pts + for i, l in enumerate(self.pts_linears): + h = self.pts_linears[i](h) + h = F.relu(h) + if i in self.skips: + h = torch.cat([input_pts, h], -1) + + if self.use_viewdirs: + alpha = self.alpha_linear(h) + feature = self.feature_linear(h) + h = torch.cat([feature, input_views], -1) + + for i, l in enumerate(self.views_linears): + h = self.views_linears[i](h) + h = F.relu(h) + + rgb = self.rgb_linear(h) + return alpha + 1.0, rgb + else: + assert False diff --git a/SparseNeuS_demo_v1/models/patch_projector.py b/SparseNeuS_demo_v1/models/patch_projector.py new file mode 100644 index 0000000000000000000000000000000000000000..24bb64527a1f9a9a1c6db8cd290d38f65b63b6d4 --- /dev/null +++ b/SparseNeuS_demo_v1/models/patch_projector.py @@ -0,0 +1,211 @@ +""" +Patch Projector +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from models.render_utils import sample_ptsFeatures_from_featureMaps + + +class PatchProjector(): + def __init__(self, patch_size): + self.h_patch_size = patch_size + self.offsets = build_patch_offset(patch_size) # the warping patch offsets index + + self.z_axis = torch.tensor([0, 0, 1]).float() + + self.plane_dist_thresh = 0.001 + + # * correctness checked + def pixel_warp(self, pts, imgs, intrinsics, + w2cs, img_wh=None): + """ + + :param pts: [N_rays, n_samples, 3] + :param imgs: [N_views, 3, H, W] + :param intrinsics: [N_views, 4, 4] + :param c2ws: [N_views, 4, 4] + :param img_wh: + :return: + """ + if img_wh is None: + N_views, _, sizeH, sizeW = imgs.shape + img_wh = [sizeW, sizeH] + + pts_color, valid_mask = sample_ptsFeatures_from_featureMaps( + pts, imgs, w2cs, intrinsics, img_wh, + proj_matrix=None, return_mask=True) # [N_views, c, N_rays, n_samples], [N_views, N_rays, n_samples] + + pts_color = pts_color.permute(2, 3, 0, 1) + valid_mask = valid_mask.permute(1, 2, 0) + + return pts_color, valid_mask # [N_rays, n_samples, N_views, 3] , [N_rays, n_samples, N_views] + + def patch_warp(self, pts, uv, normals, src_imgs, + ref_intrinsic, src_intrinsics, + ref_c2w, src_c2ws, img_wh=None + ): + """ + + :param pts: [N_rays, n_samples, 3] + :param uv : [N_rays, 2] normalized in (-1, 1) + :param normals: [N_rays, n_samples, 3] The normal of pt in world space + :param src_imgs: [N_src, 3, h, w] + :param ref_intrinsic: [4,4] + :param src_intrinsics: [N_src, 4, 4] + :param ref_c2w: [4,4] + :param src_c2ws: [N_src, 4, 4] + :return: + """ + device = pts.device + + N_rays, n_samples, _ = pts.shape + N_pts = N_rays * n_samples + + N_src, _, sizeH, sizeW = src_imgs.shape + + if img_wh is not None: + sizeW, sizeH = img_wh[0], img_wh[1] + + # scale uv from (-1, 1) to (0, W/H) + uv[:, 0] = (uv[:, 0] + 1) / 2. * (sizeW - 1) + uv[:, 1] = (uv[:, 1] + 1) / 2. * (sizeH - 1) + + ref_intr = ref_intrinsic[:3, :3] + inv_ref_intr = torch.inverse(ref_intr) + src_intrs = src_intrinsics[:, :3, :3] + inv_src_intrs = torch.inverse(src_intrs) + + ref_pose = ref_c2w + inv_ref_pose = torch.inverse(ref_pose) + src_poses = src_c2ws + inv_src_poses = torch.inverse(src_poses) + + ref_cam_loc = ref_pose[:3, 3].unsqueeze(0) # [1, 3] + sampled_dists = torch.norm(pts - ref_cam_loc, dim=-1) # [N_pts, 1] + + relative_proj = inv_src_poses @ ref_pose + R_rel = relative_proj[:, :3, :3] + t_rel = relative_proj[:, :3, 3:] + R_ref = inv_ref_pose[:3, :3] + t_ref = inv_ref_pose[:3, 3:] + + pts = pts.view(-1, 3) + normals = normals.view(-1, 3) + + with torch.no_grad(): + rot_normals = R_ref @ normals.unsqueeze(-1) # [N_pts, 3, 1] + points_in_ref = R_ref @ pts.unsqueeze( + -1) + t_ref # [N_pts, 3, 1] points in the reference frame coordiantes system + d1 = torch.sum(rot_normals * points_in_ref, dim=1).unsqueeze( + 1) # distance from the plane to ref camera center + + d2 = torch.sum(rot_normals.unsqueeze(1) * (-R_rel.transpose(1, 2) @ t_rel).unsqueeze(0), + dim=2) # distance from the plane to src camera center + valid_hom = (torch.abs(d1) > self.plane_dist_thresh) & ( + torch.abs(d1 - d2) > self.plane_dist_thresh) & ((d2 / d1) < 1) + + d1 = d1.squeeze() + sign = torch.sign(d1) + sign[sign == 0] = 1 + d = torch.clamp(torch.abs(d1), 1e-8) * sign + + H = src_intrs.unsqueeze(1) @ ( + R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ rot_normals.view(1, N_pts, 1, 3) / d.view(1, + N_pts, + 1, 1) + ) @ inv_ref_intr.view(1, 1, 3, 3) + + # replace invalid homs with fronto-parallel homographies + H_invalid = src_intrs.unsqueeze(1) @ ( + R_rel.unsqueeze(1) + t_rel.unsqueeze(1) @ self.z_axis.to(device).view(1, 1, 1, 3).expand(-1, N_pts, + -1, + -1) / sampled_dists.view( + 1, N_pts, 1, 1) + ) @ inv_ref_intr.view(1, 1, 3, 3) + tmp_m = ~valid_hom.view(-1, N_src).t() + H[tmp_m] = H_invalid[tmp_m] + + pixels = uv.view(N_rays, 1, 2) + self.offsets.float().to(device) + Npx = pixels.shape[1] + grid, warp_mask_full = self.patch_homography(H, pixels) + + warp_mask_full = warp_mask_full & (grid[..., 0] < (sizeW - self.h_patch_size)) & ( + grid[..., 1] < (sizeH - self.h_patch_size)) & (grid >= self.h_patch_size).all(dim=-1) + warp_mask_full = warp_mask_full.view(N_src, N_rays, n_samples, Npx) + + grid = torch.clamp(normalize(grid, sizeH, sizeW), -10, 10) + + sampled_rgb_val = F.grid_sample(src_imgs, grid.view(N_src, -1, 1, 2), align_corners=True).squeeze( + -1).transpose(1, 2) + sampled_rgb_val = sampled_rgb_val.view(N_src, N_rays, n_samples, Npx, 3) + + warp_mask_full = warp_mask_full.permute(1, 2, 0, 3).contiguous() # (N_rays, n_samples, N_src, Npx) + sampled_rgb_val = sampled_rgb_val.permute(1, 2, 0, 3, 4).contiguous() # (N_rays, n_samples, N_src, Npx, 3) + + return sampled_rgb_val, warp_mask_full + + def patch_homography(self, H, uv): + N, Npx = uv.shape[:2] + Nsrc = H.shape[0] + H = H.view(Nsrc, N, -1, 3, 3) + hom_uv = add_hom(uv) + + # einsum is 30 times faster + # tmp = (H.view(Nsrc, N, -1, 1, 3, 3) @ hom_uv.view(1, N, 1, -1, 3, 1)).squeeze(-1).view(Nsrc, -1, 3) + tmp = torch.einsum("vprik,pok->vproi", H, hom_uv).reshape(Nsrc, -1, 3) + + grid = tmp[..., :2] / torch.clamp(tmp[..., 2:], 1e-8) + mask = tmp[..., 2] > 0 + return grid, mask + + +def add_hom(pts): + try: + dev = pts.device + ones = torch.ones(pts.shape[:-1], device=dev).unsqueeze(-1) + return torch.cat((pts, ones), dim=-1) + + except AttributeError: + ones = np.ones((pts.shape[0], 1)) + return np.concatenate((pts, ones), axis=1) + + +def normalize(flow, h, w, clamp=None): + # either h and w are simple float or N torch.tensor where N batch size + try: + h.device + + except AttributeError: + h = torch.tensor(h, device=flow.device).float().unsqueeze(0) + w = torch.tensor(w, device=flow.device).float().unsqueeze(0) + + if len(flow.shape) == 4: + w = w.unsqueeze(1).unsqueeze(2) + h = h.unsqueeze(1).unsqueeze(2) + elif len(flow.shape) == 3: + w = w.unsqueeze(1) + h = h.unsqueeze(1) + elif len(flow.shape) == 5: + w = w.unsqueeze(0).unsqueeze(2).unsqueeze(2) + h = h.unsqueeze(0).unsqueeze(2).unsqueeze(2) + + res = torch.empty_like(flow) + if res.shape[-1] == 3: + res[..., 2] = 1 + + # for grid_sample with align_corners=True + # https://github.com/pytorch/pytorch/blob/c371542efc31b1abfe6f388042aa3ab0cef935f2/aten/src/ATen/native/GridSampler.h#L33 + res[..., 0] = 2 * flow[..., 0] / (w - 1) - 1 + res[..., 1] = 2 * flow[..., 1] / (h - 1) - 1 + + if clamp: + return torch.clamp(res, -clamp, clamp) + else: + return res + + +def build_patch_offset(h_patch_size): + offsets = torch.arange(-h_patch_size, h_patch_size + 1) + return torch.stack(torch.meshgrid(offsets, offsets, indexing="ij")[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2 diff --git a/SparseNeuS_demo_v1/models/projector.py b/SparseNeuS_demo_v1/models/projector.py new file mode 100644 index 0000000000000000000000000000000000000000..aa58d3f896edefff25cbb6fa713e7342d9b84a1d --- /dev/null +++ b/SparseNeuS_demo_v1/models/projector.py @@ -0,0 +1,425 @@ +# The codes are partly from IBRNet + +import torch +import torch.nn.functional as F +from models.render_utils import sample_ptsFeatures_from_featureMaps, sample_ptsFeatures_from_featureVolume + +def safe_l2_normalize(x, dim=None, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + +class Projector(): + """ + Obtain features from geometryVolume and rendering_feature_maps for generalized rendering + """ + + def compute_angle(self, xyz, query_c2w, supporting_c2ws): + """ + + :param xyz: [N_rays, n_samples,3 ] + :param query_c2w: [1,4,4] + :param supporting_c2ws: [n,4,4] + :return: + """ + N_rays, n_samples, _ = xyz.shape + num_views = supporting_c2ws.shape[0] + xyz = xyz.reshape(-1, 3) + + ray2tar_pose = (query_c2w[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) + ray2tar_pose /= (torch.norm(ray2tar_pose, dim=-1, keepdim=True) + 1e-6) + ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) + ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) + ray_diff = ray2tar_pose - ray2support_pose + ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) + ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) + ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) + ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) + ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product + return ray_diff.detach() + + + def compute_angle_view_independent(self, xyz, surface_normals, supporting_c2ws): + """ + + :param xyz: [N_rays, n_samples,3 ] + :param surface_normals: [N_rays, n_samples,3 ] + :param supporting_c2ws: [n,4,4] + :return: + """ + N_rays, n_samples, _ = xyz.shape + num_views = supporting_c2ws.shape[0] + xyz = xyz.reshape(-1, 3) + + ray2tar_pose = surface_normals + ray2support_pose = (supporting_c2ws[:, :3, 3].unsqueeze(1) - xyz.unsqueeze(0)) + ray2support_pose /= (torch.norm(ray2support_pose, dim=-1, keepdim=True) + 1e-6) + ray_diff = ray2tar_pose - ray2support_pose + ray_diff_norm = torch.norm(ray_diff, dim=-1, keepdim=True) + ray_diff_dot = torch.sum(ray2tar_pose * ray2support_pose, dim=-1, keepdim=True) + ray_diff_direction = ray_diff / torch.clamp(ray_diff_norm, min=1e-6) + ray_diff = torch.cat([ray_diff_direction, ray_diff_dot], dim=-1) + ray_diff = ray_diff.reshape((num_views, N_rays, n_samples, 4)) # the last dimension (4) is dot-product, + # and the first three dimensions is the normalized ray diff vector + return ray_diff.detach() + + @torch.no_grad() + def compute_z_diff(self, xyz, w2cs, intrinsics, pred_depth_values): + """ + compute the depth difference of query pts projected on the image and the predicted depth values of the image + :param xyz: [N_rays, n_samples,3 ] + :param w2cs: [N_views, 4, 4] + :param intrinsics: [N_views, 3, 3] + :param pred_depth_values: [N_views, N_rays, n_samples,1 ] + :param pred_depth_masks: [N_views, N_rays, n_samples] + :return: + """ + device = xyz.device + N_views = w2cs.shape[0] + N_rays, n_samples, _ = xyz.shape + proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :]) + + proj_rot = proj_matrix[:, :3, :3] + proj_trans = proj_matrix[:, :3, 3:] + + batch_xyz = xyz.permute(2, 0, 1).contiguous().view(1, 3, N_rays * n_samples).repeat(N_views, 1, 1) + + proj_xyz = proj_rot.bmm(batch_xyz) + proj_trans + + # X = proj_xyz[:, 0] + # Y = proj_xyz[:, 1] + Z = proj_xyz[:, 2].clamp(min=1e-3) # [N_views, N_rays*n_samples] + proj_z = Z.view(N_views, N_rays, n_samples, 1) + + z_diff = proj_z - pred_depth_values # [N_views, N_rays, n_samples,1 ] + + return z_diff + + def compute(self, + pts, + # * 3d geometry feature volumes + geometryVolume=None, + geometryVolumeMask=None, + vol_dims=None, + partial_vol_origin=None, + vol_size=None, + # * 2d rendering feature maps + rendering_feature_maps=None, + color_maps=None, + w2cs=None, + intrinsics=None, + img_wh=None, + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=None, + pred_depth_maps=None, # no use here + pred_depth_masks=None # no use here + ): + """ + extract features of pts for rendering + :param pts: + :param geometryVolume: + :param vol_dims: + :param partial_vol_origin: + :param vol_size: + :param rendering_feature_maps: + :param color_maps: + :param w2cs: + :param intrinsics: + :param img_wh: + :param rendering_img_idx: by default, we render the first view of w2cs + :return: + """ + device = pts.device + c2ws = torch.inverse(w2cs) + + if len(pts.shape) == 2: + pts = pts[None, :, :] + + N_rays, n_samples, _ = pts.shape + N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W) + + supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) + query_img_idx = torch.LongTensor([query_img_idx]).to(device) + + if query_c2w is None and query_img_idx > -1: + query_c2w = torch.index_select(c2ws, 0, query_img_idx) + supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) + supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) + supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) + supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) + supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) + + if pred_depth_maps is not None: + supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) + supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) + # print("N_supporting_views: ", N_views - 1) + N_supporting_views = N_views - 1 + else: + supporting_c2ws = c2ws + supporting_w2cs = w2cs + supporting_rendering_feature_maps = rendering_feature_maps + supporting_color_maps = color_maps + supporting_intrinsics = intrinsics + supporting_depth_maps = pred_depth_masks + supporting_depth_masks = pred_depth_masks + # print("N_supporting_views: ", N_views) + N_supporting_views = N_views + # import ipdb; ipdb.set_trace() + if geometryVolume is not None: + # * sample feature of pts from 3D feature volume + pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( + pts, geometryVolume, vol_dims, + partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples] + + if len(geometryVolumeMask.shape) == 3: + geometryVolumeMask = geometryVolumeMask[None, :, :, :] + + pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( + pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, + partial_vol_origin, vol_size) # [N_rays, n_samples, C] + + pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) + else: + pts_geometry_feature = None + pts_geometry_masks = None + + # * sample feature of pts from 2D feature maps + pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( + pts, supporting_rendering_feature_maps, supporting_w2cs, + supporting_intrinsics, img_wh, + return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples] + # import ipdb; ipdb.set_trace() + # * size (N_views, N_rays*n_samples, c) + pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() + + pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, + supporting_intrinsics, img_wh) + # * size (N_views, N_rays*n_samples, c) + pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() + + rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c] + + + ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4] + # import ipdb; ipdb.set_trace() + if pts_geometry_masks is not None: + final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ + pts_rendering_mask # [N_views, N_rays, n_samples] + else: + final_mask = pts_rendering_mask + # import ipdb; ipdb.set_trace() + z_diff, pts_pred_depth_masks = None, None + + if pred_depth_maps is not None: + pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, + supporting_intrinsics, img_wh) + pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, + 1).contiguous() # (N_views, N_rays*n_samples, 1) + + # - pts_pred_depth_masks are critical than final_mask, + # - the ray containing few invalid pts will be treated invalid + pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), + supporting_w2cs, + supporting_intrinsics, img_wh) + + pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, + 0] # (N_views, N_rays*n_samples) + + z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) + # import ipdb; ipdb.set_trace() + return pts_geometry_feature, rgb_feats, ray_diff, final_mask, z_diff, pts_pred_depth_masks + + + def compute_view_independent( + self, + pts, + # * 3d geometry feature volumes + geometryVolume=None, + geometryVolumeMask=None, + sdf_network=None, + lod=0, + vol_dims=None, + partial_vol_origin=None, + vol_size=None, + # * 2d rendering feature maps + rendering_feature_maps=None, + color_maps=None, + w2cs=None, + target_candidate_w2cs=None, + intrinsics=None, + img_wh=None, + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=None, + pred_depth_maps=None, # no use here + pred_depth_masks=None # no use here + ): + """ + extract features of pts for rendering + :param pts: + :param geometryVolume: + :param vol_dims: + :param partial_vol_origin: + :param vol_size: + :param rendering_feature_maps: + :param color_maps: + :param w2cs: + :param intrinsics: + :param img_wh: + :param rendering_img_idx: by default, we render the first view of w2cs + :return: + """ + device = pts.device + c2ws = torch.inverse(w2cs) + + if len(pts.shape) == 2: + pts = pts[None, :, :] + + N_rays, n_samples, _ = pts.shape + N_views = rendering_feature_maps.shape[0] # shape (N_views, C, H, W) + + supporting_img_idxs = torch.LongTensor([x for x in range(N_views) if x != query_img_idx]).to(device) + query_img_idx = torch.LongTensor([query_img_idx]).to(device) + + if query_c2w is None and query_img_idx > -1: + query_c2w = torch.index_select(c2ws, 0, query_img_idx) + supporting_c2ws = torch.index_select(c2ws, 0, supporting_img_idxs) + supporting_w2cs = torch.index_select(w2cs, 0, supporting_img_idxs) + supporting_rendering_feature_maps = torch.index_select(rendering_feature_maps, 0, supporting_img_idxs) + supporting_color_maps = torch.index_select(color_maps, 0, supporting_img_idxs) + supporting_intrinsics = torch.index_select(intrinsics, 0, supporting_img_idxs) + + if pred_depth_maps is not None: + supporting_depth_maps = torch.index_select(pred_depth_maps, 0, supporting_img_idxs) + supporting_depth_masks = torch.index_select(pred_depth_masks, 0, supporting_img_idxs) + # print("N_supporting_views: ", N_views - 1) + N_supporting_views = N_views - 1 + else: + supporting_c2ws = c2ws + supporting_w2cs = w2cs + supporting_rendering_feature_maps = rendering_feature_maps + supporting_color_maps = color_maps + supporting_intrinsics = intrinsics + supporting_depth_maps = pred_depth_masks + supporting_depth_masks = pred_depth_masks + # print("N_supporting_views: ", N_views) + N_supporting_views = N_views + # import ipdb; ipdb.set_trace() + if geometryVolume is not None: + # * sample feature of pts from 3D feature volume + pts_geometry_feature, pts_geometry_masks_0 = sample_ptsFeatures_from_featureVolume( + pts, geometryVolume, vol_dims, + partial_vol_origin, vol_size) # [N_rays, n_samples, C], [N_rays, n_samples] + + if len(geometryVolumeMask.shape) == 3: + geometryVolumeMask = geometryVolumeMask[None, :, :, :] + + pts_geometry_masks_1, _ = sample_ptsFeatures_from_featureVolume( + pts, geometryVolumeMask.to(geometryVolume.dtype), vol_dims, + partial_vol_origin, vol_size) # [N_rays, n_samples, C] + + pts_geometry_masks = pts_geometry_masks_0 & (pts_geometry_masks_1[..., 0] > 0) + else: + pts_geometry_feature = None + pts_geometry_masks = None + + # * sample feature of pts from 2D feature maps + pts_rendering_feats, pts_rendering_mask = sample_ptsFeatures_from_featureMaps( + pts, supporting_rendering_feature_maps, supporting_w2cs, + supporting_intrinsics, img_wh, + return_mask=True) # [N_views, C, N_rays, n_samples], # [N_views, N_rays, n_samples] + + # * size (N_views, N_rays*n_samples, c) + pts_rendering_feats = pts_rendering_feats.permute(0, 2, 3, 1).contiguous() + + pts_rendering_colors = sample_ptsFeatures_from_featureMaps(pts, supporting_color_maps, supporting_w2cs, + supporting_intrinsics, img_wh) + # * size (N_views, N_rays*n_samples, c) + pts_rendering_colors = pts_rendering_colors.permute(0, 2, 3, 1).contiguous() + + rgb_feats = torch.cat([pts_rendering_colors, pts_rendering_feats], dim=-1) # [N_views, N_rays, n_samples, 3+c] + + # import ipdb; ipdb.set_trace() + + gradients = sdf_network.gradient( + pts.reshape(-1, 3), # pts.squeeze(0), + geometryVolume.unsqueeze(0), + lod=lod + ).squeeze() + + surface_normals = safe_l2_normalize(gradients, dim=-1) # [npts, 3] + # input normals + ren_ray_diff = self.compute_angle_view_independent( + xyz=pts, + surface_normals=surface_normals, + supporting_c2ws=supporting_c2ws + ) + + # # choose closest target view direction from 32 candidate views + # # choose the closest source view as view direction instead of the normals vectors + # pts2src_centers = safe_l2_normalize((supporting_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3] + + # cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1] + # # choose the largest cosine distance as the view direction + # max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1] + + # chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3] + # ren_ray_diff = self.compute_angle_view_independent( + # xyz=pts, + # surface_normals=chosen_view_direction, + # supporting_c2ws=supporting_c2ws + # ) + + + + # # choose closest target view direction from 8 candidate views + # # choose the closest source view as view direction instead of the normals vectors + # target_candidate_c2ws = torch.inverse(target_candidate_w2cs) + # pts2src_centers = safe_l2_normalize((target_candidate_c2ws[:, :3, 3].unsqueeze(1) - pts)) # [N_views, npts, 3] + + # cosine_distance = torch.sum(pts2src_centers * surface_normals, dim=-1, keepdim=True) # [N_views, npts, 1] + # # choose the largest cosine distance as the view direction + # max_idx = torch.argmax(cosine_distance, dim=0) # [npts, 1] + + # chosen_view_direction = pts2src_centers[max_idx.squeeze(), torch.arange(pts.shape[1]), :] # [npts, 3] + # ren_ray_diff = self.compute_angle_view_independent( + # xyz=pts, + # surface_normals=chosen_view_direction, + # supporting_c2ws=supporting_c2ws + # ) + + + # ray_diff = self.compute_angle(pts, query_c2w, supporting_c2ws) # [N_views, N_rays, n_samples, 4] + # import ipdb; ipdb.set_trace() + + + # input_directions = safe_l2_normalize(pts) + # ren_ray_diff = self.compute_angle_view_independent( + # xyz=pts, + # surface_normals=input_directions, + # supporting_c2ws=supporting_c2ws + # ) + + if pts_geometry_masks is not None: + final_mask = pts_geometry_masks[None, :, :].repeat(N_supporting_views, 1, 1) & \ + pts_rendering_mask # [N_views, N_rays, n_samples] + else: + final_mask = pts_rendering_mask + # import ipdb; ipdb.set_trace() + z_diff, pts_pred_depth_masks = None, None + + if pred_depth_maps is not None: + pts_pred_depth_values = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_maps, supporting_w2cs, + supporting_intrinsics, img_wh) + pts_pred_depth_values = pts_pred_depth_values.permute(0, 2, 3, + 1).contiguous() # (N_views, N_rays*n_samples, 1) + + # - pts_pred_depth_masks are critical than final_mask, + # - the ray containing few invalid pts will be treated invalid + pts_pred_depth_masks = sample_ptsFeatures_from_featureMaps(pts, supporting_depth_masks.float(), + supporting_w2cs, + supporting_intrinsics, img_wh) + + pts_pred_depth_masks = pts_pred_depth_masks.permute(0, 2, 3, 1).contiguous()[:, :, :, + 0] # (N_views, N_rays*n_samples) + + z_diff = self.compute_z_diff(pts, supporting_w2cs, supporting_intrinsics, pts_pred_depth_values) + # import ipdb; ipdb.set_trace() + return pts_geometry_feature, rgb_feats, ren_ray_diff, final_mask, z_diff, pts_pred_depth_masks diff --git a/SparseNeuS_demo_v1/models/rays.py b/SparseNeuS_demo_v1/models/rays.py new file mode 100644 index 0000000000000000000000000000000000000000..aa45b18df32adc34124687fb06495c1652cb1678 --- /dev/null +++ b/SparseNeuS_demo_v1/models/rays.py @@ -0,0 +1,320 @@ +import os, torch +import numpy as np + +import torch.nn.functional as F + +def build_patch_offset(h_patch_size): + offsets = torch.arange(-h_patch_size, h_patch_size + 1) + return torch.stack(torch.meshgrid(offsets, offsets)[::-1], dim=-1).view(1, -1, 2) # nb_pixels_patch * 2 + + +def gen_rays_from_single_image(H, W, image, intrinsic, c2w, depth=None, mask=None): + """ + generate rays in world space, for image image + :param H: + :param W: + :param intrinsics: [3,3] + :param c2ws: [4,4] + :return: + """ + device = image.device + ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), + torch.linspace(0, W - 1, W), indexing="ij") # pytorch's meshgrid has indexing='ij' + p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3 + + # normalized ndc uv coordinates, (-1, 1) + ndc_u = 2 * xs / (W - 1) - 1 + ndc_v = 2 * ys / (H - 1) - 1 + rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) + + intrinsic_inv = torch.inverse(intrinsic) + + p = p.view(-1, 3).float().to(device) # N_rays, 3 + p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3 + rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3 + rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3 + + image = image.permute(1, 2, 0) + color = image.view(-1, 3) + depth = depth.view(-1, 1) if depth is not None else None + mask = mask.view(-1, 1) if mask is not None else torch.ones([H * W, 1]).to(device) + sample = { + 'rays_o': rays_o, + 'rays_v': rays_v, + 'rays_ndc_uv': rays_ndc_uv, + 'rays_color': color, + # 'rays_depth': depth, + 'rays_mask': mask, + 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth + } + if depth is not None: + sample['rays_depth'] = depth + + return sample + + +def gen_random_rays_from_single_image(H, W, N_rays, image, intrinsic, c2w, depth=None, mask=None, dilated_mask=None, + importance_sample=False, h_patch_size=3): + """ + generate random rays in world space, for a single image + :param H: + :param W: + :param N_rays: + :param image: [3, H, W] + :param intrinsic: [3,3] + :param c2w: [4,4] + :param depth: [H, W] + :param mask: [H, W] + :return: + """ + device = image.device + + if dilated_mask is None: + dilated_mask = mask + + if not importance_sample: + pixels_x = torch.randint(low=0, high=W, size=[N_rays]) + pixels_y = torch.randint(low=0, high=H, size=[N_rays]) + elif importance_sample and dilated_mask is not None: # sample more pts in the valid mask regions + pixels_x_1 = torch.randint(low=0, high=W, size=[N_rays // 4]) + pixels_y_1 = torch.randint(low=0, high=H, size=[N_rays // 4]) + + ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), + torch.linspace(0, W - 1, W), indexing="ij") # pytorch's meshgrid has indexing='ij' + p = torch.stack([xs, ys], dim=-1) # H, W, 2 + + try: + p_valid = p[dilated_mask > 0] # [num, 2] + random_idx = torch.randint(low=0, high=p_valid.shape[0], size=[N_rays // 4 * 3]) + except: + print("dilated_mask.shape: ", dilated_mask.shape) + print("dilated_mask valid number", dilated_mask.sum()) + + raise ValueError("hhhh") + p_select = p_valid[random_idx] # [N_rays//2, 2] + pixels_x_2 = p_select[:, 0] + pixels_y_2 = p_select[:, 1] + + pixels_x = torch.cat([pixels_x_1, pixels_x_2], dim=0).to(torch.int64) + pixels_y = torch.cat([pixels_y_1, pixels_y_2], dim=0).to(torch.int64) + + # - crop patch from images + offsets = build_patch_offset(h_patch_size).to(device) + grid_patch = torch.stack([pixels_x, pixels_y], dim=-1).view(-1, 1, 2) + offsets.float() # [N_pts, Npx, 2] + patch_mask = (pixels_x > h_patch_size) * (pixels_x < (W - h_patch_size)) * (pixels_y > h_patch_size) * ( + pixels_y < H - h_patch_size) # [N_pts] + grid_patch_u = 2 * grid_patch[:, :, 0] / (W - 1) - 1 + grid_patch_v = 2 * grid_patch[:, :, 1] / (H - 1) - 1 + grid_patch_uv = torch.stack([grid_patch_u, grid_patch_v], dim=-1) # [N_pts, Npx, 2] + patch_color = F.grid_sample(image[None, :, :, :], grid_patch_uv[None, :, :, :], mode='bilinear', + padding_mode='zeros',align_corners=True)[0] # [3, N_pts, Npx] + patch_color = patch_color.permute(1, 2, 0).contiguous() + + # normalized ndc uv coordinates, (-1, 1) + ndc_u = 2 * pixels_x / (W - 1) - 1 + ndc_v = 2 * pixels_y / (H - 1) - 1 + rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float().to(device) + + image = image.permute(1, 2, 0) # H ,W, C + color = image[(pixels_y, pixels_x)] # N_rays, 3 + + if mask is not None: + mask = mask[(pixels_y, pixels_x)] # N_rays + patch_mask = patch_mask * mask # N_rays + mask = mask.view(-1, 1) + else: + mask = torch.ones([N_rays, 1]) + + if depth is not None: + depth = depth[(pixels_y, pixels_x)] # N_rays + depth = depth.view(-1, 1) + + intrinsic_inv = torch.inverse(intrinsic) + + p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays, 3 + p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays, 3 + rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays, 3 + rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays, 3 + + sample = { + 'rays_o': rays_o, + 'rays_v': rays_v, + 'rays_ndc_uv': rays_ndc_uv, + 'rays_color': color, + # 'rays_depth': depth, + 'rays_mask': mask, + 'rays_norm_XYZ_cam': p, # - XYZ_cam, before multiply depth, + 'rays_patch_color': patch_color, + 'rays_patch_mask': patch_mask.view(-1, 1) + } + + if depth is not None: + sample['rays_depth'] = depth + + return sample + + +def gen_random_rays_of_patch_from_single_image(H, W, N_rays, num_neighboring_pts, patch_size, + image, intrinsic, c2w, depth=None, mask=None): + """ + generate random rays in world space, for a single image + sample rays from local patches + :param H: + :param W: + :param N_rays: the number of center rays of patches + :param image: [3, H, W] + :param intrinsic: [3,3] + :param c2w: [4,4] + :param depth: [H, W] + :param mask: [H, W] + :return: + """ + device = image.device + patch_radius_max = patch_size // 2 + + unit_u = 2 / (W - 1) + unit_v = 2 / (H - 1) + + pixels_x_center = torch.randint(low=patch_size, high=W - patch_size, size=[N_rays]) + pixels_y_center = torch.randint(low=patch_size, high=H - patch_size, size=[N_rays]) + + # normalized ndc uv coordinates, (-1, 1) + ndc_u_center = 2 * pixels_x_center / (W - 1) - 1 + ndc_v_center = 2 * pixels_y_center / (H - 1) - 1 + ndc_uv_center = torch.stack([ndc_u_center, ndc_v_center], dim=-1).view(-1, 2).float().to(device)[:, None, + :] # [N_rays, 1, 2] + + shift_u, shift_v = torch.rand([N_rays, num_neighboring_pts, 1]), torch.rand( + [N_rays, num_neighboring_pts, 1]) # uniform distribution of [0,1) + shift_u = 2 * (shift_u - 0.5) # mapping to [-1, 1) + shift_v = 2 * (shift_v - 0.5) + + # - avoid sample points which are too close to center point + shift_uv = torch.cat([(shift_u * patch_radius_max) * unit_u, (shift_v * patch_radius_max) * unit_v], + dim=-1) # [N_rays, num_npts, 2] + neighboring_pts_uv = ndc_uv_center + shift_uv # [N_rays, num_npts, 2] + + sampled_pts_uv = torch.cat([ndc_uv_center, neighboring_pts_uv], dim=1) # concat the center point + + # sample the gts + color = F.grid_sample(image[None, :, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear', + align_corners=True)[0] # [3, N_rays, num_npts] + depth = F.grid_sample(depth[None, None, :, :], sampled_pts_uv[None, :, :, :], mode='bilinear', + align_corners=True)[0] # [1, N_rays, num_npts] + + mask = F.grid_sample(mask[None, None, :, :].to(torch.float32), sampled_pts_uv[None, :, :, :], mode='nearest', + align_corners=True).to(torch.int64)[0] # [1, N_rays, num_npts] + + intrinsic_inv = torch.inverse(intrinsic) + + sampled_pts_uv = sampled_pts_uv.view(N_rays * (1 + num_neighboring_pts), 2) + color = color.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 3) + depth = depth.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1) + mask = mask.permute(1, 2, 0).contiguous().view(N_rays * (1 + num_neighboring_pts), 1) + + pixels_x = (sampled_pts_uv[:, 0] + 1) * (W - 1) / 2 + pixels_y = (sampled_pts_uv[:, 1] + 1) * (H - 1) / 2 + p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).float().to(device) # N_rays*num_pts, 3 + p = torch.matmul(intrinsic_inv[None, :3, :3], p[:, :, None]).squeeze() # N_rays*num_pts, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # N_rays*num_pts, 3 + rays_v = torch.matmul(c2w[None, :3, :3], rays_v[:, :, None]).squeeze() # N_rays*num_pts, 3 + rays_o = c2w[None, :3, 3].expand(rays_v.shape) # N_rays*num_pts, 3 + + sample = { + 'rays_o': rays_o, + 'rays_v': rays_v, + 'rays_ndc_uv': sampled_pts_uv, + 'rays_color': color, + 'rays_depth': depth, + 'rays_mask': mask, + # 'rays_norm_XYZ_cam': p # - XYZ_cam, before multiply depth + } + + return sample + + +def gen_random_rays_from_batch_images(H, W, N_rays, images, intrinsics, c2ws, depths=None, masks=None): + """ + + :param H: + :param W: + :param N_rays: + :param images: [B,3,H,W] + :param intrinsics: [B, 3, 3] + :param c2ws: [B, 4, 4] + :param depths: [B,H,W] + :param masks: [B,H,W] + :return: + """ + assert len(images.shape) == 4 + + rays_o = [] + rays_v = [] + rays_color = [] + rays_depth = [] + rays_mask = [] + for i in range(images.shape[0]): + sample = gen_random_rays_from_single_image(H, W, N_rays, images[i], intrinsics[i], c2ws[i], + depth=depths[i] if depths is not None else None, + mask=masks[i] if masks is not None else None) + rays_o.append(sample['rays_o']) + rays_v.append(sample['rays_v']) + rays_color.append(sample['rays_color']) + if depths is not None: + rays_depth.append(sample['rays_depth']) + if masks is not None: + rays_mask.append(sample['rays_mask']) + + sample = { + 'rays_o': torch.stack(rays_o, dim=0), # [batch, N_rays, 3] + 'rays_v': torch.stack(rays_v, dim=0), + 'rays_color': torch.stack(rays_color, dim=0), + 'rays_depth': torch.stack(rays_depth, dim=0) if depths is not None else None, + 'rays_mask': torch.stack(rays_mask, dim=0) if masks is not None else None + } + return sample + + +from scipy.spatial.transform import Rotation as Rot +from scipy.spatial.transform import Slerp + + +def gen_rays_between(c2w_0, c2w_1, intrinsic, ratio, H, W, resolution_level=1): + device = c2w_0.device + + l = resolution_level + tx = torch.linspace(0, W - 1, W // l) + ty = torch.linspace(0, H - 1, H // l) + pixels_x, pixels_y = torch.meshgrid(tx, ty, indexing="ij") + p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1).to(device) # W, H, 3 + + intrinsic_inv = torch.inverse(intrinsic[:3, :3]) + p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 + trans = c2w_0[:3, 3] * (1.0 - ratio) + c2w_1[:3, 3] * ratio + + pose_0 = c2w_0.detach().cpu().numpy() + pose_1 = c2w_1.detach().cpu().numpy() + pose_0 = np.linalg.inv(pose_0) + pose_1 = np.linalg.inv(pose_1) + rot_0 = pose_0[:3, :3] + rot_1 = pose_1[:3, :3] + rots = Rot.from_matrix(np.stack([rot_0, rot_1])) + key_times = [0, 1] + key_rots = [rot_0, rot_1] + slerp = Slerp(key_times, rots) + rot = slerp(ratio) + pose = np.diag([1.0, 1.0, 1.0, 1.0]) + pose = pose.astype(np.float32) + pose[:3, :3] = rot.as_matrix() + pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3] + pose = np.linalg.inv(pose) + + c2w = torch.from_numpy(pose).to(device) + rot = torch.from_numpy(pose[:3, :3]).cuda() + trans = torch.from_numpy(pose[:3, 3]).cuda() + rays_v = torch.matmul(rot[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3 + rays_o = trans[None, None, :3].expand(rays_v.shape) # W, H, 3 + return c2w, rays_o.transpose(0, 1).contiguous().view(-1, 3), rays_v.transpose(0, 1).contiguous().view(-1, 3) diff --git a/SparseNeuS_demo_v1/models/render_utils.py b/SparseNeuS_demo_v1/models/render_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c14d5761234a16a19ed10509f9f0972adaf04c9a --- /dev/null +++ b/SparseNeuS_demo_v1/models/render_utils.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ops.back_project import cam2pixel + + +def sample_pdf(bins, weights, n_samples, det=False): + ''' + :param bins: tensor of shape [N_rays, M+1], M is the number of bins + :param weights: tensor of shape [N_rays, M] + :param N_samples: number of samples along each ray + :param det: if True, will perform deterministic sampling + :return: [N_rays, N_samples] + ''' + device = weights.device + + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]).to(device), cdf], -1) + + # if bins.shape[1] != weights.shape[1]: # - minor modification, add this constraint + # cdf = torch.cat([torch.zeros_like(cdf[..., :1]).to(device), cdf], -1) + # Take uniform samples + if det: + u = torch.linspace(0. + 0.5 / n_samples, 1. - 0.5 / n_samples, steps=n_samples).to(device) + u = u.expand(list(cdf.shape[:-1]) + [n_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [n_samples]).to(device) + + # Invert CDF + u = u.contiguous() + # inds = searchsorted(cdf, u, side='right') + inds = torch.searchsorted(cdf, u, right=True) + + below = torch.max(torch.zeros_like(inds - 1), inds - 1) + above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (batch, n_samples, 2) + + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1] - cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u - cdf_g[..., 0]) / denom + samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) + + # pdb.set_trace() + return samples + + +def sample_ptsFeatures_from_featureVolume(pts, featureVolume, vol_dims=None, partial_vol_origin=None, vol_size=None): + """ + sample feature of pts_wrd from featureVolume, all in world space + :param pts: [N_rays, n_samples, 3] + :param featureVolume: [C,wX,wY,wZ] + :param vol_dims: [3] "3" for dimX, dimY, dimZ + :param partial_vol_origin: [3] + :return: pts_feature: [N_rays, n_samples, C] + :return: valid_mask: [N_rays] + """ + + N_rays, n_samples, _ = pts.shape + + if vol_dims is None: + pts_normalized = pts + else: + # normalized to (-1, 1) + pts_normalized = 2 * (pts - partial_vol_origin[None, None, :]) / (vol_size * (vol_dims[None, None, :] - 1)) - 1 + + valid_mask = (torch.abs(pts_normalized[:, :, 0]) < 1.0) & ( + torch.abs(pts_normalized[:, :, 1]) < 1.0) & ( + torch.abs(pts_normalized[:, :, 2]) < 1.0) # (N_rays, n_samples) + + pts_normalized = torch.flip(pts_normalized, dims=[-1]) # ! reverse the xyz for grid_sample + + # ! checked grid_sample, (x,y,z) is for (D,H,W), reverse for (W,H,D) + pts_feature = F.grid_sample(featureVolume[None, :, :, :, :], pts_normalized[None, None, :, :, :], + padding_mode='zeros', + align_corners=True).view(-1, N_rays, n_samples) # [C, N_rays, n_samples] + + pts_feature = pts_feature.permute(1, 2, 0) # [N_rays, n_samples, C] + return pts_feature, valid_mask + + +def sample_ptsFeatures_from_featureMaps(pts, featureMaps, w2cs, intrinsics, WH, proj_matrix=None, return_mask=False): + """ + sample features of pts from 2d feature maps + :param pts: [N_rays, N_samples, 3] + :param featureMaps: [N_views, C, H, W] + :param w2cs: [N_views, 4, 4] + :param intrinsics: [N_views, 3, 3] + :param proj_matrix: [N_views, 4, 4] + :param HW: + :return: + """ + # normalized to (-1, 1) + N_rays, n_samples, _ = pts.shape + N_views = featureMaps.shape[0] + + if proj_matrix is None: + proj_matrix = torch.matmul(intrinsics, w2cs[:, :3, :]) + + pts = pts.permute(2, 0, 1).contiguous().view(1, 3, N_rays, n_samples).repeat(N_views, 1, 1, 1) + pixel_grids = cam2pixel(pts, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:], + 'zeros', sizeH=WH[1], sizeW=WH[0]) # (nviews, N_rays, n_samples, 2) + + valid_mask = (torch.abs(pixel_grids[:, :, :, 0]) < 1.0) & ( + torch.abs(pixel_grids[:, :, :, 1]) < 1.00) # (nviews, N_rays, n_samples) + + pts_feature = F.grid_sample(featureMaps, pixel_grids, + padding_mode='zeros', + align_corners=True) # [N_views, C, N_rays, n_samples] + + if return_mask: + return pts_feature, valid_mask + else: + return pts_feature diff --git a/SparseNeuS_demo_v1/models/rendering_network.py b/SparseNeuS_demo_v1/models/rendering_network.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc984223450a609024a65956439ff741a6b133d --- /dev/null +++ b/SparseNeuS_demo_v1/models/rendering_network.py @@ -0,0 +1,129 @@ +# the codes are partly borrowed from IBRNet + +import torch +import torch.nn as nn +import torch.nn.functional as F + +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) + + +# default tensorflow initialization of linear layers +def weights_init(m): + if isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias.data) + + +@torch.jit.script +def fused_mean_variance(x, weight): + mean = torch.sum(x * weight, dim=2, keepdim=True) + var = torch.sum(weight * (x - mean) ** 2, dim=2, keepdim=True) + return mean, var + + +class GeneralRenderingNetwork(nn.Module): + """ + This model is not sensitive to finetuning + """ + + def __init__(self, in_geometry_feat_ch=8, in_rendering_feat_ch=56, anti_alias_pooling=True): + super(GeneralRenderingNetwork, self).__init__() + + self.in_geometry_feat_ch = in_geometry_feat_ch + self.in_rendering_feat_ch = in_rendering_feat_ch + self.anti_alias_pooling = anti_alias_pooling + + if self.anti_alias_pooling: + self.s = nn.Parameter(torch.tensor(0.2), requires_grad=True) + activation_func = nn.ELU(inplace=True) + + self.ray_dir_fc = nn.Sequential(nn.Linear(4, 16), + activation_func, + nn.Linear(16, in_rendering_feat_ch + 3), + activation_func) + + self.base_fc = nn.Sequential(nn.Linear((in_rendering_feat_ch + 3) * 3 + in_geometry_feat_ch, 64), + activation_func, + nn.Linear(64, 32), + activation_func) + + self.vis_fc = nn.Sequential(nn.Linear(32, 32), + activation_func, + nn.Linear(32, 33), + activation_func, + ) + + self.vis_fc2 = nn.Sequential(nn.Linear(32, 32), + activation_func, + nn.Linear(32, 1), + nn.Sigmoid() + ) + + self.rgb_fc = nn.Sequential(nn.Linear(32 + 1 + 4, 16), + activation_func, + nn.Linear(16, 8), + activation_func, + nn.Linear(8, 1)) + + self.base_fc.apply(weights_init) + self.vis_fc2.apply(weights_init) + self.vis_fc.apply(weights_init) + self.rgb_fc.apply(weights_init) + + def forward(self, geometry_feat, rgb_feat, ray_diff, mask): + ''' + :param geometry_feat: geometry features indicates sdf [n_rays, n_samples, n_feat] + :param rgb_feat: rgbs and image features [n_views, n_rays, n_samples, n_feat] + :param ray_diff: ray direction difference [n_views, n_rays, n_samples, 4], first 3 channels are directions, + last channel is inner product + :param mask: mask for whether each projection is valid or not. [n_views, n_rays, n_samples] + :return: rgb and density output, [n_rays, n_samples, 4] + ''' + + rgb_feat = rgb_feat.permute(1, 2, 0, 3).contiguous() + ray_diff = ray_diff.permute(1, 2, 0, 3).contiguous() + mask = mask[:, :, :, None].permute(1, 2, 0, 3).contiguous() + num_views = rgb_feat.shape[2] + geometry_feat = geometry_feat[:, :, None, :].repeat(1, 1, num_views, 1) + + direction_feat = self.ray_dir_fc(ray_diff) + rgb_in = rgb_feat[..., :3] + rgb_feat = rgb_feat + direction_feat + + if self.anti_alias_pooling: + _, dot_prod = torch.split(ray_diff, [3, 1], dim=-1) + exp_dot_prod = torch.exp(torch.abs(self.s) * (dot_prod - 1)) + weight = (exp_dot_prod - torch.min(exp_dot_prod, dim=2, keepdim=True)[0]) * mask + weight = weight / (torch.sum(weight, dim=2, keepdim=True) + 1e-8) + else: + weight = mask / (torch.sum(mask, dim=2, keepdim=True) + 1e-8) + + # compute mean and variance across different views for each point + mean, var = fused_mean_variance(rgb_feat, weight) # [n_rays, n_samples, 1, n_feat] + globalfeat = torch.cat([mean, var], dim=-1) # [n_rays, n_samples, 1, 2*n_feat] + + x = torch.cat([geometry_feat, globalfeat.expand(-1, -1, num_views, -1), rgb_feat], + dim=-1) # [n_rays, n_samples, n_views, 3*n_feat+n_geo_feat] + x = self.base_fc(x) + + x_vis = self.vis_fc(x * weight) + x_res, vis = torch.split(x_vis, [x_vis.shape[-1] - 1, 1], dim=-1) + vis = F.sigmoid(vis) * mask + x = x + x_res + vis = self.vis_fc2(x * vis) * mask + + # rgb computation + x = torch.cat([x, vis, ray_diff], dim=-1) + x = self.rgb_fc(x) + x = x.masked_fill(mask == 0, -1e9) + blending_weights_valid = F.softmax(x, dim=2) # color blending + rgb_out = torch.sum(rgb_in * blending_weights_valid, dim=2) + + mask = mask.detach().to(rgb_out.dtype) # [n_rays, n_samples, n_views, 1] + mask = torch.sum(mask, dim=2, keepdim=False) + mask = mask >= 2 # more than 2 views see the point + mask = torch.sum(mask.to(rgb_out.dtype), dim=1, keepdim=False) + valid_mask = mask > 8 # valid rays, more than 8 valid samples + return rgb_out, valid_mask # (N_rays, n_samples, 3), (N_rays, 1) diff --git a/SparseNeuS_demo_v1/models/sparse_neus_renderer.py b/SparseNeuS_demo_v1/models/sparse_neus_renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..96ffc7b547e0f83a177a81f36be38375d9cd26fb --- /dev/null +++ b/SparseNeuS_demo_v1/models/sparse_neus_renderer.py @@ -0,0 +1,985 @@ +""" +The codes are heavily borrowed from NeuS +""" + +import os +import cv2 as cv +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import logging +import mcubes +from icecream import ic +from models.render_utils import sample_pdf + +from models.projector import Projector +from tsparse.torchsparse_utils import sparse_to_dense_channel + +from models.fast_renderer import FastRenderer + +from models.patch_projector import PatchProjector + + +class SparseNeuSRenderer(nn.Module): + """ + conditional neus render; + optimize on normalized world space; + warped by nn.Module to support DataParallel traning + """ + + def __init__(self, + rendering_network_outside, + sdf_network, + variance_network, + rendering_network, + n_samples, + n_importance, + n_outside, + perturb, + alpha_type='div', + conf=None + ): + super(SparseNeuSRenderer, self).__init__() + + self.conf = conf + self.base_exp_dir = conf['general.base_exp_dir'] + + # network setups + self.rendering_network_outside = rendering_network_outside + self.sdf_network = sdf_network + self.variance_network = variance_network + self.rendering_network = rendering_network + + self.n_samples = n_samples + self.n_importance = n_importance + self.n_outside = n_outside + self.perturb = perturb + self.alpha_type = alpha_type + + self.rendering_projector = Projector() # used to obtain features for generalized rendering + + self.h_patch_size = self.conf.get_int('model.h_patch_size', default=3) + self.patch_projector = PatchProjector(self.h_patch_size) + + self.ray_tracer = FastRenderer() # ray_tracer to extract depth maps from sdf_volume + + # - fitted rendering or general rendering + try: + self.if_fitted_rendering = self.sdf_network.if_fitted_rendering + except: + self.if_fitted_rendering = False + + def up_sample(self, rays_o, rays_d, z_vals, sdf, n_importance, inv_variance, + conditional_valid_mask_volume=None): + device = rays_o.device + batch_size, n_samples = z_vals.shape + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] # n_rays, n_samples, 3 + + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(batch_size, n_samples) + pts_mask = pts_mask[:, :-1] * pts_mask[:, 1:] # [batch_size, n_samples-1] + else: + pts_mask = torch.ones([batch_size, n_samples]).to(pts.device) + + sdf = sdf.reshape(batch_size, n_samples) + prev_sdf, next_sdf = sdf[:, :-1], sdf[:, 1:] + prev_z_vals, next_z_vals = z_vals[:, :-1], z_vals[:, 1:] + mid_sdf = (prev_sdf + next_sdf) * 0.5 + dot_val = None + if self.alpha_type == 'uniform': + dot_val = torch.ones([batch_size, n_samples - 1]) * -1.0 + else: + dot_val = (next_sdf - prev_sdf) / (next_z_vals - prev_z_vals + 1e-5) + prev_dot_val = torch.cat([torch.zeros([batch_size, 1]).to(device), dot_val[:, :-1]], dim=-1) + dot_val = torch.stack([prev_dot_val, dot_val], dim=-1) + dot_val, _ = torch.min(dot_val, dim=-1, keepdim=False) + dot_val = dot_val.clip(-10.0, 0.0) * pts_mask + dist = (next_z_vals - prev_z_vals) + prev_esti_sdf = mid_sdf - dot_val * dist * 0.5 + next_esti_sdf = mid_sdf + dot_val * dist * 0.5 + prev_cdf = torch.sigmoid(prev_esti_sdf * inv_variance) + next_cdf = torch.sigmoid(next_esti_sdf * inv_variance) + alpha_sdf = (prev_cdf - next_cdf + 1e-5) / (prev_cdf + 1e-5) + + alpha = alpha_sdf + + # - apply pts_mask + alpha = pts_mask * alpha + + weights = alpha * torch.cumprod( + torch.cat([torch.ones([batch_size, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, :-1] + + z_samples = sample_pdf(z_vals, weights, n_importance, det=True).detach() + return z_samples + + def cat_z_vals(self, rays_o, rays_d, z_vals, new_z_vals, sdf, lod, + sdf_network, gru_fusion, + # * related to conditional feature + conditional_volume=None, + conditional_valid_mask_volume=None + ): + device = rays_o.device + batch_size, n_samples = z_vals.shape + _, n_importance = new_z_vals.shape + pts = rays_o[:, None, :] + rays_d[:, None, :] * new_z_vals[..., :, None] + + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(batch_size, n_importance) + pts_mask_bool = (pts_mask > 0).view(-1) + else: + pts_mask = torch.ones([batch_size, n_importance]).to(pts.device) + + new_sdf = torch.ones([batch_size * n_importance, 1]).to(pts.dtype).to(device) * 100 + + if torch.sum(pts_mask) > 1: + new_outputs = sdf_network.sdf(pts.reshape(-1, 3)[pts_mask_bool], conditional_volume, lod=lod) + new_sdf[pts_mask_bool] = new_outputs['sdf_pts_scale%d' % lod] # .reshape(batch_size, n_importance) + + new_sdf = new_sdf.view(batch_size, n_importance) + + z_vals = torch.cat([z_vals, new_z_vals], dim=-1) + sdf = torch.cat([sdf, new_sdf], dim=-1) + + z_vals, index = torch.sort(z_vals, dim=-1) + xx = torch.arange(batch_size)[:, None].expand(batch_size, n_samples + n_importance).reshape(-1) + index = index.reshape(-1) + sdf = sdf[(xx, index)].reshape(batch_size, n_samples + n_importance) + + return z_vals, sdf + + @torch.no_grad() + def get_pts_mask_for_conditional_volume(self, pts, mask_volume): + """ + + :param pts: [N, 3] + :param mask_volume: [1, 1, X, Y, Z] + :return: + """ + num_pts = pts.shape[0] + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + pts_mask = F.grid_sample(mask_volume, pts, mode='nearest') # [1, c, 1, 1, num_pts] + pts_mask = pts_mask.view(-1, num_pts).permute(1, 0).contiguous() # [num_pts, 1] + + return pts_mask + + def render_core(self, + rays_o, + rays_d, + z_vals, + sample_dist, + lod, + sdf_network, + rendering_network, + background_alpha=None, # - no use here + background_sampled_color=None, # - no use here + background_rgb=None, # - no use here + alpha_inter_ratio=0.0, + # * related to conditional feature + conditional_volume=None, + conditional_valid_mask_volume=None, + # * 2d feature maps + feature_maps=None, + color_maps=None, + w2cs=None, + intrinsics=None, + img_wh=None, + query_c2w=None, # - used for testing + if_general_rendering=True, + if_render_with_grad=True, + # * used for blending mlp rendering network + img_index=None, + rays_uv=None, + # * used for clear bg and fg + bg_num=0 + ): + device = rays_o.device + N_rays = rays_o.shape[0] + _, n_samples = z_vals.shape + dists = z_vals[..., 1:] - z_vals[..., :-1] + dists = torch.cat([dists, torch.Tensor([sample_dist]).expand(dists[..., :1].shape).to(device)], -1) + + mid_z_vals = z_vals + dists * 0.5 + mid_dists = mid_z_vals[..., 1:] - mid_z_vals[..., :-1] + + pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] # n_rays, n_samples, 3 + dirs = rays_d[:, None, :].expand(pts.shape) + + pts = pts.reshape(-1, 3) + dirs = dirs.reshape(-1, 3) + + # * if conditional_volume is restored from sparse volume, need mask for pts + if conditional_valid_mask_volume is not None: + pts_mask = self.get_pts_mask_for_conditional_volume(pts, conditional_valid_mask_volume) + pts_mask = pts_mask.reshape(N_rays, n_samples).float().detach() + pts_mask_bool = (pts_mask > 0).view(-1) + + if torch.sum(pts_mask_bool.float()) < 1: # ! when render out image, may meet this problem + pts_mask_bool[:100] = True + + else: + pts_mask = torch.ones([N_rays, n_samples]).to(pts.device) + # import ipdb; ipdb.set_trace() + # pts_valid = pts[pts_mask_bool] + sdf_nn_output = sdf_network.sdf(pts[pts_mask_bool], conditional_volume, lod=lod) + + sdf = torch.ones([N_rays * n_samples, 1]).to(pts.dtype).to(device) * 100 + sdf[pts_mask_bool] = sdf_nn_output['sdf_pts_scale%d' % lod] # [N_rays*n_samples, 1] + feature_vector_valid = sdf_nn_output['sdf_features_pts_scale%d' % lod] + feature_vector = torch.zeros([N_rays * n_samples, feature_vector_valid.shape[1]]).to(pts.dtype).to(device) + feature_vector[pts_mask_bool] = feature_vector_valid + + # * estimate alpha from sdf + gradients = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) + # import ipdb; ipdb.set_trace() + gradients[pts_mask_bool] = sdf_network.gradient( + pts[pts_mask_bool], conditional_volume, lod=lod).squeeze() + + sampled_color_mlp = None + rendering_valid_mask_mlp = None + sampled_color_patch = None + rendering_patch_mask = None + + if self.if_fitted_rendering: # used for fine-tuning + position_latent = sdf_nn_output['sampled_latent_scale%d' % lod] + sampled_color_mlp = torch.zeros([N_rays * n_samples, 3]).to(pts.dtype).to(device) + sampled_color_mlp_mask = torch.zeros([N_rays * n_samples, 1]).to(pts.dtype).to(device) + + # - extract pixel + pts_pixel_color, pts_pixel_mask = self.patch_projector.pixel_warp( + pts[pts_mask_bool][:, None, :], color_maps, intrinsics, + w2cs, img_wh=None) # [N_rays * n_samples,1, N_views, 3] , [N_rays*n_samples, 1, N_views] + pts_pixel_color = pts_pixel_color[:, 0, :, :] # [N_rays * n_samples, N_views, 3] + pts_pixel_mask = pts_pixel_mask[:, 0, :] # [N_rays*n_samples, N_views] + + # - extract patch + if_patch_blending = False if rays_uv is None else True + pts_patch_color, pts_patch_mask = None, None + if if_patch_blending: + pts_patch_color, pts_patch_mask = self.patch_projector.patch_warp( + pts.reshape([N_rays, n_samples, 3]), + rays_uv, gradients.reshape([N_rays, n_samples, 3]), + color_maps, + intrinsics[0], intrinsics, + query_c2w[0], torch.inverse(w2cs), img_wh=None + ) # (N_rays, n_samples, N_src, Npx, 3), (N_rays, n_samples, N_src, Npx) + N_src, Npx = pts_patch_mask.shape[2:] + pts_patch_color = pts_patch_color.view(N_rays * n_samples, N_src, Npx, 3)[pts_mask_bool] + pts_patch_mask = pts_patch_mask.view(N_rays * n_samples, N_src, Npx)[pts_mask_bool] + + sampled_color_patch = torch.zeros([N_rays * n_samples, Npx, 3]).to(device) + sampled_color_patch_mask = torch.zeros([N_rays * n_samples, 1]).to(device) + + sampled_color_mlp_, sampled_color_mlp_mask_, \ + sampled_color_patch_, sampled_color_patch_mask_ = sdf_network.color_blend( + pts[pts_mask_bool], + position_latent, + gradients[pts_mask_bool], + dirs[pts_mask_bool], + feature_vector[pts_mask_bool], + img_index=img_index, + pts_pixel_color=pts_pixel_color, + pts_pixel_mask=pts_pixel_mask, + pts_patch_color=pts_patch_color, + pts_patch_mask=pts_patch_mask + + ) # [n, 3], [n, 1] + sampled_color_mlp[pts_mask_bool] = sampled_color_mlp_ + sampled_color_mlp_mask[pts_mask_bool] = sampled_color_mlp_mask_.float() + sampled_color_mlp = sampled_color_mlp.view(N_rays, n_samples, 3) + sampled_color_mlp_mask = sampled_color_mlp_mask.view(N_rays, n_samples) + rendering_valid_mask_mlp = torch.mean(pts_mask * sampled_color_mlp_mask, dim=-1, keepdim=True) > 0.5 + + # patch blending + if if_patch_blending: + sampled_color_patch[pts_mask_bool] = sampled_color_patch_ + sampled_color_patch_mask[pts_mask_bool] = sampled_color_patch_mask_.float() + sampled_color_patch = sampled_color_patch.view(N_rays, n_samples, Npx, 3) + sampled_color_patch_mask = sampled_color_patch_mask.view(N_rays, n_samples) + rendering_patch_mask = torch.mean(pts_mask * sampled_color_patch_mask, dim=-1, + keepdim=True) > 0.5 # [N_rays, 1] + else: + sampled_color_patch, rendering_patch_mask = None, None + + if if_general_rendering: # used for general training + # [512, 128, 16]; [4, 512, 128, 59]; [4, 512, 128, 4] + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = self.rendering_projector.compute( + pts.view(N_rays, n_samples, 3), + # * 3d geometry feature volumes + geometryVolume=conditional_volume[0], + geometryVolumeMask=conditional_valid_mask_volume[0], + # * 2d rendering feature maps + rendering_feature_maps=feature_maps, # [n_views, 56, 256, 256] + color_maps=color_maps, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=img_wh, + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=query_c2w, + ) + + # (N_rays, n_samples, 3) + if if_render_with_grad: + # import ipdb; ipdb.set_trace() + # [nrays, 3] [nrays, 1] + sampled_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + # import ipdb; ipdb.set_trace() + else: + with torch.no_grad(): + sampled_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + else: + sampled_color, rendering_valid_mask = None, None + + inv_variance = self.variance_network(feature_vector)[:, :1].clip(1e-6, 1e6) + + true_dot_val = (dirs * gradients).sum(-1, keepdim=True) # * calculate + + iter_cos = -(F.relu(-true_dot_val * 0.5 + 0.5) * (1.0 - alpha_inter_ratio) + F.relu( + -true_dot_val) * alpha_inter_ratio) # always non-positive + + iter_cos = iter_cos * pts_mask.view(-1, 1) + + true_estimate_sdf_half_next = sdf + iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 + true_estimate_sdf_half_prev = sdf - iter_cos.clip(-10.0, 10.0) * dists.reshape(-1, 1) * 0.5 + + prev_cdf = torch.sigmoid(true_estimate_sdf_half_prev * inv_variance) + next_cdf = torch.sigmoid(true_estimate_sdf_half_next * inv_variance) + + p = prev_cdf - next_cdf + c = prev_cdf + + if self.alpha_type == 'div': + alpha_sdf = ((p + 1e-5) / (c + 1e-5)).reshape(N_rays, n_samples).clip(0.0, 1.0) + elif self.alpha_type == 'uniform': + uniform_estimate_sdf_half_next = sdf - dists.reshape(-1, 1) * 0.5 + uniform_estimate_sdf_half_prev = sdf + dists.reshape(-1, 1) * 0.5 + uniform_prev_cdf = torch.sigmoid(uniform_estimate_sdf_half_prev * inv_variance) + uniform_next_cdf = torch.sigmoid(uniform_estimate_sdf_half_next * inv_variance) + uniform_alpha = F.relu( + (uniform_prev_cdf - uniform_next_cdf + 1e-5) / (uniform_prev_cdf + 1e-5)).reshape( + N_rays, n_samples).clip(0.0, 1.0) + alpha_sdf = uniform_alpha + else: + assert False + + alpha = alpha_sdf + + # - apply pts_mask + alpha = alpha * pts_mask + + # pts_radius = torch.linalg.norm(pts, ord=2, dim=-1, keepdim=True).reshape(N_rays, n_samples) + # inside_sphere = (pts_radius < 1.0).float().detach() + # relax_inside_sphere = (pts_radius < 1.2).float().detach() + inside_sphere = pts_mask + relax_inside_sphere = pts_mask + + weights = alpha * torch.cumprod(torch.cat([torch.ones([N_rays, 1]).to(device), 1. - alpha + 1e-7], -1), -1)[:, + :-1] # n_rays, n_samples + weights_sum = weights.sum(dim=-1, keepdim=True) + alpha_sum = alpha.sum(dim=-1, keepdim=True) + + if bg_num > 0: + weights_sum_fg = weights[:, :-bg_num].sum(dim=-1, keepdim=True) + else: + weights_sum_fg = weights_sum + + if sampled_color is not None: + color = (sampled_color * weights[:, :, None]).sum(dim=1) + else: + color = None + # import ipdb; ipdb.set_trace() + + if background_rgb is not None and color is not None: + color = color + background_rgb * (1.0 - weights_sum) + # print("color device:" + str(color.device)) + # if color is not None: + # # import ipdb; ipdb.set_trace() + # color = color + (1.0 - weights_sum) + + + ###################* mlp color rendering ##################### + color_mlp = None + # import ipdb; ipdb.set_trace() + if sampled_color_mlp is not None: + color_mlp = (sampled_color_mlp * weights[:, :, None]).sum(dim=1) + + if background_rgb is not None and color_mlp is not None: + color_mlp = color_mlp + background_rgb * (1.0 - weights_sum) + + ############################ * patch blending ################ + blended_color_patch = None + if sampled_color_patch is not None: + blended_color_patch = (sampled_color_patch * weights[:, :, None, None]).sum(dim=1) # [N_rays, Npx, 3] + + ###################################################### + + gradient_error = (torch.linalg.norm(gradients.reshape(N_rays, n_samples, 3), ord=2, + dim=-1) - 1.0) ** 2 + # ! the gradient normal should be masked out, the pts out of the bounding box should also be penalized + gradient_error = (pts_mask * gradient_error).sum() / ( + (pts_mask).sum() + 1e-5) + + depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) + # print("[TEST]: weights_sum in render_core", weights_sum.mean()) + # print("[TEST]: weights_sum in render_core NAN number", weights_sum.isnan().sum()) + # if weights_sum.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + return { + 'color': color, + 'color_mask': rendering_valid_mask, # (N_rays, 1) + 'color_mlp': color_mlp, + 'color_mlp_mask': rendering_valid_mask_mlp, + 'sdf': sdf, # (N_rays, n_samples) + 'depth': depth, # (N_rays, 1) + 'dists': dists, + 'gradients': gradients.reshape(N_rays, n_samples, 3), + 'variance': 1.0 / inv_variance, + 'mid_z_vals': mid_z_vals, + 'weights': weights, + 'weights_sum': weights_sum, + 'alpha_sum': alpha_sum, + 'alpha_mean': alpha.mean(), + 'cdf': c.reshape(N_rays, n_samples), + 'gradient_error': gradient_error, + 'inside_sphere': inside_sphere, + 'blended_color_patch': blended_color_patch, + 'blended_color_patch_mask': rendering_patch_mask, + 'weights_sum_fg': weights_sum_fg + } + + def render(self, rays_o, rays_d, near, far, sdf_network, rendering_network, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio=0.0, + # * related to conditional feature + lod=None, + conditional_volume=None, + conditional_valid_mask_volume=None, + # * 2d feature maps + feature_maps=None, + color_maps=None, + w2cs=None, + intrinsics=None, + img_wh=None, + query_c2w=None, # -used for testing + if_general_rendering=True, + if_render_with_grad=True, + # * used for blending mlp rendering network + img_index=None, + rays_uv=None, + # * importance sample for second lod network + pre_sample=False, # no use here + # * for clear foreground + bg_ratio=0.0 + ): + device = rays_o.device + N_rays = len(rays_o) + # sample_dist = 2.0 / self.n_samples + sample_dist = ((far - near) / self.n_samples).mean().item() + z_vals = torch.linspace(0.0, 1.0, self.n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + bg_num = int(self.n_samples * bg_ratio) + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + if bg_num > 0: + z_vals_bg = z_vals[:, self.n_samples - bg_num:] + z_vals = z_vals[:, :self.n_samples - bg_num] + + n_samples = self.n_samples - bg_num + perturb = self.perturb + + # - significantly speed up training, for the second lod network + if pre_sample: + z_vals = self.sample_z_vals_from_maskVolume(rays_o, rays_d, near, far, + conditional_valid_mask_volume) + + if perturb_overwrite >= 0: + perturb = perturb_overwrite + if perturb > 0: + # get intervals between samples + mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.cat([mids, z_vals[..., -1:]], -1) + lower = torch.cat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape).to(device) + z_vals = lower + (upper - lower) * t_rand + + background_alpha = None + background_sampled_color = None + z_val_before = z_vals.clone() + # Up sample + if self.n_importance > 0: + with torch.no_grad(): + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] + + sdf_outputs = sdf_network.sdf( + pts.reshape(-1, 3), conditional_volume, lod=lod) + # pdb.set_trace() + sdf = sdf_outputs['sdf_pts_scale%d' % lod].reshape(N_rays, self.n_samples - bg_num) + + n_steps = 4 + for i in range(n_steps): + new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_importance // n_steps, + 64 * 2 ** i, + conditional_valid_mask_volume=conditional_valid_mask_volume, + ) + + # if new_z_vals.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + z_vals, sdf = self.cat_z_vals( + rays_o, rays_d, z_vals, new_z_vals, sdf, lod, + sdf_network, gru_fusion=False, + conditional_volume=conditional_volume, + conditional_valid_mask_volume=conditional_valid_mask_volume, + ) + + del sdf + + n_samples = self.n_samples + self.n_importance + + # Background + ret_outside = None + + # Render + if bg_num > 0: + z_vals = torch.cat([z_vals, z_vals_bg], dim=1) + # if z_vals.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + ret_fine = self.render_core(rays_o, + rays_d, + z_vals, + sample_dist, + lod, + sdf_network, + rendering_network, + background_rgb=background_rgb, + background_alpha=background_alpha, + background_sampled_color=background_sampled_color, + alpha_inter_ratio=alpha_inter_ratio, + # * related to conditional feature + conditional_volume=conditional_volume, + conditional_valid_mask_volume=conditional_valid_mask_volume, + # * 2d feature maps + feature_maps=feature_maps, + color_maps=color_maps, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=img_wh, + query_c2w=query_c2w, + if_general_rendering=if_general_rendering, + if_render_with_grad=if_render_with_grad, + # * used for blending mlp rendering network + img_index=img_index, + rays_uv=rays_uv + ) + + color_fine = ret_fine['color'] + + if self.n_outside > 0: + color_fine_mask = torch.logical_or(ret_fine['color_mask'], ret_outside['color_mask']) + else: + color_fine_mask = ret_fine['color_mask'] + + weights = ret_fine['weights'] + weights_sum = ret_fine['weights_sum'] + + gradients = ret_fine['gradients'] + mid_z_vals = ret_fine['mid_z_vals'] + + # depth = (mid_z_vals * weights[:, :n_samples]).sum(dim=1, keepdim=True) + depth = ret_fine['depth'] + depth_varaince = ((mid_z_vals - depth) ** 2 * weights[:, :n_samples]).sum(dim=-1, keepdim=True) + variance = ret_fine['variance'].reshape(N_rays, n_samples).mean(dim=-1, keepdim=True) + + # - randomly sample points from the volume, and maximize the sdf + pts_random = torch.rand([1024, 3]).float().to(device) * 2 - 1 # normalized to (-1, 1) + sdf_random = sdf_network.sdf(pts_random, conditional_volume, lod=lod)['sdf_pts_scale%d' % lod] + + result = { + 'depth': depth, + 'color_fine': color_fine, + 'color_fine_mask': color_fine_mask, + 'color_outside': ret_outside['color'] if ret_outside is not None else None, + 'color_outside_mask': ret_outside['color_mask'] if ret_outside is not None else None, + 'color_mlp': ret_fine['color_mlp'], + 'color_mlp_mask': ret_fine['color_mlp_mask'], + 'variance': variance.mean(), + 'cdf_fine': ret_fine['cdf'], + 'depth_variance': depth_varaince, + 'weights_sum': weights_sum, + 'weights_max': torch.max(weights, dim=-1, keepdim=True)[0], + 'alpha_sum': ret_fine['alpha_sum'].mean(), + 'alpha_mean': ret_fine['alpha_mean'], + 'gradients': gradients, + 'weights': weights, + 'gradient_error_fine': ret_fine['gradient_error'], + 'inside_sphere': ret_fine['inside_sphere'], + 'sdf': ret_fine['sdf'], + 'sdf_random': sdf_random, + 'blended_color_patch': ret_fine['blended_color_patch'], + 'blended_color_patch_mask': ret_fine['blended_color_patch_mask'], + 'weights_sum_fg': ret_fine['weights_sum_fg'] + } + + return result + + @torch.no_grad() + def sample_z_vals_from_sdfVolume(self, rays_o, rays_d, near, far, sdf_volume, mask_volume): + # ? based on sdf to do importance sampling, seems that too biased on pre-estimation + device = rays_o.device + N_rays = len(rays_o) + n_samples = self.n_samples * 2 + + z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[..., :, None] + + sdf = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), sdf_volume).reshape([N_rays, n_samples]) + + new_z_vals = self.up_sample(rays_o, rays_d, z_vals, sdf, self.n_samples, + 200, + conditional_valid_mask_volume=mask_volume, + ) + return new_z_vals + + @torch.no_grad() + def sample_z_vals_from_maskVolume(self, rays_o, rays_d, near, far, mask_volume): # don't use + device = rays_o.device + N_rays = len(rays_o) + n_samples = self.n_samples * 2 + + z_vals = torch.linspace(0.0, 1.0, n_samples).to(device) + z_vals = near + (far - near) * z_vals[None, :] + + if z_vals.shape[0] == 1: + z_vals = z_vals.repeat(N_rays, 1) + + mid_z_vals = (z_vals[:, 1:] + z_vals[:, :-1]) * 0.5 + + pts = rays_o[:, None, :] + rays_d[:, None, :] * mid_z_vals[..., :, None] + + pts_mask = self.get_pts_mask_for_conditional_volume(pts.view(-1, 3), mask_volume).reshape( + [N_rays, n_samples - 1]) + + # empty voxel set to 0.1, non-empty voxel set to 1 + weights = torch.where(pts_mask > 0, torch.ones_like(pts_mask).to(device), + 0.1 * torch.ones_like(pts_mask).to(device)) + + # sample more pts in non-empty voxels + z_samples = sample_pdf(z_vals, weights, self.n_samples, det=True).detach() + return z_samples + + @torch.no_grad() + def filter_pts_by_depthmaps(self, coords, pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums): + """ + Use the pred_depthmaps to remove redundant pts (pruned by sdf, sdf always have two sides, the back side is useless) + :param coords: [n, 3] int coords + :param pred_depth_maps: [N_views, 1, h, w] + :param proj_matrices: [N_views, 4, 4] + :param partial_vol_origin: [3] + :param voxel_size: 1 + :param near: 1 + :param far: 1 + :param depth_interval: 1 + :param d_plane_nums: 1 + :return: + """ + device = pred_depth_maps.device + n_views, _, sizeH, sizeW = pred_depth_maps.shape + + if len(partial_vol_origin.shape) == 1: + partial_vol_origin = partial_vol_origin[None, :] + pts = coords * voxel_size + partial_vol_origin + + rs_grid = pts.unsqueeze(0).expand(n_views, -1, -1) + rs_grid = rs_grid.permute(0, 2, 1).contiguous() # [n_views, 3, n_pts] + nV = rs_grid.shape[-1] + rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) # [n_views, 4, n_pts] + + # Project grid + im_p = proj_matrices @ rs_grid # - transform world pts to image UV space # [n_views, 4, n_pts] + im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] + im_x = im_x / im_z + im_y = im_y / im_z + + im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) + + im_grid = im_grid.view(n_views, 1, -1, 2) + sampled_depths = torch.nn.functional.grid_sample(pred_depth_maps, im_grid, mode='bilinear', + padding_mode='zeros', + align_corners=True)[:, 0, 0, :] # [n_views, n_pts] + sampled_depths_valid = (sampled_depths > 0.5 * near).float() + valid_d_min = (sampled_depths - d_plane_nums * depth_interval).clamp(near.item(), + far.item()) * sampled_depths_valid + valid_d_max = (sampled_depths + d_plane_nums * depth_interval).clamp(near.item(), + far.item()) * sampled_depths_valid + + mask = im_grid.abs() <= 1 + mask = mask[:, 0] # [n_views, n_pts, 2] + mask = (mask.sum(dim=-1) == 2) & (im_z > valid_d_min) & (im_z < valid_d_max) + + mask = mask.view(n_views, -1) + mask = mask.permute(1, 0).contiguous() # [num_pts, nviews] + + mask_final = torch.sum(mask.float(), dim=1, keepdim=False) > 0 + + return mask_final + + @torch.no_grad() + def get_valid_sparse_coords_by_sdf_depthfilter(self, sdf_volume, coords_volume, mask_volume, feature_volume, + pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums, + threshold=0.02, maximum_pts=110000): + """ + assume batch size == 1, from the first lod to get sparse voxels + :param sdf_volume: [1, X, Y, Z] + :param coords_volume: [3, X, Y, Z] + :param mask_volume: [1, X, Y, Z] + :param feature_volume: [C, X, Y, Z] + :param threshold: + :return: + """ + device = coords_volume.device + _, dX, dY, dZ = coords_volume.shape + + def prune(sdf_pts, coords_pts, mask_volume, threshold): + occupancy_mask = (torch.abs(sdf_pts) < threshold).squeeze(1) # [num_pts] + valid_coords = coords_pts[occupancy_mask] + + # - filter backside surface by depth maps + mask_filtered = self.filter_pts_by_depthmaps(valid_coords, pred_depth_maps, proj_matrices, + partial_vol_origin, voxel_size, + near, far, depth_interval, d_plane_nums) + valid_coords = valid_coords[mask_filtered] + + # - dilate + occupancy_mask = sparse_to_dense_channel(valid_coords, 1, [dX, dY, dZ], 1, 0, device) # [dX, dY, dZ, 1] + + # - dilate + occupancy_mask = occupancy_mask.float() + occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) + occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) + occupancy_mask = occupancy_mask.view(-1, 1) > 0 + + final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts] + + return final_mask, torch.sum(final_mask.float()) + + C, dX, dY, dZ = feature_volume.shape + sdf_volume = sdf_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) + mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) + + # - for check + # sdf_volume = torch.rand_like(sdf_volume).float().to(sdf_volume.device) * 0.02 + + final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) + + while (valid_num > maximum_pts) and (threshold > 0.003): + threshold = threshold - 0.002 + final_mask, valid_num = prune(sdf_volume, coords_volume, mask_volume, threshold) + + valid_coords = coords_volume[final_mask] # [N, 3] + valid_feature = feature_volume[final_mask] # [N, C] + + valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, + valid_coords], dim=1) # [N, 4], append batch idx + + # ! if the valid_num is still larger than maximum_pts, sample part of pts + if valid_num > maximum_pts: + valid_num = valid_num.long() + occupancy = torch.ones([valid_num]).to(device) > 0 + choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, + replace=False) + ind = torch.nonzero(occupancy).to(device) + occupancy[ind[choice]] = False + valid_coords = valid_coords[occupancy] + valid_feature = valid_feature[occupancy] + + print(threshold, "randomly sample to save memory") + + return valid_coords, valid_feature + + @torch.no_grad() + def get_valid_sparse_coords_by_sdf(self, sdf_volume, coords_volume, mask_volume, feature_volume, threshold=0.02, + maximum_pts=110000): + """ + assume batch size == 1, from the first lod to get sparse voxels + :param sdf_volume: [num_pts, 1] + :param coords_volume: [3, X, Y, Z] + :param mask_volume: [1, X, Y, Z] + :param feature_volume: [C, X, Y, Z] + :param threshold: + :return: + """ + + def prune(sdf_volume, mask_volume, threshold): + occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1] + + # - dilate + occupancy_mask = occupancy_mask.float() + occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) + occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) + occupancy_mask = occupancy_mask.view(-1, 1) > 0 + + final_mask = torch.logical_and(mask_volume, occupancy_mask)[:, 0] # [num_pts] + + return final_mask, torch.sum(final_mask.float()) + + C, dX, dY, dZ = feature_volume.shape + coords_volume = coords_volume.permute(1, 2, 3, 0).contiguous().view(-1, 3) + mask_volume = mask_volume.permute(1, 2, 3, 0).contiguous().view(-1, 1) + feature_volume = feature_volume.permute(1, 2, 3, 0).contiguous().view(-1, C) + + final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) + + while (valid_num > maximum_pts) and (threshold > 0.003): + threshold = threshold - 0.002 + final_mask, valid_num = prune(sdf_volume, mask_volume, threshold) + + valid_coords = coords_volume[final_mask] # [N, 3] + valid_feature = feature_volume[final_mask] # [N, C] + + valid_coords = torch.cat([torch.ones([valid_coords.shape[0], 1]).to(valid_coords.device) * 0, + valid_coords], dim=1) # [N, 4], append batch idx + + # ! if the valid_num is still larger than maximum_pts, sample part of pts + if valid_num > maximum_pts: + device = sdf_volume.device + valid_num = valid_num.long() + occupancy = torch.ones([valid_num]).to(device) > 0 + choice = np.random.choice(valid_num.cpu().numpy(), valid_num.cpu().numpy() - maximum_pts, + replace=False) + ind = torch.nonzero(occupancy).to(device) + occupancy[ind[choice]] = False + valid_coords = valid_coords[occupancy] + valid_feature = valid_feature[occupancy] + + print(threshold, "randomly sample to save memory") + + return valid_coords, valid_feature + + @torch.no_grad() + def extract_fields(self, bound_min, bound_max, resolution, query_func, device, + # * related to conditional feature + **kwargs + ): + N = 64 + X = torch.linspace(bound_min[0], bound_max[0], resolution).to(device).split(N) + Y = torch.linspace(bound_min[1], bound_max[1], resolution).to(device).split(N) + Z = torch.linspace(bound_min[2], bound_max[2], resolution).to(device).split(N) + + u = np.zeros([resolution, resolution, resolution], dtype=np.float32) + with torch.no_grad(): + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij") + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) + + # ! attention, the query function is different for extract geometry and fields + output = query_func(pts, **kwargs) + sdf = output['sdf_pts_scale%d' % kwargs['lod']].reshape(len(xs), len(ys), + len(zs)).detach().cpu().numpy() + + u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = -1 * sdf + return u + + @torch.no_grad() + def extract_geometry(self, sdf_network, bound_min, bound_max, resolution, threshold, device, occupancy_mask=None, + # * 3d feature volume + **kwargs + ): + # logging.info('threshold: {}'.format(threshold)) + + u = self.extract_fields(bound_min, bound_max, resolution, + lambda pts, **kwargs: sdf_network.sdf(pts, **kwargs), + # - sdf need to be multiplied by -1 + device, + # * 3d feature volume + **kwargs + ) + if occupancy_mask is not None: + dX, dY, dZ = occupancy_mask.shape + empty_mask = 1 - occupancy_mask + empty_mask = empty_mask.view(1, 1, dX, dY, dZ) + # - dilation + # empty_mask = F.avg_pool3d(empty_mask, kernel_size=7, stride=1, padding=3) + empty_mask = F.interpolate(empty_mask, [resolution, resolution, resolution], mode='nearest') + empty_mask = empty_mask.view(resolution, resolution, resolution).cpu().numpy() > 0 + u[empty_mask] = -100 + del empty_mask + + vertices, triangles = mcubes.marching_cubes(u, threshold) + b_max_np = bound_max.detach().cpu().numpy() + b_min_np = bound_min.detach().cpu().numpy() + + vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] + return vertices, triangles, u + + @torch.no_grad() + def extract_depth_maps(self, sdf_network, con_volume, intrinsics, c2ws, H, W, near, far): + """ + extract depth maps from the density volume + :param con_volume: [1, 1+C, dX, dY, dZ] can by con_volume or sdf_volume + :param c2ws: [B, 4, 4] + :param H: + :param W: + :param near: + :param far: + :return: + """ + device = con_volume.device + batch_size = intrinsics.shape[0] + + with torch.no_grad(): + ys, xs = torch.meshgrid(torch.linspace(0, H - 1, H), + torch.linspace(0, W - 1, W), indexing="ij") # pytorch's meshgrid has indexing='ij' + p = torch.stack([xs, ys, torch.ones_like(ys)], dim=-1) # H, W, 3 + + intrinsics_inv = torch.inverse(intrinsics) + + p = p.view(-1, 3).float().to(device) # N_rays, 3 + p = torch.matmul(intrinsics_inv[:, None, :3, :3], p[:, :, None]).squeeze() # Batch, N_rays, 3 + rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # Batch, N_rays, 3 + rays_v = torch.matmul(c2ws[:, None, :3, :3], rays_v[:, :, :, None]).squeeze() # Batch, N_rays, 3 + rays_o = c2ws[:, None, :3, 3].expand(rays_v.shape) # Batch, N_rays, 3 + rays_d = rays_v + + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + ################## - sphere tracer to extract depth maps ###################### + depth_masks_sphere, depth_maps_sphere = self.ray_tracer.extract_depth_maps( + rays_o, rays_d, + near[None, :].repeat(rays_o.shape[0], 1), + far[None, :].repeat(rays_o.shape[0], 1), + sdf_network, con_volume + ) + + depth_maps = depth_maps_sphere.view(batch_size, 1, H, W) + depth_masks = depth_masks_sphere.view(batch_size, 1, H, W) + + depth_maps = torch.where(depth_masks, depth_maps, + torch.zeros_like(depth_masks.float()).to(device)) # fill invalid pixels by 0 + + return depth_maps, depth_masks diff --git a/SparseNeuS_demo_v1/models/sparse_sdf_network.py b/SparseNeuS_demo_v1/models/sparse_sdf_network.py new file mode 100644 index 0000000000000000000000000000000000000000..817f40ed08b7cb65fb284a4666d6f6a4a3c52683 --- /dev/null +++ b/SparseNeuS_demo_v1/models/sparse_sdf_network.py @@ -0,0 +1,907 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchsparse.tensor import PointTensor, SparseTensor +import torchsparse.nn as spnn + +from tsparse.modules import SparseCostRegNet +from tsparse.torchsparse_utils import sparse_to_dense_channel +from ops.grid_sampler import grid_sample_3d, tricubic_sample_3d + +# from .gru_fusion import GRUFusion +from ops.back_project import back_project_sparse_type +from ops.generate_grids import generate_grid + +from inplace_abn import InPlaceABN + +from models.embedder import Embedding +from models.featurenet import ConvBnReLU + +import pdb +import random + +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) + + +@torch.jit.script +def fused_mean_variance(x, weight): + mean = torch.sum(x * weight, dim=1, keepdim=True) + var = torch.sum(weight * (x - mean) ** 2, dim=1, keepdim=True) + return mean, var + + +class LatentSDFLayer(nn.Module): + def __init__(self, + d_in=3, + d_out=129, + d_hidden=128, + n_layers=4, + skip_in=(4,), + multires=0, + bias=0.5, + geometric_init=True, + weight_norm=True, + activation='softplus', + d_conditional_feature=16): + super(LatentSDFLayer, self).__init__() + + self.d_conditional_feature = d_conditional_feature + + # concat latent code for ench layer input excepting the first layer and the last layer + dims_in = [d_in] + [d_hidden + d_conditional_feature for _ in range(n_layers - 2)] + [d_hidden] + dims_out = [d_hidden for _ in range(n_layers - 1)] + [d_out] + + self.embed_fn_fine = None + + if multires > 0: + embed_fn = Embedding(in_channels=d_in, N_freqs=multires) # * include the input + self.embed_fn_fine = embed_fn + dims_in[0] = embed_fn.out_channels + + self.num_layers = n_layers + self.skip_in = skip_in + + for l in range(0, self.num_layers - 1): + if l in self.skip_in: + in_dim = dims_in[l] + dims_in[0] + else: + in_dim = dims_in[l] + + out_dim = dims_out[l] + lin = nn.Linear(in_dim, out_dim) + + if geometric_init: # - from IDR code, + if l == self.num_layers - 2: + torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(in_dim), std=0.0001) + torch.nn.init.constant_(lin.bias, -bias) + # the channels for latent codes are set to 0 + torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) + torch.nn.init.constant_(lin.bias[-d_conditional_feature:], 0.0) + + elif multires > 0 and l == 0: # the first layer + torch.nn.init.constant_(lin.bias, 0.0) + # * the channels for position embeddings are set to 0 + torch.nn.init.constant_(lin.weight[:, 3:], 0.0) + # * the channels for the xyz coordinate (3 channels) for initialized by normal distribution + torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) + elif multires > 0 and l in self.skip_in: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + # * the channels for position embeddings (and conditional_feature) are initialized to 0 + torch.nn.init.constant_(lin.weight[:, -(dims_in[0] - 3 + d_conditional_feature):], 0.0) + else: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) + # the channels for latent code are initialized to 0 + torch.nn.init.constant_(lin.weight[:, -d_conditional_feature:], 0.0) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + if activation == 'softplus': + self.activation = nn.Softplus(beta=100) + else: + assert activation == 'relu' + self.activation = nn.ReLU() + + def forward(self, inputs, latent): + inputs = inputs + if self.embed_fn_fine is not None: + inputs = self.embed_fn_fine(inputs) + + # - only for lod1 network can use the pretrained params of lod0 network + if latent.shape[1] != self.d_conditional_feature: + latent = torch.cat([latent, latent], dim=1) + + x = inputs + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + # * due to the conditional bias, different from original neus version + if l in self.skip_in: + x = torch.cat([x, inputs], 1) / np.sqrt(2) + + if 0 < l < self.num_layers - 1: + x = torch.cat([x, latent], 1) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.activation(x) + + return x + + +class SparseSdfNetwork(nn.Module): + ''' + Coarse-to-fine sparse cost regularization network + return sparse volume feature for extracting sdf + ''' + + def __init__(self, lod, ch_in, voxel_size, vol_dims, + hidden_dim=128, activation='softplus', + cost_type='variance_mean', + d_pyramid_feature_compress=16, + regnet_d_out=8, num_sdf_layers=4, + multires=6, + ): + super(SparseSdfNetwork, self).__init__() + + self.lod = lod # - gradually training, the current regularization lod + self.ch_in = ch_in + self.voxel_size = voxel_size # - the voxel size of the current volume + self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume + + self.selected_views_num = 2 # the number of selected views for feature aggregation + self.hidden_dim = hidden_dim + self.activation = activation + self.cost_type = cost_type + self.d_pyramid_feature_compress = d_pyramid_feature_compress + self.gru_fusion = None + + self.regnet_d_out = regnet_d_out + self.multires = multires + + self.pos_embedder = Embedding(3, self.multires) + + self.compress_layer = ConvBnReLU( + self.ch_in, self.d_pyramid_feature_compress, 3, 1, 1, + norm_act=InPlaceABN) + sparse_ch_in = self.d_pyramid_feature_compress * 2 + + sparse_ch_in = sparse_ch_in + 16 if self.lod > 0 else sparse_ch_in + self.sparse_costreg_net = SparseCostRegNet( + d_in=sparse_ch_in, d_out=self.regnet_d_out) + # self.regnet_d_out = self.sparse_costreg_net.d_out + + if activation == 'softplus': + self.activation = nn.Softplus(beta=100) + else: + assert activation == 'relu' + self.activation = nn.ReLU() + + self.sdf_layer = LatentSDFLayer(d_in=3, + d_out=self.hidden_dim + 1, + d_hidden=self.hidden_dim, + n_layers=num_sdf_layers, + multires=multires, + geometric_init=True, + weight_norm=True, + activation=activation, + d_conditional_feature=16 # self.regnet_d_out + ) + + def upsample(self, pre_feat, pre_coords, interval, num=8): + ''' + + :param pre_feat: (Tensor), features from last level, (N, C) + :param pre_coords: (Tensor), coordinates from last level, (N, 4) (4 : Batch ind, x, y, z) + :param interval: interval of voxels, interval = scale ** 2 + :param num: 1 -> 8 + :return: up_feat : (Tensor), upsampled features, (N*8, C) + :return: up_coords: (N*8, 4), upsampled coordinates, (4 : Batch ind, x, y, z) + ''' + with torch.no_grad(): + pos_list = [1, 2, 3, [1, 2], [1, 3], [2, 3], [1, 2, 3]] + n, c = pre_feat.shape + up_feat = pre_feat.unsqueeze(1).expand(-1, num, -1).contiguous() + up_coords = pre_coords.unsqueeze(1).repeat(1, num, 1).contiguous() + for i in range(num - 1): + up_coords[:, i + 1, pos_list[i]] += interval + + up_feat = up_feat.view(-1, c) + up_coords = up_coords.view(-1, 4) + + return up_feat, up_coords + + def aggregate_multiview_features(self, multiview_features, multiview_masks): + """ + aggregate mutli-view features by compute their cost variance + :param multiview_features: (num of voxels, num_of_views, c) + :param multiview_masks: (num of voxels, num_of_views) + :return: + """ + num_pts, n_views, C = multiview_features.shape + + counts = torch.sum(multiview_masks, dim=1, keepdim=False) # [num_pts] + + assert torch.all(counts > 0) # the point is visible for at least 1 view + + volume_sum = torch.sum(multiview_features, dim=1, keepdim=False) # [num_pts, C] + volume_sq_sum = torch.sum(multiview_features ** 2, dim=1, keepdim=False) + + if volume_sum.isnan().sum() > 0: + import ipdb; ipdb.set_trace() + + del multiview_features + + counts = 1. / (counts + 1e-5) + costvar = volume_sq_sum * counts[:, None] - (volume_sum * counts[:, None]) ** 2 + + costvar_mean = torch.cat([costvar, volume_sum * counts[:, None]], dim=1) + del volume_sum, volume_sq_sum, counts + + + + return costvar_mean + + def sparse_to_dense_volume(self, coords, feature, vol_dims, interval, device=None): + """ + convert the sparse volume into dense volume to enable trilinear sampling + to save GPU memory; + :param coords: [num_pts, 3] + :param feature: [num_pts, C] + :param vol_dims: [3] dX, dY, dZ + :param interval: + :return: + """ + + # * assume batch size is 1 + if device is None: + device = feature.device + + coords_int = (coords / interval).to(torch.int64) + vol_dims = (vol_dims / interval).to(torch.int64) + + # - if stored in CPU, too slow + dense_volume = sparse_to_dense_channel( + coords_int.to(device), feature.to(device), vol_dims.to(device), + feature.shape[1], 0, device) # [X, Y, Z, C] + + valid_mask_volume = sparse_to_dense_channel( + coords_int.to(device), + torch.ones([feature.shape[0], 1]).to(feature.device), + vol_dims.to(device), + 1, 0, device) # [X, Y, Z, 1] + + dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z] + valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z] + + return dense_volume, valid_mask_volume + + def get_conditional_volume(self, feature_maps, partial_vol_origin, proj_mats, sizeH=None, sizeW=None, lod=0, + pre_coords=None, pre_feats=None, + ): + """ + + :param feature_maps: pyramid features (B,V,C0+C1+C2,H,W) fused pyramid features + :param partial_vol_origin: [B, 3] the world coordinates of the volume origin (0,0,0) + :param proj_mats: projection matrix transform world pts into image space [B,V,4,4] suitable for original image size + :param sizeH: the H of original image size + :param sizeW: the W of original image size + :param pre_coords: the coordinates of sparse volume from the prior lod + :param pre_feats: the features of sparse volume from the prior lod + :return: + """ + device = proj_mats.device + bs = feature_maps.shape[0] + N_views = feature_maps.shape[1] + minimum_visible_views = np.min([1, N_views - 1]) + # import ipdb; ipdb.set_trace() + outputs = {} + pts_samples = [] + + # ----coarse to fine---- + + # * use fused pyramid feature maps are very important + if self.compress_layer is not None: + feats = self.compress_layer(feature_maps[0]) + else: + feats = feature_maps[0] + feats = feats[:, None, :, :, :] # [V, B, C, H, W] + KRcam = proj_mats.permute(1, 0, 2, 3).contiguous() # [V, B, 4, 4] + interval = 1 + + if self.lod == 0: + # ----generate new coords---- + coords = generate_grid(self.vol_dims, 1)[0] + coords = coords.view(3, -1).to(device) # [3, num_pts] + up_coords = [] + for b in range(bs): + up_coords.append(torch.cat([torch.ones(1, coords.shape[-1]).to(coords.device) * b, coords])) + up_coords = torch.cat(up_coords, dim=1).permute(1, 0).contiguous() + # * since we only estimate the geometry of input reference image at one time; + # * mask the outside of the camera frustum + # import ipdb; ipdb.set_trace() + frustum_mask = back_project_sparse_type( + up_coords, partial_vol_origin, self.voxel_size, + feats, KRcam, sizeH=sizeH, sizeW=sizeW, only_mask=True) # [num_pts, n_views] + frustum_mask = torch.sum(frustum_mask, dim=-1) > minimum_visible_views # ! here should be large + up_coords = up_coords[frustum_mask] # [num_pts_valid, 4] + + else: + # ----upsample coords---- + assert pre_feats is not None + assert pre_coords is not None + up_feat, up_coords = self.upsample(pre_feats, pre_coords, 1) + + # ----back project---- + # give each valid 3d grid point all valid 2D features and masks + multiview_features, multiview_masks = back_project_sparse_type( + up_coords, partial_vol_origin, self.voxel_size, feats, + KRcam, sizeH=sizeH, sizeW=sizeW) # (num of voxels, num_of_views, c), (num of voxels, num_of_views) + # num_of_views = all views + + # if multiview_features.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + # import ipdb; ipdb.set_trace() + if self.lod > 0: + # ! need another invalid voxels filtering + frustum_mask = torch.sum(multiview_masks, dim=-1) > 1 + up_feat = up_feat[frustum_mask] + up_coords = up_coords[frustum_mask] + multiview_features = multiview_features[frustum_mask] + multiview_masks = multiview_masks[frustum_mask] + # if multiview_features.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + volume = self.aggregate_multiview_features(multiview_features, multiview_masks) # compute variance for all images features + # import ipdb; ipdb.set_trace() + + # if volume.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + del multiview_features, multiview_masks + + # ----concat feature from last stage---- + if self.lod != 0: + feat = torch.cat([volume, up_feat], dim=1) + else: + feat = volume + + # batch index is in the last position + r_coords = up_coords[:, [1, 2, 3, 0]] + + # if feat.isnan().sum() > 0: + # print('feat has nan:', feat.isnan().sum()) + # import ipdb; ipdb.set_trace() + + sparse_feat = SparseTensor(feat, r_coords.to( + torch.int32)) # - directly use sparse tensor to avoid point2voxel operations + # import ipdb; ipdb.set_trace() + feat = self.sparse_costreg_net(sparse_feat) + + dense_volume, valid_mask_volume = self.sparse_to_dense_volume(up_coords[:, 1:], feat, self.vol_dims, interval, + device=None) # [1, C/1, X, Y, Z] + + # if dense_volume.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + + + outputs['dense_volume_scale%d' % self.lod] = dense_volume # [1, 16, 96, 96, 96] + outputs['valid_mask_volume_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96] + outputs['visible_mask_scale%d' % self.lod] = valid_mask_volume # [1, 1, 96, 96, 96] + outputs['coords_scale%d' % self.lod] = generate_grid(self.vol_dims, interval).to(device) + # import ipdb; ipdb.set_trace() + return outputs + + def sdf(self, pts, conditional_volume, lod): + num_pts = pts.shape[0] + device = pts.device + pts_ = pts.clone() + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + # import ipdb; ipdb.set_trace() + sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts] + sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous().to(device) + + sdf_pts = self.sdf_layer(pts_, sampled_feature) + + outputs = {} + outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] + outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] + outputs['sampled_latent_scale%d' % lod] = sampled_feature + + return outputs + + @torch.no_grad() + def sdf_from_sdfvolume(self, pts, sdf_volume, lod=0): + num_pts = pts.shape[0] + device = pts.device + pts_ = pts.clone() + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + sdf = torch.nn.functional.grid_sample(sdf_volume, pts, mode='bilinear', align_corners=True, + padding_mode='border') + sdf = sdf.view(-1, num_pts).permute(1, 0).contiguous().to(device) + + outputs = {} + outputs['sdf_pts_scale%d' % lod] = sdf + + return outputs + + @torch.no_grad() + def get_sdf_volume(self, conditional_volume, mask_volume, coords_volume, partial_origin): + """ + + :param conditional_volume: [1,C, dX,dY,dZ] + :param mask_volume: [1,1, dX,dY,dZ] + :param coords_volume: [1,3, dX,dY,dZ] + :return: + """ + device = conditional_volume.device + chunk_size = 10240 + + _, C, dX, dY, dZ = conditional_volume.shape + conditional_volume = conditional_volume.view(C, dX * dY * dZ).permute(1, 0).contiguous() + mask_volume = mask_volume.view(-1) + coords_volume = coords_volume.view(3, dX * dY * dZ).permute(1, 0).contiguous() + + pts = coords_volume * self.voxel_size + partial_origin # [dX*dY*dZ, 3] + + sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(device) + + conditional_volume = conditional_volume[mask_volume > 0] + pts = pts[mask_volume > 0] + conditional_volume = conditional_volume.split(chunk_size) + pts = pts.split(chunk_size) + + sdf_all = [] + for pts_part, feature_part in zip(pts, conditional_volume): + sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] + sdf_all.append(sdf_part) + + sdf_all = torch.cat(sdf_all, dim=0) + sdf_volume[mask_volume > 0] = sdf_all + sdf_volume = sdf_volume.view(1, 1, dX, dY, dZ) + return sdf_volume + + def gradient(self, x, conditional_volume, lod): + """ + return the gradient of specific lod + :param x: + :param lod: + :return: + """ + x.requires_grad_(True) + # import ipdb; ipdb.set_trace() + with torch.enable_grad(): + output = self.sdf(x, conditional_volume, lod) + y = output['sdf_pts_scale%d' % lod] + + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + # ! Distributed Data Parallel doesn’t work with torch.autograd.grad() + # ! (i.e. it will only work if gradients are to be accumulated in .grad attributes of parameters). + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return gradients.unsqueeze(1) + + +def sparse_to_dense_volume(coords, feature, vol_dims, interval, device=None): + """ + convert the sparse volume into dense volume to enable trilinear sampling + to save GPU memory; + :param coords: [num_pts, 3] + :param feature: [num_pts, C] + :param vol_dims: [3] dX, dY, dZ + :param interval: + :return: + """ + + # * assume batch size is 1 + if device is None: + device = feature.device + + coords_int = (coords / interval).to(torch.int64) + vol_dims = (vol_dims / interval).to(torch.int64) + + # - if stored in CPU, too slow + dense_volume = sparse_to_dense_channel( + coords_int.to(device), feature.to(device), vol_dims.to(device), + feature.shape[1], 0, device) # [X, Y, Z, C] + + valid_mask_volume = sparse_to_dense_channel( + coords_int.to(device), + torch.ones([feature.shape[0], 1]).to(feature.device), + vol_dims.to(device), + 1, 0, device) # [X, Y, Z, 1] + + dense_volume = dense_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, C, X, Y, Z] + valid_mask_volume = valid_mask_volume.permute(3, 0, 1, 2).contiguous().unsqueeze(0) # [1, 1, X, Y, Z] + + return dense_volume, valid_mask_volume + + +class SdfVolume(nn.Module): + def __init__(self, volume, coords=None, type='dense'): + super(SdfVolume, self).__init__() + self.volume = torch.nn.Parameter(volume, requires_grad=True) + self.coords = coords + self.type = type + + def forward(self): + return self.volume + + +class FinetuneOctreeSdfNetwork(nn.Module): + ''' + After obtain the conditional volume from generalized network; + directly optimize the conditional volume + The conditional volume is still sparse + ''' + + def __init__(self, voxel_size, vol_dims, + origin=[-1., -1., -1.], + hidden_dim=128, activation='softplus', + regnet_d_out=8, + multires=6, + if_fitted_rendering=True, + num_sdf_layers=4, + ): + super(FinetuneOctreeSdfNetwork, self).__init__() + + self.voxel_size = voxel_size # - the voxel size of the current volume + self.vol_dims = torch.tensor(vol_dims) # - the dims of the current volume + + self.origin = torch.tensor(origin).to(torch.float32) + + self.hidden_dim = hidden_dim + self.activation = activation + + self.regnet_d_out = regnet_d_out + + self.if_fitted_rendering = if_fitted_rendering + self.multires = multires + # d_in_embedding = self.regnet_d_out if self.pos_add_type == 'latent' else 3 + # self.pos_embedder = Embedding(d_in_embedding, self.multires) + + # - the optimized parameters + self.sparse_volume_lod0 = None + self.sparse_coords_lod0 = None + + if activation == 'softplus': + self.activation = nn.Softplus(beta=100) + else: + assert activation == 'relu' + self.activation = nn.ReLU() + + self.sdf_layer = LatentSDFLayer(d_in=3, + d_out=self.hidden_dim + 1, + d_hidden=self.hidden_dim, + n_layers=num_sdf_layers, + multires=multires, + geometric_init=True, + weight_norm=True, + activation=activation, + d_conditional_feature=16 # self.regnet_d_out + ) + + # - add mlp rendering when finetuning + self.renderer = None + + d_in_renderer = 3 + self.regnet_d_out + 3 + 3 + self.renderer = BlendingRenderingNetwork( + d_feature=self.hidden_dim - 1, + mode='idr', # ! the view direction influence a lot + d_in=d_in_renderer, + d_out=50, # maximum 50 images + d_hidden=self.hidden_dim, + n_layers=3, + weight_norm=True, + multires_view=4, + squeeze_out=True, + ) + + def initialize_conditional_volumes(self, dense_volume_lod0, dense_volume_mask_lod0, + sparse_volume_lod0=None, sparse_coords_lod0=None): + """ + + :param dense_volume_lod0: [1,C,dX,dY,dZ] + :param dense_volume_mask_lod0: [1,1,dX,dY,dZ] + :param dense_volume_lod1: + :param dense_volume_mask_lod1: + :return: + """ + + if sparse_volume_lod0 is None: + device = dense_volume_lod0.device + _, C, dX, dY, dZ = dense_volume_lod0.shape + + dense_volume_lod0 = dense_volume_lod0.view(C, dX * dY * dZ).permute(1, 0).contiguous() + mask_lod0 = dense_volume_mask_lod0.view(dX * dY * dZ) > 0 + + self.sparse_volume_lod0 = SdfVolume(dense_volume_lod0[mask_lod0], type='sparse') + + coords = generate_grid(self.vol_dims, 1)[0] # [3, dX, dY, dZ] + coords = coords.view(3, dX * dY * dZ).permute(1, 0).to(device) + self.sparse_coords_lod0 = torch.nn.Parameter(coords[mask_lod0], requires_grad=False) + else: + self.sparse_volume_lod0 = SdfVolume(sparse_volume_lod0, type='sparse') + self.sparse_coords_lod0 = torch.nn.Parameter(sparse_coords_lod0, requires_grad=False) + + def get_conditional_volume(self): + dense_volume, valid_mask_volume = sparse_to_dense_volume( + self.sparse_coords_lod0, + self.sparse_volume_lod0(), self.vol_dims, interval=1, + device=None) # [1, C/1, X, Y, Z] + + # valid_mask_volume = self.dense_volume_mask_lod0 + + outputs = {} + outputs['dense_volume_scale%d' % 0] = dense_volume + outputs['valid_mask_volume_scale%d' % 0] = valid_mask_volume + + return outputs + + def tv_regularizer(self): + dense_volume, valid_mask_volume = sparse_to_dense_volume( + self.sparse_coords_lod0, + self.sparse_volume_lod0(), self.vol_dims, interval=1, + device=None) # [1, C/1, X, Y, Z] + + dx = (dense_volume[:, :, 1:, :, :] - dense_volume[:, :, :-1, :, :]) ** 2 # [1, C/1, X-1, Y, Z] + dy = (dense_volume[:, :, :, 1:, :] - dense_volume[:, :, :, :-1, :]) ** 2 # [1, C/1, X, Y-1, Z] + dz = (dense_volume[:, :, :, :, 1:] - dense_volume[:, :, :, :, :-1]) ** 2 # [1, C/1, X, Y, Z-1] + + tv = dx[:, :, :, :-1, :-1] + dy[:, :, :-1, :, :-1] + dz[:, :, :-1, :-1, :] # [1, C/1, X-1, Y-1, Z-1] + + mask = valid_mask_volume[:, :, :-1, :-1, :-1] * valid_mask_volume[:, :, 1:, :-1, :-1] * \ + valid_mask_volume[:, :, :-1, 1:, :-1] * valid_mask_volume[:, :, :-1, :-1, 1:] + + tv = torch.sqrt(tv + 1e-6).mean(dim=1, keepdim=True) * mask + # tv = tv.mean(dim=1, keepdim=True) * mask + + assert torch.all(~torch.isnan(tv)) + + return torch.mean(tv) + + def sdf(self, pts, conditional_volume, lod): + + outputs = {} + + num_pts = pts.shape[0] + device = pts.device + pts_ = pts.clone() + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + sampled_feature = grid_sample_3d(conditional_volume, pts) # [1, c, 1, 1, num_pts] + sampled_feature = sampled_feature.view(-1, num_pts).permute(1, 0).contiguous() + outputs['sampled_latent_scale%d' % lod] = sampled_feature + + sdf_pts = self.sdf_layer(pts_, sampled_feature) + + lod = 0 + outputs['sdf_pts_scale%d' % lod] = sdf_pts[:, :1] + outputs['sdf_features_pts_scale%d' % lod] = sdf_pts[:, 1:] + + return outputs + + def color_blend(self, pts, position, normals, view_dirs, feature_vectors, img_index, + pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): + + return self.renderer(torch.cat([pts, position], dim=-1), normals, view_dirs, feature_vectors, + img_index, pts_pixel_color, pts_pixel_mask, + pts_patch_color=pts_patch_color, pts_patch_mask=pts_patch_mask) + + def gradient(self, x, conditional_volume, lod): + """ + return the gradient of specific lod + :param x: + :param lod: + :return: + """ + x.requires_grad_(True) + output = self.sdf(x, conditional_volume, lod) + y = output['sdf_pts_scale%d' % 0] + + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return gradients.unsqueeze(1) + + @torch.no_grad() + def prune_dense_mask(self, threshold=0.02): + """ + Just gradually prune the mask of dense volume to decrease the number of sdf network inference + :return: + """ + chunk_size = 10240 + coords = generate_grid(self.vol_dims_lod0, 1)[0] # [3, dX, dY, dZ] + + _, dX, dY, dZ = coords.shape + + pts = coords.view(3, -1).permute(1, + 0).contiguous() * self.voxel_size_lod0 + self.origin[None, :] # [dX*dY*dZ, 3] + + # dense_volume = self.dense_volume_lod0() # [1,C,dX,dY,dZ] + dense_volume, _ = sparse_to_dense_volume( + self.sparse_coords_lod0, + self.sparse_volume_lod0(), self.vol_dims_lod0, interval=1, + device=None) # [1, C/1, X, Y, Z] + + sdf_volume = torch.ones([dX * dY * dZ, 1]).float().to(dense_volume.device) * 100 + + mask = self.dense_volume_mask_lod0.view(-1) > 0 + + pts_valid = pts[mask].to(dense_volume.device) + feature_valid = dense_volume.view(self.regnet_d_out, -1).permute(1, 0).contiguous()[mask] + + pts_valid = pts_valid.split(chunk_size) + feature_valid = feature_valid.split(chunk_size) + + sdf_list = [] + + for pts_part, feature_part in zip(pts_valid, feature_valid): + sdf_part = self.sdf_layer(pts_part, feature_part)[:, :1] + sdf_list.append(sdf_part) + + sdf_list = torch.cat(sdf_list, dim=0) + + sdf_volume[mask] = sdf_list + + occupancy_mask = torch.abs(sdf_volume) < threshold # [num_pts, 1] + + # - dilate + occupancy_mask = occupancy_mask.float() + occupancy_mask = occupancy_mask.view(1, 1, dX, dY, dZ) + occupancy_mask = F.avg_pool3d(occupancy_mask, kernel_size=7, stride=1, padding=3) + occupancy_mask = occupancy_mask > 0 + + self.dense_volume_mask_lod0 = torch.logical_and(self.dense_volume_mask_lod0, + occupancy_mask).float() # (1, 1, dX, dY, dZ) + + +class BlendingRenderingNetwork(nn.Module): + def __init__( + self, + d_feature, + mode, + d_in, + d_out, + d_hidden, + n_layers, + weight_norm=True, + multires_view=0, + squeeze_out=True, + ): + super(BlendingRenderingNetwork, self).__init__() + + self.mode = mode + self.squeeze_out = squeeze_out + dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] + + self.embedder = None + if multires_view > 0: + self.embedder = Embedding(3, multires_view) + dims[0] += (self.embedder.out_channels - 3) + + self.num_layers = len(dims) + + for l in range(0, self.num_layers - 1): + out_dim = dims[l + 1] + lin = nn.Linear(dims[l], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + setattr(self, "lin" + str(l), lin) + + self.relu = nn.ReLU() + + self.color_volume = None + + self.softmax = nn.Softmax(dim=1) + + self.type = 'blending' + + def sample_pts_from_colorVolume(self, pts): + device = pts.device + num_pts = pts.shape[0] + pts_ = pts.clone() + pts = pts.view(1, 1, 1, num_pts, 3) # - should be in range (-1, 1) + + pts = torch.flip(pts, dims=[-1]) + + sampled_color = grid_sample_3d(self.color_volume, pts) # [1, c, 1, 1, num_pts] + sampled_color = sampled_color.view(-1, num_pts).permute(1, 0).contiguous().to(device) + + return sampled_color + + def forward(self, position, normals, view_dirs, feature_vectors, img_index, + pts_pixel_color, pts_pixel_mask, pts_patch_color=None, pts_patch_mask=None): + """ + + :param position: can be 3d coord or interpolated volume latent + :param normals: + :param view_dirs: + :param feature_vectors: + :param img_index: [N_views], used to extract corresponding weights + :param pts_pixel_color: [N_pts, N_views, 3] + :param pts_pixel_mask: [N_pts, N_views] + :param pts_patch_color: [N_pts, N_views, Npx, 3] + :return: + """ + if self.embedder is not None: + view_dirs = self.embedder(view_dirs) + + rendering_input = None + + if self.mode == 'idr': + rendering_input = torch.cat([position, view_dirs, normals, feature_vectors], dim=-1) + elif self.mode == 'no_view_dir': + rendering_input = torch.cat([position, normals, feature_vectors], dim=-1) + elif self.mode == 'no_normal': + rendering_input = torch.cat([position, view_dirs, feature_vectors], dim=-1) + elif self.mode == 'no_points': + rendering_input = torch.cat([view_dirs, normals, feature_vectors], dim=-1) + elif self.mode == 'no_points_no_view_dir': + rendering_input = torch.cat([normals, feature_vectors], dim=-1) + + x = rendering_input + + for l in range(0, self.num_layers - 1): + lin = getattr(self, "lin" + str(l)) + + x = lin(x) + + if l < self.num_layers - 2: + x = self.relu(x) # [n_pts, d_out] + + ## extract value based on img_index + x_extracted = torch.index_select(x, 1, img_index.long()) + + weights_pixel = self.softmax(x_extracted) # [n_pts, N_views] + weights_pixel = weights_pixel * pts_pixel_mask + weights_pixel = weights_pixel / ( + torch.sum(weights_pixel.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views] + final_pixel_color = torch.sum(pts_pixel_color * weights_pixel[:, :, None], dim=1, + keepdim=False) # [N_pts, 3] + + final_pixel_mask = torch.sum(pts_pixel_mask.float(), dim=1, keepdim=True) > 0 # [N_pts, 1] + + final_patch_color, final_patch_mask = None, None + # pts_patch_color [N_pts, N_views, Npx, 3]; pts_patch_mask [N_pts, N_views, Npx] + if pts_patch_color is not None: + N_pts, N_views, Npx, _ = pts_patch_color.shape + patch_mask = torch.sum(pts_patch_mask, dim=-1, keepdim=False) > Npx - 1 # [N_pts, N_views] + + weights_patch = self.softmax(x_extracted) # [N_pts, N_views] + weights_patch = weights_patch * patch_mask + weights_patch = weights_patch / ( + torch.sum(weights_patch.float(), dim=1, keepdim=True) + 1e-8) # [n_pts, N_views] + + final_patch_color = torch.sum(pts_patch_color * weights_patch[:, :, None, None], dim=1, + keepdim=False) # [N_pts, Npx, 3] + final_patch_mask = torch.sum(patch_mask, dim=1, keepdim=True) > 0 # [N_pts, 1] at least one image sees + + return final_pixel_color, final_pixel_mask, final_patch_color, final_patch_mask diff --git a/SparseNeuS_demo_v1/models/trainer_generic.py b/SparseNeuS_demo_v1/models/trainer_generic.py new file mode 100644 index 0000000000000000000000000000000000000000..786ccfd0f84f45ec395db8831b78cecbda803139 --- /dev/null +++ b/SparseNeuS_demo_v1/models/trainer_generic.py @@ -0,0 +1,1207 @@ +""" +decouple the trainer with the renderer +""" +import os +import cv2 as cv +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np +import trimesh +from icecream import ic + +from utils.misc_utils import visualize_depth_numpy + +from loss.depth_metric import compute_depth_errors + +from loss.depth_loss import DepthLoss, DepthSmoothLoss + +from models.sparse_neus_renderer import SparseNeuSRenderer + +class GenericTrainer(nn.Module): + def __init__(self, + rendering_network_outside, + pyramid_feature_network_lod0, + pyramid_feature_network_lod1, + sdf_network_lod0, + sdf_network_lod1, + variance_network_lod0, + variance_network_lod1, + rendering_network_lod0, + rendering_network_lod1, + n_samples_lod0, + n_importance_lod0, + n_samples_lod1, + n_importance_lod1, + n_outside, + perturb, + alpha_type='div', + conf=None, + timestamp="", + mode='train', + base_exp_dir=None, + ): + super(GenericTrainer, self).__init__() + + self.conf = conf + self.timestamp = timestamp + + + self.base_exp_dir = base_exp_dir + + + self.anneal_start = self.conf.get_float('train.anneal_start', default=0.0) + self.anneal_end = self.conf.get_float('train.anneal_end', default=0.0) + self.anneal_start_lod1 = self.conf.get_float('train.anneal_start_lod1', default=0.0) + self.anneal_end_lod1 = self.conf.get_float('train.anneal_end_lod1', default=0.0) + + # network setups + self.rendering_network_outside = rendering_network_outside + self.pyramid_feature_network_geometry_lod0 = pyramid_feature_network_lod0 # 2D pyramid feature network for geometry + self.pyramid_feature_network_geometry_lod1 = pyramid_feature_network_lod1 # use differnet networks for the two lods + + # when num_lods==2, may consume too much memeory + self.sdf_network_lod0 = sdf_network_lod0 + self.sdf_network_lod1 = sdf_network_lod1 + + # - warpped by ModuleList to support DataParallel + self.variance_network_lod0 = variance_network_lod0 + self.variance_network_lod1 = variance_network_lod1 + + self.rendering_network_lod0 = rendering_network_lod0 + self.rendering_network_lod1 = rendering_network_lod1 + + self.n_samples_lod0 = n_samples_lod0 + self.n_importance_lod0 = n_importance_lod0 + self.n_samples_lod1 = n_samples_lod1 + self.n_importance_lod1 = n_importance_lod1 + self.n_outside = n_outside + self.num_lods = conf.get_int('model.num_lods') # the number of octree lods + self.perturb = perturb + self.alpha_type = alpha_type + + # - the two renderers + self.sdf_renderer_lod0 = SparseNeuSRenderer( + self.rendering_network_outside, + self.sdf_network_lod0, + self.variance_network_lod0, + self.rendering_network_lod0, + self.n_samples_lod0, + self.n_importance_lod0, + self.n_outside, + self.perturb, + alpha_type='div', + conf=self.conf) + + self.sdf_renderer_lod1 = SparseNeuSRenderer( + self.rendering_network_outside, + self.sdf_network_lod1, + self.variance_network_lod1, + self.rendering_network_lod1, + self.n_samples_lod1, + self.n_importance_lod1, + self.n_outside, + self.perturb, + alpha_type='div', + conf=self.conf) + + self.if_fix_lod0_networks = self.conf.get_bool('train.if_fix_lod0_networks') + + # sdf network weights + self.sdf_igr_weight = self.conf.get_float('train.sdf_igr_weight') + self.sdf_sparse_weight = self.conf.get_float('train.sdf_sparse_weight', default=0) + self.sdf_decay_param = self.conf.get_float('train.sdf_decay_param', default=100) + self.fg_bg_weight = self.conf.get_float('train.fg_bg_weight', default=0.00) + self.bg_ratio = self.conf.get_float('train.bg_ratio', default=0.0) + + self.depth_criterion = DepthLoss() + + # - DataParallel mode, cannot modify attributes in forward() + # self.iter_step = 0 + self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') + + # - True for finetuning; False for general training + self.if_fitted_rendering = self.conf.get_bool('train.if_fitted_rendering', default=False) + + self.prune_depth_filter = self.conf.get_bool('model.prune_depth_filter', default=False) + + def get_trainable_params(self): + # set trainable params + + self.params_to_train = [] + + if not self.if_fix_lod0_networks: + # load pretrained featurenet + self.params_to_train += list(self.pyramid_feature_network_geometry_lod0.parameters()) + self.params_to_train += list(self.sdf_network_lod0.parameters()) + self.params_to_train += list(self.variance_network_lod0.parameters()) + + if self.rendering_network_lod0 is not None: + self.params_to_train += list(self.rendering_network_lod0.parameters()) + + if self.sdf_network_lod1 is not None: + # load pretrained featurenet + self.params_to_train += list(self.pyramid_feature_network_geometry_lod1.parameters()) + + self.params_to_train += list(self.sdf_network_lod1.parameters()) + self.params_to_train += list(self.variance_network_lod1.parameters()) + if self.rendering_network_lod1 is not None: + self.params_to_train += list(self.rendering_network_lod1.parameters()) + + return self.params_to_train + + def train_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['near_fars'][0, 0, :1], sample['near_fars'][0, 0, 1:] + + # the full-size ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + proj_matrices = sample['affine_mats'] + scale_mat = sample['scale_mat'] + trans_mat = sample['trans_mat'] + + # *********************** Lod==0 *********************** + if not self.if_fix_lod0_networks: + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs) + + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + + else: + with torch.no_grad(): + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + # print("Checker2:, construct cost volume") + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + # * extract depth maps for all the images + depth_maps_lod0, depth_masks_lod0 = None, None + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + if self.prune_depth_filter: + depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( + self.sdf_network_lod0, sdf_volume_lod0, intrinsics_l_4x, c2ws, + sizeH // 4, sizeW // 4, near * 1.5, far) + depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', + align_corners=True) + depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') + + # *************** losses + loss_lod0, losses_lod0, depth_statis_lod0 = None, None, None + + if not self.if_fix_lod0_networks: + + render_out = self.sdf_renderer_lod0.render( + rays_o, rays_d, near, far, + self.sdf_network_lod0, + self.rendering_network_lod0, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod0, + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + if_general_rendering=True, + if_render_with_grad=True, + ) + + loss_lod0, losses_lod0, depth_statis_lod0 = self.cal_losses_sdf(render_out, sample_rays, + iter_step, lod=0) + + # *********************** Lod==1 *********************** + + loss_lod1, losses_lod1, depth_statis_lod1 = None, None, None + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + # geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + # ? It seems that training gru_fusion, this part should be trainable too + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, 1:, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices[:,1:], + # proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + # if not self.if_gru_fusion_lod1: + render_out_lod1 = self.sdf_renderer_lod1.render( + rays_o, rays_d, near, far, + self.sdf_network_lod1, + self.rendering_network_lod1, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod1, + # * related to conditional feature + lod=1, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume=con_valid_mask_volume_lod1, + # * 2d feature maps + feature_maps=geometry_feature_maps_lod1, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + bg_ratio=self.bg_ratio, + ) + loss_lod1, losses_lod1, depth_statis_lod1 = self.cal_losses_sdf(render_out_lod1, sample_rays, + iter_step, lod=1) + + # print("Checker3:, compute losses") + # # - extract mesh + if iter_step % self.val_mesh_freq == 0: + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_lod0, + self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, lod=0, + threshold=0, + # occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='train_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, + trans_mat=trans_mat) + torch.cuda.empty_cache() + + if self.num_lods > 1: + self.validate_mesh(self.sdf_network_lod1, + self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, lod=1, + # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), + mode='train_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, + trans_mat=trans_mat) + losses = { + # - lod 0 + 'loss_lod0': loss_lod0, + 'losses_lod0': losses_lod0, + 'depth_statis_lod0': depth_statis_lod0, + + # - lod 1 + 'loss_lod1': loss_lod1, + 'losses_lod1': losses_lod1, + 'depth_statis_lod1': depth_statis_lod1, + + } + + return losses + + def val_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + chunk_size=512, + save_vis=False, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + H, W = sizeH, sizeW + + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + proj_matrices = sample['affine_mats'] + + # render_img_idx = sample['render_img_idx'][0] + # true_img = sample['images'][0][render_img_idx] + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + query_c2w = sample['query_c2w'] # [1,4,4] + query_w2c = sample['query_w2c'] # [1,4,4] + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + + depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() + + scale_factor = sample['scale_factor'][0].cpu().numpy() + true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None + if true_depth is not None: + true_depth = true_depth[0].cpu().numpy() + true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] + else: + true_depth_colored = None + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + + # - obtain conditional features + with torch.no_grad(): + # - obtain conditional features + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # - lod 0 + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + depth_maps_lod0, depth_masks_lod0 = None, None + if self.prune_depth_filter: + depth_maps_lod0_l4x, depth_masks_lod0_l4x = self.sdf_renderer_lod0.extract_depth_maps( + self.sdf_network_lod0, sdf_volume_lod0, + intrinsics_l_4x, c2ws, + sizeH // 4, sizeW // 4, near * 1.5, far) # - near*1.5 is a experienced number + depth_maps_lod0 = F.interpolate(depth_maps_lod0_l4x, size=(sizeH, sizeW), mode='bilinear', + align_corners=True) + depth_masks_lod0 = F.interpolate(depth_masks_lod0_l4x.float(), size=(sizeH, sizeW), mode='nearest') + + #### visualize the depth_maps_lod0 for checking + colored_depth_maps_lod0 = [] + for i in range(depth_maps_lod0.shape[0]): + colored_depth_maps_lod0.append( + visualize_depth_numpy(depth_maps_lod0[i, 0].cpu().numpy(), [depth_min, depth_max])[0]) + + colored_depth_maps_lod0 = np.concatenate(colored_depth_maps_lod0, axis=0).astype(np.uint8) + os.makedirs(os.path.join(self.base_exp_dir, 'depth_maps_lod0'), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'depth_maps_lod0', + '{:0>8d}_{}.png'.format(iter_step, meta)), + colored_depth_maps_lod0[:, :, ::-1]) + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + with torch.no_grad(): + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + out_rgb_fine = [] + out_normal_fine = [] + out_depth_fine = [] + + out_rgb_fine_lod1 = [] + out_normal_fine_lod1 = [] + out_depth_fine_lod1 = [] + + # out_depth_fine_explicit = [] + if save_vis: + for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + + # ****** lod 0 **** + render_out = self.sdf_renderer_lod0.render( + rays_o_batch, rays_d_batch, near, far, + self.sdf_network_lod0, + self.rendering_network_lod0, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod0, + # * related to conditional feature + lod=0, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume=con_valid_mask_volume_lod0, + # * 2d feature maps + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_render_with_grad=False, + ) + + feasible = lambda key: ((key in render_out) and (render_out[key] is not None)) + + if feasible('depth'): + out_depth_fine.append(render_out['depth'].detach().cpu().numpy()) + + # if render_out['color_coarse'] is not None: + if feasible('color_fine'): + out_rgb_fine.append(render_out['color_fine'].detach().cpu().numpy()) + if feasible('gradients') and feasible('weights'): + if render_out['inside_sphere'] is not None: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples_lod0 + self.n_importance_lod0, + None] * render_out['inside_sphere'][ + ..., None]).sum(dim=1).detach().cpu().numpy()) + else: + out_normal_fine.append((render_out['gradients'] * render_out['weights'][:, + :self.n_samples_lod0 + self.n_importance_lod0, + None]).sum(dim=1).detach().cpu().numpy()) + del render_out + + # ****************** lod 1 ************************** + if self.num_lods > 1: + for rays_o_batch, rays_d_batch in zip(rays_o, rays_d): + render_out_lod1 = self.sdf_renderer_lod1.render( + rays_o_batch, rays_d_batch, near, far, + self.sdf_network_lod1, + self.rendering_network_lod1, + background_rgb=background_rgb, + alpha_inter_ratio=alpha_inter_ratio_lod1, + # * related to conditional feature + lod=1, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume=con_valid_mask_volume_lod1, + # * 2d feature maps + feature_maps=geometry_feature_maps_lod1, + color_maps=imgs, + w2cs=w2cs, + intrinsics=intrinsics, + img_wh=[sizeW, sizeH], + query_c2w=query_c2w, + if_render_with_grad=False, + ) + + feasible = lambda key: ((key in render_out_lod1) and (render_out_lod1[key] is not None)) + + if feasible('depth'): + out_depth_fine_lod1.append(render_out_lod1['depth'].detach().cpu().numpy()) + + # if render_out['color_coarse'] is not None: + if feasible('color_fine'): + out_rgb_fine_lod1.append(render_out_lod1['color_fine'].detach().cpu().numpy()) + if feasible('gradients') and feasible('weights'): + if render_out_lod1['inside_sphere'] is not None: + out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, + :self.n_samples_lod1 + self.n_importance_lod1, + None] * + render_out_lod1['inside_sphere'][ + ..., None]).sum(dim=1).detach().cpu().numpy()) + else: + out_normal_fine_lod1.append((render_out_lod1['gradients'] * render_out_lod1['weights'][:, + :self.n_samples_lod1 + self.n_importance_lod1, + None]).sum( + dim=1).detach().cpu().numpy()) + del render_out_lod1 + + # - save visualization of lod 0 + + self.save_visualization(true_img, true_depth_colored, out_depth_fine, out_normal_fine, + query_w2c[0], out_rgb_fine, H, W, + depth_min, depth_max, iter_step, meta, "val_lod0", true_depth=true_depth, scale_factor=scale_factor) + + if self.num_lods > 1: + self.save_visualization(true_img, true_depth_colored, out_depth_fine_lod1, out_normal_fine_lod1, + query_w2c[0], out_rgb_fine_lod1, H, W, + depth_min, depth_max, iter_step, meta, "val_lod1", true_depth=true_depth, scale_factor=scale_factor) + + # - extract mesh + if (iter_step % self.val_mesh_freq == 0): + torch.cuda.empty_cache() + self.validate_mesh(self.sdf_network_lod0, + self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, lod=0, + threshold=0, + # occupancy_mask=con_valid_mask_volume_lod0[0, 0], + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + torch.cuda.empty_cache() + + if self.num_lods > 1: + self.validate_mesh(self.sdf_network_lod1, + self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, lod=1, + # occupancy_mask=con_valid_mask_volume_lod1[0, 0].detach(), + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat) + + torch.cuda.empty_cache() + + + + def export_mesh_step(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + chunk_size=512, + save_vis=False, + ): + # * only support batch_size==1 + # ! attention: the list of string cannot be splited in DataParallel + batch_idx = sample['batch_idx'][0] + meta = sample['meta'][batch_idx] # the scan lighting ref_view info + + sizeW = sample['img_wh'][0][0] + sizeH = sample['img_wh'][0][1] + H, W = sizeH, sizeW + + partial_vol_origin = sample['partial_vol_origin'] # [B, 3] + near, far = sample['query_near_far'][0, :1], sample['query_near_far'][0, 1:] + + # the ray variables + sample_rays = sample['rays'] + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + rays_ndc_uv = sample_rays['rays_ndc_uv'][0] + + imgs = sample['images'][0] + intrinsics = sample['intrinsics'][0] + intrinsics_l_4x = intrinsics.clone() + intrinsics_l_4x[:, :2] *= 0.25 + w2cs = sample['w2cs'][0] + c2ws = sample['c2ws'][0] + # target_candidate_w2cs = sample['target_candidate_w2cs'][0] + proj_matrices = sample['affine_mats'] + + + # - the image to render + scale_mat = sample['scale_mat'] # [1,4,4] used to convert mesh into true scale + trans_mat = sample['trans_mat'] + query_c2w = sample['query_c2w'] # [1,4,4] + query_w2c = sample['query_w2c'] # [1,4,4] + true_img = sample['query_image'][0] + true_img = np.uint8(true_img.permute(1, 2, 0).cpu().numpy() * 255) + + depth_min, depth_max = near.cpu().numpy(), far.cpu().numpy() + + scale_factor = sample['scale_factor'][0].cpu().numpy() + true_depth = sample['query_depth'] if 'query_depth' in sample.keys() else None + if true_depth is not None: + true_depth = true_depth[0].cpu().numpy() + true_depth_colored = visualize_depth_numpy(true_depth, [depth_min, depth_max])[0] + else: + true_depth_colored = None + + rays_o = rays_o.reshape(-1, 3).split(chunk_size) + rays_d = rays_d.reshape(-1, 3).split(chunk_size) + # import time + # jha_begin1 = time.time() + # - obtain conditional features + with torch.no_grad(): + # - obtain conditional features + geometry_feature_maps = self.obtain_pyramid_feature_maps(imgs, lod=0) + # - lod 0 + conditional_features_lod0 = self.sdf_network_lod0.get_conditional_volume( + feature_maps=geometry_feature_maps[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + lod=0, + ) + # jha_end1 = time.time() + # print("get_conditional_volume: ", jha_end1 - jha_begin1) + # jha_begin2 = time.time() + con_volume_lod0 = conditional_features_lod0['dense_volume_scale0'] + con_valid_mask_volume_lod0 = conditional_features_lod0['valid_mask_volume_scale0'] + coords_lod0 = conditional_features_lod0['coords_scale0'] # [1,3,wX,wY,wZ] + + if self.num_lods > 1: + sdf_volume_lod0 = self.sdf_network_lod0.get_sdf_volume( + con_volume_lod0, con_valid_mask_volume_lod0, + coords_lod0, partial_vol_origin) # [1, 1, dX, dY, dZ] + + depth_maps_lod0, depth_masks_lod0 = None, None + + + if self.num_lods > 1: + geometry_feature_maps_lod1 = self.obtain_pyramid_feature_maps(imgs, lod=1) + + if self.prune_depth_filter: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf_depthfilter( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0], + depth_maps_lod0, proj_matrices[0], + partial_vol_origin, self.sdf_network_lod0.voxel_size, + near, far, self.sdf_network_lod0.voxel_size, 12) + else: + pre_coords, pre_feats = self.sdf_renderer_lod0.get_valid_sparse_coords_by_sdf( + sdf_volume_lod0[0], coords_lod0[0], con_valid_mask_volume_lod0[0], con_volume_lod0[0]) + + pre_coords[:, 1:] = pre_coords[:, 1:] * 2 + + with torch.no_grad(): + conditional_features_lod1 = self.sdf_network_lod1.get_conditional_volume( + feature_maps=geometry_feature_maps_lod1[None, :, :, :, :], + partial_vol_origin=partial_vol_origin, + proj_mats=proj_matrices, + sizeH=sizeH, + sizeW=sizeW, + pre_coords=pre_coords, + pre_feats=pre_feats, + ) + + con_volume_lod1 = conditional_features_lod1['dense_volume_scale1'] + con_valid_mask_volume_lod1 = conditional_features_lod1['valid_mask_volume_scale1'] + + out_rgb_fine = [] + out_normal_fine = [] + out_depth_fine = [] + + out_rgb_fine_lod1 = [] + out_normal_fine_lod1 = [] + out_depth_fine_lod1 = [] + + # jha_end2 = time.time() + # print("interval before starting mesh export: ", jha_end2 - jha_begin2) + + # - extract mesh + if (iter_step % self.val_mesh_freq == 0): + torch.cuda.empty_cache() + # jha_begin3 = time.time() + self.validate_colored_mesh( + density_or_sdf_network=self.sdf_network_lod0, + func_extract_geometry=self.sdf_renderer_lod0.extract_geometry, + conditional_volume=con_volume_lod0, + conditional_valid_mask_volume = con_valid_mask_volume_lod0, + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + target_candidate_w2cs=None, + intrinsics=intrinsics, + rendering_network=self.rendering_network_lod0, + rendering_projector=self.sdf_renderer_lod0.rendering_projector, + lod=0, + threshold=0, + query_c2w=query_c2w, + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat + ) + torch.cuda.empty_cache() + # jha_end3 = time.time() + # print("validate_colored_mesh_test_time: ", jha_end3 - jha_begin3) + + if self.num_lods > 1: + self.validate_colored_mesh( + density_or_sdf_network=self.sdf_network_lod1, + func_extract_geometry=self.sdf_renderer_lod1.extract_geometry, + conditional_volume=con_volume_lod1, + conditional_valid_mask_volume = con_valid_mask_volume_lod1, + feature_maps=geometry_feature_maps, + color_maps=imgs, + w2cs=w2cs, + target_candidate_w2cs=None, + intrinsics=intrinsics, + rendering_network=self.rendering_network_lod1, + rendering_projector=self.sdf_renderer_lod1.rendering_projector, + lod=1, + threshold=0, + query_c2w=query_c2w, + mode='val_bg', meta=meta, + iter_step=iter_step, scale_mat=scale_mat, trans_mat=trans_mat + ) + torch.cuda.empty_cache() + + + + def save_visualization(self, true_img, true_colored_depth, out_depth, out_normal, w2cs, out_color, H, W, + depth_min, depth_max, iter_step, meta, comment, out_color_mlp=[], true_depth=None, scale_factor=1.0): + if len(out_color) > 0: + img_fine = (np.concatenate(out_color, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) + + if len(out_color_mlp) > 0: + img_mlp = (np.concatenate(out_color_mlp, axis=0).reshape([H, W, 3]) * 256).clip(0, 255) + + if len(out_normal) > 0: + normal_img = np.concatenate(out_normal, axis=0) + rot = w2cs[:3, :3].detach().cpu().numpy() + # - convert normal from world space to camera space + normal_img = (np.matmul(rot[None, :, :], + normal_img[:, :, None]).reshape([H, W, 3]) * 128 + 128).clip(0, 255) + if len(out_depth) > 0: + pred_depth = np.concatenate(out_depth, axis=0).reshape([H, W]) + pred_depth_colored = visualize_depth_numpy(pred_depth, [depth_min, depth_max])[0] + + if len(out_depth) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'depths_' + comment), exist_ok=True) + if true_colored_depth is not None: + + if true_depth is not None: + depth_error_map = np.abs(true_depth - pred_depth) * 2.0 / scale_factor + # [256, 256, 1] -> [256, 256, 3] + depth_error_map = np.tile(depth_error_map[:, :, None], [1, 1, 3]) + print("meta: ", meta) + print("scale_factor: ", scale_factor) + print("depth_error_mean: ", depth_error_map.mean()) + depth_visualized = np.concatenate( + [(depth_error_map * 255).astype(np.uint8), true_colored_depth, pred_depth_colored, true_img], axis=1)[:, :, ::-1] + # print("depth_visualized.shape: ", depth_visualized.shape) + # write depth error result text on img, the input is a numpy array of [256, 1024, 3] + # cv.putText(depth_visualized.copy(), "depth_error_mean: {:.4f}".format(depth_error_map.mean()), (10, 30), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) + else: + depth_visualized = np.concatenate( + [true_colored_depth, pred_depth_colored, true_img])[:, :, ::-1] + cv.imwrite( + os.path.join(self.base_exp_dir, 'depths_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), depth_visualized + ) + else: + cv.imwrite( + os.path.join(self.base_exp_dir, 'depths_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [pred_depth_colored, true_img])[:, :, ::-1]) + if len(out_color) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [img_fine, true_img])[:, :, ::-1]) # bgr2rgb + # compute psnr (image pixel lie in [0, 255]) + mse_loss = np.mean((img_fine - true_img) ** 2) + psnr = 10 * np.log10(255 ** 2 / mse_loss) + + print("PSNR: ", psnr) + + if len(out_color_mlp) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'synthesized_color_mlp_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + np.concatenate( + [img_mlp, true_img])[:, :, ::-1]) # bgr2rgb + + if len(out_normal) > 0: + os.makedirs(os.path.join(self.base_exp_dir, 'normals_' + comment), exist_ok=True) + cv.imwrite(os.path.join(self.base_exp_dir, 'normals_' + comment, + '{:0>8d}_{}.png'.format(iter_step, meta)), + normal_img[:, :, ::-1]) + + def forward(self, sample, + perturb_overwrite=-1, + background_rgb=None, + alpha_inter_ratio_lod0=0.0, + alpha_inter_ratio_lod1=0.0, + iter_step=0, + mode='train', + save_vis=False, + ): + + if mode == 'train': + return self.train_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step + ) + elif mode == 'val': + import time + begin = time.time() + result = self.val_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step, + save_vis=save_vis, + ) + end = time.time() + print("val_step time: ", end - begin) + return result + elif mode == 'export_mesh': + import time + begin = time.time() + result = self.export_mesh_step(sample, + perturb_overwrite=perturb_overwrite, + background_rgb=background_rgb, + alpha_inter_ratio_lod0=alpha_inter_ratio_lod0, + alpha_inter_ratio_lod1=alpha_inter_ratio_lod1, + iter_step=iter_step, + save_vis=save_vis, + ) + end = time.time() + print("export mesh time: ", end - begin) + return result + def obtain_pyramid_feature_maps(self, imgs, lod=0): + """ + get feature maps of all conditional images + :param imgs: + :return: + """ + + if lod == 0: + extractor = self.pyramid_feature_network_geometry_lod0 + elif lod >= 1: + extractor = self.pyramid_feature_network_geometry_lod1 + + pyramid_feature_maps = extractor(imgs) + + # * the pyramid features are very important, if only use the coarst features, hard to optimize + fused_feature_maps = torch.cat([ + F.interpolate(pyramid_feature_maps[0], scale_factor=4, mode='bilinear', align_corners=True), + F.interpolate(pyramid_feature_maps[1], scale_factor=2, mode='bilinear', align_corners=True), + pyramid_feature_maps[2] + ], dim=1) + + return fused_feature_maps + + def cal_losses_sdf(self, render_out, sample_rays, iter_step=-1, lod=0): + + # loss weight schedule; the regularization terms should be added in later training stage + def get_weight(iter_step, weight): + if lod == 1: + anneal_start = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = anneal_end * 2 + else: + anneal_start = self.anneal_start if lod == 0 else self.anneal_start_lod1 + anneal_end = self.anneal_end if lod == 0 else self.anneal_end_lod1 + anneal_end = anneal_end * 2 + + if iter_step < 0: + return weight + + if anneal_end == 0.0: + return weight + elif iter_step < anneal_start: + return 0.0 + else: + return np.min( + [1.0, + (iter_step - anneal_start) / (anneal_end - anneal_start)]) * weight + + rays_o = sample_rays['rays_o'][0] + rays_d = sample_rays['rays_v'][0] + true_rgb = sample_rays['rays_color'][0] + + if 'rays_depth' in sample_rays.keys(): + true_depth = sample_rays['rays_depth'][0] + else: + true_depth = None + mask = sample_rays['rays_mask'][0] + + color_fine = render_out['color_fine'] + color_fine_mask = render_out['color_fine_mask'] + depth_pred = render_out['depth'] + + variance = render_out['variance'] + cdf_fine = render_out['cdf_fine'] + weight_sum = render_out['weights_sum'] + + gradient_error_fine = render_out['gradient_error_fine'] + + sdf = render_out['sdf'] + + # * color generated by mlp + color_mlp = render_out['color_mlp'] + color_mlp_mask = render_out['color_mlp_mask'] + + if color_fine is not None: + # Color loss + color_mask = color_fine_mask if color_fine_mask is not None else mask + color_mask = color_mask[..., 0] + color_error = (color_fine[color_mask] - true_rgb[color_mask]) + # print("Nan number", torch.isnan(color_error).sum()) + # print("Color error shape", color_error.shape) + color_fine_loss = F.l1_loss(color_error, torch.zeros_like(color_error).to(color_error.device), + reduction='mean') + # print(color_fine_loss) + psnr = 20.0 * torch.log10( + 1.0 / (((color_fine[color_mask] - true_rgb[color_mask]) ** 2).mean() / (3.0)).sqrt()) + else: + color_fine_loss = 0. + psnr = 0. + + if color_mlp is not None: + # Color loss + color_mlp_mask = color_mlp_mask[..., 0] + color_error_mlp = (color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) + color_mlp_loss = F.l1_loss(color_error_mlp, + torch.zeros_like(color_error_mlp).to(color_error_mlp.device), + reduction='mean') + + psnr_mlp = 20.0 * torch.log10( + 1.0 / (((color_mlp[color_mlp_mask] - true_rgb[color_mlp_mask]) ** 2).mean() / (3.0)).sqrt()) + else: + color_mlp_loss = 0. + psnr_mlp = 0. + + # depth loss is only used for inference, not included in total loss + if true_depth is not None: + # depth_loss = self.depth_criterion(depth_pred, true_depth, mask) + depth_loss = self.depth_criterion(depth_pred, true_depth) + + # # depth evaluation + # depth_statis = compute_depth_errors(depth_pred.detach().cpu().numpy(), true_depth.cpu().numpy()) + # depth_statis = numpy2tensor(depth_statis, device=rays_o.device) + depth_statis = None + else: + depth_loss = 0. + depth_statis = None + + sparse_loss_1 = torch.exp( + -1 * torch.abs(render_out['sdf_random']) * self.sdf_decay_param).mean() # - should equal + sparse_loss_2 = torch.exp(-1 * torch.abs(sdf) * self.sdf_decay_param).mean() + sparse_loss = (sparse_loss_1 + sparse_loss_2) / 2 + + sdf_mean = torch.abs(sdf).mean() + sparseness_1 = (torch.abs(sdf) < 0.01).to(torch.float32).mean() + sparseness_2 = (torch.abs(sdf) < 0.02).to(torch.float32).mean() + + # Eikonal loss + gradient_error_loss = gradient_error_fine + + # ! the first 50k, don't use bg constraint + fg_bg_weight = 0.0 if iter_step < 50000 else get_weight(iter_step, self.fg_bg_weight) + + # Mask loss, optional + # The images of DTU dataset contain large black regions (0 rgb values), + # can use this data prior to make fg more clean + background_loss = 0.0 + fg_bg_loss = 0.0 + if self.fg_bg_weight > 0 and torch.mean((mask < 0.5).to(torch.float32)) > 0.02: + weights_sum_fg = render_out['weights_sum_fg'] + fg_bg_error = (weights_sum_fg - mask)[mask < 0.5] + fg_bg_loss = F.l1_loss(fg_bg_error, + torch.zeros_like(fg_bg_error).to(fg_bg_error.device), + reduction='mean') + + + + loss = 1.0 * depth_loss + color_fine_loss + color_mlp_loss + \ + sparse_loss * get_weight(iter_step, self.sdf_sparse_weight) + \ + fg_bg_loss * fg_bg_weight + \ + gradient_error_loss * self.sdf_igr_weight # ! gradient_error_loss need a mask + + losses = { + "loss": loss, + "depth_loss": depth_loss, + "color_fine_loss": color_fine_loss, + "color_mlp_loss": color_mlp_loss, + "gradient_error_loss": gradient_error_loss, + "background_loss": background_loss, + "sparse_loss": sparse_loss, + "sparseness_1": sparseness_1, + "sparseness_2": sparseness_2, + "sdf_mean": sdf_mean, + "psnr": psnr, + "psnr_mlp": psnr_mlp, + "weights_sum": render_out['weights_sum'], + "weights_sum_fg": render_out['weights_sum_fg'], + "alpha_sum": render_out['alpha_sum'], + "variance": render_out['variance'], + "sparse_weight": get_weight(iter_step, self.sdf_sparse_weight), + "fg_bg_weight": fg_bg_weight, + "fg_bg_loss": fg_bg_loss, # added by jha, bug of sparseNeuS + } + losses = torch.tensor(losses, device=rays_o.device) + return loss, losses, depth_statis + + @torch.no_grad() + def validate_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, + threshold=0.0, mode='val', + # * 3d feature volume + conditional_volume=None, lod=None, occupancy_mask=None, + bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, + trans_mat=None + ): + + bound_min = torch.tensor(bound_min, dtype=torch.float32) + bound_max = torch.tensor(bound_max, dtype=torch.float32) + + vertices, triangles, fields = func_extract_geometry( + density_or_sdf_network, + bound_min, bound_max, resolution=resolution, + threshold=threshold, device=conditional_volume.device, + # * 3d feature volume + conditional_volume=conditional_volume, lod=lod, + occupancy_mask=occupancy_mask + ) + + + if scale_mat is not None: + scale_mat_np = scale_mat.cpu().numpy() + vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] + + if trans_mat is not None: # w2c_ref_inv + trans_mat_np = trans_mat.cpu().numpy() + vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) + vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] + + mesh = trimesh.Trimesh(vertices, triangles) + os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode), exist_ok=True) + mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, + 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) + + + + def validate_colored_mesh(self, density_or_sdf_network, func_extract_geometry, world_space=True, resolution=360, + threshold=0.0, mode='val', + # * 3d feature volume + conditional_volume=None, + conditional_valid_mask_volume=None, + feature_maps=None, + color_maps = None, + w2cs=None, + target_candidate_w2cs=None, + intrinsics=None, + rendering_network=None, + rendering_projector=None, + query_c2w=None, + lod=None, occupancy_mask=None, + bound_min=[-1, -1, -1], bound_max=[1, 1, 1], meta='', iter_step=0, scale_mat=None, + trans_mat=None + ): + + bound_min = torch.tensor(bound_min, dtype=torch.float32) + bound_max = torch.tensor(bound_max, dtype=torch.float32) + # import time + # jha_begin4 = time.time() + vertices, triangles, fields = func_extract_geometry( + density_or_sdf_network, + bound_min, bound_max, resolution=resolution, + threshold=threshold, device=conditional_volume.device, + # * 3d feature volume + conditional_volume=conditional_volume, lod=lod, + occupancy_mask=occupancy_mask + ) + # jha_end4 = time.time() + # print("[TEST]: func_extract_geometry time", jha_end4 - jha_begin4) + + # import time + # jha_begin5 = time.time() + + + with torch.no_grad(): + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask, _, _ = rendering_projector.compute_view_independent( + torch.tensor(vertices).to(conditional_volume), + lod=lod, # JHA EDITED + # * 3d geometry feature volumes + geometryVolume=conditional_volume[0], + geometryVolumeMask=conditional_valid_mask_volume[0], + sdf_network=density_or_sdf_network, + # * 2d rendering feature maps + rendering_feature_maps=feature_maps, # [n_view, 56, 256, 256] + color_maps=color_maps, + w2cs=w2cs, + target_candidate_w2cs=target_candidate_w2cs, + intrinsics=intrinsics, + img_wh=[256,256], + query_img_idx=0, # the index of the N_views dim for rendering + query_c2w=query_c2w, + ) + + + vertices_color, rendering_valid_mask = rendering_network( + ren_geo_feats, ren_rgb_feats, ren_ray_diff, ren_mask) + + # jha_end5 = time.time() + # print("[TEST]: rendering_network time", jha_end5 - jha_begin5) + + if scale_mat is not None: + scale_mat_np = scale_mat.cpu().numpy() + vertices = vertices * scale_mat_np[0][0, 0] + scale_mat_np[0][:3, 3][None] + + if trans_mat is not None: # w2c_ref_inv + trans_mat_np = trans_mat.cpu().numpy() + vertices_homo = np.concatenate([vertices, np.ones_like(vertices[:, :1])], axis=1) + vertices = np.matmul(trans_mat_np, vertices_homo[:, :, None])[:, :3, 0] + + vertices_color = np.array(vertices_color.squeeze(0).cpu() * 255, dtype=np.uint8) + mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertices_color) + os.makedirs(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod)), exist_ok=True) + # mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod), + # 'mesh_{:0>8d}_{}_lod{:0>1d}.ply'.format(iter_step, meta, lod))) + # MODIFIED + mesh.export(os.path.join(self.base_exp_dir, 'meshes_' + mode, 'lod{:0>1d}'.format(lod), + 'mesh_{:0>8d}_gradio_lod{:0>1d}.ply'.format(iter_step, lod))) \ No newline at end of file diff --git a/SparseNeuS_demo_v1/ops/__init__.py b/SparseNeuS_demo_v1/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/ops/back_project.py b/SparseNeuS_demo_v1/ops/back_project.py new file mode 100644 index 0000000000000000000000000000000000000000..5398f285f786a0e6c7a029138aa8a6554aae6e58 --- /dev/null +++ b/SparseNeuS_demo_v1/ops/back_project.py @@ -0,0 +1,175 @@ +import torch +from torch.nn.functional import grid_sample + + +def back_project_sparse_type(coords, origin, voxel_size, feats, KRcam, sizeH=None, sizeW=None, only_mask=False, + with_proj_z=False): + # - modified version from NeuRecon + ''' + Unproject the image fetures to form a 3D (sparse) feature volume + + :param coords: coordinates of voxels, + dim: (num of voxels, 4) (4 : batch ind, x, y, z) + :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) + dim: (batch size, 3) (3: x, y, z) + :param voxel_size: floats specifying the size of a voxel + :param feats: image features + dim: (num of views, batch size, C, H, W) + :param KRcam: projection matrix + dim: (num of views, batch size, 4, 4) + :return: feature_volume_all: 3D feature volumes + dim: (num of voxels, num_of_views, c) + :return: mask_volume_all: indicate the voxel of sampled feature volume is valid or not + dim: (num of voxels, num_of_views) + ''' + n_views, bs, c, h, w = feats.shape + device = feats.device + + if sizeH is None: + sizeH, sizeW = h, w # - if the KRcam is not suitable for the current feats + + feature_volume_all = torch.zeros(coords.shape[0], n_views, c).to(device) + mask_volume_all = torch.zeros([coords.shape[0], n_views], dtype=torch.int32).to(device) + # import ipdb; ipdb.set_trace() + for batch in range(bs): + # import ipdb; ipdb.set_trace() + batch_ind = torch.nonzero(coords[:, 0] == batch).squeeze(1) + coords_batch = coords[batch_ind][:, 1:] + + coords_batch = coords_batch.view(-1, 3) + origin_batch = origin[batch].unsqueeze(0) + feats_batch = feats[:, batch] + proj_batch = KRcam[:, batch] + + grid_batch = coords_batch * voxel_size + origin_batch.float() + rs_grid = grid_batch.unsqueeze(0).expand(n_views, -1, -1) + rs_grid = rs_grid.permute(0, 2, 1).contiguous() + nV = rs_grid.shape[-1] + rs_grid = torch.cat([rs_grid, torch.ones([n_views, 1, nV]).to(device)], dim=1) + + # Project grid + im_p = proj_batch @ rs_grid # - transform world pts to image UV space + im_x, im_y, im_z = im_p[:, 0], im_p[:, 1], im_p[:, 2] + + im_z[im_z >= 0] = im_z[im_z >= 0].clamp(min=1e-6) + + im_x = im_x / im_z + im_y = im_y / im_z + + im_grid = torch.stack([2 * im_x / (sizeW - 1) - 1, 2 * im_y / (sizeH - 1) - 1], dim=-1) + mask = im_grid.abs() <= 1 + mask = (mask.sum(dim=-1) == 2) & (im_z > 0) + + mask = mask.view(n_views, -1) + mask = mask.permute(1, 0).contiguous() # [num_pts, nviews] + + mask_volume_all[batch_ind] = mask.to(torch.int32) + + if only_mask: + return mask_volume_all + + feats_batch = feats_batch.view(n_views, c, h, w) + im_grid = im_grid.view(n_views, 1, -1, 2) + features = grid_sample(feats_batch, im_grid, padding_mode='zeros', align_corners=True) + # if features.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + features = features.view(n_views, c, -1) + features = features.permute(2, 0, 1).contiguous() # [num_pts, nviews, c] + + feature_volume_all[batch_ind] = features + + if with_proj_z: + im_z = im_z.view(n_views, 1, -1).permute(2, 0, 1).contiguous() # [num_pts, nviews, 1] + return feature_volume_all, mask_volume_all, im_z + # if feature_volume_all.isnan().sum() > 0: + # import ipdb; ipdb.set_trace() + return feature_volume_all, mask_volume_all + + +def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode, sizeH=None, sizeW=None, with_depth=False): + """Transform coordinates in the camera frame to the pixel frame. + Args: + cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 3, H, W] + proj_c2p_rot: rotation matrix of cameras -- [B, 3, 3] + proj_c2p_tr: translation vectors of cameras -- [B, 3, 1] + Returns: + array of [-1,1] coordinates -- [B, H, W, 2] + """ + b, _, h, w = cam_coords.size() + if sizeH is None: + sizeH = h + sizeW = w + + cam_coords_flat = cam_coords.view(b, 3, -1) # [B, 3, H*W] + if proj_c2p_rot is not None: + pcoords = proj_c2p_rot.bmm(cam_coords_flat) + else: + pcoords = cam_coords_flat + + if proj_c2p_tr is not None: + pcoords = pcoords + proj_c2p_tr # [B, 3, H*W] + X = pcoords[:, 0] + Y = pcoords[:, 1] + Z = pcoords[:, 2].clamp(min=1e-3) + + X_norm = 2 * (X / Z) / (sizeW - 1) - 1 # Normalized, -1 if on extreme left, + # 1 if on extreme right (x = w-1) [B, H*W] + Y_norm = 2 * (Y / Z) / (sizeH - 1) - 1 # Idem [B, H*W] + if padding_mode == 'zeros': + X_mask = ((X_norm > 1) + (X_norm < -1)).detach() + X_norm[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray + Y_mask = ((Y_norm > 1) + (Y_norm < -1)).detach() + Y_norm[Y_mask] = 2 + + if with_depth: + pixel_coords = torch.stack([X_norm, Y_norm, Z], dim=2) # [B, H*W, 3] + return pixel_coords.view(b, h, w, 3) + else: + pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2] + return pixel_coords.view(b, h, w, 2) + + +# * have already checked, should check whether proj_matrix is for right coordinate system and resolution +def back_project_dense_type(coords, origin, voxel_size, feats, proj_matrix, sizeH=None, sizeW=None): + ''' + Unproject the image fetures to form a 3D (dense) feature volume + + :param coords: coordinates of voxels, + dim: (batch, nviews, 3, X,Y,Z) + :param origin: origin of the partial voxel volume (xyz position of voxel (0, 0, 0)) + dim: (batch size, 3) (3: x, y, z) + :param voxel_size: floats specifying the size of a voxel + :param feats: image features + dim: (batch size, num of views, C, H, W) + :param proj_matrix: projection matrix + dim: (batch size, num of views, 4, 4) + :return: feature_volume_all: 3D feature volumes + dim: (batch, nviews, C, X,Y,Z) + :return: count: number of times each voxel can be seen + dim: (batch, nviews, 1, X,Y,Z) + ''' + + batch, nviews, _, wX, wY, wZ = coords.shape + + if sizeH is None: + sizeH, sizeW = feats.shape[-2:] + proj_matrix = proj_matrix.view(batch * nviews, *proj_matrix.shape[2:]) + + coords_wrd = coords * voxel_size + origin.view(batch, 1, 3, 1, 1, 1) + coords_wrd = coords_wrd.view(batch * nviews, 3, wX * wY * wZ, 1) # (b*nviews,3,wX*wY*wZ, 1) + + pixel_grids = cam2pixel(coords_wrd, proj_matrix[:, :3, :3], proj_matrix[:, :3, 3:], + 'zeros', sizeH=sizeH, sizeW=sizeW) # (b*nviews,wX*wY*wZ, 2) + pixel_grids = pixel_grids.view(batch * nviews, 1, wX * wY * wZ, 2) + + feats = feats.view(batch * nviews, *feats.shape[2:]) # (b*nviews,c,h,w) + + ones = torch.ones((batch * nviews, 1, *feats.shape[2:])).to(feats.dtype).to(feats.device) + + features_volume = torch.nn.functional.grid_sample(feats, pixel_grids, padding_mode='zeros', align_corners=True) + counts_volume = torch.nn.functional.grid_sample(ones, pixel_grids, padding_mode='zeros', align_corners=True) + + features_volume = features_volume.view(batch, nviews, -1, wX, wY, wZ) # (batch, nviews, C, X,Y,Z) + counts_volume = counts_volume.view(batch, nviews, -1, wX, wY, wZ) + return features_volume, counts_volume + diff --git a/SparseNeuS_demo_v1/ops/generate_grids.py b/SparseNeuS_demo_v1/ops/generate_grids.py new file mode 100644 index 0000000000000000000000000000000000000000..304c1c4c1a424c4bc219f39815ed43fea1d9de5d --- /dev/null +++ b/SparseNeuS_demo_v1/ops/generate_grids.py @@ -0,0 +1,33 @@ +import torch + + +def generate_grid(n_vox, interval): + """ + generate grid + if 3D volume, grid[:,:,x,y,z] = (x,y,z) + :param n_vox: + :param interval: + :return: + """ + with torch.no_grad(): + # Create voxel grid + grid_range = [torch.arange(0, n_vox[axis], interval) for axis in range(3)] + grid = torch.stack(torch.meshgrid(grid_range[0], grid_range[1], grid_range[2], indexing="ij")) # 3 dx dy dz + # ! don't create tensor on gpu; imbalanced gpu memory in ddp mode + grid = grid.unsqueeze(0).type(torch.float32) # 1 3 dx dy dz + + return grid + + +if __name__ == "__main__": + import torch.nn.functional as F + grid = generate_grid([5, 6, 8], 1) + + pts = 2 * torch.tensor([1, 2, 3]) / (torch.tensor([5, 6, 8]) - 1) - 1 + pts = pts.view(1, 1, 1, 1, 3) + + pts = torch.flip(pts, dims=[-1]) + + sampled = F.grid_sample(grid, pts, mode='nearest') + + print(sampled) diff --git a/SparseNeuS_demo_v1/ops/grid_sampler.py b/SparseNeuS_demo_v1/ops/grid_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..44113faa705f0b98a5689c0e4fb9e7a95865d6c1 --- /dev/null +++ b/SparseNeuS_demo_v1/ops/grid_sampler.py @@ -0,0 +1,467 @@ +""" +pytorch grid_sample doesn't support second-order derivative +implement custom version +""" + +import torch +import torch.nn.functional as F +import numpy as np + + +def grid_sample_2d(image, optical): + N, C, IH, IW = image.shape + _, H, W, _ = optical.shape + + ix = optical[..., 0] + iy = optical[..., 1] + + ix = ((ix + 1) / 2) * (IW - 1); + iy = ((iy + 1) / 2) * (IH - 1); + with torch.no_grad(): + ix_nw = torch.floor(ix); + iy_nw = torch.floor(iy); + ix_ne = ix_nw + 1; + iy_ne = iy_nw; + ix_sw = ix_nw; + iy_sw = iy_nw + 1; + ix_se = ix_nw + 1; + iy_se = iy_nw + 1; + + nw = (ix_se - ix) * (iy_se - iy) + ne = (ix - ix_sw) * (iy_sw - iy) + sw = (ix_ne - ix) * (iy - iy_ne) + se = (ix - ix_nw) * (iy - iy_nw) + + with torch.no_grad(): + torch.clamp(ix_nw, 0, IW - 1, out=ix_nw) + torch.clamp(iy_nw, 0, IH - 1, out=iy_nw) + + torch.clamp(ix_ne, 0, IW - 1, out=ix_ne) + torch.clamp(iy_ne, 0, IH - 1, out=iy_ne) + + torch.clamp(ix_sw, 0, IW - 1, out=ix_sw) + torch.clamp(iy_sw, 0, IH - 1, out=iy_sw) + + torch.clamp(ix_se, 0, IW - 1, out=ix_se) + torch.clamp(iy_se, 0, IH - 1, out=iy_se) + + image = image.view(N, C, IH * IW) + + nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1)) + ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1)) + sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1)) + se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1)) + + out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + + ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) + + sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) + + se_val.view(N, C, H, W) * se.view(N, 1, H, W)) + + return out_val + + +# - checked for correctness +def grid_sample_3d(volume, optical): + """ + bilinear sampling cannot guarantee continuous first-order gradient + mimic pytorch grid_sample function + The 8 corner points of a volume noted as: 4 points (front view); 4 points (back view) + fnw (front north west) point + bse (back south east) point + :param volume: [B, C, X, Y, Z] + :param optical: [B, x, y, z, 3] + :return: + """ + N, C, ID, IH, IW = volume.shape + _, D, H, W, _ = optical.shape + + ix = optical[..., 0] + iy = optical[..., 1] + iz = optical[..., 2] + + ix = ((ix + 1) / 2) * (IW - 1) + iy = ((iy + 1) / 2) * (IH - 1) + iz = ((iz + 1) / 2) * (ID - 1) + + mask_x = (ix > 0) & (ix < IW) + mask_y = (iy > 0) & (iy < IH) + mask_z = (iz > 0) & (iz < ID) + + mask = mask_x & mask_y & mask_z # [B, x, y, z] + mask = mask[:, None, :, :, :].repeat(1, C, 1, 1, 1) # [B, C, x, y, z] + + with torch.no_grad(): + # back north west + ix_bnw = torch.floor(ix) + iy_bnw = torch.floor(iy) + iz_bnw = torch.floor(iz) + + ix_bne = ix_bnw + 1 + iy_bne = iy_bnw + iz_bne = iz_bnw + + ix_bsw = ix_bnw + iy_bsw = iy_bnw + 1 + iz_bsw = iz_bnw + + ix_bse = ix_bnw + 1 + iy_bse = iy_bnw + 1 + iz_bse = iz_bnw + + # front view + ix_fnw = ix_bnw + iy_fnw = iy_bnw + iz_fnw = iz_bnw + 1 + + ix_fne = ix_bnw + 1 + iy_fne = iy_bnw + iz_fne = iz_bnw + 1 + + ix_fsw = ix_bnw + iy_fsw = iy_bnw + 1 + iz_fsw = iz_bnw + 1 + + ix_fse = ix_bnw + 1 + iy_fse = iy_bnw + 1 + iz_fse = iz_bnw + 1 + + # back view + bnw = (ix_fse - ix) * (iy_fse - iy) * (iz_fse - iz) # smaller volume, larger weight + bne = (ix - ix_fsw) * (iy_fsw - iy) * (iz_fsw - iz) + bsw = (ix_fne - ix) * (iy - iy_fne) * (iz_fne - iz) + bse = (ix - ix_fnw) * (iy - iy_fnw) * (iz_fnw - iz) + + # front view + fnw = (ix_bse - ix) * (iy_bse - iy) * (iz - iz_bse) # smaller volume, larger weight + fne = (ix - ix_bsw) * (iy_bsw - iy) * (iz - iz_bsw) + fsw = (ix_bne - ix) * (iy - iy_bne) * (iz - iz_bne) + fse = (ix - ix_bnw) * (iy - iy_bnw) * (iz - iz_bnw) + + with torch.no_grad(): + # back view + torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw) + torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw) + torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw) + + torch.clamp(ix_bne, 0, IW - 1, out=ix_bne) + torch.clamp(iy_bne, 0, IH - 1, out=iy_bne) + torch.clamp(iz_bne, 0, ID - 1, out=iz_bne) + + torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw) + torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw) + torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw) + + torch.clamp(ix_bse, 0, IW - 1, out=ix_bse) + torch.clamp(iy_bse, 0, IH - 1, out=iy_bse) + torch.clamp(iz_bse, 0, ID - 1, out=iz_bse) + + # front view + torch.clamp(ix_fnw, 0, IW - 1, out=ix_fnw) + torch.clamp(iy_fnw, 0, IH - 1, out=iy_fnw) + torch.clamp(iz_fnw, 0, ID - 1, out=iz_fnw) + + torch.clamp(ix_fne, 0, IW - 1, out=ix_fne) + torch.clamp(iy_fne, 0, IH - 1, out=iy_fne) + torch.clamp(iz_fne, 0, ID - 1, out=iz_fne) + + torch.clamp(ix_fsw, 0, IW - 1, out=ix_fsw) + torch.clamp(iy_fsw, 0, IH - 1, out=iy_fsw) + torch.clamp(iz_fsw, 0, ID - 1, out=iz_fsw) + + torch.clamp(ix_fse, 0, IW - 1, out=ix_fse) + torch.clamp(iy_fse, 0, IH - 1, out=iy_fse) + torch.clamp(iz_fse, 0, ID - 1, out=iz_fse) + + # xxx = volume[:, :, iz_bnw.long(), iy_bnw.long(), ix_bnw.long()] + volume = volume.view(N, C, ID * IH * IW) + # yyy = volume[:, :, (iz_bnw * ID + iy_bnw * IW + ix_bnw).long()] + + # back view + bnw_val = torch.gather(volume, 2, + (iz_bnw * ID ** 2 + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) + bne_val = torch.gather(volume, 2, + (iz_bne * ID ** 2 + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1)) + bsw_val = torch.gather(volume, 2, + (iz_bsw * ID ** 2 + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) + bse_val = torch.gather(volume, 2, + (iz_bse * ID ** 2 + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1)) + + # front view + fnw_val = torch.gather(volume, 2, + (iz_fnw * ID ** 2 + iy_fnw * IW + ix_fnw).long().view(N, 1, D * H * W).repeat(1, C, 1)) + fne_val = torch.gather(volume, 2, + (iz_fne * ID ** 2 + iy_fne * IW + ix_fne).long().view(N, 1, D * H * W).repeat(1, C, 1)) + fsw_val = torch.gather(volume, 2, + (iz_fsw * ID ** 2 + iy_fsw * IW + ix_fsw).long().view(N, 1, D * H * W).repeat(1, C, 1)) + fse_val = torch.gather(volume, 2, + (iz_fse * ID ** 2 + iy_fse * IW + ix_fse).long().view(N, 1, D * H * W).repeat(1, C, 1)) + + out_val = ( + # back + bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) + + bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) + + bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) + + bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W) + + # front + fnw_val.view(N, C, D, H, W) * fnw.view(N, 1, D, H, W) + + fne_val.view(N, C, D, H, W) * fne.view(N, 1, D, H, W) + + fsw_val.view(N, C, D, H, W) * fsw.view(N, 1, D, H, W) + + fse_val.view(N, C, D, H, W) * fse.view(N, 1, D, H, W) + + ) + + # * zero padding + out_val = torch.where(mask, out_val, torch.zeros_like(out_val).float().to(out_val.device)) + + return out_val + + +# Interpolation kernel +def get_weight(s, a=-0.5): + mask_0 = (torch.abs(s) >= 0) & (torch.abs(s) <= 1) + mask_1 = (torch.abs(s) > 1) & (torch.abs(s) <= 2) + mask_2 = torch.abs(s) > 2 + + weight = torch.zeros_like(s).to(s.device) + weight = torch.where(mask_0, (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1, weight) + weight = torch.where(mask_1, + a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a, + weight) + + # if (torch.abs(s) >= 0) & (torch.abs(s) <= 1): + # return (a + 2) * (torch.abs(s) ** 3) - (a + 3) * (torch.abs(s) ** 2) + 1 + # + # elif (torch.abs(s) > 1) & (torch.abs(s) <= 2): + # return a * (torch.abs(s) ** 3) - (5 * a) * (torch.abs(s) ** 2) + (8 * a) * torch.abs(s) - 4 * a + # return 0 + + return weight + + +def cubic_interpolate(p, x): + """ + one dimensional cubic interpolation + :param p: [N, 4] (4) should be in order + :param x: [N] + :return: + """ + return p[:, 1] + 0.5 * x * (p[:, 2] - p[:, 0] + x * ( + 2.0 * p[:, 0] - 5.0 * p[:, 1] + 4.0 * p[:, 2] - p[:, 3] + x * ( + 3.0 * (p[:, 1] - p[:, 2]) + p[:, 3] - p[:, 0]))) + + +def bicubic_interpolate(p, x, y, if_batch=True): + """ + two dimensional cubic interpolation + :param p: [N, 4, 4] + :param x: [N] + :param y: [N] + :return: + """ + num = p.shape[0] + + if not if_batch: + arr0 = cubic_interpolate(p[:, 0, :], x) # [N] + arr1 = cubic_interpolate(p[:, 1, :], x) + arr2 = cubic_interpolate(p[:, 2, :], x) + arr3 = cubic_interpolate(p[:, 3, :], x) + return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), y) # [N] + else: + x = x[:, None].repeat(1, 4).view(-1) + p = p.contiguous().view(num * 4, 4) + arr = cubic_interpolate(p, x) + arr = arr.view(num, 4) + + return cubic_interpolate(arr, y) + + +def tricubic_interpolate(p, x, y, z): + """ + three dimensional cubic interpolation + :param p: [N,4,4,4] + :param x: [N] + :param y: [N] + :param z: [N] + :return: + """ + num = p.shape[0] + + arr0 = bicubic_interpolate(p[:, 0, :, :], x, y) # [N] + arr1 = bicubic_interpolate(p[:, 1, :, :], x, y) + arr2 = bicubic_interpolate(p[:, 2, :, :], x, y) + arr3 = bicubic_interpolate(p[:, 3, :, :], x, y) + + return cubic_interpolate(torch.stack([arr0, arr1, arr2, arr3], dim=-1), z) # [N] + + +def cubic_interpolate_batch(p, x): + """ + one dimensional cubic interpolation + :param p: [B, N, 4] (4) should be in order + :param x: [B, N] + :return: + """ + return p[:, :, 1] + 0.5 * x * (p[:, :, 2] - p[:, :, 0] + x * ( + 2.0 * p[:, :, 0] - 5.0 * p[:, :, 1] + 4.0 * p[:, :, 2] - p[:, :, 3] + x * ( + 3.0 * (p[:, :, 1] - p[:, :, 2]) + p[:, :, 3] - p[:, :, 0]))) + + +def bicubic_interpolate_batch(p, x, y): + """ + two dimensional cubic interpolation + :param p: [B, N, 4, 4] + :param x: [B, N] + :param y: [B, N] + :return: + """ + B, N, _, _ = p.shape + + x = x[:, :, None].repeat(1, 1, 4).view(B, N * 4) # [B, N*4] + arr = cubic_interpolate_batch(p.contiguous().view(B, N * 4, 4), x) + arr = arr.view(B, N, 4) + return cubic_interpolate_batch(arr, y) # [B, N] + + +# * batch version cannot speed up training +def tricubic_interpolate_batch(p, x, y, z): + """ + three dimensional cubic interpolation + :param p: [N,4,4,4] + :param x: [N] + :param y: [N] + :param z: [N] + :return: + """ + N = p.shape[0] + + x = x[None, :].repeat(4, 1) + y = y[None, :].repeat(4, 1) + + p = p.permute(1, 0, 2, 3).contiguous() + + arr = bicubic_interpolate_batch(p[:, :, :, :], x, y) # [4, N] + + arr = arr.permute(1, 0).contiguous() # [N, 4] + + return cubic_interpolate(arr, z) # [N] + + +def tricubic_sample_3d(volume, optical): + """ + tricubic sampling; can guarantee continuous gradient (interpolation border) + :param volume: [B, C, ID, IH, IW] + :param optical: [B, D, H, W, 3] + :param sample_num: + :return: + """ + + @torch.no_grad() + def get_shifts(x): + x1 = -1 * (1 + x - torch.floor(x)) + x2 = -1 * (x - torch.floor(x)) + x3 = torch.floor(x) + 1 - x + x4 = torch.floor(x) + 2 - x + + return torch.stack([x1, x2, x3, x4], dim=-1) # (B,d,h,w,4) + + N, C, ID, IH, IW = volume.shape + _, D, H, W, _ = optical.shape + + device = volume.device + + ix = optical[..., 0] + iy = optical[..., 1] + iz = optical[..., 2] + + ix = ((ix + 1) / 2) * (IW - 1) # (B,d,h,w) + iy = ((iy + 1) / 2) * (IH - 1) + iz = ((iz + 1) / 2) * (ID - 1) + + ix = ix.view(-1) + iy = iy.view(-1) + iz = iz.view(-1) + + with torch.no_grad(): + shifts_x = get_shifts(ix).view(-1, 4) # (B*d*h*w,4) + shifts_y = get_shifts(iy).view(-1, 4) + shifts_z = get_shifts(iz).view(-1, 4) + + perm_weights = torch.ones([N * D * H * W, 4 * 4 * 4]).long().to(device) + perm = torch.cumsum(perm_weights, dim=-1) - 1 # (B*d*h*w,64) + + perm_z = perm // 16 # [N*D*H*W, num] + perm_y = (perm - perm_z * 16) // 4 + perm_x = (perm - perm_z * 16 - perm_y * 4) + + shifts_x = torch.gather(shifts_x, 1, perm_x) # [N*D*H*W, num] + shifts_y = torch.gather(shifts_y, 1, perm_y) + shifts_z = torch.gather(shifts_z, 1, perm_z) + + ix_target = (ix[:, None] + shifts_x).long() # [N*D*H*W, num] + iy_target = (iy[:, None] + shifts_y).long() + iz_target = (iz[:, None] + shifts_z).long() + + torch.clamp(ix_target, 0, IW - 1, out=ix_target) + torch.clamp(iy_target, 0, IH - 1, out=iy_target) + torch.clamp(iz_target, 0, ID - 1, out=iz_target) + + local_dist_x = ix - ix_target[:, 1] # ! attention here is [:, 1] + local_dist_y = iy - iy_target[:, 1 + 4] + local_dist_z = iz - iz_target[:, 1 + 16] + + local_dist_x = local_dist_x.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) + local_dist_y = local_dist_y.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) + local_dist_z = local_dist_z.view(N, 1, D * H * W).repeat(1, C, 1).view(-1) + + # ! attention: IW is correct + idx_target = iz_target * ID ** 2 + iy_target * IW + ix_target # [N*D*H*W, num] + + volume = volume.view(N, C, ID * IH * IW) + + out = torch.gather(volume, 2, + idx_target.view(N, 1, D * H * W * 64).repeat(1, C, 1)) + out = out.view(N * C * D * H * W, 4, 4, 4) + + # - tricubic_interpolate() is a bit faster than tricubic_interpolate_batch() + final = tricubic_interpolate(out, local_dist_x, local_dist_y, local_dist_z).view(N, C, D, H, W) # [N,C,D,H,W] + + return final + + + +if __name__ == "__main__": + # image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3) + # + # optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2) + # + # print(grid_sample_2d(image, optical)) + # + # print(F.grid_sample(image, optical, padding_mode='border', align_corners=True)) + + from ops.generate_grids import generate_grid + + p = torch.tensor([x for x in range(4)]).view(1, 4).float() + + v = cubic_interpolate(p, torch.tensor([0.5]).view(1)) + # v = bicubic_interpolate(p, torch.tensor([2/3]).view(1) , torch.tensor([2/3]).view(1)) + + vsize = 9 + volume = generate_grid([vsize, vsize, vsize], 1) # [1,3,10,10,10] + # volume = torch.tensor([x for x in range(1000)]).view(1, 1, 10, 10, 10).float() + X, Y, Z = 0, 0, 6 + x = 2 * X / (vsize - 1) - 1 + y = 2 * Y / (vsize - 1) - 1 + z = 2 * Z / (vsize - 1) - 1 + + # print(volume[:, :, Z, Y, X]) + + # volume = volume.view(1, 3, -1) + # xx = volume[:, :, Z * 9*9 + Y * 9 + X] + + optical = torch.Tensor([-0.6, -0.7, 0.5, 0.3, 0.5, 0.5]).view(1, 1, 1, 2, 3) + + print(F.grid_sample(volume, optical, padding_mode='border', align_corners=True)) + print(grid_sample_3d(volume, optical)) + print(tricubic_sample_3d(volume, optical)) + # target, relative_coords = implicit_sample_3d(volume, optical, 1) + # print(target) diff --git a/SparseNeuS_demo_v1/tsparse/__init__.py b/SparseNeuS_demo_v1/tsparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/tsparse/modules.py b/SparseNeuS_demo_v1/tsparse/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..520809144718d84b77708bbc7a582a64078958b4 --- /dev/null +++ b/SparseNeuS_demo_v1/tsparse/modules.py @@ -0,0 +1,326 @@ +import torch +import torch.nn as nn +import torchsparse +import torchsparse.nn as spnn +from torchsparse.tensor import PointTensor + +from tsparse.torchsparse_utils import * + + +# __all__ = ['SPVCNN', 'SConv3d', 'SparseConvGRU'] + + +class ConvBnReLU(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size=3, stride=1, pad=1): + super(ConvBnReLU, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, + kernel_size, stride=stride, padding=pad, bias=False) + self.bn = nn.BatchNorm2d(out_channels) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + return self.activation(self.bn(self.conv(x))) + + +class ConvBnReLU3D(nn.Module): + def __init__(self, in_channels, out_channels, + kernel_size=3, stride=1, pad=1): + super(ConvBnReLU3D, self).__init__() + self.conv = nn.Conv3d(in_channels, out_channels, + kernel_size, stride=stride, padding=pad, bias=False) + self.bn = nn.BatchNorm3d(out_channels) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + return self.activation(self.bn(self.conv(x))) + + +################################### feature net ###################################### +class FeatureNet(nn.Module): + """ + output 3 levels of features using a FPN structure + """ + + def __init__(self): + super(FeatureNet, self).__init__() + + self.conv0 = nn.Sequential( + ConvBnReLU(3, 8, 3, 1, 1), + ConvBnReLU(8, 8, 3, 1, 1)) + + self.conv1 = nn.Sequential( + ConvBnReLU(8, 16, 5, 2, 2), + ConvBnReLU(16, 16, 3, 1, 1), + ConvBnReLU(16, 16, 3, 1, 1)) + + self.conv2 = nn.Sequential( + ConvBnReLU(16, 32, 5, 2, 2), + ConvBnReLU(32, 32, 3, 1, 1), + ConvBnReLU(32, 32, 3, 1, 1)) + + self.toplayer = nn.Conv2d(32, 32, 1) + self.lat1 = nn.Conv2d(16, 32, 1) + self.lat0 = nn.Conv2d(8, 32, 1) + + # to reduce channel size of the outputs from FPN + self.smooth1 = nn.Conv2d(32, 16, 3, padding=1) + self.smooth0 = nn.Conv2d(32, 8, 3, padding=1) + + def _upsample_add(self, x, y): + return torch.nn.functional.interpolate(x, scale_factor=2, + mode="bilinear", align_corners=True) + y + + def forward(self, x): + # x: (B, 3, H, W) + conv0 = self.conv0(x) # (B, 8, H, W) + conv1 = self.conv1(conv0) # (B, 16, H//2, W//2) + conv2 = self.conv2(conv1) # (B, 32, H//4, W//4) + feat2 = self.toplayer(conv2) # (B, 32, H//4, W//4) + feat1 = self._upsample_add(feat2, self.lat1(conv1)) # (B, 32, H//2, W//2) + feat0 = self._upsample_add(feat1, self.lat0(conv0)) # (B, 32, H, W) + + # reduce output channels + feat1 = self.smooth1(feat1) # (B, 16, H//2, W//2) + feat0 = self.smooth0(feat0) # (B, 8, H, W) + + # feats = {"level_0": feat0, + # "level_1": feat1, + # "level_2": feat2} + + return [feat2, feat1, feat0] # coarser to finer features + + +class BasicSparseConvolutionBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d(inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride), + spnn.BatchNorm(outc), + spnn.ReLU(True)) + + def forward(self, x): + out = self.net(x) + return out + + +class BasicSparseDeconvolutionBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d(inc, + outc, + kernel_size=ks, + stride=stride, + transposed=True), + spnn.BatchNorm(outc), + spnn.ReLU(True)) + + def forward(self, x): + return self.net(x) + + +class SparseResidualBlock(nn.Module): + def __init__(self, inc, outc, ks=3, stride=1, dilation=1): + super().__init__() + self.net = nn.Sequential( + spnn.Conv3d(inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride), spnn.BatchNorm(outc), + spnn.ReLU(True), + spnn.Conv3d(outc, + outc, + kernel_size=ks, + dilation=dilation, + stride=1), spnn.BatchNorm(outc)) + + self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ + nn.Sequential( + spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), + spnn.BatchNorm(outc) + ) + + self.relu = spnn.ReLU(True) + + def forward(self, x): + out = self.relu(self.net(x) + self.downsample(x)) + return out + + +class SPVCNN(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + self.dropout = kwargs['dropout'] + + cr = kwargs.get('cr', 1.0) + cs = [32, 64, 128, 96, 96] + cs = [int(cr * x) for x in cs] + + if 'pres' in kwargs and 'vres' in kwargs: + self.pres = kwargs['pres'] + self.vres = kwargs['vres'] + + self.stem = nn.Sequential( + spnn.Conv3d(kwargs['in_channels'], cs[0], kernel_size=3, stride=1), + spnn.BatchNorm(cs[0]), spnn.ReLU(True) + ) + + self.stage1 = nn.Sequential( + BasicSparseConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), + SparseResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), + SparseResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), + ) + + self.stage2 = nn.Sequential( + BasicSparseConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), + SparseResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), + SparseResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), + ) + + self.up1 = nn.ModuleList([ + BasicSparseDeconvolutionBlock(cs[2], cs[3], ks=2, stride=2), + nn.Sequential( + SparseResidualBlock(cs[3] + cs[1], cs[3], ks=3, stride=1, + dilation=1), + SparseResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), + ) + ]) + + self.up2 = nn.ModuleList([ + BasicSparseDeconvolutionBlock(cs[3], cs[4], ks=2, stride=2), + nn.Sequential( + SparseResidualBlock(cs[4] + cs[0], cs[4], ks=3, stride=1, + dilation=1), + SparseResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), + ) + ]) + + self.point_transforms = nn.ModuleList([ + nn.Sequential( + nn.Linear(cs[0], cs[2]), + nn.BatchNorm1d(cs[2]), + nn.ReLU(True), + ), + nn.Sequential( + nn.Linear(cs[2], cs[4]), + nn.BatchNorm1d(cs[4]), + nn.ReLU(True), + ) + ]) + + self.weight_initialization() + + if self.dropout: + self.dropout = nn.Dropout(0.3, True) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def forward(self, z): + # x: SparseTensor z: PointTensor + x0 = initial_voxelize(z, self.pres, self.vres) + + x0 = self.stem(x0) + z0 = voxel_to_point(x0, z, nearest=False) + z0.F = z0.F + + x1 = point_to_voxel(x0, z0) + x1 = self.stage1(x1) + x2 = self.stage2(x1) + z1 = voxel_to_point(x2, z0) + z1.F = z1.F + self.point_transforms[0](z0.F) + + y3 = point_to_voxel(x2, z1) + if self.dropout: + y3.F = self.dropout(y3.F) + y3 = self.up1[0](y3) + y3 = torchsparse.cat([y3, x1]) + y3 = self.up1[1](y3) + + y4 = self.up2[0](y3) + y4 = torchsparse.cat([y4, x0]) + y4 = self.up2[1](y4) + z3 = voxel_to_point(y4, z1) + z3.F = z3.F + self.point_transforms[1](z1.F) + + return z3.F + + +class SparseCostRegNet(nn.Module): + """ + Sparse cost regularization network; + require sparse tensors as input + """ + + def __init__(self, d_in, d_out=8): + super(SparseCostRegNet, self).__init__() + self.d_in = d_in + self.d_out = d_out + + self.conv0 = BasicSparseConvolutionBlock(d_in, d_out) + + self.conv1 = BasicSparseConvolutionBlock(d_out, 16, stride=2) + self.conv2 = BasicSparseConvolutionBlock(16, 16) + + self.conv3 = BasicSparseConvolutionBlock(16, 32, stride=2) + self.conv4 = BasicSparseConvolutionBlock(32, 32) + + self.conv5 = BasicSparseConvolutionBlock(32, 64, stride=2) + self.conv6 = BasicSparseConvolutionBlock(64, 64) + + self.conv7 = BasicSparseDeconvolutionBlock(64, 32, ks=3, stride=2) + + self.conv9 = BasicSparseDeconvolutionBlock(32, 16, ks=3, stride=2) + + self.conv11 = BasicSparseDeconvolutionBlock(16, d_out, ks=3, stride=2) + + def forward(self, x): + """ + + :param x: sparse tensor + :return: sparse tensor + """ + conv0 = self.conv0(x) + conv2 = self.conv2(self.conv1(conv0)) + conv4 = self.conv4(self.conv3(conv2)) + + x = self.conv6(self.conv5(conv4)) + x = conv4 + self.conv7(x) + del conv4 + x = conv2 + self.conv9(x) + del conv2 + x = conv0 + self.conv11(x) + del conv0 + return x.F + + +class SConv3d(nn.Module): + def __init__(self, inc, outc, pres, vres, ks=3, stride=1, dilation=1): + super().__init__() + self.net = spnn.Conv3d(inc, + outc, + kernel_size=ks, + dilation=dilation, + stride=stride) + self.point_transforms = nn.Sequential( + nn.Linear(inc, outc), + ) + self.pres = pres + self.vres = vres + + def forward(self, z): + x = initial_voxelize(z, self.pres, self.vres) + x = self.net(x) + out = voxel_to_point(x, z, nearest=False) + out.F = out.F + self.point_transforms(z.F) + return out diff --git a/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py b/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32f5b92ae5ef4bf9836b1e4c1dc17eaf3f7c93f9 --- /dev/null +++ b/SparseNeuS_demo_v1/tsparse/torchsparse_utils.py @@ -0,0 +1,137 @@ +""" +Copied from: +https://github.com/mit-han-lab/spvnas/blob/b24f50379ed888d3a0e784508a809d4e92e820c0/core/models/utils.py +""" +import torch +import torchsparse.nn.functional as F +from torchsparse import PointTensor, SparseTensor +from torchsparse.nn.utils import get_kernel_offsets + +import numpy as np + +# __all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point'] + + +# z: PointTensor +# return: SparseTensor +def initial_voxelize(z, init_res, after_res): + new_float_coord = torch.cat( + [(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1) + + pc_hash = F.sphash(torch.floor(new_float_coord).int()) + sparse_hash = torch.unique(pc_hash) + idx_query = F.sphashquery(pc_hash, sparse_hash) + counts = F.spcount(idx_query.int(), len(sparse_hash)) + + inserted_coords = F.spvoxelize(torch.floor(new_float_coord), idx_query, + counts) + inserted_coords = torch.round(inserted_coords).int() + inserted_feat = F.spvoxelize(z.F, idx_query, counts) + + new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) + new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords) + z.additional_features['idx_query'][1] = idx_query + z.additional_features['counts'][1] = counts + z.C = new_float_coord + + return new_tensor + + +# x: SparseTensor, z: PointTensor +# return: SparseTensor +def point_to_voxel(x, z): + if z.additional_features is None or z.additional_features.get('idx_query') is None \ + or z.additional_features['idx_query'].get(x.s) is None: + # pc_hash = hash_gpu(torch.floor(z.C).int()) + pc_hash = F.sphash( + torch.cat([ + torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], + z.C[:, -1].int().view(-1, 1) + ], 1)) + sparse_hash = F.sphash(x.C) + idx_query = F.sphashquery(pc_hash, sparse_hash) + counts = F.spcount(idx_query.int(), x.C.shape[0]) + z.additional_features['idx_query'][x.s] = idx_query + z.additional_features['counts'][x.s] = counts + else: + idx_query = z.additional_features['idx_query'][x.s] + counts = z.additional_features['counts'][x.s] + + inserted_feat = F.spvoxelize(z.F, idx_query, counts) + new_tensor = SparseTensor(inserted_feat, x.C, x.s) + new_tensor.cmaps = x.cmaps + new_tensor.kmaps = x.kmaps + + return new_tensor + + +# x: SparseTensor, z: PointTensor +# return: PointTensor +def voxel_to_point(x, z, nearest=False): + if z.idx_query is None or z.weights is None or z.idx_query.get( + x.s) is None or z.weights.get(x.s) is None: + off = get_kernel_offsets(2, x.s, 1, device=z.F.device) + # old_hash = kernel_hash_gpu(torch.floor(z.C).int(), off) + old_hash = F.sphash( + torch.cat([ + torch.floor(z.C[:, :3] / x.s[0]).int() * x.s[0], + z.C[:, -1].int().view(-1, 1) + ], 1), off) + mm = x.C.to(z.F.device) + pc_hash = F.sphash(x.C.to(z.F.device)) + idx_query = F.sphashquery(old_hash, pc_hash) + weights = F.calc_ti_weights(z.C, idx_query, + scale=x.s[0]).transpose(0, 1).contiguous() + idx_query = idx_query.transpose(0, 1).contiguous() + if nearest: + weights[:, 1:] = 0. + idx_query[:, 1:] = -1 + new_feat = F.spdevoxelize(x.F, idx_query, weights) + new_tensor = PointTensor(new_feat, + z.C, + idx_query=z.idx_query, + weights=z.weights) + new_tensor.additional_features = z.additional_features + new_tensor.idx_query[x.s] = idx_query + new_tensor.weights[x.s] = weights + z.idx_query[x.s] = idx_query + z.weights[x.s] = weights + + else: + new_feat = F.spdevoxelize(x.F, z.idx_query.get(x.s), + z.weights.get(x.s)) # - sparse trilinear interpoltation operation + new_tensor = PointTensor(new_feat, + z.C, + idx_query=z.idx_query, + weights=z.weights) + new_tensor.additional_features = z.additional_features + + return new_tensor + + +def sparse_to_dense_torch_batch(locs, values, dim, default_val): + dense = torch.full([dim[0], dim[1], dim[2], dim[3]], float(default_val), device=locs.device) + dense[locs[:, 0], locs[:, 1], locs[:, 2], locs[:, 3]] = values + return dense + + +def sparse_to_dense_torch(locs, values, dim, default_val, device): + dense = torch.full([dim[0], dim[1], dim[2]], float(default_val), device=device) + if locs.shape[0] > 0: + dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values + return dense + + +def sparse_to_dense_channel(locs, values, dim, c, default_val, device): + locs = locs.to(torch.int64) + dense = torch.full([dim[0], dim[1], dim[2], c], float(default_val), device=device) + if locs.shape[0] > 0: + dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values + return dense + + +def sparse_to_dense_np(locs, values, dim, default_val): + dense = np.zeros([dim[0], dim[1], dim[2]], dtype=values.dtype) + dense.fill(default_val) + dense[locs[:, 0], locs[:, 1], locs[:, 2]] = values + return dense diff --git a/SparseNeuS_demo_v1/utils/__init__.py b/SparseNeuS_demo_v1/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SparseNeuS_demo_v1/utils/misc_utils.py b/SparseNeuS_demo_v1/utils/misc_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..85e80cf4e2bcf8bed0086e2b6c8a3bf3da40a056 --- /dev/null +++ b/SparseNeuS_demo_v1/utils/misc_utils.py @@ -0,0 +1,219 @@ +import os, torch, cv2, re +import numpy as np + +from PIL import Image +import torch.nn.functional as F +import torchvision.transforms as T + +# Misc +img2mse = lambda x, y: torch.mean((x - y) ** 2) +mse2psnr = lambda x: -10. * torch.log(x) / torch.log(torch.Tensor([10.])) +to8b = lambda x: (255 * np.clip(x, 0, 1)).astype(np.uint8) +mse2psnr2 = lambda x: -10. * np.log(x) / np.log(10.) + + +def get_psnr(imgs_pred, imgs_gt): + psnrs = [] + for (img, tar) in zip(imgs_pred, imgs_gt): + psnrs.append(mse2psnr2(np.mean((img - tar.cpu().numpy()) ** 2))) + return np.array(psnrs) + + +def init_log(log, keys): + for key in keys: + log[key] = torch.tensor([0.0], dtype=float) + return log + + +def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET): + """ + depth: (H, W) + """ + + x = np.nan_to_num(depth) # change nan to 0 + if minmax is None: + mi = np.min(x[x > 0]) # get minimum positive depth (ignore background) + ma = np.max(x) + else: + mi, ma = minmax + + x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 + x = (255 * x).astype(np.uint8) + x_ = cv2.applyColorMap(x, cmap) + return x_, [mi, ma] + + +def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET): + """ + depth: (H, W) + """ + if type(depth) is not np.ndarray: + depth = depth.cpu().numpy() + + x = np.nan_to_num(depth) # change nan to 0 + if minmax is None: + mi = np.min(x[x > 0]) # get minimum positive depth (ignore background) + ma = np.max(x) + else: + mi, ma = minmax + + x = (x - mi) / (ma - mi + 1e-8) # normalize to 0~1 + x = (255 * x).astype(np.uint8) + x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) + x_ = T.ToTensor()(x_) # (3, H, W) + return x_, [mi, ma] + + +def abs_error_numpy(depth_pred, depth_gt, mask): + depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] + return np.abs(depth_pred - depth_gt) + + +def abs_error(depth_pred, depth_gt, mask): + depth_pred, depth_gt = depth_pred[mask], depth_gt[mask] + err = depth_pred - depth_gt + return np.abs(err) if type(depth_pred) is np.ndarray else err.abs() + + +def acc_threshold(depth_pred, depth_gt, mask, threshold): + """ + computes the percentage of pixels whose depth error is less than @threshold + """ + errors = abs_error(depth_pred, depth_gt, mask) + acc_mask = errors < threshold + return acc_mask.astype('float') if type(depth_pred) is np.ndarray else acc_mask.float() + + +def to_tensor_cuda(data, device, filter): + for item in data.keys(): + + if item in filter: + continue + + if type(data[item]) is np.ndarray: + data[item] = torch.tensor(data[item], dtype=torch.float32, device=device) + else: + data[item] = data[item].float().to(device) + return data + + +def to_cuda(data, device, filter): + for item in data.keys(): + if item in filter: + continue + + data[item] = data[item].float().to(device) + return data + + +def tensor_unsqueeze(data, filter): + for item in data.keys(): + if item in filter: + continue + + data[item] = data[item][None] + return data + + +def filter_keys(dict): + dict.pop('N_samples') + if 'ndc' in dict.keys(): + dict.pop('ndc') + if 'lindisp' in dict.keys(): + dict.pop('lindisp') + return dict + + +def sub_selete_data(data_batch, device, idx, filtKey=[], + filtIndex=['view_ids_all', 'c2ws_all', 'scan', 'bbox', 'w2ref', 'ref2w', 'light_id', 'ckpt', + 'idx']): + data_sub_selete = {} + for item in data_batch.keys(): + data_sub_selete[item] = data_batch[item][:, idx].float() if ( + item not in filtIndex and torch.is_tensor(item) and item.dim() > 2) else data_batch[item].float() + if not data_sub_selete[item].is_cuda: + data_sub_selete[item] = data_sub_selete[item].to(device) + return data_sub_selete + + +def detach_data(dictionary): + dictionary_new = {} + for key in dictionary.keys(): + dictionary_new[key] = dictionary[key].detach().clone() + return dictionary_new + + +def read_pfm(filename): + file = open(filename, 'rb') + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().decode('utf-8').rstrip() + if header == 'PF': + color = True + elif header == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + file.close() + return data, scale + + +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR + + +# from warmup_scheduler import GradualWarmupScheduler +def get_scheduler(hparams, optimizer): + eps = 1e-8 + if hparams.lr_scheduler == 'steplr': + scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, + gamma=hparams.decay_gamma) + elif hparams.lr_scheduler == 'cosine': + scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) + + else: + raise ValueError('scheduler not recognized!') + + # if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']: + # scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier, + # total_epoch=hparams.warmup_epochs, after_scheduler=scheduler) + return scheduler + + +#### pairing #### +def get_nearest_pose_ids(tar_pose, ref_poses, num_select): + ''' + Args: + tar_pose: target pose [N, 4, 4] + ref_poses: reference poses [M, 4, 4] + num_select: the number of nearest views to select + Returns: the selected indices + ''' + + dists = np.linalg.norm(tar_pose[:, None, :3, 3] - ref_poses[None, :, :3, 3], axis=-1) + + sorted_ids = np.argsort(dists, axis=-1) + selected_ids = sorted_ids[:, :num_select] + return selected_ids diff --git a/configs/sd-objaverse-finetune-c_concat-256.yaml b/configs/sd-objaverse-finetune-c_concat-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..488dafa27fcd632215ab869f9ab15c8ed452b66a --- /dev/null +++ b/configs/sd-objaverse-finetune-c_concat-256.yaml @@ -0,0 +1,117 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.LatentDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "image_target" + cond_stage_key: "image_cond" + image_size: 32 + channels: 4 + cond_stage_trainable: false # Note: different from the one we trained before + conditioning_key: hybrid + monitor: val/loss_simple_ema + scale_factor: 0.18215 + + scheduler_config: # 10000 warmup steps + target: ldm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [ 100 ] + cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases + f_start: [ 1.e-6 ] + f_max: [ 1. ] + f_min: [ 1. ] + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + image_size: 32 # unused + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_heads: 8 + use_spatial_transformer: True + transformer_depth: 1 + context_dim: 768 + use_checkpoint: True + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder + + +data: + target: ldm.data.simple.ObjaverseDataModuleFromConfig + params: + root_dir: 'views_whole_sphere' + batch_size: 192 + num_workers: 16 + total_view: 4 + train: + validation: False + image_transforms: + size: 256 + + validation: + validation: True + image_transforms: + size: 256 + + +lightning: + find_unused_parameters: false + metrics_over_trainsteps_checkpoint: True + modelcheckpoint: + params: + every_n_train_steps: 5000 + callbacks: + image_logger: + target: main.ImageLogger + params: + batch_frequency: 500 + max_images: 32 + increase_log_steps: False + log_first_step: True + log_images_kwargs: + use_ema_scope: False + inpaint: False + plot_progressive_rows: False + plot_diffusion_rows: False + N: 32 + unconditional_guidance_scale: 3.0 + unconditional_guidance_label: [""] + + trainer: + benchmark: True + val_check_interval: 5000000 # really sorry + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/data/base.py b/ldm/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..742794e631081bbfa7c44f3df6f83373ca5c15c1 --- /dev/null +++ b/ldm/data/base.py @@ -0,0 +1,40 @@ +import os +import numpy as np +from abc import abstractmethod +from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset + + +class Txt2ImgIterableBaseDataset(IterableDataset): + ''' + Define an interface to make the IterableDatasets for text2img data chainable + ''' + def __init__(self, num_records=0, valid_ids=None, size=256): + super().__init__() + self.num_records = num_records + self.valid_ids = valid_ids + self.sample_ids = valid_ids + self.size = size + + print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') + + def __len__(self): + return self.num_records + + @abstractmethod + def __iter__(self): + pass + + +class PRNGMixin(object): + """ + Adds a prng property which is a numpy RandomState which gets + reinitialized whenever the pid changes to avoid synchronized sampling + behavior when used in conjunction with multiprocessing. + """ + @property + def prng(self): + currentpid = os.getpid() + if getattr(self, "_initpid", None) != currentpid: + self._initpid = currentpid + self._prng = np.random.RandomState() + return self._prng diff --git a/ldm/data/coco.py b/ldm/data/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5e27e6ec6a51932f67b83dd88533cb39631e26 --- /dev/null +++ b/ldm/data/coco.py @@ -0,0 +1,253 @@ +import os +import json +import albumentations +import numpy as np +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset +from abc import abstractmethod + + +class CocoBase(Dataset): + """needed for (image, caption, segmentation) pairs""" + def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, + crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None): + self.split = self.get_split() + self.size = size + if crop_size is None: + self.crop_size = size + else: + self.crop_size = crop_size + + assert crop_type in [None, 'random', 'center'] + self.crop_type = crop_type + self.use_segmenation = use_segmentation + self.onehot = onehot_segmentation # return segmentation as rgb or one hot + self.stuffthing = use_stuffthing # include thing in segmentation + if self.onehot and not self.stuffthing: + raise NotImplemented("One hot mode is only supported for the " + "stuffthings version because labels are stored " + "a bit different.") + + data_json = datajson + with open(data_json) as json_file: + self.json_data = json.load(json_file) + self.img_id_to_captions = dict() + self.img_id_to_filepath = dict() + self.img_id_to_segmentation_filepath = dict() + + assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json", + f"captions_val{self.year()}.json"] + # TODO currently hardcoded paths, would be better to follow logic in + # cocstuff pixelmaps + if self.use_segmenation: + if self.stuffthing: + self.segmentation_prefix = ( + f"data/cocostuffthings/val{self.year()}" if + data_json.endswith(f"captions_val{self.year()}.json") else + f"data/cocostuffthings/train{self.year()}") + else: + self.segmentation_prefix = ( + f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if + data_json.endswith(f"captions_val{self.year()}.json") else + f"data/coco/annotations/stuff_train{self.year()}_pixelmaps") + + imagedirs = self.json_data["images"] + self.labels = {"image_ids": list()} + for imgdir in tqdm(imagedirs, desc="ImgToPath"): + self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) + self.img_id_to_captions[imgdir["id"]] = list() + pngfilename = imgdir["file_name"].replace("jpg", "png") + if self.use_segmenation: + self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( + self.segmentation_prefix, pngfilename) + if given_files is not None: + if pngfilename in given_files: + self.labels["image_ids"].append(imgdir["id"]) + else: + self.labels["image_ids"].append(imgdir["id"]) + + capdirs = self.json_data["annotations"] + for capdir in tqdm(capdirs, desc="ImgToCaptions"): + # there are in average 5 captions per image + #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) + self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"]) + + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + if self.split=="validation": + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + # default option for train is random crop + if self.crop_type in [None, 'random']: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = albumentations.Compose( + [self.rescaler, self.cropper], + additional_targets={"segmentation": "image"}) + if force_no_crop: + self.rescaler = albumentations.Resize(height=self.size, width=self.size) + self.preprocessor = albumentations.Compose( + [self.rescaler], + additional_targets={"segmentation": "image"}) + + @abstractmethod + def year(self): + raise NotImplementedError() + + def __len__(self): + return len(self.labels["image_ids"]) + + def preprocess_image(self, image_path, segmentation_path=None): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if segmentation_path: + segmentation = Image.open(segmentation_path) + if not self.onehot and not segmentation.mode == "RGB": + segmentation = segmentation.convert("RGB") + segmentation = np.array(segmentation).astype(np.uint8) + if self.onehot: + assert self.stuffthing + # stored in caffe format: unlabeled==255. stuff and thing from + # 0-181. to be compatible with the labels in + # https://github.com/nightrome/cocostuff/blob/master/labels.txt + # we shift stuffthing one to the right and put unlabeled in zero + # as long as segmentation is uint8 shifting to right handles the + # latter too + assert segmentation.dtype == np.uint8 + segmentation = segmentation + 1 + + processed = self.preprocessor(image=image, segmentation=segmentation) + + image, segmentation = processed["image"], processed["segmentation"] + else: + image = self.preprocessor(image=image,)['image'] + + image = (image / 127.5 - 1.0).astype(np.float32) + if segmentation_path: + if self.onehot: + assert segmentation.dtype == np.uint8 + # make it one hot + n_labels = 183 + flatseg = np.ravel(segmentation) + onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) + onehot[np.arange(flatseg.size), flatseg] = True + onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) + segmentation = onehot + else: + segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) + return image, segmentation + else: + return image + + def __getitem__(self, i): + img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] + if self.use_segmenation: + seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] + image, segmentation = self.preprocess_image(img_path, seg_path) + else: + image = self.preprocess_image(img_path) + captions = self.img_id_to_captions[self.labels["image_ids"][i]] + # randomly draw one of all available captions per image + caption = captions[np.random.randint(0, len(captions))] + example = {"image": image, + #"caption": [str(caption[0])], + "caption": caption, + "img_path": img_path, + "filename_": img_path.split(os.sep)[-1] + } + if self.use_segmenation: + example.update({"seg_path": seg_path, 'segmentation': segmentation}) + return example + + +class CocoImagesAndCaptionsTrain2017(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,): + super().__init__(size=size, + dataroot="data/coco/train2017", + datajson="data/coco/annotations/captions_train2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) + + def get_split(self): + return "train" + + def year(self): + return '2017' + + +class CocoImagesAndCaptionsValidation2017(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None): + super().__init__(size=size, + dataroot="data/coco/val2017", + datajson="data/coco/annotations/captions_val2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files) + + def get_split(self): + return "validation" + + def year(self): + return '2017' + + + +class CocoImagesAndCaptionsTrain2014(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'): + super().__init__(size=size, + dataroot="data/coco/train2014", + datajson="data/coco/annotations2014/annotations/captions_train2014.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + use_segmentation=False, + crop_type=crop_type) + + def get_split(self): + return "train" + + def year(self): + return '2014' + +class CocoImagesAndCaptionsValidation2014(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None,crop_type='center',**kwargs): + super().__init__(size=size, + dataroot="data/coco/val2014", + datajson="data/coco/annotations2014/annotations/captions_val2014.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files, + use_segmentation=False, + crop_type=crop_type) + + def get_split(self): + return "validation" + + def year(self): + return '2014' + +if __name__ == '__main__': + with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file: + json_data = json.load(json_file) + capdirs = json_data["annotations"] + import pudb; pudb.set_trace() + #d2 = CocoImagesAndCaptionsTrain2014(size=256) + d2 = CocoImagesAndCaptionsValidation2014(size=256) + print("constructed dataset.") + print(f"length of {d2.__class__.__name__}: {len(d2)}") + + ex2 = d2[0] + # ex3 = d3[0] + # print(ex1["image"].shape) + print(ex2["image"].shape) + # print(ex3["image"].shape) + # print(ex1["segmentation"].shape) + print(ex2["caption"].__class__.__name__) diff --git a/ldm/data/dummy.py b/ldm/data/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..3b74a77fe8954686e480d28aaed19e52d3e3c9b7 --- /dev/null +++ b/ldm/data/dummy.py @@ -0,0 +1,34 @@ +import numpy as np +import random +import string +from torch.utils.data import Dataset, Subset + +class DummyData(Dataset): + def __init__(self, length, size): + self.length = length + self.size = size + + def __len__(self): + return self.length + + def __getitem__(self, i): + x = np.random.randn(*self.size) + letters = string.ascii_lowercase + y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) + return {"jpg": x, "txt": y} + + +class DummyDataWithEmbeddings(Dataset): + def __init__(self, length, size, emb_size): + self.length = length + self.size = size + self.emb_size = emb_size + + def __len__(self): + return self.length + + def __getitem__(self, i): + x = np.random.randn(*self.size) + y = np.random.randn(*self.emb_size).astype(np.float32) + return {"jpg": x, "txt": y} + diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..66231964a685cc875243018461a6aaa63a96dbf0 --- /dev/null +++ b/ldm/data/imagenet.py @@ -0,0 +1,394 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from taming.data.imagenet import ImagePaths + +from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + self.size = retrieve(self.config, "size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + example["caption"] = example["human_label"] # dummy caption + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/ldm/data/inpainting/__init__.py b/ldm/data/inpainting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/data/inpainting/synthetic_mask.py b/ldm/data/inpainting/synthetic_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..bb4c38f3a79b8eb40553469d6f0656ad2f54609a --- /dev/null +++ b/ldm/data/inpainting/synthetic_mask.py @@ -0,0 +1,166 @@ +from PIL import Image, ImageDraw +import numpy as np + +settings = { + "256narrow": { + "p_irr": 1, + "min_n_irr": 4, + "max_n_irr": 50, + "max_l_irr": 40, + "max_w_irr": 10, + "min_n_box": None, + "max_n_box": None, + "min_s_box": None, + "max_s_box": None, + "marg": None, + }, + "256train": { + "p_irr": 0.5, + "min_n_irr": 1, + "max_n_irr": 5, + "max_l_irr": 200, + "max_w_irr": 100, + "min_n_box": 1, + "max_n_box": 4, + "min_s_box": 30, + "max_s_box": 150, + "marg": 10, + }, + "512train": { # TODO: experimental + "p_irr": 0.5, + "min_n_irr": 1, + "max_n_irr": 5, + "max_l_irr": 450, + "max_w_irr": 250, + "min_n_box": 1, + "max_n_box": 4, + "min_s_box": 30, + "max_s_box": 300, + "marg": 10, + }, + "512train-large": { # TODO: experimental + "p_irr": 0.5, + "min_n_irr": 1, + "max_n_irr": 5, + "max_l_irr": 450, + "max_w_irr": 400, + "min_n_box": 1, + "max_n_box": 4, + "min_s_box": 75, + "max_s_box": 450, + "marg": 10, + }, +} + + +def gen_segment_mask(mask, start, end, brush_width): + mask = mask > 0 + mask = (255 * mask).astype(np.uint8) + mask = Image.fromarray(mask) + draw = ImageDraw.Draw(mask) + draw.line([start, end], fill=255, width=brush_width, joint="curve") + mask = np.array(mask) / 255 + return mask + + +def gen_box_mask(mask, masked): + x_0, y_0, w, h = masked + mask[y_0:y_0 + h, x_0:x_0 + w] = 1 + return mask + + +def gen_round_mask(mask, masked, radius): + x_0, y_0, w, h = masked + xy = [(x_0, y_0), (x_0 + w, y_0 + w)] + + mask = mask > 0 + mask = (255 * mask).astype(np.uint8) + mask = Image.fromarray(mask) + draw = ImageDraw.Draw(mask) + draw.rounded_rectangle(xy, radius=radius, fill=255) + mask = np.array(mask) / 255 + return mask + + +def gen_large_mask(prng, img_h, img_w, + marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr, + min_n_box, max_n_box, min_s_box, max_s_box): + """ + img_h: int, an image height + img_w: int, an image width + marg: int, a margin for a box starting coordinate + p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask + + min_n_irr: int, min number of segments + max_n_irr: int, max number of segments + max_l_irr: max length of a segment in polygonal chain + max_w_irr: max width of a segment in polygonal chain + + min_n_box: int, min bound for the number of box primitives + max_n_box: int, max bound for the number of box primitives + min_s_box: int, min length of a box side + max_s_box: int, max length of a box side + """ + + mask = np.zeros((img_h, img_w)) + uniform = prng.randint + + if np.random.uniform(0, 1) < p_irr: # generate polygonal chain + n = uniform(min_n_irr, max_n_irr) # sample number of segments + + for _ in range(n): + y = uniform(0, img_h) # sample a starting point + x = uniform(0, img_w) + + a = uniform(0, 360) # sample angle + l = uniform(10, max_l_irr) # sample segment length + w = uniform(5, max_w_irr) # sample a segment width + + # draw segment starting from (x,y) to (x_,y_) using brush of width w + x_ = x + l * np.sin(a) + y_ = y + l * np.cos(a) + + mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w) + x, y = x_, y_ + else: # generate Box masks + n = uniform(min_n_box, max_n_box) # sample number of rectangles + + for _ in range(n): + h = uniform(min_s_box, max_s_box) # sample box shape + w = uniform(min_s_box, max_s_box) + + x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box + y_0 = uniform(marg, img_h - marg - h) + + if np.random.uniform(0, 1) < 0.5: + mask = gen_box_mask(mask, masked=(x_0, y_0, w, h)) + else: + r = uniform(0, 60) # sample radius + mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r) + return mask + + +make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"]) +make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"]) +make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"]) +make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"]) + + +MASK_MODES = { + "256train": make_lama_mask, + "256narrow": make_narrow_lama_mask, + "512train": make_512_lama_mask, + "512train-large": make_512_lama_mask_large +} + +if __name__ == "__main__": + import sys + + out = sys.argv[1] + + prng = np.random.RandomState(1) + kwargs = settings["256train"] + mask = gen_large_mask(prng, 256, 256, **kwargs) + mask = (255 * mask).astype(np.uint8) + mask = Image.fromarray(mask) + mask.save(out) diff --git a/ldm/data/laion.py b/ldm/data/laion.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb608c1a4cf2b7c0215bdd7c1c81841e3a39b0c --- /dev/null +++ b/ldm/data/laion.py @@ -0,0 +1,537 @@ +import webdataset as wds +import kornia +from PIL import Image +import io +import os +import torchvision +from PIL import Image +import glob +import random +import numpy as np +import pytorch_lightning as pl +from tqdm import tqdm +from omegaconf import OmegaConf +from einops import rearrange +import torch +from webdataset.handlers import warn_and_continue + + +from ldm.util import instantiate_from_config +from ldm.data.inpainting.synthetic_mask import gen_large_mask, MASK_MODES +from ldm.data.base import PRNGMixin + + +class DataWithWings(torch.utils.data.IterableDataset): + def __init__(self, min_size, transform=None, target_transform=None): + self.min_size = min_size + self.transform = transform if transform is not None else nn.Identity() + self.target_transform = target_transform if target_transform is not None else nn.Identity() + self.kv = OnDiskKV(file='/home/ubuntu/laion5B-watermark-safety-ordered', key_format='q', value_format='ee') + self.kv_aesthetic = OnDiskKV(file='/home/ubuntu/laion5B-aesthetic-tags-kv', key_format='q', value_format='e') + self.pwatermark_threshold = 0.8 + self.punsafe_threshold = 0.5 + self.aesthetic_threshold = 5. + self.total_samples = 0 + self.samples = 0 + location = 'pipe:aws s3 cp --quiet s3://s-datasets/laion5b/laion2B-data/{000000..231349}.tar -' + + self.inner_dataset = wds.DataPipeline( + wds.ResampledShards(location), + wds.tarfile_to_samples(handler=wds.warn_and_continue), + wds.shuffle(1000, handler=wds.warn_and_continue), + wds.decode('pilrgb', handler=wds.warn_and_continue), + wds.map(self._add_tags, handler=wds.ignore_and_continue), + wds.select(self._filter_predicate), + wds.map_dict(jpg=self.transform, txt=self.target_transform, punsafe=self._punsafe_to_class, handler=wds.warn_and_continue), + wds.to_tuple('jpg', 'txt', 'punsafe', handler=wds.warn_and_continue), + ) + + @staticmethod + def _compute_hash(url, text): + if url is None: + url = '' + if text is None: + text = '' + total = (url + text).encode('utf-8') + return mmh3.hash64(total)[0] + + def _add_tags(self, x): + hsh = self._compute_hash(x['json']['url'], x['txt']) + pwatermark, punsafe = self.kv[hsh] + aesthetic = self.kv_aesthetic[hsh][0] + return {**x, 'pwatermark': pwatermark, 'punsafe': punsafe, 'aesthetic': aesthetic} + + def _punsafe_to_class(self, punsafe): + return torch.tensor(punsafe >= self.punsafe_threshold).long() + + def _filter_predicate(self, x): + try: + return x['pwatermark'] < self.pwatermark_threshold and x['aesthetic'] >= self.aesthetic_threshold and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size + except: + return False + + def __iter__(self): + return iter(self.inner_dataset) + + +def dict_collation_fn(samples, combine_tensors=True, combine_scalars=True): + """Take a list of samples (as dictionary) and create a batch, preserving the keys. + If `tensors` is True, `ndarray` objects are combined into + tensor batches. + :param dict samples: list of samples + :param bool tensors: whether to turn lists of ndarrays into a single ndarray + :returns: single sample consisting of a batch + :rtype: dict + """ + keys = set.intersection(*[set(sample.keys()) for sample in samples]) + batched = {key: [] for key in keys} + + for s in samples: + [batched[key].append(s[key]) for key in batched] + + result = {} + for key in batched: + if isinstance(batched[key][0], (int, float)): + if combine_scalars: + result[key] = np.array(list(batched[key])) + elif isinstance(batched[key][0], torch.Tensor): + if combine_tensors: + result[key] = torch.stack(list(batched[key])) + elif isinstance(batched[key][0], np.ndarray): + if combine_tensors: + result[key] = np.array(list(batched[key])) + else: + result[key] = list(batched[key]) + return result + + +class WebDataModuleFromConfig(pl.LightningDataModule): + def __init__(self, tar_base, batch_size, train=None, validation=None, + test=None, num_workers=4, multinode=True, min_size=None, + max_pwatermark=1.0, + **kwargs): + super().__init__(self) + print(f'Setting tar base to {tar_base}') + self.tar_base = tar_base + self.batch_size = batch_size + self.num_workers = num_workers + self.train = train + self.validation = validation + self.test = test + self.multinode = multinode + self.min_size = min_size # filter out very small images + self.max_pwatermark = max_pwatermark # filter out watermarked images + + def make_loader(self, dataset_config, train=True): + if 'image_transforms' in dataset_config: + image_transforms = [instantiate_from_config(tt) for tt in dataset_config.image_transforms] + else: + image_transforms = [] + + image_transforms.extend([torchvision.transforms.ToTensor(), + torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = torchvision.transforms.Compose(image_transforms) + + if 'transforms' in dataset_config: + transforms_config = OmegaConf.to_container(dataset_config.transforms) + else: + transforms_config = dict() + + transform_dict = {dkey: load_partial_from_config(transforms_config[dkey]) + if transforms_config[dkey] != 'identity' else identity + for dkey in transforms_config} + img_key = dataset_config.get('image_key', 'jpeg') + transform_dict.update({img_key: image_transforms}) + + if 'postprocess' in dataset_config: + postprocess = instantiate_from_config(dataset_config['postprocess']) + else: + postprocess = None + + shuffle = dataset_config.get('shuffle', 0) + shardshuffle = shuffle > 0 + + nodesplitter = wds.shardlists.split_by_node if self.multinode else wds.shardlists.single_node_only + + if self.tar_base == "__improvedaesthetic__": + print("## Warning, loading the same improved aesthetic dataset " + "for all splits and ignoring shards parameter.") + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -" + else: + tars = os.path.join(self.tar_base, dataset_config.shards) + + dset = wds.WebDataset( + tars, + nodesplitter=nodesplitter, + shardshuffle=shardshuffle, + handler=wds.warn_and_continue).repeat().shuffle(shuffle) + print(f'Loading webdataset with {len(dset.pipeline[0].urls)} shards.') + + dset = (dset + .select(self.filter_keys) + .decode('pil', handler=wds.warn_and_continue) + .select(self.filter_size) + .map_dict(**transform_dict, handler=wds.warn_and_continue) + ) + if postprocess is not None: + dset = dset.map(postprocess) + dset = (dset + .batched(self.batch_size, partial=False, + collation_fn=dict_collation_fn) + ) + + loader = wds.WebLoader(dset, batch_size=None, shuffle=False, + num_workers=self.num_workers) + + return loader + + def filter_size(self, x): + try: + valid = True + if self.min_size is not None and self.min_size > 1: + try: + valid = valid and x['json']['original_width'] >= self.min_size and x['json']['original_height'] >= self.min_size + except Exception: + valid = False + if self.max_pwatermark is not None and self.max_pwatermark < 1.0: + try: + valid = valid and x['json']['pwatermark'] <= self.max_pwatermark + except Exception: + valid = False + return valid + except Exception: + return False + + def filter_keys(self, x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def train_dataloader(self): + return self.make_loader(self.train) + + def val_dataloader(self): + return self.make_loader(self.validation, train=False) + + def test_dataloader(self): + return self.make_loader(self.test, train=False) + + +from ldm.modules.image_degradation import degradation_fn_bsr_light +import cv2 + +class AddLR(object): + def __init__(self, factor, output_size, initial_size=None, image_key="jpg"): + self.factor = factor + self.output_size = output_size + self.image_key = image_key + self.initial_size = initial_size + + def pt2np(self, x): + x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy() + return x + + def np2pt(self, x): + x = torch.from_numpy(x)/127.5-1.0 + return x + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = self.pt2np(sample[self.image_key]) + if self.initial_size is not None: + x = cv2.resize(x, (self.initial_size, self.initial_size), interpolation=2) + x = degradation_fn_bsr_light(x, sf=self.factor)['image'] + x = cv2.resize(x, (self.output_size, self.output_size), interpolation=2) + x = self.np2pt(x) + sample['lr'] = x + return sample + +class AddBW(object): + def __init__(self, image_key="jpg"): + self.image_key = image_key + + def pt2np(self, x): + x = ((x+1.0)*127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy() + return x + + def np2pt(self, x): + x = torch.from_numpy(x)/127.5-1.0 + return x + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = sample[self.image_key] + w = torch.rand(3, device=x.device) + w /= w.sum() + out = torch.einsum('hwc,c->hw', x, w) + + # Keep as 3ch so we can pass to encoder, also we might want to add hints + sample['lr'] = out.unsqueeze(-1).tile(1,1,3) + return sample + +class AddMask(PRNGMixin): + def __init__(self, mode="512train", p_drop=0.): + super().__init__() + assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"' + self.make_mask = MASK_MODES[mode] + self.p_drop = p_drop + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = sample['jpg'] + mask = self.make_mask(self.prng, x.shape[0], x.shape[1]) + if self.prng.choice(2, p=[1 - self.p_drop, self.p_drop]): + mask = np.ones_like(mask) + mask[mask < 0.5] = 0 + mask[mask > 0.5] = 1 + mask = torch.from_numpy(mask[..., None]) + sample['mask'] = mask + sample['masked_image'] = x * (mask < 0.5) + return sample + + +class AddEdge(PRNGMixin): + def __init__(self, mode="512train", mask_edges=True): + super().__init__() + assert mode in list(MASK_MODES.keys()), f'unknown mask generation mode "{mode}"' + self.make_mask = MASK_MODES[mode] + self.n_down_choices = [0] + self.sigma_choices = [1, 2] + self.mask_edges = mask_edges + + @torch.no_grad() + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = sample['jpg'] + + mask = self.make_mask(self.prng, x.shape[0], x.shape[1]) + mask[mask < 0.5] = 0 + mask[mask > 0.5] = 1 + mask = torch.from_numpy(mask[..., None]) + sample['mask'] = mask + + n_down_idx = self.prng.choice(len(self.n_down_choices)) + sigma_idx = self.prng.choice(len(self.sigma_choices)) + + n_choices = len(self.n_down_choices)*len(self.sigma_choices) + raveled_idx = np.ravel_multi_index((n_down_idx, sigma_idx), + (len(self.n_down_choices), len(self.sigma_choices))) + normalized_idx = raveled_idx/max(1, n_choices-1) + + n_down = self.n_down_choices[n_down_idx] + sigma = self.sigma_choices[sigma_idx] + + kernel_size = 4*sigma+1 + kernel_size = (kernel_size, kernel_size) + sigma = (sigma, sigma) + canny = kornia.filters.Canny( + low_threshold=0.1, + high_threshold=0.2, + kernel_size=kernel_size, + sigma=sigma, + hysteresis=True, + ) + y = (x+1.0)/2.0 # in 01 + y = y.unsqueeze(0).permute(0, 3, 1, 2).contiguous() + + # down + for i_down in range(n_down): + size = min(y.shape[-2], y.shape[-1])//2 + y = kornia.geometry.transform.resize(y, size, antialias=True) + + # edge + _, y = canny(y) + + if n_down > 0: + size = x.shape[0], x.shape[1] + y = kornia.geometry.transform.resize(y, size, interpolation="nearest") + + y = y.permute(0, 2, 3, 1)[0].expand(-1, -1, 3).contiguous() + y = y*2.0-1.0 + + if self.mask_edges: + sample['masked_image'] = y * (mask < 0.5) + else: + sample['masked_image'] = y + sample['mask'] = torch.zeros_like(sample['mask']) + + # concat normalized idx + sample['smoothing_strength'] = torch.ones_like(sample['mask'])*normalized_idx + + return sample + + +def example00(): + url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/000000.tar -" + dataset = wds.WebDataset(url) + example = next(iter(dataset)) + for k in example: + print(k, type(example[k])) + + print(example["__key__"]) + for k in ["json", "txt"]: + print(example[k].decode()) + + image = Image.open(io.BytesIO(example["jpg"])) + outdir = "tmp" + os.makedirs(outdir, exist_ok=True) + image.save(os.path.join(outdir, example["__key__"] + ".png")) + + + def load_example(example): + return { + "key": example["__key__"], + "image": Image.open(io.BytesIO(example["jpg"])), + "text": example["txt"].decode(), + } + + + for i, example in tqdm(enumerate(dataset)): + ex = load_example(example) + print(ex["image"].size, ex["text"]) + if i >= 100: + break + + +def example01(): + # the first laion shards contain ~10k examples each + url = "pipe:aws s3 cp s3://s-datasets/laion5b/laion2B-data/{000000..000002}.tar -" + + batch_size = 3 + shuffle_buffer = 10000 + dset = wds.WebDataset( + url, + nodesplitter=wds.shardlists.split_by_node, + shardshuffle=True, + ) + dset = (dset + .shuffle(shuffle_buffer, initial=shuffle_buffer) + .decode('pil', handler=warn_and_continue) + .batched(batch_size, partial=False, + collation_fn=dict_collation_fn) + ) + + num_workers = 2 + loader = wds.WebLoader(dset, batch_size=None, shuffle=False, num_workers=num_workers) + + batch_sizes = list() + keys_per_epoch = list() + for epoch in range(5): + keys = list() + for batch in tqdm(loader): + batch_sizes.append(len(batch["__key__"])) + keys.append(batch["__key__"]) + + for bs in batch_sizes: + assert bs==batch_size + print(f"{len(batch_sizes)} batches of size {batch_size}.") + batch_sizes = list() + + keys_per_epoch.append(keys) + for i_batch in [0, 1, -1]: + print(f"Batch {i_batch} of epoch {epoch}:") + print(keys[i_batch]) + print("next epoch.") + + +def example02(): + from omegaconf import OmegaConf + from torch.utils.data.distributed import DistributedSampler + from torch.utils.data import IterableDataset + from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler + from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator + + #config = OmegaConf.load("configs/stable-diffusion/txt2img-1p4B-multinode-clip-encoder-high-res-512.yaml") + #config = OmegaConf.load("configs/stable-diffusion/txt2img-upscale-clip-encoder-f16-1024.yaml") + config = OmegaConf.load("configs/stable-diffusion/txt2img-v2-clip-encoder-improved_aesthetics-256.yaml") + datamod = WebDataModuleFromConfig(**config["data"]["params"]) + dataloader = datamod.train_dataloader() + + for batch in dataloader: + print(batch.keys()) + print(batch["jpg"].shape) + break + + +def example03(): + # improved aesthetics + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{000000..060207}.tar -" + dataset = wds.WebDataset(tars) + + def filter_keys(x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def filter_size(x): + try: + return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512 + except Exception: + return False + + def filter_watermark(x): + try: + return x['json']['pwatermark'] < 0.5 + except Exception: + return False + + dataset = (dataset + .select(filter_keys) + .decode('pil', handler=wds.warn_and_continue)) + n_save = 20 + n_total = 0 + n_large = 0 + n_large_nowm = 0 + for i, example in enumerate(dataset): + n_total += 1 + if filter_size(example): + n_large += 1 + if filter_watermark(example): + n_large_nowm += 1 + if n_large_nowm < n_save+1: + image = example["jpg"] + image.save(os.path.join("tmp", f"{n_large_nowm-1:06}.png")) + + if i%500 == 0: + print(i) + print(f"Large: {n_large}/{n_total} | {n_large/n_total*100:.2f}%") + if n_large > 0: + print(f"No Watermark: {n_large_nowm}/{n_large} | {n_large_nowm/n_large*100:.2f}%") + + + +def example04(): + # improved aesthetics + for i_shard in range(60208)[::-1]: + print(i_shard) + tars = "pipe:aws s3 cp s3://s-laion/improved-aesthetics-laion-2B-en-subsets/aesthetics_tars/{:06}.tar -".format(i_shard) + dataset = wds.WebDataset(tars) + + def filter_keys(x): + try: + return ("jpg" in x) and ("txt" in x) + except Exception: + return False + + def filter_size(x): + try: + return x['json']['original_width'] >= 512 and x['json']['original_height'] >= 512 + except Exception: + return False + + dataset = (dataset + .select(filter_keys) + .decode('pil', handler=wds.warn_and_continue)) + try: + example = next(iter(dataset)) + except Exception: + print(f"Error @ {i_shard}") + + +if __name__ == "__main__": + #example01() + #example02() + example03() + #example04() diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py new file mode 100644 index 0000000000000000000000000000000000000000..6256e45715ff0b57c53f985594d27cbbbff0e68e --- /dev/null +++ b/ldm/data/lsun.py @@ -0,0 +1,92 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + + +class LSUNBase(Dataset): + def __init__(self, + txt_file, + data_root, + size=None, + interpolation="bicubic", + flip_p=0.5 + ): + self.data_paths = txt_file + self.data_root = data_root + with open(self.data_paths, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + } + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example + + +class LSUNChurchesTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) + + +class LSUNChurchesValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", + flip_p=flip_p, **kwargs) + + +class LSUNBedroomsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) + + +class LSUNBedroomsValidation(LSUNBase): + def __init__(self, flip_p=0.0, **kwargs): + super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", + flip_p=flip_p, **kwargs) + + +class LSUNCatsTrain(LSUNBase): + def __init__(self, **kwargs): + super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) + + +class LSUNCatsValidation(LSUNBase): + def __init__(self, flip_p=0., **kwargs): + super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", + flip_p=flip_p, **kwargs) diff --git a/ldm/data/nerf_like.py b/ldm/data/nerf_like.py new file mode 100644 index 0000000000000000000000000000000000000000..84ef18288db005c72d3b5832144a7bd5cfffe9b2 --- /dev/null +++ b/ldm/data/nerf_like.py @@ -0,0 +1,165 @@ +from torch.utils.data import Dataset +import os +import json +import numpy as np +import torch +import imageio +import math +import cv2 +from torchvision import transforms + +def cartesian_to_spherical(xyz): + ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) + xy = xyz[:,0]**2 + xyz[:,1]**2 + z = np.sqrt(xy + xyz[:,2]**2) + theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down + #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + azimuth = np.arctan2(xyz[:,1], xyz[:,0]) + return np.array([theta, azimuth, z]) + + +def get_T(T_target, T_cond): + theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) + return d_T + +def get_spherical(T_target, T_cond): + theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + d_T = torch.tensor([math.degrees(d_theta.item()), math.degrees(d_azimuth.item()), d_z.item()]) + return d_T + +class RTMV(Dataset): + def __init__(self, root_dir='datasets/RTMV/google_scanned',\ + first_K=64, resolution=256, load_target=False): + self.root_dir = root_dir + self.scene_list = sorted(next(os.walk(root_dir))[1]) + self.resolution = resolution + self.first_K = first_K + self.load_target = load_target + + def __len__(self): + return len(self.scene_list) + + def __getitem__(self, idx): + scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) + with open(os.path.join(scene_dir, 'transforms.json'), "r") as f: + meta = json.load(f) + imgs = [] + poses = [] + for i_img in range(self.first_K): + meta_img = meta['frames'][i_img] + + if i_img == 0 or self.load_target: + img_path = os.path.join(scene_dir, meta_img['file_path']) + img = imageio.imread(img_path) + img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) + imgs.append(img) + + c2w = meta_img['transform_matrix'] + poses.append(c2w) + + imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs + imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) + imgs = imgs * 2 - 1. # convert to stable diffusion range + poses = torch.tensor(np.array(poses).astype(np.float32)) + return imgs, poses + + def blend_rgba(self, img): + img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB + return img + + +class GSO(Dataset): + def __init__(self, root_dir='datasets/GoogleScannedObjects',\ + split='val', first_K=5, resolution=256, load_target=False, name='render_mvs'): + self.root_dir = root_dir + with open(os.path.join(root_dir, '%s.json' % split), "r") as f: + self.scene_list = json.load(f) + self.resolution = resolution + self.first_K = first_K + self.load_target = load_target + self.name = name + + def __len__(self): + return len(self.scene_list) + + def __getitem__(self, idx): + scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) + with open(os.path.join(scene_dir, 'transforms_%s.json' % self.name), "r") as f: + meta = json.load(f) + imgs = [] + poses = [] + for i_img in range(self.first_K): + meta_img = meta['frames'][i_img] + + if i_img == 0 or self.load_target: + img_path = os.path.join(scene_dir, meta_img['file_path']) + img = imageio.imread(img_path) + img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) + imgs.append(img) + + c2w = meta_img['transform_matrix'] + poses.append(c2w) + + imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs + mask = imgs[:, :, :, -1] + imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) + imgs = imgs * 2 - 1. # convert to stable diffusion range + poses = torch.tensor(np.array(poses).astype(np.float32)) + return imgs, poses + + def blend_rgba(self, img): + img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB + return img + +class WILD(Dataset): + def __init__(self, root_dir='data/nerf_wild',\ + first_K=33, resolution=256, load_target=False): + self.root_dir = root_dir + self.scene_list = sorted(next(os.walk(root_dir))[1]) + self.resolution = resolution + self.first_K = first_K + self.load_target = load_target + + def __len__(self): + return len(self.scene_list) + + def __getitem__(self, idx): + scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) + with open(os.path.join(scene_dir, 'transforms_train.json'), "r") as f: + meta = json.load(f) + imgs = [] + poses = [] + for i_img in range(self.first_K): + meta_img = meta['frames'][i_img] + + if i_img == 0 or self.load_target: + img_path = os.path.join(scene_dir, meta_img['file_path']) + img = imageio.imread(img_path + '.png') + img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) + imgs.append(img) + + c2w = meta_img['transform_matrix'] + poses.append(c2w) + + imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs + imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) + imgs = imgs * 2 - 1. # convert to stable diffusion range + poses = torch.tensor(np.array(poses).astype(np.float32)) + return imgs, poses + + def blend_rgba(self, img): + img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB + return img \ No newline at end of file diff --git a/ldm/data/simple.py b/ldm/data/simple.py new file mode 100644 index 0000000000000000000000000000000000000000..a853e2188e4e61cf91c3e1ca0da3e4f0069dbcee --- /dev/null +++ b/ldm/data/simple.py @@ -0,0 +1,526 @@ +from typing import Dict +import webdataset as wds +import numpy as np +from omegaconf import DictConfig, ListConfig +import torch +from torch.utils.data import Dataset +from pathlib import Path +import json +from PIL import Image +from torchvision import transforms +import torchvision +from einops import rearrange +from ldm.util import instantiate_from_config +from datasets import load_dataset +import pytorch_lightning as pl +import copy +import csv +import cv2 +import random +import matplotlib.pyplot as plt +from torch.utils.data import DataLoader +import json +import os, sys +import webdataset as wds +import math +from torch.utils.data.distributed import DistributedSampler + +# Some hacky things to make experimentation easier +def make_transform_multi_folder_data(paths, caption_files=None, **kwargs): + ds = make_multi_folder_data(paths, caption_files, **kwargs) + return TransformDataset(ds) + +def make_nfp_data(base_path): + dirs = list(Path(base_path).glob("*/")) + print(f"Found {len(dirs)} folders") + print(dirs) + tforms = [transforms.Resize(512), transforms.CenterCrop(512)] + datasets = [NfpDataset(x, image_transforms=copy.copy(tforms), default_caption="A view from a train window") for x in dirs] + return torch.utils.data.ConcatDataset(datasets) + + +class VideoDataset(Dataset): + def __init__(self, root_dir, image_transforms, caption_file, offset=8, n=2): + self.root_dir = Path(root_dir) + self.caption_file = caption_file + self.n = n + ext = "mp4" + self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) + self.offset = offset + + if isinstance(image_transforms, ListConfig): + image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + self.tform = image_transforms + with open(self.caption_file) as f: + reader = csv.reader(f) + rows = [row for row in reader] + self.captions = dict(rows) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + for i in range(10): + try: + return self._load_sample(index) + except Exception: + # Not really good enough but... + print("uh oh") + + def _load_sample(self, index): + n = self.n + filename = self.paths[index] + min_frame = 2*self.offset + 2 + vid = cv2.VideoCapture(str(filename)) + max_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT)) + curr_frame_n = random.randint(min_frame, max_frames) + vid.set(cv2.CAP_PROP_POS_FRAMES,curr_frame_n) + _, curr_frame = vid.read() + + prev_frames = [] + for i in range(n): + prev_frame_n = curr_frame_n - (i+1)*self.offset + vid.set(cv2.CAP_PROP_POS_FRAMES,prev_frame_n) + _, prev_frame = vid.read() + prev_frame = self.tform(Image.fromarray(prev_frame[...,::-1])) + prev_frames.append(prev_frame) + + vid.release() + caption = self.captions[filename.name] + data = { + "image": self.tform(Image.fromarray(curr_frame[...,::-1])), + "prev": torch.cat(prev_frames, dim=-1), + "txt": caption + } + return data + +# end hacky things + + +def make_tranforms(image_transforms): + # if isinstance(image_transforms, ListConfig): + # image_transforms = [instantiate_from_config(tt) for tt in image_transforms] + image_transforms = [] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + image_transforms = transforms.Compose(image_transforms) + return image_transforms + + +def make_multi_folder_data(paths, caption_files=None, **kwargs): + """Make a concat dataset from multiple folders + Don't suport captions yet + + If paths is a list, that's ok, if it's a Dict interpret it as: + k=folder v=n_times to repeat that + """ + list_of_paths = [] + if isinstance(paths, (Dict, DictConfig)): + assert caption_files is None, \ + "Caption files not yet supported for repeats" + for folder_path, repeats in paths.items(): + list_of_paths.extend([folder_path]*repeats) + paths = list_of_paths + + if caption_files is not None: + datasets = [FolderData(p, caption_file=c, **kwargs) for (p, c) in zip(paths, caption_files)] + else: + datasets = [FolderData(p, **kwargs) for p in paths] + return torch.utils.data.ConcatDataset(datasets) + + + +class NfpDataset(Dataset): + def __init__(self, + root_dir, + image_transforms=[], + ext="jpg", + default_caption="", + ) -> None: + """assume sequential frames and a deterministic transform""" + + self.root_dir = Path(root_dir) + self.default_caption = default_caption + + self.paths = sorted(list(self.root_dir.rglob(f"*.{ext}"))) + self.tform = make_tranforms(image_transforms) + + def __len__(self): + return len(self.paths) - 1 + + + def __getitem__(self, index): + prev = self.paths[index] + curr = self.paths[index+1] + data = {} + data["image"] = self._load_im(curr) + data["prev"] = self._load_im(prev) + data["txt"] = self.default_caption + return data + + def _load_im(self, filename): + im = Image.open(filename).convert("RGB") + return self.tform(im) + +class ObjaverseDataModuleFromConfig(pl.LightningDataModule): + def __init__(self, root_dir, batch_size, total_view, train=None, validation=None, + test=None, num_workers=4, **kwargs): + super().__init__(self) + self.root_dir = root_dir + self.batch_size = batch_size + self.num_workers = num_workers + self.total_view = total_view + + if train is not None: + dataset_config = train + if validation is not None: + dataset_config = validation + + if 'image_transforms' in dataset_config: + image_transforms = [torchvision.transforms.Resize(dataset_config.image_transforms.size)] + else: + image_transforms = [] + image_transforms.extend([transforms.ToTensor(), + transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c'))]) + self.image_transforms = torchvision.transforms.Compose(image_transforms) + + + def train_dataloader(self): + dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=False, \ + image_transforms=self.image_transforms) + sampler = DistributedSampler(dataset) + return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) + + def val_dataloader(self): + dataset = ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=True, \ + image_transforms=self.image_transforms) + sampler = DistributedSampler(dataset) + return wds.WebLoader(dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + def test_dataloader(self): + return wds.WebLoader(ObjaverseData(root_dir=self.root_dir, total_view=self.total_view, validation=self.validation),\ + batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) + + +class ObjaverseData(Dataset): + def __init__(self, + root_dir='.objaverse/hf-objaverse-v1/views', + image_transforms=[], + ext="png", + default_trans=torch.zeros(3), + postprocess=None, + return_paths=False, + total_view=4, + validation=False + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_trans = default_trans + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + self.total_view = total_view + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + with open(os.path.join(root_dir, 'valid_paths.json')) as f: + self.paths = json.load(f) + + total_objects = len(self.paths) + if validation: + self.paths = self.paths[math.floor(total_objects / 100. * 99.):] # used last 1% as validation + else: + self.paths = self.paths[:math.floor(total_objects / 100. * 99.)] # used first 99% as training + print('============= length of dataset %d =============' % len(self.paths)) + self.tform = image_transforms + + def __len__(self): + return len(self.paths) + + def cartesian_to_spherical(self, xyz): + ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) + xy = xyz[:,0]**2 + xyz[:,1]**2 + z = np.sqrt(xy + xyz[:,2]**2) + theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down + #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up + azimuth = np.arctan2(xyz[:,1], xyz[:,0]) + return np.array([theta, azimuth, z]) + + def get_T(self, target_RT, cond_RT): + R, T = target_RT[:3, :3], target_RT[:, -1] + T_target = -R.T @ T + + R, T = cond_RT[:3, :3], cond_RT[:, -1] + T_cond = -R.T @ T + + theta_cond, azimuth_cond, z_cond = self.cartesian_to_spherical(T_cond[None, :]) + theta_target, azimuth_target, z_target = self.cartesian_to_spherical(T_target[None, :]) + + d_theta = theta_target - theta_cond + d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) + d_z = z_target - z_cond + + d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) + return d_T + + def load_im(self, path, color): + ''' + replace background pixel with random color in rendering + ''' + try: + img = plt.imread(path) + except: + print(path) + sys.exit() + img[img[:, :, -1] == 0.] = color + img = Image.fromarray(np.uint8(img[:, :, :3] * 255.)) + return img + + def __getitem__(self, index): + + data = {} + if self.paths[index][-2:] == '_1': # dirty fix for rendering dataset twice + total_view = 8 + else: + total_view = 4 + index_target, index_cond = random.sample(range(total_view), 2) # without replacement + filename = os.path.join(self.root_dir, self.paths[index]) + + # print(self.paths[index]) + + if self.return_paths: + data["path"] = str(filename) + + color = [1., 1., 1., 1.] + + try: + target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) + cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color)) + target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) + cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond)) + except: + # very hacky solution, sorry about this + filename = os.path.join(self.root_dir, '692db5f2d3a04bb286cb977a7dba903e_1') # this one we know is valid + target_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_target), color)) + cond_im = self.process_im(self.load_im(os.path.join(filename, '%03d.png' % index_cond), color)) + target_RT = np.load(os.path.join(filename, '%03d.npy' % index_target)) + cond_RT = np.load(os.path.join(filename, '%03d.npy' % index_cond)) + target_im = torch.zeros_like(target_im) + cond_im = torch.zeros_like(cond_im) + + data["image_target"] = target_im + data["image_cond"] = cond_im + data["T"] = self.get_T(target_RT, cond_RT) + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) + +class FolderData(Dataset): + def __init__(self, + root_dir, + caption_file=None, + image_transforms=[], + ext="jpg", + default_caption="", + postprocess=None, + return_paths=False, + ) -> None: + """Create a dataset from a folder of images. + If you pass in a root directory it will be searched for images + ending in ext (ext can be a list) + """ + self.root_dir = Path(root_dir) + self.default_caption = default_caption + self.return_paths = return_paths + if isinstance(postprocess, DictConfig): + postprocess = instantiate_from_config(postprocess) + self.postprocess = postprocess + if caption_file is not None: + with open(caption_file, "rt") as f: + ext = Path(caption_file).suffix.lower() + if ext == ".json": + captions = json.load(f) + elif ext == ".jsonl": + lines = f.readlines() + lines = [json.loads(x) for x in lines] + captions = {x["file_name"]: x["text"].strip("\n") for x in lines} + else: + raise ValueError(f"Unrecognised format: {ext}") + self.captions = captions + else: + self.captions = None + + if not isinstance(ext, (tuple, list, ListConfig)): + ext = [ext] + + # Only used if there is no caption file + self.paths = [] + for e in ext: + self.paths.extend(sorted(list(self.root_dir.rglob(f"*.{e}")))) + self.tform = make_tranforms(image_transforms) + + def __len__(self): + if self.captions is not None: + return len(self.captions.keys()) + else: + return len(self.paths) + + def __getitem__(self, index): + data = {} + if self.captions is not None: + chosen = list(self.captions.keys())[index] + caption = self.captions.get(chosen, None) + if caption is None: + caption = self.default_caption + filename = self.root_dir/chosen + else: + filename = self.paths[index] + + if self.return_paths: + data["path"] = str(filename) + + im = Image.open(filename).convert("RGB") + im = self.process_im(im) + data["image"] = im + + if self.captions is not None: + data["txt"] = caption + else: + data["txt"] = self.default_caption + + if self.postprocess is not None: + data = self.postprocess(data) + + return data + + def process_im(self, im): + im = im.convert("RGB") + return self.tform(im) +import random + +class TransformDataset(): + def __init__(self, ds, extra_label="sksbspic"): + self.ds = ds + self.extra_label = extra_label + self.transforms = { + "align": transforms.Resize(768), + "centerzoom": transforms.CenterCrop(768), + "randzoom": transforms.RandomCrop(768), + } + + + def __getitem__(self, index): + data = self.ds[index] + + im = data['image'] + im = im.permute(2,0,1) + # In case data is smaller than expected + im = transforms.Resize(1024)(im) + + tform_name = random.choice(list(self.transforms.keys())) + im = self.transforms[tform_name](im) + + im = im.permute(1,2,0) + + data['image'] = im + data['txt'] = data['txt'] + f" {self.extra_label} {tform_name}" + + return data + + def __len__(self): + return len(self.ds) + +def hf_dataset( + name, + image_transforms=[], + image_column="image", + text_column="text", + split='train', + image_key='image', + caption_key='txt', + ): + """Make huggingface dataset with appropriate list of transforms applied + """ + ds = load_dataset(name, split=split) + tform = make_tranforms(image_transforms) + + assert image_column in ds.column_names, f"Didn't find column {image_column} in {ds.column_names}" + assert text_column in ds.column_names, f"Didn't find column {text_column} in {ds.column_names}" + + def pre_process(examples): + processed = {} + processed[image_key] = [tform(im) for im in examples[image_column]] + processed[caption_key] = examples[text_column] + return processed + + ds.set_transform(pre_process) + return ds + +class TextOnly(Dataset): + def __init__(self, captions, output_size, image_key="image", caption_key="txt", n_gpus=1): + """Returns only captions with dummy images""" + self.output_size = output_size + self.image_key = image_key + self.caption_key = caption_key + if isinstance(captions, Path): + self.captions = self._load_caption_file(captions) + else: + self.captions = captions + + if n_gpus > 1: + # hack to make sure that all the captions appear on each gpu + repeated = [n_gpus*[x] for x in self.captions] + self.captions = [] + [self.captions.extend(x) for x in repeated] + + def __len__(self): + return len(self.captions) + + def __getitem__(self, index): + dummy_im = torch.zeros(3, self.output_size, self.output_size) + dummy_im = rearrange(dummy_im * 2. - 1., 'c h w -> h w c') + return {self.image_key: dummy_im, self.caption_key: self.captions[index]} + + def _load_caption_file(self, filename): + with open(filename, 'rt') as f: + captions = f.readlines() + return [x.strip('\n') for x in captions] + + + +import random +import json +class IdRetreivalDataset(FolderData): + def __init__(self, ret_file, *args, **kwargs): + super().__init__(*args, **kwargs) + with open(ret_file, "rt") as f: + self.ret = json.load(f) + + def __getitem__(self, index): + data = super().__getitem__(index) + key = self.paths[index].name + matches = self.ret[key] + if len(matches) > 0: + retreived = random.choice(matches) + else: + retreived = key + filename = self.root_dir/retreived + im = Image.open(filename).convert("RGB") + im = self.process_im(im) + # data["match"] = im + data["match"] = torch.cat((data["image"], im), dim=-1) + return data diff --git a/ldm/extras.py b/ldm/extras.py new file mode 100644 index 0000000000000000000000000000000000000000..62e654b330c44b85565f958d04bee217a168d7ec --- /dev/null +++ b/ldm/extras.py @@ -0,0 +1,77 @@ +from pathlib import Path +from omegaconf import OmegaConf +import torch +from ldm.util import instantiate_from_config +import logging +from contextlib import contextmanager + +from contextlib import contextmanager +import logging + +@contextmanager +def all_logging_disabled(highest_level=logging.CRITICAL): + """ + A context manager that will prevent any logging messages + triggered during the body from being processed. + + :param highest_level: the maximum logging level in use. + This would only need to be changed if a custom level greater than CRITICAL + is defined. + + https://gist.github.com/simon-weber/7853144 + """ + # two kind-of hacks here: + # * can't get the highest logging level in effect => delegate to the user + # * can't get the current module-level override => use an undocumented + # (but non-private!) interface + + previous_level = logging.root.manager.disable + + logging.disable(highest_level) + + try: + yield + finally: + logging.disable(previous_level) + +def load_training_dir(train_dir, device, epoch="last"): + """Load a checkpoint and config from training directory""" + train_dir = Path(train_dir) + ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) + assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" + config = list(train_dir.rglob(f"*-project.yaml")) + assert len(ckpt) > 0, f"didn't find any config in {train_dir}" + if len(config) > 1: + print(f"found {len(config)} matching config files") + config = sorted(config)[-1] + print(f"selecting {config}") + else: + config = config[0] + + + config = OmegaConf.load(config) + return load_model_from_config(config, ckpt[0], device) + +def load_model_from_config(config, ckpt, device="cpu", verbose=False): + """Loads a model from config and a ckpt + if config is a path will use omegaconf to load + """ + if isinstance(config, (str, Path)): + config = OmegaConf.load(config) + + with all_logging_disabled(): + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location="cpu") + global_step = pl_sd["global_step"] + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + model.to(device) + model.eval() + model.cond_stage_model.device = device + return model \ No newline at end of file diff --git a/ldm/guidance.py b/ldm/guidance.py new file mode 100644 index 0000000000000000000000000000000000000000..53d1a2a61b5f2f086178154cf04ea078e0835845 --- /dev/null +++ b/ldm/guidance.py @@ -0,0 +1,96 @@ +from typing import List, Tuple +from scipy import interpolate +import numpy as np +import torch +import matplotlib.pyplot as plt +from IPython.display import clear_output +import abc + + +class GuideModel(torch.nn.Module, abc.ABC): + def __init__(self) -> None: + super().__init__() + + @abc.abstractmethod + def preprocess(self, x_img): + pass + + @abc.abstractmethod + def compute_loss(self, inp): + pass + + +class Guider(torch.nn.Module): + def __init__(self, sampler, guide_model, scale=1.0, verbose=False): + """Apply classifier guidance + + Specify a guidance scale as either a scalar + Or a schedule as a list of tuples t = 0->1 and scale, e.g. + [(0, 10), (0.5, 20), (1, 50)] + """ + super().__init__() + self.sampler = sampler + self.index = 0 + self.show = verbose + self.guide_model = guide_model + self.history = [] + + if isinstance(scale, (Tuple, List)): + times = np.array([x[0] for x in scale]) + values = np.array([x[1] for x in scale]) + self.scale_schedule = {"times": times, "values": values} + else: + self.scale_schedule = float(scale) + + self.ddim_timesteps = sampler.ddim_timesteps + self.ddpm_num_timesteps = sampler.ddpm_num_timesteps + + + def get_scales(self): + if isinstance(self.scale_schedule, float): + return len(self.ddim_timesteps)*[self.scale_schedule] + + interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) + fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps + return interpolater(fractional_steps) + + def modify_score(self, model, e_t, x, t, c): + + # TODO look up index by t + scale = self.get_scales()[self.index] + + if (scale == 0): + return e_t + + sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) + x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) + + inp = self.guide_model.preprocess(x_img) + loss = self.guide_model.compute_loss(inp) + grads = torch.autograd.grad(loss.sum(), x_in)[0] + correction = grads * scale + + if self.show: + clear_output(wait=True) + print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) + self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) + plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) + plt.axis('off') + plt.show() + plt.imshow(correction[0][0].detach().cpu()) + plt.axis('off') + plt.show() + + + e_t_mod = e_t - sqrt_1ma*correction + if self.show: + fig, axs = plt.subplots(1, 3) + axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) + axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) + axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) + plt.show() + self.index += 1 + return e_t_mod \ No newline at end of file diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..be39da9ca6dacc22bf3df9c7389bbb403a4a3ade --- /dev/null +++ b/ldm/lr_scheduler.py @@ -0,0 +1,98 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n,**kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0. + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) + self.last_f = f + return f + diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9c4f45498561953b8085981609b2a3298a5473 --- /dev/null +++ b/ldm/models/autoencoder.py @@ -0,0 +1,443 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..67e98b9d8ffb96a150b517497ace0a242d7163ef --- /dev/null +++ b/ldm/models/diffusion/classifier.py @@ -0,0 +1,267 @@ +import os +import torch +import pytorch_lightning as pl +from omegaconf import OmegaConf +from torch.nn import functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from copy import deepcopy +from einops import rearrange +from glob import glob +from natsort import natsorted + +from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel +from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config + +__models__ = { + 'class_label': EncoderUNetModel, + 'segmentation': UNetModel +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class NoisyLatentImageClassifier(pl.LightningModule): + + def __init__(self, + diffusion_path, + num_classes, + ckpt_path=None, + pool='attention', + label_key=None, + diffusion_ckpt_path=None, + scheduler_config=None, + weight_decay=1.e-2, + log_steps=10, + monitor='val/loss', + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.num_classes = num_classes + # get latest config of diffusion model + diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] + self.diffusion_config = OmegaConf.load(diffusion_config).model + self.diffusion_config.params.ckpt_path = diffusion_ckpt_path + self.load_diffusion() + + self.monitor = monitor + self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 + self.log_time_interval = self.diffusion_model.num_timesteps // log_steps + self.log_steps = log_steps + + self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ + else self.diffusion_model.cond_stage_key + + assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' + + if self.label_key not in __models__: + raise NotImplementedError() + + self.load_classifier(ckpt_path, pool) + + self.scheduler_config = scheduler_config + self.use_scheduler = self.scheduler_config is not None + self.weight_decay = weight_decay + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def load_diffusion(self): + model = instantiate_from_config(self.diffusion_config) + self.diffusion_model = model.eval() + self.diffusion_model.train = disabled_train + for param in self.diffusion_model.parameters(): + param.requires_grad = False + + def load_classifier(self, ckpt_path, pool): + model_config = deepcopy(self.diffusion_config.params.unet_config.params) + model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels + model_config.out_channels = self.num_classes + if self.label_key == 'class_label': + model_config.pool = pool + + self.model = __models__[self.label_key](**model_config) + if ckpt_path is not None: + print('#####################################################################') + print(f'load from ckpt "{ckpt_path}"') + print('#####################################################################') + self.init_from_ckpt(ckpt_path) + + @torch.no_grad() + def get_x_noisy(self, x, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x)) + continuous_sqrt_alpha_cumprod = None + if self.diffusion_model.use_continuous_noise: + continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) + # todo: make sure t+1 is correct here + + return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, + continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) + + def forward(self, x_noisy, t, *args, **kwargs): + return self.model(x_noisy, t) + + @torch.no_grad() + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + @torch.no_grad() + def get_conditioning(self, batch, k=None): + if k is None: + k = self.label_key + assert k is not None, 'Needs to provide label key' + + targets = batch[k].to(self.device) + + if self.label_key == 'segmentation': + targets = rearrange(targets, 'b h w c -> b c h w') + for down in range(self.numd): + h, w = targets.shape[-2:] + targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') + + # targets = rearrange(targets,'b c h w -> b h w c') + + return targets + + def compute_top_k(self, logits, labels, k, reduction="mean"): + _, top_ks = torch.topk(logits, k, dim=1) + if reduction == "mean": + return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() + elif reduction == "none": + return (top_ks == labels[:, None]).float().sum(dim=-1) + + def on_train_epoch_start(self): + # save some memory + self.diffusion_model.model.to('cpu') + + @torch.no_grad() + def write_logs(self, loss, logits, targets): + log_prefix = 'train' if self.training else 'val' + log = {} + log[f"{log_prefix}/loss"] = loss.mean() + log[f"{log_prefix}/acc@1"] = self.compute_top_k( + logits, targets, k=1, reduction="mean" + ) + log[f"{log_prefix}/acc@5"] = self.compute_top_k( + logits, targets, k=5, reduction="mean" + ) + + self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) + self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) + self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) + + def shared_step(self, batch, t=None): + x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) + targets = self.get_conditioning(batch) + if targets.dim() == 4: + targets = targets.argmax(dim=1) + if t is None: + t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() + else: + t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() + x_noisy = self.get_x_noisy(x, t) + logits = self(x_noisy, t) + + loss = F.cross_entropy(logits, targets, reduction='none') + + self.write_logs(loss.detach(), logits.detach(), targets.detach()) + + loss = loss.mean() + return loss, logits, x_noisy, targets + + def training_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + return loss + + def reset_noise_accs(self): + self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in + range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} + + def on_validation_start(self): + self.reset_noise_accs() + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + loss, *_ = self.shared_step(batch) + + for t in self.noisy_acc: + _, logits, _, targets = self.shared_step(batch, t) + self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) + self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) + + return loss + + def configure_optimizers(self): + optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) + + if self.use_scheduler: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [optimizer], scheduler + + return optimizer + + @torch.no_grad() + def log_images(self, batch, N=8, *args, **kwargs): + log = dict() + x = self.get_input(batch, self.diffusion_model.first_stage_key) + log['inputs'] = x + + y = self.get_conditioning(batch) + + if self.label_key == 'class_label': + y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) + log['labels'] = y + + if ismap(y): + log['labels'] = self.diffusion_model.to_rgb(y) + + for step in range(self.log_steps): + current_time = step * self.log_time_interval + + _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) + + log[f'inputs@t{current_time}'] = x_noisy + + pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) + pred = rearrange(pred, 'b h w c -> b c h w') + + log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) + + for key in log: + log[key] = log[key][:N] + + return log diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..5db306d8dd82ca8868e34cddfeb4a01daf259c08 --- /dev/null +++ b/ldm/models/diffusion/ddim.py @@ -0,0 +1,326 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial +from einops import rearrange + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor +from ldm.models.diffusion.sampling_util import renorm_thresholding, norm_thresholding, spatial_norm_thresholding + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.device = model.device + + def to(self, device): + """Same as to in torch module + Don't really underestand why this isn't a module in the first place""" + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + new_v = getattr(self, k).to(device) + setattr(self, k, new_v) + + + def register_buffer(self, name, attr, device=None): + if type(attr) == torch.Tensor: + attr = attr.to(device) + # if attr.device != torch.device("cuda"): + # attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas), self.device) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod), self.device) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev), self.device) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())), self.device) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())), self.device) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())), self.device) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())), self.device) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)), self.device) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas, self.device) + self.register_buffer('ddim_alphas', ddim_alphas, self.device) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev, self.device) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas), self.device) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps, self.device) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + t_start=-1): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + timesteps = timesteps[:t_start] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: + img = callback(i, img, pred_x0) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + return x_dec \ No newline at end of file diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6d5017af4f84fdc95c6389a2dcc8d6b8a03080 --- /dev/null +++ b/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1994 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.rank_zero import rank_zero_only +from omegaconf import ListConfig + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import VQModelInterface, IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.modules.attention import CrossAttention + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + make_it_fit=False, + ucg_training=None, + ): + super().__init__() + assert parameterization in ["eps", "x0"], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape)==len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + else: + raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1-p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + unet_trainable=True, + *args, **kwargs): + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__': + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.unet_trainable = unet_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + + # construct linear projection layer for concatenating image CLIP embedding and RT + self.cc_projection = nn.Linear(772, 768) + nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768]) + nn.init.zeros_(list(self.cc_projection.parameters())[1]) + self.cc_projection.requires_grad_(True) + + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, uncond=0.05): + x = super().get_input(batch, k) + T = batch['T'].to(memory_format=torch.contiguous_format).float() + + if bs is not None: + x = x[:bs] + T = T[:bs].to(self.device) + + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + cond_key = cond_key or self.cond_stage_key + xc = super().get_input(batch, cond_key).to(self.device) + if bs is not None: + xc = xc[:bs] + cond = {} + + # To support classifier-free guidance, randomly drop out only text conditioning 5%, only image conditioning 5%, and both 5%. + random = torch.rand(x.size(0), device=x.device) + prompt_mask = rearrange(random < 2 * uncond, "n -> n 1 1") + input_mask = 1 - rearrange((random >= uncond).float() * (random < 3 * uncond).float(), "n -> n 1 1 1") + null_prompt = self.get_learned_conditioning([""]) + + # z.shape: [8, 4, 64, 64]; c.shape: [8, 1, 768] + # print('=========== xc shape ===========', xc.shape) + with torch.enable_grad(): + clip_emb = self.get_learned_conditioning(xc).detach() + null_prompt = self.get_learned_conditioning([""]).detach() + cond["c_crossattn"] = [self.cc_projection(torch.cat([torch.where(prompt_mask, null_prompt, clip_emb), T[:, None, :]], dim=-1))] + cond["c_concat"] = [input_mask * self.encode_first_stage((xc.to(self.device))).mode().detach()] + out = [z, cond] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + # @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + uf = self.split_input_params["vqf"] + bs, nc, h, w = z.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(z, ks, stride, uf=uf) + + z = unfold(z) # (bn, nc * prod(**ks), L) + # 1. Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + # 2. apply model loop over last dim + if isinstance(self.first_stage_model, VQModelInterface): + output_list = [self.first_stage_model.decode(z[:, :, :, :, i], + force_not_quantize=predict_cids or force_not_quantize) + for i in range(z.shape[-1])] + else: + + output_list = [self.first_stage_model.decode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) # # (bn, nc, ks[0], ks[1], L) + o = o * weighting + # Reverse 1. reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization # norm is shape (1, 1, h, w) + return decoded + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + else: + if isinstance(self.first_stage_model, VQModelInterface): + return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize) + else: + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + if hasattr(self, "split_input_params"): + if self.split_input_params["patch_distributed_vq"]: + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + df = self.split_input_params["vqf"] + self.split_input_params['original_image_size'] = x.shape[-2:] + bs, nc, h, w = x.shape + if ks[0] > h or ks[1] > w: + ks = (min(ks[0], h), min(ks[1], w)) + print("reducing Kernel") + + if stride[0] > h or stride[1] > w: + stride = (min(stride[0], h), min(stride[1], w)) + print("reducing stride") + + fold, unfold, normalization, weighting = self.get_fold_unfold(x, ks, stride, df=df) + z = unfold(x) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + output_list = [self.first_stage_model.encode(z[:, :, :, :, i]) + for i in range(z.shape[-1])] + + o = torch.stack(output_list, axis=-1) + o = o * weighting + + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + decoded = fold(o) + decoded = decoded / normalization + return decoded + + else: + return self.first_stage_model.encode(x) + else: + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + # if self.cond_stage_trainable: + # c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset + def rescale_bbox(bbox): + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + return x0, y0, w, h + + return [rescale_bbox(b) for b in bboxes] + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + if hasattr(self, "split_input_params"): + assert len(cond) == 1 # todo can only deal with one conditioning atm + assert not return_ids + ks = self.split_input_params["ks"] # eg. (128, 128) + stride = self.split_input_params["stride"] # eg. (64, 64) + + h, w = x_noisy.shape[-2:] + + fold, unfold, normalization, weighting = self.get_fold_unfold(x_noisy, ks, stride) + + z = unfold(x_noisy) # (bn, nc * prod(**ks), L) + # Reshape to img shape + z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])] + + if self.cond_stage_key in ["image", "LR_image", "segmentation", + 'bbox_img'] and self.model.conditioning_key: # todo check for completeness + c_key = next(iter(cond.keys())) # get key + c = next(iter(cond.values())) # get value + assert (len(c) == 1) # todo extend to list with more than one elem + c = c[0] # get element + + c = unfold(c) + c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1])) # (bn, nc, ks[0], ks[1], L ) + + cond_list = [{c_key: [c[:, :, :, :, i]]} for i in range(c.shape[-1])] + + elif self.cond_stage_key == 'coordinates_bbox': + assert 'original_image_size' in self.split_input_params, 'BoudingBoxRescaling is missing original_image_size' + + # assuming padding of unfold is always 0 and its dilation is always 1 + n_patches_per_row = int((w - ks[0]) / stride[0] + 1) + full_img_h, full_img_w = self.split_input_params['original_image_size'] + # as we are operating on latents, we need the factor from the original image size to the + # spatial latent size to properly rescale the crops for regenerating the bbox annotations + num_downs = self.first_stage_model.encoder.num_resolutions - 1 + rescale_latent = 2 ** (num_downs) + + # get top left postions of patches as conforming for the bbbox tokenizer, therefore we + # need to rescale the tl patch coordinates to be in between (0,1) + tl_patch_coordinates = [(rescale_latent * stride[0] * (patch_nr % n_patches_per_row) / full_img_w, + rescale_latent * stride[1] * (patch_nr // n_patches_per_row) / full_img_h) + for patch_nr in range(z.shape[-1])] + + # patch_limits are tl_coord, width and height coordinates as (x_tl, y_tl, h, w) + patch_limits = [(x_tl, y_tl, + rescale_latent * ks[0] / full_img_w, + rescale_latent * ks[1] / full_img_h) for x_tl, y_tl in tl_patch_coordinates] + # patch_values = [(np.arange(x_tl,min(x_tl+ks, 1.)),np.arange(y_tl,min(y_tl+ks, 1.))) for x_tl, y_tl in tl_patch_coordinates] + + # tokenize crop coordinates for the bounding boxes of the respective patches + patch_limits_tknzd = [torch.LongTensor(self.bbox_tokenizer._crop_encoder(bbox))[None].to(self.device) + for bbox in patch_limits] # list of length l with tensors of shape (1, 2) + # cut tknzd crop position from conditioning + assert isinstance(cond, dict), 'cond must be dict to be fed into model' + cut_cond = cond['c_crossattn'][0][..., :-2].to(self.device) + + adapted_cond = torch.stack([torch.cat([cut_cond, p], dim=1) for p in patch_limits_tknzd]) + adapted_cond = rearrange(adapted_cond, 'l b n -> (l b) n') + adapted_cond = self.get_learned_conditioning(adapted_cond) + adapted_cond = rearrange(adapted_cond, '(l b) n d -> l b n d', l=z.shape[-1]) + + cond_list = [{'c_crossattn': [e]} for e in adapted_cond] + + else: + cond_list = [cond for i in range(z.shape[-1])] # Todo make this more efficient + + # apply model by loop over crops + output_list = [self.model(z_list[i], t, **cond_list[i]) for i in range(z.shape[-1])] + assert not isinstance(output_list[0], + tuple) # todo cant deal with multiple model outputs check this never happens + + o = torch.stack(output_list, axis=-1) + o = o * weighting + # Reverse reshape to img shape + o = o.view((o.shape[0], -1, o.shape[-1])) # (bn, nc * ks[0] * ks[1], L) + # stitch crops together + x_recon = fold(o) / normalization + + else: + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None,**kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None, image_size=512): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + cond = {} + cond["c_crossattn"] = [c] + cond["c_concat"] = [torch.zeros([batch_size, 4, image_size // 8, image_size // 8]).to(self.device)] + return cond + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, + ddim_steps=ddim_steps,eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label, image_size=x.shape[-1]) + # uc = torch.zeros_like(c) + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + + samples, _ = self.sample_log(cond=c,batch_size=N,ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = [] + if self.unet_trainable == "attn": + print("Training only unet attention layers") + for n, m in self.model.named_modules(): + if isinstance(m, CrossAttention) and n.endswith('attn2'): + params.extend(m.parameters()) + if self.unet_trainable == "conv_in": + print("Training only unet input conv layers") + params = list(self.model.diffusion_model.input_blocks[0][0].parameters()) + elif self.unet_trainable is True or self.unet_trainable == "all": + print("Training the full unet") + params = list(self.model.parameters()) + else: + raise ValueError(f"Unrecognised setting for unet_trainable: {self.unet_trainable}") + + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + + if self.cc_projection is not None: + params = params + list(self.cc_projection.parameters()) + print('========== optimizing for cc projection weight ==========') + + opt = torch.optim.AdamW([{"params": self.model.parameters(), "lr": lr}, + {"params": self.cc_projection.parameters(), "lr": 10. * lr}], lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + # c_crossattn dimension: torch.Size([8, 1, 768]) 1 + # cc dimension: torch.Size([8, 1, 768] + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + #import pudb; pu.db + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentInpaintDiffusion(LatentDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + def __init__(self, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + keep_finetune_dims=4, # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print(f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:,self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + log["masked_image"] = rearrange(batch["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class Layout2ImgDiffusion(LatentDiffusion): + # TODO: move all layout-specific hacks to this class + def __init__(self, cond_stage_key, *args, **kwargs): + assert cond_stage_key == 'coordinates_bbox', 'Layout2ImgDiffusion only for cond_stage_key="coordinates_bbox"' + super().__init__(cond_stage_key=cond_stage_key, *args, **kwargs) + + def log_images(self, batch, N=8, *args, **kwargs): + logs = super().log_images(batch=batch, N=N, *args, **kwargs) + + key = 'train' if self.training else 'validation' + dset = self.trainer.datamodule.datasets[key] + mapper = dset.conditional_builders[self.cond_stage_key] + + bbox_imgs = [] + map_fn = lambda catno: dset.get_textual_label(dset.get_category_id(catno)) + for tknzd_bbox in batch[self.cond_stage_key][:N]: + bboximg = mapper.plot(tknzd_bbox.detach().cpu(), map_fn, (256, 256)) + bbox_imgs.append(bboximg) + + cond_img = torch.stack(bbox_imgs, dim=0) + logs['bbox_image'] = cond_img + return logs + + +class SimpleUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.low_scale_key = low_scale_key + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + + encoder_posterior = self.encode_first_stage(x_low) + zx = self.get_first_stage_encoding(encoder_posterior).detach() + all_conds = {"c_concat": [zx], "c_crossattn": [c]} + + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + return z, all_conds, x, xrec, xc, x_low + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log + +class MultiCatFrameDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_key="LR", **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.low_scale_key = low_scale_key + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + n = 2 + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + cat_conds = batch[self.low_scale_key][:bs] + cats = [] + for i in range(n): + x_low = cat_conds[:,:,:,3*i:3*(i+1)] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + encoder_posterior = self.encode_first_stage(x_low) + zx = self.get_first_stage_encoding(encoder_posterior).detach() + cats.append(zx) + + all_conds = {"c_concat": [torch.cat(cats, dim=1)], "c_crossattn": [c]} + + if log_mode: + # TODO: maybe disable if too expensive + interpretability = False + if interpretability: + zx = zx[:, :, ::2, ::2] + return z, all_conds, x, xrec, xc, x_low + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low = self.get_input(batch, self.first_stage_key, bs=N, log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2]//25) + log["conditioning"] = xc + elif self.cond_stage_key == 'class_label': + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2]//25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py new file mode 100644 index 0000000000000000000000000000000000000000..080edeec9efed663f0e01de0afbbf3bed1cfa1d1 --- /dev/null +++ b/ldm/models/diffusion/plms.py @@ -0,0 +1,259 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/ldm/models/diffusion/sampling_util.py b/ldm/models/diffusion/sampling_util.py new file mode 100644 index 0000000000000000000000000000000000000000..a0ae00fe86044456fc403af403be71ff15112424 --- /dev/null +++ b/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,50 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def renorm_thresholding(x0, value): + # renorm + pred_max = x0.max() + pred_min = x0.min() + pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 + pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 + + s = torch.quantile( + rearrange(pred_x0, 'b ... -> b (...)').abs(), + value, + dim=-1 + ) + s.clamp_(min=1.0) + s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) + + # clip by threshold + # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max + + # temporary hack: numpy on cpu + pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() + pred_x0 = torch.tensor(pred_x0).to(self.model.device) + + # re.renorm + pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 + pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range + return pred_x0 + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..124effbeee03d2f0950f6cac6aa455be5a6d359f --- /dev/null +++ b/ldm/modules/attention.py @@ -0,0 +1,266 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', attn, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + self.disable_self_attn = disable_self_attn + self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False): + super().__init__() + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, + disable_self_attn=disable_self_attn) + for d in range(depth)] + ) + + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + x = self.proj_out(x) + return x + x_in diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..533e589a2024f1d7c52093d8c472c3b1b6617e26 --- /dev/null +++ b/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,835 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from ldm.util import instantiate_from_config +from ldm.modules.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..6b994cca787464d34f6367edf486974b3542f808 --- /dev/null +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,996 @@ +from abc import abstractmethod +from functools import partial +import math +from typing import Iterable + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + #self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") # todo: convert to warning + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) + diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a952e6c40308c33edd422da0ce6a60f47e73661b --- /dev/null +++ b/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,267 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..f2b8ef901130efc171aa69742ca0244d94d3f2e9 --- /dev/null +++ b/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c75af43565f6e140287644aaaefa97dd6e67c5 --- /dev/null +++ b/ldm/modules/ema.py @@ -0,0 +1,76 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates + else torch.tensor(-1,dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + #remove as '.'-character is not allowed in buffers + s_name = name.replace('.','') + self.m_name2s_name.update({name:s_name}) + self.register_buffer(s_name,p.clone().detach().data) + + self.collected_params = [] + + def forward(self,model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b1afccfc55d1b8162d6da8c0316082584a4bde34 --- /dev/null +++ b/ldm/modules/encoders/modules.py @@ -0,0 +1,550 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial +import kornia + +from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test +from ldm.util import default +import clip + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + +class FaceClipEncoder(AbstractEncoder): + def __init__(self, augment=True, retreival_key=None): + super().__init__() + self.encoder = FrozenCLIPImageEmbedder() + self.augment = augment + self.retreival_key = retreival_key + + def forward(self, img): + encodings = [] + with torch.no_grad(): + x_offset = 125 + if self.retreival_key: + # Assumes retrieved image are packed into the second half of channels + face = img[:,3:,190:440,x_offset:(512-x_offset)] + other = img[:,:3,...].clone() + else: + face = img[:,:,190:440,x_offset:(512-x_offset)] + other = img.clone() + + if self.augment: + face = K.RandomHorizontalFlip()(face) + + other[:,:,190:440,x_offset:(512-x_offset)] *= 0 + encodings = [ + self.encoder.encode(face), + self.encoder.encode(other), + ] + + return torch.cat(encodings, dim=1) + + def encode(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) + + return self(img) + +class FaceIdClipEncoder(AbstractEncoder): + def __init__(self): + super().__init__() + self.encoder = FrozenCLIPImageEmbedder() + for p in self.encoder.parameters(): + p.requires_grad = False + self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True) + + def forward(self, img): + encodings = [] + with torch.no_grad(): + face = kornia.geometry.resize(img, (256, 256), + interpolation='bilinear', align_corners=True) + + other = img.clone() + other[:,:,184:452,122:396] *= 0 + encodings = [ + self.id.encode(face), + self.encoder.encode(other), + ] + + return torch.cat(encodings, dim=1) + + def encode(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device) + + return self(img) + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class'): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class TransformerEmbedder(AbstractEncoder): + """Some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + super().__init__() + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer)) + + def forward(self, tokens): + tokens = tokens.to(self.device) # meh + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, x): + return self(x) + + +class BERTTokenizer(AbstractEncoder): + """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" + def __init__(self, device="cuda", vq_interface=True, max_length=77): + super().__init__() + from transformers import BertTokenizerFast # TODO: add to reuquirements + self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") + self.device = device + self.vq_interface = vq_interface + self.max_length = max_length + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + return tokens + + @torch.no_grad() + def encode(self, text): + tokens = self(text) + if not self.vq_interface: + return tokens + return None, None, [None, None, tokens] + + def decode(self, text): + return text + + +class BERTEmbedder(AbstractEncoder): + """Uses the BERT tokenizr model and add some transformer encoder layers""" + def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, + device="cuda",use_tokenizer=True, embedding_dropout=0.0): + super().__init__() + self.use_tknz_fn = use_tokenizer + if self.use_tknz_fn: + self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) + self.device = device + self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, + attn_layers=Encoder(dim=n_embed, depth=n_layer), + emb_dropout=embedding_dropout) + + def forward(self, text): + if self.use_tknz_fn: + tokens = self.tknz_fn(text)#.to(self.device) + else: + tokens = text + z = self.transformer(tokens, return_embeddings=True) + return z + + def encode(self, text): + # output of length 77 + return self(text) + + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + +from ldm.thirdp.psp.id_loss import IDFeatures +import kornia.augmentation as K + +class FrozenFaceEncoder(AbstractEncoder): + def __init__(self, model_path, augment=False): + super().__init__() + self.loss_fn = IDFeatures(model_path) + # face encoder is frozen + for p in self.loss_fn.parameters(): + p.requires_grad = False + # Mapper is trainable + self.mapper = torch.nn.Linear(512, 768) + p = 0.25 + if augment: + self.augment = K.AugmentationSequential( + K.RandomHorizontalFlip(p=0.5), + K.RandomEqualize(p=p), + # K.RandomPlanckianJitter(p=p), + # K.RandomPlasmaBrightness(p=p), + # K.RandomPlasmaContrast(p=p), + # K.ColorJiggle(0.02, 0.2, 0.2, p=p), + ) + else: + self.augment = False + + def forward(self, img): + if isinstance(img, list): + # Uncondition + return torch.zeros((1, 1, 768), device=self.mapper.weight.device) + + if self.augment is not None: + # Transforms require 0-1 + img = self.augment((img + 1)/2) + img = 2*img - 1 + + feat = self.loss_fn(img, crop=True) + feat = self.mapper(feat.unsqueeze(1)) + return feat + + def encode(self, img): + return self(img) + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + +import torch.nn.functional as F +from transformers import CLIPVisionModel +class ClipImageProjector(AbstractEncoder): + """ + Uses the CLIP image encoder. + """ + def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.model = CLIPVisionModel.from_pretrained(version) + self.model.train() + self.max_length = max_length # TODO: typical value? + self.antialias = True + self.mapper = torch.nn.Linear(1024, 768) + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + null_cond = self.get_null_cond(version, max_length) + self.register_buffer('null_cond', null_cond) + + @torch.no_grad() + def get_null_cond(self, version, max_length): + device = self.mean.device + embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) + null_cond = embedder([""]) + return null_cond + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + if isinstance(x, list): + return self.null_cond + # x is assumed to be in range [-1,1] + x = self.preprocess(x) + outputs = self.model(pixel_values=x) + last_hidden_state = outputs.last_hidden_state + last_hidden_state = self.mapper(last_hidden_state) + return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0]) + + def encode(self, im): + return self(im) + +class ProjectedFrozenCLIPEmbedder(AbstractEncoder): + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32 + super().__init__() + self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length) + self.projection = torch.nn.Linear(768, 768) + + def forward(self, text): + z = self.embedder(text) + return self.projection(z) + + def encode(self, text): + return self(text) + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=False, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + + def preprocess(self, x): + # Expects inputs in the range -1, 1 + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic',align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, 768, device=device) + return self.model.encode_image(self.preprocess(x)).float() + + def encode(self, im): + return self(im).unsqueeze(1) + +from torchvision import transforms +import random + +class FrozenCLIPImageMutliEmbedder(AbstractEncoder): + """ + Uses the CLIP image encoder. + Not actually frozen... If you want that set cond_stage_trainable=False in cfg + """ + def __init__( + self, + model='ViT-L/14', + jit=False, + device='cpu', + antialias=True, + max_crops=5, + ): + super().__init__() + self.model, _ = clip.load(name=model, device=device, jit=jit) + # We don't use the text part so delete it + del self.model.transformer + self.antialias = antialias + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.max_crops = max_crops + + def preprocess(self, x): + + # Expects inputs in the range -1, 1 + randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1)) + max_crops = self.max_crops + patches = [] + crops = [randcrop(x) for _ in range(max_crops)] + patches.extend(crops) + x = torch.cat(patches, dim=0) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x): + # x is assumed to be in range [-1,1] + if isinstance(x, list): + # [""] denotes condition dropout for ucg + device = self.model.visual.conv1.weight.device + return torch.zeros(1, self.max_crops, 768, device=device) + batch_tokens = [] + for im in x: + patches = self.preprocess(im.unsqueeze(0)) + tokens = self.model.encode_image(patches).float() + for t in tokens: + if random.random() < 0.1: + t *= 0 + batch_tokens.append(tokens.unsqueeze(0)) + + return torch.cat(batch_tokens, dim=0) + + def encode(self, im): + return self(im) + +class SpatialRescaler(nn.Module): + def __init__(self, + n_stages=1, + method='bilinear', + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None + if self.remap_output: + print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') + self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) + + def forward(self,x): + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +from ldm.util import instantiate_from_config +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like + + +class LowScaleEncoder(nn.Module): + def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64, + scale_factor=1.0): + super().__init__() + self.max_noise_level = max_noise_level + self.model = instantiate_from_config(model_config) + self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start, + linear_end=linear_end) + self.out_size = output_size + self.scale_factor = scale_factor + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + z = self.model.encode(x).sample() + z = z * self.scale_factor + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + z = self.q_sample(z, noise_level) + if self.out_size is not None: + z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode + # z = z.repeat_interleave(2, -2).repeat_interleave(2, -1) + return z, noise_level + + def decode(self, z): + z = z / self.scale_factor + return self.model.decode(z) + + +if __name__ == "__main__": + from ldm.util import count_params + sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"] + model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + model = FrozenCLIPEmbedder().cuda() + count_params(model, True) + z = model(sentences) + print(z.shape) + + print("done.") diff --git a/ldm/modules/evaluate/adm_evaluator.py b/ldm/modules/evaluate/adm_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..508cddf206e9aa8b2fa1de32e69a7b78acee13c0 --- /dev/null +++ b/ldm/modules/evaluate/adm_evaluator.py @@ -0,0 +1,676 @@ +import argparse +import io +import os +import random +import warnings +import zipfile +from abc import ABC, abstractmethod +from contextlib import contextmanager +from functools import partial +from multiprocessing import cpu_count +from multiprocessing.pool import ThreadPool +from typing import Iterable, Optional, Tuple +import yaml + +import numpy as np +import requests +import tensorflow.compat.v1 as tf +from scipy import linalg +from tqdm.auto import tqdm + +INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" +INCEPTION_V3_PATH = "classify_image_graph_def.pb" + +FID_POOL_NAME = "pool_3:0" +FID_SPATIAL_NAME = "mixed_6/conv:0" + +REQUIREMENTS = f"This script has the following requirements: \n" \ + 'tensorflow-gpu>=2.0' + "\n" + 'scipy' + "\n" + "requests" + "\n" + "tqdm" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--ref_batch", help="path to reference batch npz file") + parser.add_argument("--sample_batch", help="path to sample batch npz file") + args = parser.parse_args() + + config = tf.ConfigProto( + allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph + ) + config.gpu_options.allow_growth = True + evaluator = Evaluator(tf.Session(config=config)) + + print("warming up TensorFlow...") + # This will cause TF to print a bunch of verbose stuff now rather + # than after the next print(), to help prevent confusion. + evaluator.warmup() + + print("computing reference batch activations...") + ref_acts = evaluator.read_activations(args.ref_batch) + print("computing/reading reference batch statistics...") + ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) + + print("computing sample batch activations...") + sample_acts = evaluator.read_activations(args.sample_batch) + print("computing/reading sample batch statistics...") + sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts) + + print("Computing evaluations...") + is_ = evaluator.compute_inception_score(sample_acts[0]) + print("Inception Score:", is_) + fid = sample_stats.frechet_distance(ref_stats) + print("FID:", fid) + sfid = sample_stats_spatial.frechet_distance(ref_stats_spatial) + print("sFID:", sfid) + prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) + print("Precision:", prec) + print("Recall:", recall) + + savepath = '/'.join(args.sample_batch.split('/')[:-1]) + results_file = os.path.join(savepath,'evaluation_metrics.yaml') + print(f'Saving evaluation results to "{results_file}"') + + results = { + 'IS': is_, + 'FID': fid, + 'sFID': sfid, + 'Precision:':prec, + 'Recall': recall + } + + with open(results_file, 'w') as f: + yaml.dump(results, f, default_flow_style=False) + +class InvalidFIDException(Exception): + pass + + +class FIDStatistics: + def __init__(self, mu: np.ndarray, sigma: np.ndarray): + self.mu = mu + self.sigma = sigma + + def frechet_distance(self, other, eps=1e-6): + """ + Compute the Frechet distance between two sets of statistics. + """ + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 + mu1, sigma1 = self.mu, self.sigma + mu2, sigma2 = other.mu, other.sigma + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert ( + mu1.shape == mu2.shape + ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" + assert ( + sigma1.shape == sigma2.shape + ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" + + diff = mu1 - mu2 + + # product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; adding %s to diagonal of cov estimates" + % eps + ) + warnings.warn(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +class Evaluator: + def __init__( + self, + session, + batch_size=64, + softmax_batch_size=512, + ): + self.sess = session + self.batch_size = batch_size + self.softmax_batch_size = softmax_batch_size + self.manifold_estimator = ManifoldEstimator(session) + with self.sess.graph.as_default(): + self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) + self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) + self.pool_features, self.spatial_features = _create_feature_graph(self.image_input) + self.softmax = _create_softmax_graph(self.softmax_input) + + def warmup(self): + self.compute_activations(np.zeros([1, 8, 64, 64, 3])) + + def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: + with open_npz_array(npz_path, "arr_0") as reader: + return self.compute_activations(reader.read_batches(self.batch_size)) + + def compute_activations(self, batches: Iterable[np.ndarray],silent=False) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute image features for downstream evals. + + :param batches: a iterator over NHWC numpy arrays in [0, 255]. + :return: a tuple of numpy arrays of shape [N x X], where X is a feature + dimension. The tuple is (pool_3, spatial). + """ + preds = [] + spatial_preds = [] + it = batches if silent else tqdm(batches) + for batch in it: + batch = batch.astype(np.float32) + pred, spatial_pred = self.sess.run( + [self.pool_features, self.spatial_features], {self.image_input: batch} + ) + preds.append(pred.reshape([pred.shape[0], -1])) + spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) + return ( + np.concatenate(preds, axis=0), + np.concatenate(spatial_preds, axis=0), + ) + + def read_statistics( + self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] + ) -> Tuple[FIDStatistics, FIDStatistics]: + obj = np.load(npz_path) + if "mu" in list(obj.keys()): + return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( + obj["mu_s"], obj["sigma_s"] + ) + return tuple(self.compute_statistics(x) for x in activations) + + def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: + mu = np.mean(activations, axis=0) + sigma = np.cov(activations, rowvar=False) + return FIDStatistics(mu, sigma) + + def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float: + softmax_out = [] + for i in range(0, len(activations), self.softmax_batch_size): + acts = activations[i : i + self.softmax_batch_size] + softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})) + preds = np.concatenate(softmax_out, axis=0) + # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 + scores = [] + for i in range(0, len(preds), split_size): + part = preds[i : i + split_size] + kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) + kl = np.mean(np.sum(kl, 1)) + scores.append(np.exp(kl)) + return float(np.mean(scores)) + + def compute_prec_recall( + self, activations_ref: np.ndarray, activations_sample: np.ndarray + ) -> Tuple[float, float]: + radii_1 = self.manifold_estimator.manifold_radii(activations_ref) + radii_2 = self.manifold_estimator.manifold_radii(activations_sample) + pr = self.manifold_estimator.evaluate_pr( + activations_ref, radii_1, activations_sample, radii_2 + ) + return (float(pr[0][0]), float(pr[1][0])) + + +class ManifoldEstimator: + """ + A helper for comparing manifolds of feature vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 + """ + + def __init__( + self, + session, + row_batch_size=10000, + col_batch_size=10000, + nhood_sizes=(3,), + clamp_to_percentile=None, + eps=1e-5, + ): + """ + Estimate the manifold of given feature vectors. + + :param session: the TensorFlow session. + :param row_batch_size: row batch size to compute pairwise distances + (parameter to trade-off between memory usage and performance). + :param col_batch_size: column batch size to compute pairwise distances. + :param nhood_sizes: number of neighbors used to estimate the manifold. + :param clamp_to_percentile: prune hyperspheres that have radius larger than + the given percentile. + :param eps: small number for numerical stability. + """ + self.distance_block = DistanceBlock(session) + self.row_batch_size = row_batch_size + self.col_batch_size = col_batch_size + self.nhood_sizes = nhood_sizes + self.num_nhoods = len(nhood_sizes) + self.clamp_to_percentile = clamp_to_percentile + self.eps = eps + + def warmup(self): + feats, radii = ( + np.zeros([1, 2048], dtype=np.float32), + np.zeros([1, 1], dtype=np.float32), + ) + self.evaluate_pr(feats, radii, feats, radii) + + def manifold_radii(self, features: np.ndarray) -> np.ndarray: + num_images = len(features) + + # Estimate manifold of features by calculating distances to k-NN of each sample. + radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) + distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) + seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) + + for begin1 in range(0, num_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_images) + row_batch = features[begin1:end1] + + for begin2 in range(0, num_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_images) + col_batch = features[begin2:end2] + + # Compute distances between batches. + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(row_batch, col_batch) + + # Find the k-nearest neighbor from the current batch. + radii[begin1:end1, :] = np.concatenate( + [ + x[:, self.nhood_sizes] + for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1) + ], + axis=0, + ) + + if self.clamp_to_percentile is not None: + max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) + radii[radii > max_distances] = 0 + return radii + + def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray): + """ + Evaluate if new feature vectors are at the manifold. + """ + num_eval_images = eval_features.shape[0] + num_ref_images = radii.shape[0] + distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32) + batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) + max_realism_score = np.zeros([num_eval_images], dtype=np.float32) + nearest_indices = np.zeros([num_eval_images], dtype=np.int32) + + for begin1 in range(0, num_eval_images, self.row_batch_size): + end1 = min(begin1 + self.row_batch_size, num_eval_images) + feature_batch = eval_features[begin1:end1] + + for begin2 in range(0, num_ref_images, self.col_batch_size): + end2 = min(begin2 + self.col_batch_size, num_ref_images) + ref_batch = features[begin2:end2] + + distance_batch[ + 0 : end1 - begin1, begin2:end2 + ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) + + # From the minibatch of new feature vectors, determine if they are in the estimated manifold. + # If a feature vector is inside a hypersphere of some reference sample, then + # the new sample lies at the estimated manifold. + # The radii of the hyperspheres are determined from distances of neighborhood size k. + samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii + batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) + + max_realism_score[begin1:end1] = np.max( + radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 + ) + nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1) + + return { + "fraction": float(np.mean(batch_predictions)), + "batch_predictions": batch_predictions, + "max_realisim_score": max_realism_score, + "nearest_indices": nearest_indices, + } + + def evaluate_pr( + self, + features_1: np.ndarray, + radii_1: np.ndarray, + features_2: np.ndarray, + radii_2: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Evaluate precision and recall efficiently. + + :param features_1: [N1 x D] feature vectors for reference batch. + :param radii_1: [N1 x K1] radii for reference vectors. + :param features_2: [N2 x D] feature vectors for the other batch. + :param radii_2: [N x K2] radii for other vectors. + :return: a tuple of arrays for (precision, recall): + - precision: an np.ndarray of length K1 + - recall: an np.ndarray of length K2 + """ + features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) + features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) + for begin_1 in range(0, len(features_1), self.row_batch_size): + end_1 = begin_1 + self.row_batch_size + batch_1 = features_1[begin_1:end_1] + for begin_2 in range(0, len(features_2), self.col_batch_size): + end_2 = begin_2 + self.col_batch_size + batch_2 = features_2[begin_2:end_2] + batch_1_in, batch_2_in = self.distance_block.less_thans( + batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] + ) + features_1_status[begin_1:end_1] |= batch_1_in + features_2_status[begin_2:end_2] |= batch_2_in + return ( + np.mean(features_2_status.astype(np.float64), axis=0), + np.mean(features_1_status.astype(np.float64), axis=0), + ) + + +class DistanceBlock: + """ + Calculate pairwise distances between vectors. + + Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 + """ + + def __init__(self, session): + self.session = session + + # Initialize TF graph to calculate pairwise distances. + with session.graph.as_default(): + self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) + self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) + distance_block_16 = _batch_pairwise_distances( + tf.cast(self._features_batch1, tf.float16), + tf.cast(self._features_batch2, tf.float16), + ) + self.distance_block = tf.cond( + tf.reduce_all(tf.math.is_finite(distance_block_16)), + lambda: tf.cast(distance_block_16, tf.float32), + lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2), + ) + + # Extra logic for less thans. + self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) + self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) + dist32 = tf.cast(self.distance_block, tf.float32)[..., None] + self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) + self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0) + + def pairwise_distances(self, U, V): + """ + Evaluate pairwise distances between two batches of feature vectors. + """ + return self.session.run( + self.distance_block, + feed_dict={self._features_batch1: U, self._features_batch2: V}, + ) + + def less_thans(self, batch_1, radii_1, batch_2, radii_2): + return self.session.run( + [self._batch_1_in, self._batch_2_in], + feed_dict={ + self._features_batch1: batch_1, + self._features_batch2: batch_2, + self._radii1: radii_1, + self._radii2: radii_2, + }, + ) + + +def _batch_pairwise_distances(U, V): + """ + Compute pairwise distances between two batches of feature vectors. + """ + with tf.variable_scope("pairwise_dist_block"): + # Squared norms of each row in U and V. + norm_u = tf.reduce_sum(tf.square(U), 1) + norm_v = tf.reduce_sum(tf.square(V), 1) + + # norm_u as a column and norm_v as a row vectors. + norm_u = tf.reshape(norm_u, [-1, 1]) + norm_v = tf.reshape(norm_v, [1, -1]) + + # Pairwise squared Euclidean distances. + D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) + + return D + + +class NpzArrayReader(ABC): + @abstractmethod + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + pass + + @abstractmethod + def remaining(self) -> int: + pass + + def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: + def gen_fn(): + while True: + batch = self.read_batch(batch_size) + if batch is None: + break + yield batch + + rem = self.remaining() + num_batches = rem // batch_size + int(rem % batch_size != 0) + return BatchIterator(gen_fn, num_batches) + + +class BatchIterator: + def __init__(self, gen_fn, length): + self.gen_fn = gen_fn + self.length = length + + def __len__(self): + return self.length + + def __iter__(self): + return self.gen_fn() + + +class StreamingNpzArrayReader(NpzArrayReader): + def __init__(self, arr_f, shape, dtype): + self.arr_f = arr_f + self.shape = shape + self.dtype = dtype + self.idx = 0 + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.shape[0]: + return None + + bs = min(batch_size, self.shape[0] - self.idx) + self.idx += bs + + if self.dtype.itemsize == 0: + return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) + + read_count = bs * np.prod(self.shape[1:]) + read_size = int(read_count * self.dtype.itemsize) + data = _read_bytes(self.arr_f, read_size, "array data") + return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) + + def remaining(self) -> int: + return max(0, self.shape[0] - self.idx) + + +class MemoryNpzArrayReader(NpzArrayReader): + def __init__(self, arr): + self.arr = arr + self.idx = 0 + + @classmethod + def load(cls, path: str, arr_name: str): + with open(path, "rb") as f: + arr = np.load(f)[arr_name] + return cls(arr) + + def read_batch(self, batch_size: int) -> Optional[np.ndarray]: + if self.idx >= self.arr.shape[0]: + return None + + res = self.arr[self.idx : self.idx + batch_size] + self.idx += batch_size + return res + + def remaining(self) -> int: + return max(0, self.arr.shape[0] - self.idx) + + +@contextmanager +def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: + with _open_npy_file(path, arr_name) as arr_f: + version = np.lib.format.read_magic(arr_f) + if version == (1, 0): + header = np.lib.format.read_array_header_1_0(arr_f) + elif version == (2, 0): + header = np.lib.format.read_array_header_2_0(arr_f) + else: + yield MemoryNpzArrayReader.load(path, arr_name) + return + shape, fortran, dtype = header + if fortran or dtype.hasobject: + yield MemoryNpzArrayReader.load(path, arr_name) + else: + yield StreamingNpzArrayReader(arr_f, shape, dtype) + + +def _read_bytes(fp, size, error_template="ran out of data"): + """ + Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 + + Read from file-like object until size bytes are read. + Raises ValueError if not EOF is encountered before size bytes are read. + Non-blocking objects only supported if they derive from io objects. + Required as e.g. ZipExtFile in python 2.6 can return less data than + requested. + """ + data = bytes() + while True: + # io files (default in python3) return None or raise on + # would-block, python2 file will truncate, probably nothing can be + # done about that. note that regular files can't be non-blocking + try: + r = fp.read(size - len(data)) + data += r + if len(r) == 0 or len(data) == size: + break + except io.BlockingIOError: + pass + if len(data) != size: + msg = "EOF: reading %s, expected %d bytes got %d" + raise ValueError(msg % (error_template, size, len(data))) + else: + return data + + +@contextmanager +def _open_npy_file(path: str, arr_name: str): + with open(path, "rb") as f: + with zipfile.ZipFile(f, "r") as zip_f: + if f"{arr_name}.npy" not in zip_f.namelist(): + raise ValueError(f"missing {arr_name} in npz file") + with zip_f.open(f"{arr_name}.npy", "r") as arr_f: + yield arr_f + + +def _download_inception_model(): + if os.path.exists(INCEPTION_V3_PATH): + return + print("downloading InceptionV3 model...") + with requests.get(INCEPTION_V3_URL, stream=True) as r: + r.raise_for_status() + tmp_path = INCEPTION_V3_PATH + ".tmp" + with open(tmp_path, "wb") as f: + for chunk in tqdm(r.iter_content(chunk_size=8192)): + f.write(chunk) + os.rename(tmp_path, INCEPTION_V3_PATH) + + +def _create_feature_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + pool3, spatial = tf.import_graph_def( + graph_def, + input_map={f"ExpandDims:0": input_batch}, + return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], + name=prefix, + ) + _update_shapes(pool3) + spatial = spatial[..., :7] + return pool3, spatial + + +def _create_softmax_graph(input_batch): + _download_inception_model() + prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" + with open(INCEPTION_V3_PATH, "rb") as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + (matmul,) = tf.import_graph_def( + graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix + ) + w = matmul.inputs[1] + logits = tf.matmul(input_batch, w) + return tf.nn.softmax(logits) + + +def _update_shapes(pool3): + # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 + ops = pool3.graph.get_operations() + for op in ops: + for o in op.outputs: + shape = o.get_shape() + if shape._dims is not None: # pylint: disable=protected-access + # shape = [s.value for s in shape] TF 1.x + shape = [s for s in shape] # TF 2.x + new_shape = [] + for j, s in enumerate(shape): + if s == 1 and j == 0: + new_shape.append(None) + else: + new_shape.append(s) + o.__dict__["_shape_val"] = tf.TensorShape(new_shape) + return pool3 + + +def _numpy_partition(arr, kth, **kwargs): + num_workers = min(cpu_count(), len(arr)) + chunk_size = len(arr) // num_workers + extra = len(arr) % num_workers + + start_idx = 0 + batches = [] + for i in range(num_workers): + size = chunk_size + (1 if i < extra else 0) + batches.append(arr[start_idx : start_idx + size]) + start_idx += size + + with ThreadPool(num_workers) as pool: + return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) + + +if __name__ == "__main__": + print(REQUIREMENTS) + main() diff --git a/ldm/modules/evaluate/evaluate_perceptualsim.py b/ldm/modules/evaluate/evaluate_perceptualsim.py new file mode 100644 index 0000000000000000000000000000000000000000..c85fef967b60b90e3001b0cc29aa70b1a80ed36f --- /dev/null +++ b/ldm/modules/evaluate/evaluate_perceptualsim.py @@ -0,0 +1,630 @@ +import argparse +import glob +import os +from tqdm import tqdm +from collections import namedtuple + +import numpy as np +import torch +import torchvision.transforms as transforms +from torchvision import models +from PIL import Image + +from ldm.modules.evaluate.ssim import ssim + + +transform = transforms.Compose([transforms.ToTensor()]) + +def normalize_tensor(in_feat, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1)).view( + in_feat.size()[0], 1, in_feat.size()[2], in_feat.size()[3] + ) + return in_feat / (norm_factor.expand_as(in_feat) + eps) + + +def cos_sim(in0, in1): + in0_norm = normalize_tensor(in0) + in1_norm = normalize_tensor(in1) + N = in0.size()[0] + X = in0.size()[2] + Y = in0.size()[3] + + return torch.mean( + torch.mean( + torch.sum(in0_norm * in1_norm, dim=1).view(N, 1, X, Y), dim=2 + ).view(N, 1, 1, Y), + dim=3, + ).view(N) + + +class squeezenet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(squeezenet, self).__init__() + pretrained_features = models.squeezenet1_1( + pretrained=pretrained + ).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.slice6 = torch.nn.Sequential() + self.slice7 = torch.nn.Sequential() + self.N_slices = 7 + for x in range(2): + self.slice1.add_module(str(x), pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), pretrained_features[x]) + for x in range(10, 11): + self.slice5.add_module(str(x), pretrained_features[x]) + for x in range(11, 12): + self.slice6.add_module(str(x), pretrained_features[x]) + for x in range(12, 13): + self.slice7.add_module(str(x), pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + h = self.slice6(h) + h_relu6 = h + h = self.slice7(h) + h_relu7 = h + vgg_outputs = namedtuple( + "SqueezeOutputs", + ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"], + ) + out = vgg_outputs( + h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7 + ) + + return out + + +class alexnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(alexnet, self).__init__() + alexnet_pretrained_features = models.alexnet( + pretrained=pretrained + ).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(2): + self.slice1.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(2, 5): + self.slice2.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(5, 8): + self.slice3.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(8, 10): + self.slice4.add_module(str(x), alexnet_pretrained_features[x]) + for x in range(10, 12): + self.slice5.add_module(str(x), alexnet_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1 = h + h = self.slice2(h) + h_relu2 = h + h = self.slice3(h) + h_relu3 = h + h = self.slice4(h) + h_relu4 = h + h = self.slice5(h) + h_relu5 = h + alexnet_outputs = namedtuple( + "AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"] + ) + out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) + + return out + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", + ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"], + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + + return out + + +class resnet(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True, num=18): + super(resnet, self).__init__() + if num == 18: + self.net = models.resnet18(pretrained=pretrained) + elif num == 34: + self.net = models.resnet34(pretrained=pretrained) + elif num == 50: + self.net = models.resnet50(pretrained=pretrained) + elif num == 101: + self.net = models.resnet101(pretrained=pretrained) + elif num == 152: + self.net = models.resnet152(pretrained=pretrained) + self.N_slices = 5 + + self.conv1 = self.net.conv1 + self.bn1 = self.net.bn1 + self.relu = self.net.relu + self.maxpool = self.net.maxpool + self.layer1 = self.net.layer1 + self.layer2 = self.net.layer2 + self.layer3 = self.net.layer3 + self.layer4 = self.net.layer4 + + def forward(self, X): + h = self.conv1(X) + h = self.bn1(h) + h = self.relu(h) + h_relu1 = h + h = self.maxpool(h) + h = self.layer1(h) + h_conv2 = h + h = self.layer2(h) + h_conv3 = h + h = self.layer3(h) + h_conv4 = h + h = self.layer4(h) + h_conv5 = h + + outputs = namedtuple( + "Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"] + ) + out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) + + return out + +# Off-the-shelf deep network +class PNet(torch.nn.Module): + """Pre-trained network with all channels equally weighted by default""" + + def __init__(self, pnet_type="vgg", pnet_rand=False, use_gpu=True): + super(PNet, self).__init__() + + self.use_gpu = use_gpu + + self.pnet_type = pnet_type + self.pnet_rand = pnet_rand + + self.shift = torch.Tensor([-0.030, -0.088, -0.188]).view(1, 3, 1, 1) + self.scale = torch.Tensor([0.458, 0.448, 0.450]).view(1, 3, 1, 1) + + if self.pnet_type in ["vgg", "vgg16"]: + self.net = vgg16(pretrained=not self.pnet_rand, requires_grad=False) + elif self.pnet_type == "alex": + self.net = alexnet( + pretrained=not self.pnet_rand, requires_grad=False + ) + elif self.pnet_type[:-2] == "resnet": + self.net = resnet( + pretrained=not self.pnet_rand, + requires_grad=False, + num=int(self.pnet_type[-2:]), + ) + elif self.pnet_type == "squeeze": + self.net = squeezenet( + pretrained=not self.pnet_rand, requires_grad=False + ) + + self.L = self.net.N_slices + + if use_gpu: + self.net.cuda() + self.shift = self.shift.cuda() + self.scale = self.scale.cuda() + + def forward(self, in0, in1, retPerLayer=False): + in0_sc = (in0 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) + in1_sc = (in1 - self.shift.expand_as(in0)) / self.scale.expand_as(in0) + + outs0 = self.net.forward(in0_sc) + outs1 = self.net.forward(in1_sc) + + if retPerLayer: + all_scores = [] + for (kk, out0) in enumerate(outs0): + cur_score = 1.0 - cos_sim(outs0[kk], outs1[kk]) + if kk == 0: + val = 1.0 * cur_score + else: + val = val + cur_score + if retPerLayer: + all_scores += [cur_score] + + if retPerLayer: + return (val, all_scores) + else: + return val + + + + +# The SSIM metric +def ssim_metric(img1, img2, mask=None): + return ssim(img1, img2, mask=mask, size_average=False) + + +# The PSNR metric +def psnr(img1, img2, mask=None,reshape=False): + b = img1.size(0) + if not (mask is None): + b = img1.size(0) + mse_err = (img1 - img2).pow(2) * mask + if reshape: + mse_err = mse_err.reshape(b, -1).sum(dim=1) / ( + 3 * mask.reshape(b, -1).sum(dim=1).clamp(min=1) + ) + else: + mse_err = mse_err.view(b, -1).sum(dim=1) / ( + 3 * mask.view(b, -1).sum(dim=1).clamp(min=1) + ) + else: + if reshape: + mse_err = (img1 - img2).pow(2).reshape(b, -1).mean(dim=1) + else: + mse_err = (img1 - img2).pow(2).view(b, -1).mean(dim=1) + + psnr = 10 * (1 / mse_err).log10() + return psnr + + +# The perceptual similarity metric +def perceptual_sim(img1, img2, vgg16): + # First extract features + dist = vgg16(img1 * 2 - 1, img2 * 2 - 1) + + return dist + +def load_img(img_name, size=None): + try: + img = Image.open(img_name) + + if type(size) == int: + img = img.resize((size, size)) + elif size is not None: + img = img.resize((size[1], size[0])) + + img = transform(img).cuda() + img = img.unsqueeze(0) + except Exception as e: + print("Failed at loading %s " % img_name) + print(e) + img = torch.zeros(1, 3, 256, 256).cuda() + raise + return img + + +def compute_perceptual_similarity(folder, pred_img, tgt_img, take_every_other): + + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + folders = os.listdir(folder) + for i, f in tqdm(enumerate(sorted(folders))): + pred_imgs = glob.glob(folder + f + "/" + pred_img) + tgt_imgs = glob.glob(folder + f + "/" + tgt_img) + assert len(tgt_imgs) == 1 + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + for p_img in pred_imgs: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + perc_sim = min(perc_sim, t_perc_sim) + + ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item()) + psnr_sim = max(psnr_sim, psnr(p_img, t_img).item()) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + values_psnr += [psnr_sim] + + if take_every_other: + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [ + min(values_percsim[2 * i], values_percsim[2 * i + 1]) + ] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + return { + "Perceptual similarity": (avg_percsim, std_percsim), + "PSNR": (avg_psnr, std_psnr), + "SSIM": (avg_ssim, std_ssim), + } + + +def compute_perceptual_similarity_from_list(pred_imgs_list, tgt_imgs_list, + take_every_other, + simple_format=True): + + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + equal_count = 0 + ambig_count = 0 + for i, tgt_img in enumerate(tqdm(tgt_imgs_list)): + pred_imgs = pred_imgs_list[i] + tgt_imgs = [tgt_img] + assert len(tgt_imgs) == 1 + + if type(pred_imgs) != list: + pred_imgs = [pred_imgs] + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + assert len(pred_imgs)>0 + for p_img in pred_imgs: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + perc_sim = min(perc_sim, t_perc_sim) + + ssim_sim = max(ssim_sim, ssim_metric(p_img, t_img).item()) + psnr_sim = max(psnr_sim, psnr(p_img, t_img).item()) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + if psnr_sim != np.float("inf"): + values_psnr += [psnr_sim] + else: + if torch.allclose(p_img, t_img): + equal_count += 1 + print("{} equal src and wrp images.".format(equal_count)) + else: + ambig_count += 1 + print("{} ambiguous src and wrp images.".format(ambig_count)) + + if take_every_other: + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [ + min(values_percsim[2 * i], values_percsim[2 * i + 1]) + ] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + if simple_format: + # just to make yaml formatting readable + return { + "Perceptual similarity": [float(avg_percsim), float(std_percsim)], + "PSNR": [float(avg_psnr), float(std_psnr)], + "SSIM": [float(avg_ssim), float(std_ssim)], + } + else: + return { + "Perceptual similarity": (avg_percsim, std_percsim), + "PSNR": (avg_psnr, std_psnr), + "SSIM": (avg_ssim, std_ssim), + } + + +def compute_perceptual_similarity_from_list_topk(pred_imgs_list, tgt_imgs_list, + take_every_other, resize=False): + + # Load VGG16 for feature similarity + vgg16 = PNet().to("cuda") + vgg16.eval() + vgg16.cuda() + + values_percsim = [] + values_ssim = [] + values_psnr = [] + individual_percsim = [] + individual_ssim = [] + individual_psnr = [] + for i, tgt_img in enumerate(tqdm(tgt_imgs_list)): + pred_imgs = pred_imgs_list[i] + tgt_imgs = [tgt_img] + assert len(tgt_imgs) == 1 + + if type(pred_imgs) != list: + assert False + pred_imgs = [pred_imgs] + + perc_sim = 10000 + ssim_sim = -10 + psnr_sim = -10 + sample_percsim = list() + sample_ssim = list() + sample_psnr = list() + for p_img in pred_imgs: + if resize: + t_img = load_img(tgt_imgs[0], size=(256,256)) + else: + t_img = load_img(tgt_imgs[0]) + p_img = load_img(p_img, size=t_img.shape[2:]) + + t_perc_sim = perceptual_sim(p_img, t_img, vgg16).item() + sample_percsim.append(t_perc_sim) + perc_sim = min(perc_sim, t_perc_sim) + + t_ssim = ssim_metric(p_img, t_img).item() + sample_ssim.append(t_ssim) + ssim_sim = max(ssim_sim, t_ssim) + + t_psnr = psnr(p_img, t_img).item() + sample_psnr.append(t_psnr) + psnr_sim = max(psnr_sim, t_psnr) + + values_percsim += [perc_sim] + values_ssim += [ssim_sim] + values_psnr += [psnr_sim] + individual_percsim.append(sample_percsim) + individual_ssim.append(sample_ssim) + individual_psnr.append(sample_psnr) + + if take_every_other: + assert False, "Do this later, after specifying topk to get proper results" + n_valuespercsim = [] + n_valuesssim = [] + n_valuespsnr = [] + for i in range(0, len(values_percsim) // 2): + n_valuespercsim += [ + min(values_percsim[2 * i], values_percsim[2 * i + 1]) + ] + n_valuespsnr += [max(values_psnr[2 * i], values_psnr[2 * i + 1])] + n_valuesssim += [max(values_ssim[2 * i], values_ssim[2 * i + 1])] + + values_percsim = n_valuespercsim + values_ssim = n_valuesssim + values_psnr = n_valuespsnr + + avg_percsim = np.mean(np.array(values_percsim)) + std_percsim = np.std(np.array(values_percsim)) + + avg_psnr = np.mean(np.array(values_psnr)) + std_psnr = np.std(np.array(values_psnr)) + + avg_ssim = np.mean(np.array(values_ssim)) + std_ssim = np.std(np.array(values_ssim)) + + individual_percsim = np.array(individual_percsim) + individual_psnr = np.array(individual_psnr) + individual_ssim = np.array(individual_ssim) + + return { + "avg_of_best": { + "Perceptual similarity": [float(avg_percsim), float(std_percsim)], + "PSNR": [float(avg_psnr), float(std_psnr)], + "SSIM": [float(avg_ssim), float(std_ssim)], + }, + "individual": { + "PSIM": individual_percsim, + "PSNR": individual_psnr, + "SSIM": individual_ssim, + } + } + + +if __name__ == "__main__": + args = argparse.ArgumentParser() + args.add_argument("--folder", type=str, default="") + args.add_argument("--pred_image", type=str, default="") + args.add_argument("--target_image", type=str, default="") + args.add_argument("--take_every_other", action="store_true", default=False) + args.add_argument("--output_file", type=str, default="") + + opts = args.parse_args() + + folder = opts.folder + pred_img = opts.pred_image + tgt_img = opts.target_image + + results = compute_perceptual_similarity( + folder, pred_img, tgt_img, opts.take_every_other + ) + + f = open(opts.output_file, 'w') + for key in results: + print("%s for %s: \n" % (key, opts.folder)) + print( + "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) + ) + + f.write("%s for %s: \n" % (key, opts.folder)) + f.write( + "\t {:0.4f} | {:0.4f} \n".format(results[key][0], results[key][1]) + ) + + f.close() diff --git a/ldm/modules/evaluate/frechet_video_distance.py b/ldm/modules/evaluate/frechet_video_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e13c41505d9895016cdda1a1fd59aec33ab4d0 --- /dev/null +++ b/ldm/modules/evaluate/frechet_video_distance.py @@ -0,0 +1,147 @@ +# coding=utf-8 +# Copyright 2022 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python2, python3 +"""Minimal Reference implementation for the Frechet Video Distance (FVD). + +FVD is a metric for the quality of video generation models. It is inspired by +the FID (Frechet Inception Distance) used for images, but uses a different +embedding to be better suitable for videos. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +import six +import tensorflow.compat.v1 as tf +import tensorflow_gan as tfgan +import tensorflow_hub as hub + + +def preprocess(videos, target_resolution): + """Runs some preprocessing on the videos for I3D model. + + Args: + videos: [batch_size, num_frames, height, width, depth] The videos to be + preprocessed. We don't care about the specific dtype of the videos, it can + be anything that tf.image.resize_bilinear accepts. Values are expected to + be in the range 0-255. + target_resolution: (width, height): target video resolution + + Returns: + videos: [batch_size, num_frames, height, width, depth] + """ + videos_shape = list(videos.shape) + all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) + resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) + target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] + output_videos = tf.reshape(resized_videos, target_shape) + scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 + return scaled_videos + + +def _is_in_graph(tensor_name): + """Checks whether a given tensor does exists in the graph.""" + try: + tf.get_default_graph().get_tensor_by_name(tensor_name) + except KeyError: + return False + return True + + +def create_id3_embedding(videos,warmup=False,batch_size=16): + """Embeds the given videos using the Inflated 3D Convolution ne twork. + + Downloads the graph of the I3D from tf.hub and adds it to the graph on the + first call. + + Args: + videos: [batch_size, num_frames, height=224, width=224, depth=3]. + Expected range is [-1, 1]. + + Returns: + embedding: [batch_size, embedding_size]. embedding_size depends + on the model used. + + Raises: + ValueError: when a provided embedding_layer is not supported. + """ + + # batch_size = 16 + module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" + + + # Making sure that we import the graph separately for + # each different input video tensor. + module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( + videos.name).replace(":", "_") + + + + assert_ops = [ + tf.Assert( + tf.reduce_max(videos) <= 1.001, + ["max value in frame is > 1", videos]), + tf.Assert( + tf.reduce_min(videos) >= -1.001, + ["min value in frame is < -1", videos]), + tf.assert_equal( + tf.shape(videos)[0], + batch_size, ["invalid frame batch size: ", + tf.shape(videos)], + summarize=6), + ] + with tf.control_dependencies(assert_ops): + videos = tf.identity(videos) + + module_scope = "%s_apply_default/" % module_name + + # To check whether the module has already been loaded into the graph, we look + # for a given tensor name. If this tensor name exists, we assume the function + # has been called before and the graph was imported. Otherwise we import it. + # Note: in theory, the tensor could exist, but have wrong shapes. + # This will happen if create_id3_embedding is called with a frames_placehoder + # of wrong size/batch size, because even though that will throw a tf.Assert + # on graph-execution time, it will insert the tensor (with wrong shape) into + # the graph. This is why we need the following assert. + if warmup: + video_batch_size = int(videos.shape[0]) + assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" + tensor_name = module_scope + "RGB/inception_i3d/Mean:0" + if not _is_in_graph(tensor_name): + i3d_model = hub.Module(module_spec, name=module_name) + i3d_model(videos) + + # gets the kinetics-i3d-400-logits layer + tensor_name = module_scope + "RGB/inception_i3d/Mean:0" + tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) + return tensor + + +def calculate_fvd(real_activations, + generated_activations): + """Returns a list of ops that compute metrics as funcs of activations. + + Args: + real_activations: [num_samples, embedding_size] + generated_activations: [num_samples, embedding_size] + + Returns: + A scalar that contains the requested FVD. + """ + return tfgan.eval.frechet_classifier_distance_from_activations( + real_activations, generated_activations) diff --git a/ldm/modules/evaluate/ssim.py b/ldm/modules/evaluate/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8883ccb3b30455a76caf2e4d1e04745f75d214 --- /dev/null +++ b/ldm/modules/evaluate/ssim.py @@ -0,0 +1,124 @@ +# MIT Licence + +# Methods to predict the SSIM, taken from +# https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py + +from math import exp + +import torch +import torch.nn.functional as F +from torch.autograd import Variable + +def gaussian(window_size, sigma): + gauss = torch.Tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable( + _2D_window.expand(channel, 1, window_size, window_size).contiguous() + ) + return window + + +def _ssim( + img1, img2, window, window_size, channel, mask=None, size_average=True +): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) + - mu1_sq + ) + sigma2_sq = ( + F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) + - mu2_sq + ) + sigma12 = ( + F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) + - mu1_mu2 + ) + + C1 = (0.01) ** 2 + C2 = (0.03) ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) + + if not (mask is None): + b = mask.size(0) + ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask + ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( + dim=1 + ).clamp(min=1) + return ssim_map + + import pdb + + pdb.set_trace + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2, mask=None): + (_, channel, _, _) = img1.size() + + if ( + channel == self.channel + and self.window.data.type() == img1.data.type() + ): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim( + img1, + img2, + window, + self.window_size, + channel, + mask, + self.size_average, + ) + + +def ssim(img1, img2, window_size=11, mask=None, size_average=True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, mask, size_average) diff --git a/ldm/modules/evaluate/torch_frechet_video_distance.py b/ldm/modules/evaluate/torch_frechet_video_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..04856b828a17cdc97fa88a7b9d2f7fe0f735b3fc --- /dev/null +++ b/ldm/modules/evaluate/torch_frechet_video_distance.py @@ -0,0 +1,294 @@ +# based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks! +import os +import numpy as np +import io +import re +import requests +import html +import hashlib +import urllib +import urllib.request +import scipy.linalg +import multiprocessing as mp +import glob + + +from tqdm import tqdm +from typing import Any, List, Tuple, Union, Dict, Callable + +from torchvision.io import read_video +import torch; torch.set_grad_enabled(False) +from einops import rearrange + +from nitro.util import isvideo + +def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float: + print('Calculate frechet distance...') + m = np.square(mu_sample - mu_ref).sum() + s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2)) + + return float(fid) + + +def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + mu = feats.mean(axis=0) # [d] + sigma = np.cov(feats, rowvar=False) # [d, d] + + return mu, sigma + + +def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match('^[a-z]+://', url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith('file://'): + filename = urllib.parse.urlparse(url).path + if re.match(r'^/[a-zA-Z]:', filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError("Google Drive download quota exceeded -- please try again later") + + match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) + +def load_video(ip): + vid, *_ = read_video(ip) + vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8) + return vid + +def get_data_from_str(input_str,nprc = None): + assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory' + vid_filelist = glob.glob(os.path.join(input_str,'*.mp4')) + print(f'Found {len(vid_filelist)} videos in dir {input_str}') + + if nprc is None: + try: + nprc = mp.cpu_count() + except NotImplementedError: + print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading') + nprc = 1 + + pool = mp.Pool(processes=nprc) + + vids = [] + for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'): + vids.append(v) + + + vids = torch.stack(vids,dim=0).float() + + return vids + +def get_stats(stats): + assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}' + + print(f'Using precomputed statistics under {stats}') + stats = np.load(stats) + stats = {key: stats[key] for key in stats.files} + + return stats + + + + +@torch.no_grad() +def compute_fvd(ref_input, sample_input, bs=32, + ref_stats=None, + sample_stats=None, + nprc_load=None): + + + + calc_stats = ref_stats is None or sample_stats is None + + if calc_stats: + + only_ref = sample_stats is not None + only_sample = ref_stats is not None + + + if isinstance(ref_input,str) and not only_sample: + ref_input = get_data_from_str(ref_input,nprc_load) + + if isinstance(sample_input, str) and not only_ref: + sample_input = get_data_from_str(sample_input, nprc_load) + + stats = compute_statistics(sample_input,ref_input, + device='cuda' if torch.cuda.is_available() else 'cpu', + bs=bs, + only_ref=only_ref, + only_sample=only_sample) + + if only_ref: + stats.update(get_stats(sample_stats)) + elif only_sample: + stats.update(get_stats(ref_stats)) + + + + else: + stats = get_stats(sample_stats) + stats.update(get_stats(ref_stats)) + + fvd = compute_frechet_distance(**stats) + + return {'FVD' : fvd,} + + +@torch.no_grad() +def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict: + detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1' + detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer. + + with open_url(detector_url, verbose=False) as f: + detector = torch.jit.load(f).eval().to(device) + + + + assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive' + + ref_embed, sample_embed = [], [] + + info = f'Computing I3D activations for FVD score with batch size {bs}' + + if only_ref: + + if not isvideo(videos_real): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() + print(videos_real.shape) + + if videos_real.shape[0] % bs == 0: + n_secs = videos_real.shape[0] // bs + else: + n_secs = videos_real.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + + for ref_v in tqdm(videos_real, total=len(videos_real),desc=info): + + feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + ref_embed.append(feats_ref) + + elif only_sample: + + if not isvideo(videos_fake): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() + print(videos_fake.shape) + + if videos_fake.shape[0] % bs == 0: + n_secs = videos_fake.shape[0] // bs + else: + n_secs = videos_fake.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + + for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info): + feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + sample_embed.append(feats_sample) + + + else: + + if not isvideo(videos_real): + # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] + videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() + + if not isvideo(videos_fake): + videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() + + if videos_fake.shape[0] % bs == 0: + n_secs = videos_fake.shape[0] // bs + else: + n_secs = videos_fake.shape[0] // bs + 1 + + videos_real = torch.tensor_split(videos_real, n_secs, dim=0) + videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0) + + for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info): + # print(ref_v.shape) + # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) + # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) + + + feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() + sample_embed.append(feats_sample) + ref_embed.append(feats_ref) + + out = dict() + if len(sample_embed) > 0: + sample_embed = np.concatenate(sample_embed,axis=0) + mu_sample, sigma_sample = compute_stats(sample_embed) + out.update({'mu_sample': mu_sample, + 'sigma_sample': sigma_sample}) + + if len(ref_embed) > 0: + ref_embed = np.concatenate(ref_embed,axis=0) + mu_ref, sigma_ref = compute_stats(ref_embed) + out.update({'mu_ref': mu_ref, + 'sigma_ref': sigma_ref}) + + + return out diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7836cada81f90ded99c58d5942eea4c3477f58fc --- /dev/null +++ b/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 0000000000000000000000000000000000000000..32ef56169978e550090261cddbcf5eb611a6173b --- /dev/null +++ b/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 0000000000000000000000000000000000000000..dfa760689762d4e9490fe4d817f844955f1b35de --- /dev/null +++ b/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 0000000000000000000000000000000000000000..4249b43de0f22707758d13c240268a401642f6e6 Binary files /dev/null and b/ldm/modules/image_degradation/utils/test.png differ diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 0000000000000000000000000000000000000000..0175f155ad900ae33c3c46ed87f49b352e3faf98 --- /dev/null +++ b/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..876d7c5bd6e3245ee77feb4c482b7a8143604ad5 --- /dev/null +++ b/ldm/modules/losses/__init__.py @@ -0,0 +1 @@ +from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator \ No newline at end of file diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..672c1e32a1389def02461c0781339681060c540e --- /dev/null +++ b/ldm/modules/losses/contperceptual.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn + +from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..f69981769e4bd5462600458c4fcf26620f7e4306 --- /dev/null +++ b/ldm/modules/losses/vqperceptual.py @@ -0,0 +1,167 @@ +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init +from taming.modules.losses.lpips import LPIPS +from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss + + +def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): + assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] + loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) + loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) + loss_real = (weights * loss_real).sum() / weights.sum() + loss_fake = (weights * loss_fake).sum() / weights.sum() + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def measure_perplexity(predicted_indices, n_embed): + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use + +def l1(x, y): + return torch.abs(x-y) + + +def l2(x, y): + return torch.pow((x-y), 2) + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", + pixel_loss="l1"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + assert perceptual_loss in ["lpips", "clips", "dists"] + assert pixel_loss in ["l1", "l2"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + if perceptual_loss == "lpips": + print(f"{self.__class__.__name__}: Running with LPIPS.") + self.perceptual_loss = LPIPS().eval() + else: + raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") + self.perceptual_weight = perceptual_weight + + if pixel_loss == "l1": + self.pixel_loss = l1 + else: + self.pixel_loss = l2 + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + self.n_classes = n_classes + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", predicted_indices=None): + if not exists(codebook_loss): + codebook_loss = torch.tensor([0.]).to(inputs.device) + #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + if predicted_indices is not None: + assert self.n_classes is not None + with torch.no_grad(): + perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) + log[f"{split}/perplexity"] = perplexity + log[f"{split}/cluster_usage"] = cluster_usage + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc15bf9cfe0111a910e7de33d04ffdec3877576 --- /dev/null +++ b/ldm/modules/x_transformer.py @@ -0,0 +1,641 @@ +"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers""" +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial +from inspect import isfunction +from collections import namedtuple +from einops import rearrange, repeat, reduce + +# constants + +DEFAULT_DIM_HEAD = 64 + +Intermediates = namedtuple('Intermediates', [ + 'pre_softmax_attn', + 'post_softmax_attn' +]) + +LayerIntermediates = namedtuple('Intermediates', [ + 'hiddens', + 'attn_intermediates' +]) + + +class AbsolutePositionalEmbedding(nn.Module): + def __init__(self, dim, max_seq_len): + super().__init__() + self.emb = nn.Embedding(max_seq_len, dim) + self.init_() + + def init_(self): + nn.init.normal_(self.emb.weight, std=0.02) + + def forward(self, x): + n = torch.arange(x.shape[1], device=x.device) + return self.emb(n)[None, :, :] + + +class FixedPositionalEmbedding(nn.Module): + def __init__(self, dim): + super().__init__() + inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x, seq_dim=1, offset=0): + t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset + sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1) + return emb[None, :, :] + + +# helpers + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def always(val): + def inner(*args, **kwargs): + return val + return inner + + +def not_equals(val): + def inner(x): + return x != val + return inner + + +def equals(val): + def inner(x): + return x == val + return inner + + +def max_neg_value(tensor): + return -torch.finfo(tensor.dtype).max + + +# keyword argument helpers + +def pick_and_pop(keys, d): + values = list(map(lambda key: d.pop(key), keys)) + return dict(zip(keys, values)) + + +def group_dict_by_key(cond, d): + return_val = [dict(), dict()] + for key in d.keys(): + match = bool(cond(key)) + ind = int(not match) + return_val[ind][key] = d[key] + return (*return_val,) + + +def string_begins_with(prefix, str): + return str.startswith(prefix) + + +def group_by_key_prefix(prefix, d): + return group_dict_by_key(partial(string_begins_with, prefix), d) + + +def groupby_prefix_and_trim(prefix, d): + kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) + kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + return kwargs_without_prefix, kwargs + + +# classes +class Scale(nn.Module): + def __init__(self, value, fn): + super().__init__() + self.value = value + self.fn = fn + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.value, *rest) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x, **kwargs): + x, *rest = self.fn(x, **kwargs) + return (x * self.g, *rest) + + +class ScaleNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps=1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class Residual(nn.Module): + def forward(self, x, residual): + return x + residual + + +class GRUGating(nn.Module): + def __init__(self, dim): + super().__init__() + self.gru = nn.GRUCell(dim, dim) + + def forward(self, x, residual): + gated_output = self.gru( + rearrange(x, 'b n d -> (b n) d'), + rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +# feedforward + +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +# attention. +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + mask=None, + talking_heads=False, + sparse_topk=None, + use_entmax15=False, + num_mem_kv=0, + dropout=0., + on_attn=False + ): + super().__init__() + if use_entmax15: + raise NotImplementedError("Check out entmax activation instead of softmax activation!") + self.scale = dim_head ** -0.5 + self.heads = heads + self.causal = causal + self.mask = mask + + inner_dim = dim_head * heads + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_k = nn.Linear(dim, inner_dim, bias=False) + self.to_v = nn.Linear(dim, inner_dim, bias=False) + self.dropout = nn.Dropout(dropout) + + # talking heads + self.talking_heads = talking_heads + if talking_heads: + self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads)) + + # explicit topk sparse attention + self.sparse_topk = sparse_topk + + # entmax + #self.attn_fn = entmax15 if use_entmax15 else F.softmax + self.attn_fn = F.softmax + + # add memory key / values + self.num_mem_kv = num_mem_kv + if num_mem_kv > 0: + self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head)) + + # attention on attention + self.attn_on_attn = on_attn + self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + rel_pos=None, + sinusoidal_emb=None, + prev_attn=None, + mem=None + ): + b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device + kv_input = default(context, x) + + q_input = x + k_input = kv_input + v_input = kv_input + + if exists(mem): + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) + + if exists(sinusoidal_emb): + # in shortformer, the query would start at a position offset depending on the past cached memory + offset = k_input.shape[-2] - q_input.shape[-2] + q_input = q_input + sinusoidal_emb(q_input, offset=offset) + k_input = k_input + sinusoidal_emb(k_input) + + q = self.to_q(q_input) + k = self.to_k(k_input) + v = self.to_v(v_input) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v)) + + input_mask = None + if any(map(exists, (mask, context_mask))): + q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool()) + k_mask = q_mask if not exists(context) else context_mask + k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()) + q_mask = rearrange(q_mask, 'b i -> b () i ()') + k_mask = rearrange(k_mask, 'b j -> b () () j') + input_mask = q_mask * k_mask + + if self.num_mem_kv > 0: + mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v)) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) + if exists(input_mask): + input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + mask_value = max_neg_value(dots) + + if exists(prev_attn): + dots = dots + prev_attn + + pre_softmax_attn = dots + + if talking_heads: + dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() + + if exists(rel_pos): + dots = rel_pos(dots) + + if exists(input_mask): + dots.masked_fill_(~input_mask, mask_value) + del input_mask + + if self.causal: + i, j = dots.shape[-2:] + r = torch.arange(i, device=device) + mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j') + mask = F.pad(mask, (j - i, 0), value=False) + dots.masked_fill_(mask, mask_value) + del mask + + if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]: + top, _ = dots.topk(self.sparse_topk, dim=-1) + vk = top[..., -1].unsqueeze(-1).expand_as(dots) + mask = dots < vk + dots.masked_fill_(mask, mask_value) + del mask + + attn = self.attn_fn(dots, dim=-1) + post_softmax_attn = attn + + attn = self.dropout(attn) + + if talking_heads: + attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous() + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + intermediates = Intermediates( + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn + ) + + return self.to_out(out), intermediates + + +class AttentionLayers(nn.Module): + def __init__( + self, + dim, + depth, + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_rezero=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + position_infused_attn=False, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + gate_residual=False, + **kwargs + ): + super().__init__() + ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) + attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs) + + dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + + self.dim = dim + self.depth = depth + self.layers = nn.ModuleList([]) + + self.has_pos_emb = position_infused_attn + self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None + self.rotary_pos_emb = always(None) + + assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + self.rel_pos = None + + self.pre_norm = pre_norm + + self.residual_attn = residual_attn + self.cross_residual_attn = cross_residual_attn + + norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm + norm_class = RMSNorm if use_rmsnorm else norm_class + norm_fn = partial(norm_class, dim) + + norm_fn = nn.Identity if use_rezero else norm_fn + branch_fn = Rezero if use_rezero else None + + if cross_attend and not only_cross: + default_block = ('a', 'c', 'f') + elif cross_attend and only_cross: + default_block = ('c', 'f') + else: + default_block = ('a', 'f') + + if macaron: + default_block = ('f',) + default_block + + if exists(custom_layers): + layer_types = custom_layers + elif exists(par_ratio): + par_depth = depth * len(default_block) + assert 1 < par_ratio <= par_depth, 'par ratio out of range' + default_block = tuple(filter(not_equals('f'), default_block)) + par_attn = par_depth // par_ratio + depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + par_width = (depth_cut + depth_cut // par_attn) // par_attn + assert len(default_block) <= par_width, 'default block is too large for par_ratio' + par_block = default_block + ('f',) * (par_width - len(default_block)) + par_head = par_block * par_attn + layer_types = par_head + ('f',) * (par_depth - len(par_head)) + elif exists(sandwich_coef): + assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' + layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + else: + layer_types = default_block * depth + + self.layer_types = layer_types + self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + + for layer_type in self.layer_types: + if layer_type == 'a': + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == 'c': + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == 'f': + layer = FeedForward(dim, **ff_kwargs) + layer = layer if not macaron else Scale(0.5, layer) + else: + raise Exception(f'invalid layer type {layer_type}') + + if isinstance(layer, Attention) and exists(branch_fn): + layer = branch_fn(layer) + + if gate_residual: + residual_fn = GRUGating(dim) + else: + residual_fn = Residual() + + self.layers.append(nn.ModuleList([ + norm_fn(), + layer, + residual_fn + ])) + + def forward( + self, + x, + context=None, + mask=None, + context_mask=None, + mems=None, + return_hiddens=False + ): + hiddens = [] + intermediates = [] + prev_attn = None + prev_cross_attn = None + + mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers + + for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)): + is_last = ind == (len(self.layers) - 1) + + if layer_type == 'a': + hiddens.append(x) + layer_mem = mems.pop(0) + + residual = x + + if self.pre_norm: + x = norm(x) + + if layer_type == 'a': + out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos, + prev_attn=prev_attn, mem=layer_mem) + elif layer_type == 'c': + out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn) + elif layer_type == 'f': + out = block(x) + + x = residual_fn(out, residual) + + if layer_type in ('a', 'c'): + intermediates.append(inter) + + if layer_type == 'a' and self.residual_attn: + prev_attn = inter.pre_softmax_attn + elif layer_type == 'c' and self.cross_residual_attn: + prev_cross_attn = inter.pre_softmax_attn + + if not self.pre_norm and not is_last: + x = norm(x) + + if return_hiddens: + intermediates = LayerIntermediates( + hiddens=hiddens, + attn_intermediates=intermediates + ) + + return x, intermediates + + return x + + +class Encoder(AttentionLayers): + def __init__(self, **kwargs): + assert 'causal' not in kwargs, 'cannot set causality on encoder' + super().__init__(causal=False, **kwargs) + + + +class TransformerWrapper(nn.Module): + def __init__( + self, + *, + num_tokens, + max_seq_len, + attn_layers, + emb_dim=None, + max_mem_len=0., + emb_dropout=0., + num_memory_tokens=None, + tie_embedding=False, + use_pos_emb=True + ): + super().__init__() + assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + + dim = attn_layers.dim + emb_dim = default(emb_dim, dim) + + self.max_seq_len = max_seq_len + self.max_mem_len = max_mem_len + self.num_tokens = num_tokens + + self.token_emb = nn.Embedding(num_tokens, emb_dim) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if ( + use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) + + self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() + self.attn_layers = attn_layers + self.norm = nn.LayerNorm(dim) + + self.init_() + + self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t() + + # memory tokens (like [cls]) from Memory Transformers paper + num_memory_tokens = default(num_memory_tokens, 0) + self.num_memory_tokens = num_memory_tokens + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) + + # let funnel encoder know number of memory tokens, if specified + if hasattr(attn_layers, 'num_memory_tokens'): + attn_layers.num_memory_tokens = num_memory_tokens + + def init_(self): + nn.init.normal_(self.token_emb.weight, std=0.02) + + def forward( + self, + x, + return_embeddings=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + **kwargs + ): + b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens + x = self.token_emb(x) + x += self.pos_emb(x) + x = self.emb_dropout(x) + + x = self.project_emb(x) + + if num_mem > 0: + mem = repeat(self.memory_tokens, 'n d -> b n d', b=b) + x = torch.cat((mem, x), dim=1) + + # auto-handle masking after appending memory tokens + if exists(mask): + mask = F.pad(mask, (num_mem, 0), value=True) + + x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs) + x = self.norm(x) + + mem, x = x[:, :num_mem], x[:, num_mem:] + + out = self.to_logits(x) if not return_embeddings else x + + if return_mems: + hiddens = intermediates.hiddens + new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens + new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + return out, new_mems + + if return_attn: + attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + return out, attn_maps + + return out + diff --git a/ldm/thirdp/psp/helpers.py b/ldm/thirdp/psp/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..983baaa50ea9df0cbabe09aba80293ddf7709845 --- /dev/null +++ b/ldm/thirdp/psp/helpers.py @@ -0,0 +1,121 @@ +# https://github.com/eladrich/pixel2style2pixel + +from collections import namedtuple +import torch +from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module + +""" +ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +def l2_norm(input, axis=1): + norm = torch.norm(input, 2, axis, True) + output = torch.div(input, norm) + return output + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + """ A named tuple describing a ResNet block. """ + + +def get_block(in_channel, depth, num_units, stride=2): + return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=8), + get_block(in_channel=128, depth=256, num_units=36), + get_block(in_channel=256, depth=512, num_units=3) + ] + else: + raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) + return blocks + + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + return module_input * x + + +class bottleneck_IR(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut + + +class bottleneck_IR_SE(Module): + def __init__(self, in_channel, depth, stride): + super(bottleneck_IR_SE, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth) + ) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth), + SEModule(depth, 16) + ) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + return res + shortcut \ No newline at end of file diff --git a/ldm/thirdp/psp/id_loss.py b/ldm/thirdp/psp/id_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e08ee095bd20ff664dcf470de15ff54f839b38e2 --- /dev/null +++ b/ldm/thirdp/psp/id_loss.py @@ -0,0 +1,23 @@ +# https://github.com/eladrich/pixel2style2pixel +import torch +from torch import nn +from ldm.thirdp.psp.model_irse import Backbone + + +class IDFeatures(nn.Module): + def __init__(self, model_path): + super(IDFeatures, self).__init__() + print('Loading ResNet ArcFace') + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) + self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) + self.facenet.eval() + + def forward(self, x, crop=False): + # Not sure of the image range here + if crop: + x = torch.nn.functional.interpolate(x, (256, 256), mode="area") + x = x[:, :, 35:223, 32:220] + x = self.face_pool(x) + x_feats = self.facenet(x) + return x_feats diff --git a/ldm/thirdp/psp/model_irse.py b/ldm/thirdp/psp/model_irse.py new file mode 100644 index 0000000000000000000000000000000000000000..21cedd2994a6eed5a0afd451b08dd09801fe60c0 --- /dev/null +++ b/ldm/thirdp/psp/model_irse.py @@ -0,0 +1,86 @@ +# https://github.com/eladrich/pixel2style2pixel + +from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module +from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm + +""" +Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) +""" + + +class Backbone(Module): + def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): + super(Backbone, self).__init__() + assert input_size in [112, 224], "input_size should be 112 or 224" + assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" + assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" + blocks = get_blocks(num_layers) + if mode == 'ir': + unit_module = bottleneck_IR + elif mode == 'ir_se': + unit_module = bottleneck_IR_SE + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), + PReLU(64)) + if input_size == 112: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 7 * 7, 512), + BatchNorm1d(512, affine=affine)) + else: + self.output_layer = Sequential(BatchNorm2d(512), + Dropout(drop_ratio), + Flatten(), + Linear(512 * 14 * 14, 512), + BatchNorm1d(512, affine=affine)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append(unit_module(bottleneck.in_channel, + bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + def forward(self, x): + x = self.input_layer(x) + x = self.body(x) + x = self.output_layer(x) + return l2_norm(x) + + +def IR_50(input_size): + """Constructs a ir-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_101(input_size): + """Constructs a ir-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_152(input_size): + """Constructs a ir-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_50(input_size): + """Constructs a ir_se-50 model.""" + model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_101(input_size): + """Constructs a ir_se-101 model.""" + model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) + return model + + +def IR_SE_152(input_size): + """Constructs a ir_se-152 model.""" + model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) + return model \ No newline at end of file diff --git a/ldm/util.py b/ldm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..07e2689a919f605a50866bdfd1e0faf5cc7fadc0 --- /dev/null +++ b/ldm/util.py @@ -0,0 +1,256 @@ +import importlib + +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + +import os +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +import torch +import time +import cv2 +import PIL + +def pil_rectangle_crop(im): + width, height = im.size # Get dimensions + + if width <= height: + left = 0 + right = width + top = (height - width)/2 + bottom = (height + width)/2 + else: + + top = 0 + bottom = height + left = (width - height) / 2 + bottom = (width + height) / 2 + + # Crop the center of the image + im = im.crop((left, top, right, bottom)) + return im + +def add_margin(pil_img, color, size=256): + width, height = pil_img.size + result = Image.new(pil_img.mode, (size, size), color) + result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) + return result + +def load_and_preprocess(interface, input_im): + ''' + :param input_im (PIL Image). + :return image (H, W, 3) array in [0, 1]. + ''' + # See https://github.com/Ir1d/image-background-remove-tool + image = input_im.convert('RGB') + + image_without_background = interface([image])[0] + image_without_background = np.array(image_without_background) + est_seg = image_without_background > 127 + image = np.array(image) + foreground = est_seg[:, : , -1].astype(np.bool_) + image[~foreground] = [255., 255., 255.] + x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) + image = image[y:y+h, x:x+w, :] + image = PIL.Image.fromarray(np.array(image)) + + # resize image such that long edge is 512 + image.thumbnail([200, 200], Image.Resampling.LANCZOS) + image = add_margin(image, (255, 255, 255), size=256) + image = np.array(image) + + return image + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/one2345_elev_est/.gitignore b/one2345_elev_est/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..0fe207cdc4cb61b3622443c8f5c739097174306c --- /dev/null +++ b/one2345_elev_est/.gitignore @@ -0,0 +1,3 @@ +build/ +.idea/ +*.egg-info/ diff --git a/one2345_elev_est/oee/models/loftr/__init__.py b/one2345_elev_est/oee/models/loftr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d69b9c131cf41e95c5c6ee7d389b375267b22fa --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/__init__.py @@ -0,0 +1,2 @@ +from .loftr import LoFTR +from .utils.cvpr_ds_config import default_cfg diff --git a/one2345_elev_est/oee/models/loftr/backbone/__init__.py b/one2345_elev_est/oee/models/loftr/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b6e731b3f53ab367c89ef0ea8e1cbffb0d990775 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/backbone/__init__.py @@ -0,0 +1,11 @@ +from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 + + +def build_backbone(config): + if config['backbone_type'] == 'ResNetFPN': + if config['resolution'] == (8, 2): + return ResNetFPN_8_2(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py b/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..985e5b3f273a51e51447a8025ca3aadbe46752eb --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/backbone/resnet_fpn.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +class ResNetFPN_16_4(nn.Module): + """ + ResNet+FPN, output resolution are 1/16 and 1/4. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out = self.layer4_outconv(x4) + + x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + return [x4_out, x2_out] diff --git a/one2345_elev_est/oee/models/loftr/loftr.py b/one2345_elev_est/oee/models/loftr/loftr.py new file mode 100644 index 0000000000000000000000000000000000000000..79c491ee47a4d67cb8b3fe493397349e0867accd --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +from .utils.position_encoding import PositionEncodingSine +from .loftr_module import LocalFeatureTransformer, FinePreprocess +from .utils.coarse_matching import CoarseMatching +from .utils.fine_matching import FineMatching + + +class LoFTR(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = build_backbone(config) + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'], + temp_bug_fix=config['coarse']['temp_bug_fix']) + self.loftr_coarse = LocalFeatureTransformer(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching() + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) + else: # handle different input shapes + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # 2. coarse-level loftr module + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + # 3. match coarse-level + self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py b/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca51db4f50a0c4f3dcd795e74b83e633ab2e990a --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr_module/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py b/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb8eefd362240a9901a335f0e6e07770ff04567 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr_module/fine_preprocess.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.W = self.config['fine_window_size'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + return feat0, feat1 + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py b/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..b73c5a6a6a722a44c0b68f70cb77c0988b8a5fb3 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr_module/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() diff --git a/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py b/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d79390ca08953bbef44e98149e662a681a16e42e --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/loftr_module/transformer.py @@ -0,0 +1,101 @@ +import copy +import torch +import torch.nn as nn +from .linear_attention import LinearAttention, FullAttention + + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 diff --git a/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py b/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..a97263339462dec3af9705d33d6ee634e2f46914 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/coarse_matching.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature = config['dsmax_temperature'] + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + if self.match_type == 'dual_softmax': + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) / self.temperature + if mask_c0 is not None: + sim_matrix.masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + elif self.match_type == 'sinkhorn': + # sinkhorn, dustbin included + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) + if mask_c0 is not None: + sim_matrix[:, :L, :S].masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + + # build uniform prior & use sinkhorn + log_assign_matrix = self.log_optimal_transport( + sim_matrix, self.bin_score, self.skh_iters) + assign_matrix = log_assign_matrix.exp() + conf_matrix = assign_matrix[:, :-1, :-1] + + # filter prediction with dustbin score (only in evaluation mode) + if not self.training and self.skh_prefilter: + filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] + filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] + conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 + conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 + + if self.config['sparse_spvs']: + data.update({'conf_matrix_with_bin': assign_matrix.clone()}) + + data.update({'conf_matrix': conf_matrix}) + + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches diff --git a/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py b/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9ce70154d3a1b961d3b4f08897415720f451f8 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/cvpr_ds_config.py @@ -0,0 +1,50 @@ +from yacs.config import CfgNode as CN + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +_CN = CN() +_CN.BACKBONE_TYPE = 'ResNetFPN' +_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.FINE_CONCAT_COARSE_FEAT = True + +# 1. LoFTR-backbone (local feature CNN) config +_CN.RESNETFPN = CN() +_CN.RESNETFPN.INITIAL_DIM = 128 +_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.COARSE = CN() +_CN.COARSE.D_MODEL = 256 +_CN.COARSE.D_FFN = 256 +_CN.COARSE.NHEAD = 8 +_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.COARSE.TEMP_BUG_FIX = False + +# 3. Coarse-Matching config +_CN.MATCH_COARSE = CN() +_CN.MATCH_COARSE.THR = 0.2 +_CN.MATCH_COARSE.BORDER_RM = 2 +_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.MATCH_COARSE.SKH_ITERS = 3 +_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.MATCH_COARSE.SKH_PREFILTER = True +_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory +_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock + +# 4. LoFTR-fine module config +_CN.FINE = CN() +_CN.FINE.D_MODEL = 128 +_CN.FINE.D_FFN = 128 +_CN.FINE.NHEAD = 8 +_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 +_CN.FINE.ATTENTION = 'linear' + +default_cfg = lower_config(_CN) diff --git a/one2345_elev_est/oee/models/loftr/utils/fine_matching.py b/one2345_elev_est/oee/models/loftr/utils/fine_matching.py new file mode 100644 index 0000000000000000000000000000000000000000..6e77aded52e1eb5c01e22c2738104f3b09d6922a --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/fine_matching.py @@ -0,0 +1,74 @@ +import math +import torch +import torch.nn as nn + +from kornia.geometry.subpix import dsnt +from kornia.utils.grid import create_meshgrid + + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self): + super().__init__() + + def forward(self, feat_f0, feat_f1, data): + """ + Args: + feat0 (torch.Tensor): [M, WW, C] + feat1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_f0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_f0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match(coords_normalized, data) + + @torch.no_grad() + def get_fine_match(self, coords_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) diff --git a/one2345_elev_est/oee/models/loftr/utils/geometry.py b/one2345_elev_est/oee/models/loftr/utils/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..f95cdb65b48324c4f4ceb20231b1bed992b41116 --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/geometry.py @@ -0,0 +1,54 @@ +import torch + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 diff --git a/one2345_elev_est/oee/models/loftr/utils/position_encoding.py b/one2345_elev_est/oee/models/loftr/utils/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..732d28c814ef93bf48d338ba7554f6dcfc3b880e --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/position_encoding.py @@ -0,0 +1,42 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + if temp_bug_fix: + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + else: # a buggy implementation (for backward compatability only) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] diff --git a/one2345_elev_est/oee/models/loftr/utils/supervision.py b/one2345_elev_est/oee/models/loftr/utils/supervision.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce6e79ec72b45fcb6b187e33bda93a47b168acd --- /dev/null +++ b/one2345_elev_est/oee/models/loftr/utils/supervision.py @@ -0,0 +1,151 @@ +from math import log +from loguru import logger + +import torch +from einops import repeat +from kornia.utils import create_meshgrid + +from .geometry import warp_kpts + +############## ↓ Coarse-Level supervision ↓ ############## + + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['LOFTR']['RESOLUTION'][0] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale0' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt + if 'mask0' in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + + # warp kpts bi-directionally and resize them to coarse-level resolution + # (no depth consistency check, since it leads to worse results experimentally) + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) + _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + w_pt0_c = w_pt0_i / scale1 + w_pt1_c = w_pt1_i / scale0 + + # 3. check if mutual nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round().long() + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 + w_pt1_c_round = w_pt1_c[:, :, :].round().long() + nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 + nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 + + loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i + }) + + +def compute_supervision_coarse(data, config): + assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_coarse(data, config) + else: + raise ValueError(f'Unknown data source: {data_source}') + + +############## ↓ Fine-Level supervision ↓ ############## + +@torch.no_grad() +def spvs_fine(data, config): + """ + Update: + data (dict):{ + "expec_f_gt": [M, 2]} + """ + # 1. misc + # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') + w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] + scale = config['LOFTR']['RESOLUTION'][1] + radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2 + + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + + # 3. compute gt + scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale + # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later + expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] + data.update({"expec_f_gt": expec_f_gt}) + + +def compute_supervision_fine(data, config): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config) + else: + raise NotImplementedError diff --git a/one2345_elev_est/oee/utils/elev_est_api.py b/one2345_elev_est/oee/utils/elev_est_api.py new file mode 100644 index 0000000000000000000000000000000000000000..a82345c02eae79e19e450c7d4583467da153f501 --- /dev/null +++ b/one2345_elev_est/oee/utils/elev_est_api.py @@ -0,0 +1,206 @@ +import matplotlib.pyplot as plt +import warnings + +import numpy as np +import cv2 +import os +import os.path as osp +import imageio +from copy import deepcopy + +import loguru +import torch +from oee.models.loftr import LoFTR, default_cfg +import matplotlib.cm as cm + +from oee.utils import plt_utils +from oee.utils.plotting import make_matching_figure +from oee.utils.utils3d import rect_to_img, canonical_to_camera, calc_pose + + +class ElevEstHelper: + _feature_matcher = None + + @classmethod + def get_feature_matcher(cls): + if cls._feature_matcher is None: + loguru.logger.info("Loading feature matcher...") + _default_cfg = deepcopy(default_cfg) + _default_cfg['coarse']['temp_bug_fix'] = True # set to False when using the old ckpt + matcher = LoFTR(config=_default_cfg) + ckpt_path = "weights/indoor_ds_new.ckpt" + if not osp.exists(ckpt_path): + loguru.logger.info("Downloading feature matcher...") + os.makedirs("weights", exist_ok=True) + import gdown + gdown.cached_download(url="https://drive.google.com/uc?id=19s3QvcCWQ6g-N1PrYlDCg-2mOJZ3kkgS", + path=ckpt_path) + matcher.load_state_dict(torch.load(ckpt_path)['state_dict']) + matcher = matcher.eval().cuda() + cls._feature_matcher = matcher + return cls._feature_matcher + + +def mask_out_bkgd(img_path, dbg=False): + img = imageio.imread_v2(img_path) + if img.shape[-1] == 4: + fg_mask = img[:, :, :3] + else: + loguru.logger.info("Image has no alpha channel, using thresholding to mask out background") + fg_mask = ~(img > 245).all(axis=-1) + if dbg: + plt.imshow(plt_utils.vis_mask(img, fg_mask.astype(np.uint8), color=[0, 255, 0])) + plt.show() + return fg_mask + + +def get_feature_matching(img_paths, dbg=False): + assert len(img_paths) == 4 + matcher = ElevEstHelper.get_feature_matcher() + feature_matching = {} + masks = [] + for i in range(4): + mask = mask_out_bkgd(img_paths[i], dbg=dbg) + masks.append(mask) + for i in range(0, 4): + for j in range(i + 1, 4): + img0_pth = img_paths[i] + img1_pth = img_paths[j] + mask0 = masks[i] + mask1 = masks[j] + img0_raw = cv2.imread(img0_pth, cv2.IMREAD_GRAYSCALE) + img1_raw = cv2.imread(img1_pth, cv2.IMREAD_GRAYSCALE) + original_shape = img0_raw.shape + img0_raw_resized = cv2.resize(img0_raw, (480, 480)) + img1_raw_resized = cv2.resize(img1_raw, (480, 480)) + + img0 = torch.from_numpy(img0_raw_resized)[None][None].cuda() / 255. + img1 = torch.from_numpy(img1_raw_resized)[None][None].cuda() / 255. + batch = {'image0': img0, 'image1': img1} + + # Inference with LoFTR and get prediction + with torch.no_grad(): + matcher(batch) + mkpts0 = batch['mkpts0_f'].cpu().numpy() + mkpts1 = batch['mkpts1_f'].cpu().numpy() + mconf = batch['mconf'].cpu().numpy() + mkpts0[:, 0] = mkpts0[:, 0] * original_shape[1] / 480 + mkpts0[:, 1] = mkpts0[:, 1] * original_shape[0] / 480 + mkpts1[:, 0] = mkpts1[:, 0] * original_shape[1] / 480 + mkpts1[:, 1] = mkpts1[:, 1] * original_shape[0] / 480 + keep0 = mask0[mkpts0[:, 1].astype(int), mkpts1[:, 0].astype(int)] + keep1 = mask1[mkpts1[:, 1].astype(int), mkpts1[:, 0].astype(int)] + keep = np.logical_and(keep0, keep1) + mkpts0 = mkpts0[keep] + mkpts1 = mkpts1[keep] + mconf = mconf[keep] + if dbg: + # Draw visualization + color = cm.jet(mconf) + text = [ + 'LoFTR', + 'Matches: {}'.format(len(mkpts0)), + ] + fig = make_matching_figure(img0_raw, img1_raw, mkpts0, mkpts1, color, text=text) + fig.show() + feature_matching[f"{i}_{j}"] = np.concatenate([mkpts0, mkpts1, mconf[:, None]], axis=1) + + return feature_matching + + +def gen_pose_hypothesis(center_elevation): + elevations = np.radians( + [center_elevation, center_elevation - 10, center_elevation + 10, center_elevation, center_elevation]) # 45~120 + azimuths = np.radians([30, 30, 30, 20, 40]) + input_poses = calc_pose(elevations, azimuths, len(azimuths)) + input_poses = input_poses[1:] + input_poses[..., 1] *= -1 + input_poses[..., 2] *= -1 + return input_poses + + +def ba_error_general(K, matches, poses): + projmat0 = K @ poses[0].inverse()[:3, :4] + projmat1 = K @ poses[1].inverse()[:3, :4] + match_01 = matches[0] + pts0 = match_01[:, :2] + pts1 = match_01[:, 2:4] + Xref = cv2.triangulatePoints(projmat0.cpu().numpy(), projmat1.cpu().numpy(), + pts0.cpu().numpy().T, pts1.cpu().numpy().T) + Xref = Xref[:3] / Xref[3:] + Xref = Xref.T + Xref = torch.from_numpy(Xref).cuda().float() + reproj_error = 0 + for match, cp in zip(matches[1:], poses[2:]): + dist = (torch.norm(match_01[:, :2][:, None, :] - match[:, :2][None, :, :], dim=-1)) + if dist.numel() > 0: + # print("dist.shape", dist.shape) + m0to2_index = dist.argmin(1) + keep = dist[torch.arange(match_01.shape[0]), m0to2_index] < 1 + if keep.sum() > 0: + xref_in2 = rect_to_img(K, canonical_to_camera(Xref, cp.inverse())) + reproj_error2 = torch.norm(match[m0to2_index][keep][:, 2:4] - xref_in2[keep], dim=-1) + conf02 = match[m0to2_index][keep][:, -1] + reproj_error += (reproj_error2 * conf02).sum() / (conf02.sum()) + + return reproj_error + + +def find_optim_elev(elevs, nimgs, matches, K, dbg=False): + errs = [] + for elev in elevs: + err = 0 + cam_poses = gen_pose_hypothesis(elev) + for start in range(nimgs - 1): + batch_matches, batch_poses = [], [] + for i in range(start, nimgs + start): + ci = i % nimgs + batch_poses.append(cam_poses[ci]) + for j in range(nimgs - 1): + key = f"{start}_{(start + j + 1) % nimgs}" + match = matches[key] + batch_matches.append(match) + err += ba_error_general(K, batch_matches, batch_poses) + errs.append(err) + errs = torch.tensor(errs) + if dbg: + plt.plot(elevs, errs) + plt.show() + optim_elev = elevs[torch.argmin(errs)].item() + return optim_elev + + +def get_elev_est(feature_matching, min_elev=30, max_elev=150, K=None, dbg=False): + flag = True + matches = {} + for i in range(4): + for j in range(i + 1, 4): + match_ij = feature_matching[f"{i}_{j}"] + if len(match_ij) == 0: + flag = False + match_ji = np.concatenate([match_ij[:, 2:4], match_ij[:, 0:2], match_ij[:, 4:5]], axis=1) + matches[f"{i}_{j}"] = torch.from_numpy(match_ij).float().cuda() + matches[f"{j}_{i}"] = torch.from_numpy(match_ji).float().cuda() + if not flag: + loguru.logger.info("0 matches, could not estimate elevation") + return None + interval = 10 + elevs = np.arange(min_elev, max_elev, interval) + optim_elev1 = find_optim_elev(elevs, 4, matches, K) + + elevs = np.arange(optim_elev1 - 10, optim_elev1 + 10, 1) + optim_elev2 = find_optim_elev(elevs, 4, matches, K) + + return optim_elev2 + + +def elev_est_api(img_paths, min_elev=30, max_elev=150, K=None, dbg=False): + feature_matching = get_feature_matching(img_paths, dbg=dbg) + if K is None: + loguru.logger.warning("K is not provided, using default K") + K = np.array([[280.0, 0, 128.0], + [0, 280.0, 128.0], + [0, 0, 1]]) + K = torch.from_numpy(K).cuda().float() + elev = get_elev_est(feature_matching, min_elev, max_elev, K, dbg=dbg) + return elev diff --git a/one2345_elev_est/oee/utils/plotting.py b/one2345_elev_est/oee/utils/plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7ac1de4b1fb6d0cbeda2f61eca81c68a9ba423 --- /dev/null +++ b/one2345_elev_est/oee/utils/plotting.py @@ -0,0 +1,154 @@ +import bisect +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + + +def _compute_conf_thresh(data): + dataset_name = data['dataset_name'][0].lower() + if dataset_name == 'scannet': + thr = 5e-4 + elif dataset_name == 'megadepth': + thr = 1e-4 + else: + raise ValueError(f'Unknown dataset: {dataset_name}') + return thr + + +# --- VISUALIZATION --- # + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + + +def _make_evaluation_figure(data, b_id, alpha='dynamic'): + b_mask = data['m_bids'] == b_id + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] + + epi_errs = data['epi_errs'][b_mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def _make_confidence_figure(data, b_id): + # TODO: Implement confidence figure + raise NotImplementedError() + + +def make_matching_figures(data, config, mode='evaluation'): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA) + elif mode == 'confidence': + fig = _make_confidence_figure(data, b_id) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) diff --git a/one2345_elev_est/oee/utils/plt_utils.py b/one2345_elev_est/oee/utils/plt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..92353edab179de9f702633a01e123e94403bd83f --- /dev/null +++ b/one2345_elev_est/oee/utils/plt_utils.py @@ -0,0 +1,318 @@ +import os.path as osp +import os +import matplotlib.pyplot as plt +import torch +import cv2 +import math + +import numpy as np +import tqdm +from cv2 import findContours +from dl_ext.primitive import safe_zip +from dl_ext.timer import EvalTime + + +def plot_confidence(confidence): + n = len(confidence) + plt.plot(np.arange(n), confidence) + plt.show() + + +def image_grid( + images, + rows=None, + cols=None, + fill: bool = True, + show_axes: bool = False, + rgb=None, + show=True, + label=None, + **kwargs +): + """ + A util function for plotting a grid of images. + Args: + images: (N, H, W, 4) array of RGBA images + rows: number of rows in the grid + cols: number of columns in the grid + fill: boolean indicating if the space between images should be filled + show_axes: boolean indicating if the axes of the plots should be visible + rgb: boolean, If True, only RGB channels are plotted. + If False, only the alpha channel is plotted. + Returns: + None + """ + evaltime = EvalTime(disable=True) + evaltime('') + if isinstance(images, torch.Tensor): + images = images.detach().cpu() + if len(images[0].shape) == 2: + rgb = False + if images[0].shape[-1] == 2: + # flow + images = [flow_to_image(im) for im in images] + if (rows is None) != (cols is None): + raise ValueError("Specify either both rows and cols or neither.") + + if rows is None: + rows = int(len(images) ** 0.5) + cols = math.ceil(len(images) / rows) + + gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} + if len(images) < 50: + figsize = (10, 10) + else: + figsize = (15, 15) + evaltime('0.5') + plt.figure(figsize=figsize) + # fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=figsize) + if label: + # fig.suptitle(label, fontsize=30) + plt.suptitle(label, fontsize=30) + # bleed = 0 + # fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) + evaltime('subplots') + + # for i, (ax, im) in enumerate(tqdm.tqdm(zip(axarr.ravel(), images), leave=True, total=len(images))): + for i in range(len(images)): + # evaltime(f'{i} begin') + plt.subplot(rows, cols, i + 1) + if rgb: + # only render RGB channels + plt.imshow(images[i][..., :3], **kwargs) + # ax.imshow(im[..., :3], **kwargs) + else: + # only render Alpha channel + plt.imshow(images[i], **kwargs) + # ax.imshow(im, **kwargs) + if not show_axes: + plt.axis('off') + # ax.set_axis_off() + # ax.set_title(f'{i}') + plt.title(f'{i}') + # evaltime(f'{i} end') + evaltime('2') + if show: + plt.show() + # return fig + + +def depth_grid( + depths, + rows=None, + cols=None, + fill: bool = True, + show_axes: bool = False, +): + """ + A util function for plotting a grid of images. + Args: + images: (N, H, W, 4) array of RGBA images + rows: number of rows in the grid + cols: number of columns in the grid + fill: boolean indicating if the space between images should be filled + show_axes: boolean indicating if the axes of the plots should be visible + rgb: boolean, If True, only RGB channels are plotted. + If False, only the alpha channel is plotted. + Returns: + None + """ + if (rows is None) != (cols is None): + raise ValueError("Specify either both rows and cols or neither.") + + if rows is None: + rows = len(depths) + cols = 1 + + gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} + fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)) + bleed = 0 + fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)) + + for ax, im in zip(axarr.ravel(), depths): + ax.imshow(im) + if not show_axes: + ax.set_axis_off() + plt.show() + + +def hover_masks_on_imgs(images, masks): + masks = np.array(masks) + new_imgs = [] + tids = list(range(1, masks.max() + 1)) + colors = colormap(rgb=True, lighten=True) + for im, mask in tqdm.tqdm(safe_zip(images, masks), total=len(images)): + for tid in tids: + im = vis_mask( + im, + (mask == tid).astype(np.uint8), + color=colors[tid], + alpha=0.5, + border_alpha=0.5, + border_color=[255, 255, 255], + border_thick=3) + new_imgs.append(im) + return new_imgs + + +def vis_mask(img, + mask, + color=[255, 255, 255], + alpha=0.4, + show_border=True, + border_alpha=0.5, + border_thick=1, + border_color=None): + """Visualizes a single binary mask.""" + if isinstance(mask, torch.Tensor): + from anypose.utils.pn_utils import to_array + mask = to_array(mask > 0).astype(np.uint8) + img = img.astype(np.float32) + idx = np.nonzero(mask) + + img[idx[0], idx[1], :] *= 1.0 - alpha + img[idx[0], idx[1], :] += [alpha * x for x in color] + + if show_border: + contours, _ = findContours( + mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + # contours = [c for c in contours if c.shape[0] > 10] + if border_color is None: + border_color = color + if not isinstance(border_color, list): + border_color = border_color.tolist() + if border_alpha < 1: + with_border = img.copy() + cv2.drawContours(with_border, contours, -1, border_color, + border_thick, cv2.LINE_AA) + img = (1 - border_alpha) * img + border_alpha * with_border + else: + cv2.drawContours(img, contours, -1, border_color, border_thick, + cv2.LINE_AA) + + return img.astype(np.uint8) + + +def colormap(rgb=False, lighten=True): + """Copied from Detectron codebase.""" + color_list = np.array( + [ + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.167, 0.000, 0.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.000, 0.000, 0.000, + 0.143, 0.143, 0.143, + 0.286, 0.286, 0.286, + 0.429, 0.429, 0.429, + 0.571, 0.571, 0.571, + 0.714, 0.714, 0.714, + 0.857, 0.857, 0.857, + 1.000, 1.000, 1.000 + ] + ).astype(np.float32) + color_list = color_list.reshape((-1, 3)) + if not rgb: + color_list = color_list[:, ::-1] + + if lighten: + # Make all the colors a little lighter / whiter. This is copied + # from the detectron visualization code (search for 'w_ratio'). + w_ratio = 0.4 + color_list = (color_list * (1 - w_ratio) + w_ratio) + return color_list * 255 + + +def vis_layer_mask(masks, save_path=None): + masks = torch.as_tensor(masks) + tids = masks.unique().tolist() + tids.remove(0) + for tid in tqdm.tqdm(tids): + show = save_path is None + image_grid(masks == tid, label=f'{tid}', show=show) + if save_path: + os.makedirs(osp.dirname(save_path), exist_ok=True) + plt.savefig(save_path % tid) + plt.close('all') + + +def show(x, **kwargs): + if isinstance(x, torch.Tensor): + x = x.detach().cpu() + plt.imshow(x, **kwargs) + plt.show() + + +def vis_title(rgb, text, shift_y=30): + tmp = rgb.copy() + shift_x = rgb.shape[1] // 2 + cv2.putText(tmp, text, + (shift_x, shift_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=2, lineType=cv2.LINE_AA) + return tmp diff --git a/one2345_elev_est/oee/utils/utils3d.py b/one2345_elev_est/oee/utils/utils3d.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc92fbde4143a4ed5187c989e3f98a896e7caab --- /dev/null +++ b/one2345_elev_est/oee/utils/utils3d.py @@ -0,0 +1,62 @@ +import numpy as np +import torch + + +def cart_to_hom(pts): + """ + :param pts: (N, 3 or 2) + :return pts_hom: (N, 4 or 3) + """ + if isinstance(pts, np.ndarray): + pts_hom = np.concatenate((pts, np.ones([*pts.shape[:-1], 1], dtype=np.float32)), -1) + else: + ones = torch.ones([*pts.shape[:-1], 1], dtype=torch.float32, device=pts.device) + pts_hom = torch.cat((pts, ones), dim=-1) + return pts_hom + + +def hom_to_cart(pts): + return pts[..., :-1] / pts[..., -1:] + + +def canonical_to_camera(pts, pose): + pts = cart_to_hom(pts) + pts = pts @ pose.transpose(-1, -2) + pts = hom_to_cart(pts) + return pts + + +def rect_to_img(K, pts_rect): + from dl_ext.vision_ext.datasets.kitti.structures import Calibration + pts_2d_hom = pts_rect @ K.t() + pts_img = Calibration.hom_to_cart(pts_2d_hom) + return pts_img + + +def calc_pose(phis, thetas, size, radius=1.2): + import torch + def normalize(vectors): + return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) + + device = torch.device('cuda') + thetas = torch.FloatTensor(thetas).to(device) + phis = torch.FloatTensor(phis).to(device) + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + -radius * torch.cos(thetas) * torch.sin(phis), + radius * torch.cos(phis), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = normalize(centers).squeeze(0) + up_vector = torch.FloatTensor([0, 0, 1]).to(device).unsqueeze(0).repeat(size, 1) + right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1)) + if right_vector.pow(2).sum() < 0.01: + right_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1) + up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + return poses diff --git a/one2345_elev_est/pyproject.toml b/one2345_elev_est/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..13cbbaaab1cdb4adc8e5553027d34cf9d1c1abe4 --- /dev/null +++ b/one2345_elev_est/pyproject.toml @@ -0,0 +1,7 @@ +[project] +name = "one2345_elev_est" +version = "0.1" + +[tool.setuptools.packages.find] +exclude = ["configs", "tests"] # empty by default +namespaces = false # true by default \ No newline at end of file diff --git a/one2345_elev_est/tools/estimate_wild_imgs.py b/one2345_elev_est/tools/estimate_wild_imgs.py new file mode 100644 index 0000000000000000000000000000000000000000..47103d46e48d3758a2cc237f659dd6fe5fc183c7 --- /dev/null +++ b/one2345_elev_est/tools/estimate_wild_imgs.py @@ -0,0 +1,35 @@ +import tqdm +import imageio +import json +import os.path as osp +import os + +from oee.utils import plt_utils +from oee.utils.elev_est_api import elev_est_api + + +def visualize(img_paths, elev): + imgs = [imageio.imread_v2(img_path) for img_path in img_paths] + plt_utils.image_grid(imgs, 2, 2, label=f"elev={elev}") + + +def estimate_elev(root_dir): + # root_dir = "/home/linghao/Datasets/objaverse-processed/zero12345_img/wild" + # dataset = "supp_fail" + # root_dir = "/home/chao/chao/OpenComplete/zero123/zero123/gradio_tmp/" + # obj_names = sorted(os.listdir(root_dir)) + # results = {} + # for obj_name in tqdm.tqdm(obj_names): + img_dir = osp.join(root_dir, "stage2_8") + img_paths = [] + for i in range(4): + img_paths.append(f"{img_dir}/0_{i}.png") + elev = elev_est_api(img_paths) + # visualize(img_paths, elev) + # results[obj_name] = elev + # json.dump(results, open(osp.join(root_dir, f"../{dataset}_elev.json"), "w"), indent=4) + return elev + + +# if __name__ == '__main__': +# main() diff --git a/one2345_elev_est/tools/example.py b/one2345_elev_est/tools/example.py new file mode 100644 index 0000000000000000000000000000000000000000..065f31e3e64f21494b04ca2aed87e665ddc6d23d --- /dev/null +++ b/one2345_elev_est/tools/example.py @@ -0,0 +1,38 @@ +import imageio +import numpy as np + +from oee.utils import plt_utils +from oee.utils.elev_est_api import elev_est_api +import argparse + + +def visualize(img_paths, elev): + imgs = [imageio.imread_v2(img_path) for img_path in img_paths] + plt_utils.image_grid(imgs, 2, 2, label=f"elev={elev}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--img_paths", type=str, nargs=4, help="image paths", + default=["assets/example_data/0_0.png", + "assets/example_data/0_1.png", + "assets/example_data/0_2.png", + "assets/example_data/0_3.png"]) + parser.add_argument("--min_elev", type=float, default=30, help="min elevation") + parser.add_argument("--max_elev", type=float, default=150, help="max elevation") + parser.add_argument("--dbg", default=False, action="store_true", help="debug mode") + parser.add_argument("--K_path", type=str, default=None, help="path to K") + args = parser.parse_args() + + if args.K_path is not None: + K = np.loadtxt(args.K_path) + else: + K = None + + elev = elev_est_api(args.img_paths, args.min_elev, args.max_elev, K, args.dbg) + + visualize(args.img_paths, elev) + + +if __name__ == '__main__': + main() diff --git a/one2345_elev_est/tools/weights/indoor_ds_new.ckpt b/one2345_elev_est/tools/weights/indoor_ds_new.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..ef68cc903b08e710d46a51c2aeb97407e047f3a5 --- /dev/null +++ b/one2345_elev_est/tools/weights/indoor_ds_new.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be9ff88b323ec27889114719f668ae41aff7034b56a4c4acbd46b8b180b87ed3 +size 46355053 diff --git a/sam_utils.py b/sam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6371910204a4b1826261c2eed450bfdb9244cf --- /dev/null +++ b/sam_utils.py @@ -0,0 +1,50 @@ +import os +import numpy as np +import torch +from PIL import Image +import time + +from segment_anything import sam_model_registry, SamPredictor + +def sam_init(device_id=0): + sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_vit_h_4b8939.pth") + model_type = "vit_h" + + device = "cuda:{}".format(device_id) if torch.cuda.is_available() else "cpu" + + sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) + predictor = SamPredictor(sam) + return predictor + +def sam_out_nosave(predictor, input_image, *bbox_sliders): + bbox = np.array(bbox_sliders) + image = np.asarray(input_image) + + start_time = time.time() + predictor.set_image(image) + + h, w, _ = image.shape + input_point = np.array([[h//2, w//2]]) + input_label = np.array([1]) + + masks, scores, logits = predictor.predict( + point_coords=input_point, + point_labels=input_label, + multimask_output=True, + ) + + masks_bbox, scores_bbox, logits_bbox = predictor.predict( + box=bbox, + multimask_output=True + ) + + print(f"SAM Time: {time.time() - start_time:.3f}s") + opt_idx = np.argmax(scores) + mask = masks[opt_idx] + out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) + out_image[:, :, :3] = image + out_image_bbox = out_image.copy() + out_image[:, :, 3] = mask.astype(np.uint8) * 255 + out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox) + torch.cuda.empty_cache() + return Image.fromarray(out_image_bbox, mode='RGBA') \ No newline at end of file diff --git a/sam_vit_h_4b8939.pth b/sam_vit_h_4b8939.pth new file mode 100644 index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72 --- /dev/null +++ b/sam_vit_h_4b8939.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e +size 2564550879 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5549ce1de43f3f0d6e385506941852e503a8184d --- /dev/null +++ b/utils.py @@ -0,0 +1,103 @@ +import os +import json +import numpy as np +import cv2 +from PIL import Image + +# contrast correction, rescale and recenter +def image_preprocess_nosave(input_image, lower_contrast=True, rescale=True): + + image_arr = np.array(input_image) + in_w, in_h = image_arr.shape[:2] + + if lower_contrast: + alpha = 0.8 # Contrast control (1.0-3.0) + beta = 0 # Brightness control (0-100) + # Apply the contrast adjustment + image_arr = cv2.convertScaleAbs(image_arr, alpha=alpha, beta=beta) + image_arr[image_arr[...,-1]>200, -1] = 255 + + ret, mask = cv2.threshold(np.array(input_image.split()[-1]), 0, 255, cv2.THRESH_BINARY) + x, y, w, h = cv2.boundingRect(mask) + max_size = max(w, h) + ratio = 0.75 + if rescale: + side_len = int(max_size / ratio) + else: + side_len = in_w + padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8) + center = side_len//2 + padded_image[center-h//2:center-h//2+h, center-w//2:center-w//2+w] = image_arr[y:y+h, x:x+w] + rgba = Image.fromarray(padded_image).resize((256, 256), Image.LANCZOS) + + rgba_arr = np.array(rgba) / 255.0 + rgb = rgba_arr[...,:3] * rgba_arr[...,-1:] + (1 - rgba_arr[...,-1:]) + return Image.fromarray((rgb * 255).astype(np.uint8)) + +# pose generation +def calc_pose(phis, thetas, size, radius = 1.2, device='cuda'): + import torch + def normalize(vectors): + return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) + thetas = torch.FloatTensor(thetas).to(device) + phis = torch.FloatTensor(phis).to(device) + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + -radius * torch.cos(thetas) * torch.sin(phis), + radius * torch.cos(phis), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = normalize(centers).squeeze(0) + up_vector = torch.FloatTensor([0, 0, 1]).to(device).unsqueeze(0).repeat(size, 1) + right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1)) + if right_vector.pow(2).sum() < 0.01: + right_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(size, 1) + up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float, device=device)[:3].unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + return poses + +def get_poses(init_elev): + mid = init_elev + deg = 10 + if init_elev <= 75: + low = init_elev + 30 + # e.g. 30, 60, 20, 40, 30, 30, 50, 70, 50, 50 + + elevations = np.radians([mid]*4 + [low]*4 + [mid-deg,mid+deg,mid,mid]*4 + [low-deg,low+deg,low,low]*4) + img_ids = [f"{num}.png" for num in range(8)] + [f"{num}_{view_num}.png" for num in range(8) for view_num in range(4)] + else: + + high = init_elev - 30 + elevations = np.radians([mid]*4 + [high]*4 + [mid-deg,mid+deg,mid,mid]*4 + [high-deg,high+deg,high,high]*4) + img_ids = [f"{num}.png" for num in list(range(4)) + list(range(8,12))] + \ + [f"{num}_{view_num}.png" for num in list(range(4)) + list(range(8,12)) for view_num in range(4)] + overlook_theta = [30+x*90 for x in range(4)] + eyelevel_theta = [60+x*90 for x in range(4)] + source_theta_delta = [0, 0, -deg, deg] + azimuths = np.radians(overlook_theta + eyelevel_theta + \ + [view_theta + source for view_theta in overlook_theta for source in source_theta_delta] + \ + [view_theta + source for view_theta in eyelevel_theta for source in source_theta_delta]) + return img_ids, calc_pose(elevations, azimuths, len(azimuths)).cpu().numpy() + + +def gen_poses(shape_dir, pose_est): + img_ids, input_poses = get_poses(pose_est) + + out_dict = {} + focal = 560/2; h = w = 256 + out_dict['intrinsics'] = [[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]] + out_dict['near_far'] = [1.2-0.7, 1.2+0.7] + out_dict['c2ws'] = {} + for view_id, img_id in enumerate(img_ids): + pose = input_poses[view_id] + pose = pose.tolist() + pose = [pose[0], pose[1], pose[2], [0, 0, 0, 1]] + out_dict['c2ws'][img_id] = pose + json_path = os.path.join(shape_dir, 'pose.json') + with open(json_path, 'w') as f: + json.dump(out_dict, f, indent=4) diff --git a/zero123_utils.py b/zero123_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ad274d47c87065ed576e1cfb803c4f741c89e6 --- /dev/null +++ b/zero123_utils.py @@ -0,0 +1,175 @@ +import os +import numpy as np +import torch +from contextlib import nullcontext +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from einops import rearrange +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from omegaconf import OmegaConf +from PIL import Image +from rich import print +from transformers import CLIPImageProcessor +from torch import autocast +from torchvision import transforms + + +def load_model_from_config(config, ckpt, device, verbose=False): + print(f'Loading model from {ckpt}') + pl_sd = torch.load(ckpt, map_location='cpu') + if 'global_step' in pl_sd: + print(f'Global Step: {pl_sd["global_step"]}') + sd = pl_sd['state_dict'] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print('missing keys:') + print(m) + if len(u) > 0 and verbose: + print('unexpected keys:') + print(u) + + model.to(device) + model.eval() + return model + + +def init_model(device, ckpt): + config = os.path.join(os.path.dirname(__file__), 'configs/sd-objaverse-finetune-c_concat-256.yaml') + config = OmegaConf.load(config) + + # Instantiate all models beforehand for efficiency. + models = dict() + print('Instantiating LatentDiffusion...') + models['turncam'] = torch.compile(load_model_from_config(config, ckpt, device=device)) + print('Instantiating StableDiffusionSafetyChecker...') + models['nsfw'] = StableDiffusionSafetyChecker.from_pretrained( + 'CompVis/stable-diffusion-safety-checker').to(device) + models['clip_fe'] = CLIPImageProcessor.from_pretrained( + "openai/clip-vit-large-patch14") + # We multiply all by some factor > 1 to make them less likely to be triggered. + models['nsfw'].concept_embeds_weights *= 1.2 + models['nsfw'].special_care_embeds_weights *= 1.2 + + return models + +@torch.no_grad() +def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='autocast', ddim_eta=1.0, ddim_steps=75, scale=3.0, h=256, w=256): + precision_scope = autocast if precision == 'autocast' else nullcontext + with precision_scope("cuda"): + with model.ema_scope(): + c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1) + T = [] + for x, y in zip(xs, ys): + T.append([np.radians(x), np.sin(np.radians(y)), np.cos(np.radians(y)), 0]) + T = torch.tensor(np.array(T))[:, None, :].float().to(c.device) + c = torch.cat([c, T], dim=-1) + c = model.cc_projection(c) + cond = {} + cond['c_crossattn'] = [c] + cond['c_concat'] = [model.encode_first_stage(input_im).mode().detach() + .repeat(n_samples, 1, 1, 1)] + if scale != 1.0: + uc = {} + uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)] + uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)] + else: + uc = None + + shape = [4, h // 8, w // 8] + samples_ddim, _ = sampler.sample(S=ddim_steps, + conditioning=cond, + batch_size=n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=uc, + eta=ddim_eta, + x_T=None) + print(samples_ddim.shape) + # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False) + x_samples_ddim = model.decode_first_stage(samples_ddim) + ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu() + del cond, c, x_samples_ddim, samples_ddim, uc, input_im + torch.cuda.empty_cache() + return ret_imgs + +@torch.no_grad() +def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0): + # raw_im = raw_im.resize([256, 256], Image.LANCZOS) + # input_im_init = preprocess_image(models, raw_im, preprocess=False) + input_im_init = np.asarray(raw_im, dtype=np.float32) / 255.0 + input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device) + input_im = input_im * 2 - 1 + + # stage 1: 8 + delta_x_1_8 = [0] * 4 + [30] * 4 + [-30] * 4 + delta_y_1_8 = [0+90*(i%4) if i < 4 else 30+90*(i%4) for i in range(8)] + [30+90*(i%4) for i in range(4)] + + ret_imgs = [] + sampler = DDIMSampler(model) + # sampler.to(device) + if adjust_set != []: + x_samples_ddims_8 = sample_model_batch(model, sampler, input_im, + [delta_x_1_8[i] for i in adjust_set], [delta_y_1_8[i] for i in adjust_set], + n_samples=len(adjust_set), ddim_steps=ddim_steps, scale=scale) + else: + x_samples_ddims_8 = sample_model_batch(model, sampler, input_im, delta_x_1_8, delta_y_1_8, n_samples=len(delta_x_1_8), ddim_steps=ddim_steps, scale=scale) + sample_idx = 0 + for stage1_idx in range(len(delta_x_1_8)): + if adjust_set != [] and stage1_idx not in adjust_set: + continue + x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx].numpy(), 'c h w -> h w c') + out_image = Image.fromarray(x_sample.astype(np.uint8)) + ret_imgs.append(out_image) + if save_path: + out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx))) + sample_idx += 1 + del x_samples_ddims_8 + del sampler + torch.cuda.empty_cache() + return ret_imgs + +def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_2, indices, device, ddim_steps=75, scale=3.0): + for stage1_idx in indices: + # save stage 1 image + # x_sample = 255.0 * rearrange(x_samples_ddims[stage1_idx].cpu().numpy(), 'c h w -> h w c') + # Image.fromarray(x_sample.astype(np.uint8)).save() + stage1_image_path = os.path.join(save_path_stage1, '%d.png'%(stage1_idx)) + + raw_im = Image.open(stage1_image_path) + # input_im_init = preprocess_image(models, raw_im, preprocess=False) + input_im_init = np.asarray(raw_im, dtype=np.float32) #/ 255.0 + input_im_init[input_im_init >= 253.0] = 255.0 + input_im_init = input_im_init / 255.0 + input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device) + input_im = input_im * 2 - 1 + # infer stage 2 + sampler = DDIMSampler(model) + # sampler.to(device) + # stage2_in = x_samples_ddims[stage1_idx][None, ...].to(device) * 2 - 1 + x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale) + for stage2_idx in range(len(delta_x_2)): + x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c') + Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx))) + del input_im + del x_samples_ddims_stage2 + torch.cuda.empty_cache() + +def zero123_infer(model, input_dir_path, start_idx=0, end_idx=12, indices=None, device="cuda", ddim_steps=75, scale=3.0): + # input_img_path = os.path.join(input_dir_path, "input_256.png") + save_path_8 = os.path.join(input_dir_path, "stage1_8") + save_path_8_2 = os.path.join(input_dir_path, "stage2_8") + os.makedirs(save_path_8_2, exist_ok=True) + + # raw_im = Image.open(input_img_path) + # # input_im_init = preprocess_image(models, raw_im, preprocess=False) + # input_im_init = np.asarray(raw_im, dtype=np.float32) / 255.0 + # input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device) + # input_im = input_im * 2 - 1 + + # stage 2: 6*4 or 8*4 + delta_x_2 = [-10, 10, 0, 0] + delta_y_2 = [0, 0, -10, 10] + + infer_stage_2(model, save_path_8, save_path_8_2, delta_x_2, delta_y_2, indices=indices if indices else list(range(start_idx,end_idx)), device=device, ddim_steps=ddim_steps, scale=scale)