import os import json import gzip import glob import torch import numpy as np import imageio import torch.nn.functional as F import cv2 def load_co3d_data(cfg): # load meta with gzip.open(cfg.annot_path, 'rt', encoding='utf8') as zipfile: annot = [v for v in json.load(zipfile) if v['sequence_name'] == cfg.sequence_name] with open(cfg.split_path) as f: split = json.load(f) train_im_path = set() test_im_path = set() for k, lst in split.items(): for v in lst: if v[0] == cfg.sequence_name: if 'known' in k: train_im_path.add(v[-1]) else: test_im_path.add(v[-1]) assert len(annot) == len(train_im_path) + len(test_im_path), 'Mismatch: '\ f'{len(annot)} == {len(train_im_path) + len(test_im_path)}' # load datas imgs = [] masks = [] poses = [] Ks = [] i_split = [[], []] remove_empty_masks_cnt = [0, 0] for i, meta in enumerate(annot): im_fname = meta['image']['path'] assert im_fname in train_im_path or im_fname in test_im_path sid = 0 if im_fname in train_im_path else 1 if meta['mask']['mass'] == 0: remove_empty_masks_cnt[sid] += 1 continue im_path = os.path.join(cfg.datadir, im_fname) mask_path = os.path.join(cfg.datadir, meta['mask']['path']) mask = imageio.imread(mask_path) / 255. if mask.max() < 0.5: remove_empty_masks_cnt[sid] += 1 continue Rt = np.concatenate([meta['viewpoint']['R'], np.array(meta['viewpoint']['T'])[:,None]], 1) pose = np.linalg.inv(np.concatenate([Rt, [[0,0,0,1]]])) imgs.append(imageio.imread(im_path) / 255.) masks.append(mask) poses.append(pose) assert imgs[-1].shape[:2] == tuple(meta['image']['size']) half_image_size_wh = np.float32(meta['image']['size'][::-1]) * 0.5 principal_point = np.float32(meta['viewpoint']['principal_point']) focal_length = np.float32(meta['viewpoint']['focal_length']) principal_point_px = -1.0 * (principal_point - 1.0) * half_image_size_wh focal_length_px = focal_length * half_image_size_wh Ks.append(np.array([ [focal_length_px[0], 0, principal_point_px[0]], [0, focal_length_px[1], principal_point_px[1]], [0, 0, 1], ])) i_split[sid].append(len(imgs)-1) if sum(remove_empty_masks_cnt) > 0: print('load_co3d_data: removed %d train / %d test due to empty mask' % tuple(remove_empty_masks_cnt)) print(f'load_co3d_data: num images {len(i_split[0])} train / {len(i_split[1])} test') imgs = np.array(imgs) masks = np.array(masks) poses = np.stack(poses, 0) Ks = np.stack(Ks, 0) render_poses = poses[i_split[-1]] i_split.append(i_split[-1]) # visyalization hwf H, W = np.array([im.shape[:2] for im in imgs]).mean(0).astype(int) focal = Ks[:,[0,1],[0,1]].mean() return imgs, masks, poses, render_poses, [H, W, focal], Ks, i_split