HaolinLiu commited on
Commit
18bb538
1 Parent(s): dcad693

update files for demo

Browse files
configs/finetune_triplane_diffusion.yaml CHANGED
@@ -37,7 +37,7 @@ model:
37
  norm: "batch"
38
  img_in_channels: 1280
39
  vit_reso: 16
40
- use_cat_embedding: ???
41
  block_type: multiview_local
42
  par_point_encoder:
43
  plane_reso: 64
 
37
  norm: "batch"
38
  img_in_channels: 1280
39
  vit_reso: 16
40
+ use_cat_embedding: False #only use category embedding when all categories are trained
41
  block_type: multiview_local
42
  par_point_encoder:
43
  plane_reso: 64
datasets/taxonomy.py CHANGED
@@ -1,20 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  category_map={
2
- "bathtub":0,
3
- "bed":1,
4
- "cabinet":2,
5
- "chair":3,
6
- "dishwasher":4,
7
- "fireplace":5,
8
- "oven":6,
9
- "refrigerator":7,
10
- "shelf":8,
11
- "sink":9,
12
- "sofa":10,
13
- "stool":11,
14
- "stove":12,
15
- "table":13,
16
- "toilet":14,
17
- "washer":15
18
  }
19
 
20
  category_map_from_synthetic={
 
1
+ # category_map={
2
+ # "bathtub":0,
3
+ # "bed":1,
4
+ # "cabinet":2,
5
+ # "chair":3,
6
+ # "dishwasher":4,
7
+ # "fireplace":5,
8
+ # "oven":6,
9
+ # "refrigerator":7,
10
+ # "shelf":8,
11
+ # "sink":9,
12
+ # "sofa":10,
13
+ # "stool":11,
14
+ # "stove":12,
15
+ # "table":13,
16
+ # "toilet":14,
17
+ # "washer":15
18
+ # }
19
+
20
  category_map={
21
+ "chair":0,
22
+ "sofa":1,
23
+ "table":2,
24
+ "cabinet":3,
25
+ "bed":4,
26
+ "shelf":5
 
 
 
 
 
 
 
 
 
 
27
  }
28
 
29
  category_map_from_synthetic={
demo/api.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ sys.path.append("..")
3
+ from configs.config_utils import CONFIG
4
+ from models import get_model
5
+ import torch
6
+ import numpy as np
7
+ import open3d as o3d
8
+ import timm
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
+ from simple_dataset import InTheWild_Dataset,classname_remap,classname_map
12
+ try:
13
+ from torchvision.transforms import InterpolationMode
14
+ BICUBIC = InterpolationMode.BICUBIC
15
+ except ImportError:
16
+ BICUBIC = Image.BICUBIC
17
+ import mcubes
18
+ import trimesh
19
+ from torch.utils.data import DataLoader
20
+
21
+ def image_transform(n_px):
22
+ return Compose([
23
+ Resize(n_px, interpolation=BICUBIC),
24
+ CenterCrop(n_px),
25
+ ToTensor(),
26
+ Normalize((0.48145466, 0.4578275, 0.40821073),
27
+ (0.26862954, 0.26130258, 0.27577711)),
28
+ ])
29
+
30
+ MAX_IMG_LENGTH=5 #take up to 5 images as inputs
31
+
32
+ ae_paths={
33
+ "chair":"../checkpoint/ae/chair/best-checkpoint.pth",
34
+ "table":"../checkpoint/ae/table/best-checkpoint.pth",
35
+ "cabinet":"../checkpoint/ae/cabinet/best-checkpoint.pth",
36
+ "shelf":"../checkpoint/ae/shelf/best-checkpoint.pth",
37
+ "sofa":"../checkpoint/ae/sofa/best-checkpoint.pth",
38
+ "bed":"../checkpoint/ae/bed/best-checkpoint.pth"
39
+ }
40
+ dm_paths={
41
+ "chair":"../checkpoint/finetune_dm/chair/best-checkpoint.pth",
42
+ "table":"../checkpoint/finetune_dm/table/best-checkpoint.pth",
43
+ "cabinet":"../checkpoint/finetune_dm/cabinet/best-checkpoint.pth",
44
+ "shelf":"../checkpoint/finetune_dm/shelf/best-checkpoint.pth",
45
+ "sofa":"../checkpoint/finetune_dm/sofa/best-checkpoint.pth",
46
+ "bed":"../checkpoint/finetune_dm/bed/best-checkpoint.pth"
47
+ }
48
+
49
+ def inference(ae_model,dm_model,data_batch,device,reso=256):
50
+ density = reso
51
+ gap = 2.2 / density
52
+ x = np.linspace(-1.1, 1.1, int(density + 1))
53
+ y = np.linspace(-1.1, 1.1, int(density + 1))
54
+ z = np.linspace(-1.1, 1.1, int(density + 1))
55
+ xv, yv, zv = np.meshgrid(x, y, z, indexing='ij')
56
+ grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,
57
+ non_blocking=True)
58
+ with torch.no_grad():
59
+ sample_input = dm_model.prepare_sample_data(data_batch)
60
+ sampled_array = dm_model.sample(sample_input, num_steps=36).float()
61
+ sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")
62
+
63
+ model_ids = data_batch['model_id']
64
+ tran_mats = data_batch['tran_mat']
65
+
66
+ output_meshes={}
67
+
68
+ for j in range(sampled_array.shape[0]):
69
+ grid_list = torch.split(grid, 128 ** 3, dim=1)
70
+ output_list = []
71
+ with torch.no_grad():
72
+ for sub_grid in grid_list:
73
+ output_list.append(ae_model.decode(sampled_array[j:j + 1], sub_grid))
74
+ output = torch.cat(output_list, dim=1)
75
+ logits = output[j].detach()
76
+
77
+ volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy()
78
+ verts, faces = mcubes.marching_cubes(volume, 0)
79
+
80
+ verts *= gap
81
+ verts -= 1.1
82
+
83
+ tran_mat = tran_mats[j].numpy()
84
+ verts_homo = np.concatenate([verts, np.ones((verts.shape[0], 1))], axis=1)
85
+ verts_inwrd = np.dot(verts_homo, tran_mat.T)[:, 0:3]
86
+ m_inwrd = trimesh.Trimesh(verts_inwrd, faces[:, ::-1]) #transform the mesh into world coordinate
87
+
88
+ output_meshes[model_ids[j]]=m_inwrd
89
+ return output_meshes
90
+
91
+ if __name__=="__main__":
92
+ import argparse
93
+ parser=argparse.ArgumentParser()
94
+ parser.add_argument("--data_dir", type=str, default="../example_process_data")
95
+ parser.add_argument('--scene_id', default="all", type=str)
96
+ parser.add_argument("--save_dir", type=str,default="../example_output_data")
97
+ args = parser.parse_args()
98
+
99
+ config_path="../configs/finetune_triplane_diffusion.yaml"
100
+ config=CONFIG(config_path).config
101
+
102
+ '''creating save folder'''
103
+ save_folder=os.path.join(args.save_dir,args.scene_id)
104
+ os.makedirs(save_folder,exist_ok=True)
105
+
106
+ '''prepare model'''
107
+ device=torch.device("cuda")
108
+ ae_config=config['model']['ae']
109
+ dm_config=config['model']['dm']
110
+ dm_model=get_model(dm_config).to(device)
111
+ ae_model=get_model(ae_config).to(device)
112
+ dm_model.eval()
113
+ ae_model.eval()
114
+
115
+ '''preparing data'''
116
+ '''find out how many classes are there in the whole scene'''
117
+ images_folder=os.path.join(args.data_dir,args.scene_id,"6_images")
118
+ object_id_list=os.listdir(images_folder)
119
+ object_class_list=[item.split("_")[0] for item in object_id_list]
120
+ all_object_class=list(set(object_class_list))
121
+
122
+ exist_super_categories=[]
123
+ for object_class in all_object_class:
124
+ if object_class not in classname_remap:
125
+ continue
126
+ else:
127
+ exist_super_categories.append(classname_remap[object_class]) #find which category specific models should be employed
128
+ exist_super_categories=list(set(exist_super_categories))
129
+ for super_category in exist_super_categories:
130
+ print("processing %s"%(super_category))
131
+ ae_ckpt=torch.load(ae_paths[super_category],map_location="cpu")["model"]
132
+ dm_ckpt=torch.load(dm_paths[super_category],map_location="cpu")["model"]
133
+ ae_model.load_state_dict(ae_ckpt)
134
+ dm_model.load_state_dict(dm_ckpt)
135
+ dataset = InTheWild_Dataset(data_dir=args.data_dir, scene_id=args.scene_id, category=super_category, max_n_imgs=5)
136
+ dataloader=DataLoader(
137
+ dataset=dataset,
138
+ num_workers=1,
139
+ batch_size=1,
140
+ shuffle=False
141
+ )
142
+ for data_batch in dataloader:
143
+ output_meshes=inference(ae_model,dm_model,data_batch,device)
144
+ #print(output_meshes)
145
+ for model_id in output_meshes:
146
+ mesh=output_meshes[model_id]
147
+ save_path=os.path.join(save_folder,model_id+".ply")
148
+ print("saving to %s"%(save_path))
149
+ mesh.export(save_path)
150
+
151
+
demo/extract_vit_features.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ sys.path.append("..")
3
+ import numpy
4
+ from simple_dataset import Simple_InTheWild_dataset
5
+ import argparse
6
+ from torch.utils.data import DataLoader
7
+ import timm
8
+ import torch
9
+ import numpy as np
10
+ from util import misc
11
+
12
+ parser=argparse.ArgumentParser()
13
+ parser.add_argument("--data_dir",type=str,default="../example_process_data")
14
+ parser.add_argument('--world_size', default=1, type=int,
15
+ help='number of distributed processes')
16
+ parser.add_argument('--local_rank', default=-1, type=int)
17
+ parser.add_argument('--dist_on_itp', action='store_true')
18
+ parser.add_argument('--dist_url', default='env://',
19
+ help='url used to set up distributed training')
20
+ parser.add_argument('--scene_id',default="all",type=str)
21
+ args=parser.parse_args()
22
+
23
+
24
+ misc.init_distributed_mode(args)
25
+ dataset=Simple_InTheWild_dataset(dataset_dir=args.data_dir,scene_id=args.scene_id,n_px=224)
26
+ num_tasks = misc.get_world_size()
27
+ global_rank = misc.get_rank()
28
+ print(num_tasks,global_rank)
29
+ sampler = torch.utils.data.DistributedSampler(
30
+ dataset, num_replicas=num_tasks, rank=global_rank,
31
+ shuffle=False) # shuffle=True to reduce monitor bias
32
+
33
+ dataloader=DataLoader(
34
+ dataset,
35
+ sampler=sampler,
36
+ batch_size=10,
37
+ num_workers=4,
38
+ pin_memory=True,
39
+ drop_last=False
40
+ )
41
+ VIT_MODEL = 'vit_huge_patch14_224_clip_laion2b'
42
+ model=timm.create_model(VIT_MODEL, pretrained=True,pretrained_cfg_overlay=dict(file="./open_clip_pytorch_model.bin"))
43
+ model=model.eval().float().cuda()
44
+ for idx,data_batch in enumerate(dataloader):
45
+ if idx%10==0:
46
+ print("{}/{}".format(dataloader.__len__(),idx))
47
+ images = data_batch["images"].cuda().float()
48
+ model_id= data_batch["model_id"]
49
+ image_name=data_batch["image_name"]
50
+ scene_id=data_batch["scene_id"]
51
+ with torch.no_grad():
52
+ output_features=model.forward_features(images)
53
+ for j in range(output_features.shape[0]):
54
+ save_folder=os.path.join(args.data_dir,scene_id[j],"7_img_feature",model_id[j])
55
+ os.makedirs(save_folder,exist_ok=True)
56
+ save_path=os.path.join(save_folder,image_name[j]+".npz")
57
+ np.savez_compressed(save_path,img_features=output_features[j].detach().cpu().numpy().astype(np.float32))
demo/process_data.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import argparse
4
+ import open3d as o3d
5
+ import glob
6
+ import cv2
7
+ import copy
8
+
9
+ def get_roll_rot(angle):
10
+ ca=np.cos(angle)
11
+ sa=np.sin(angle)
12
+ rot=np.array([
13
+ [ca,-sa,0,0],
14
+ [sa,ca,0,0],
15
+ [0,0,1,0],
16
+ [0,0,0,1]
17
+ ])
18
+ return rot
19
+ def rotate_mat(direction):
20
+ if direction == 'Up':
21
+ return np.eye(4)
22
+ elif direction == 'Left':
23
+ rot_mat=get_roll_rot(np.pi/2)
24
+ elif direction == 'Right':
25
+ rot_mat=get_roll_rot(-np.pi/2)
26
+ elif direction == 'Down':
27
+ rot_mat=get_roll_rot(np.pi)
28
+ else:
29
+ raise Exception(f'No such direction (={direction}) rotation')
30
+ return rot_mat
31
+
32
+ def rotate_K(K,direction):
33
+ if direction == 'Up' or direction=="Down":
34
+ new_K4=np.eye(4)
35
+ new_K4[0:3,0:3]=copy.deepcopy(K)
36
+ return new_K4
37
+ elif direction == 'Left' or direction =="Right":
38
+ fx,fy,cx,cy=K[0,0],K[1,1],K[0,2],K[1,2]
39
+ new_K4 = np.array([
40
+ [fy, 0, cy, 0],
41
+ [0, fx, cx, 0],
42
+ [0, 0, 1, 0],
43
+ [0, 0, 0, 1]
44
+ ])
45
+ return new_K4
46
+
47
+ def rotate_bbox(bbox,direction, H,W):
48
+
49
+ x_min,y_min,x_max,y_max=bbox[0:4]
50
+ if direction == 'Up':
51
+ return bbox
52
+ elif direction == 'Left':
53
+ #print(W-bbox[1],W-bbox[3])
54
+ new_bbox=[min(H-bbox[1],H-bbox[3]),bbox[0],max(H-bbox[1],H-bbox[3]),bbox[2]]
55
+ elif direction == 'Right':
56
+ new_bbox=[bbox[1],min(W-bbox[0],W-bbox[2]),bbox[3],max(W-bbox[0],W-bbox[2])]
57
+ elif direction == 'Down':
58
+ new_bbox=[min(W-x_min,W-x_max),min(H-y_min,H-y_max),max(W-x_min,W-x_max),max(H-y_min,H-y_max)]
59
+ else:
60
+ raise Exception(f'No such direction (={direction}) rotation')
61
+ return new_bbox
62
+
63
+ def rotate_image(img, direction):
64
+ if direction == 'Up':
65
+ pass
66
+ elif direction == 'Left':
67
+ img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
68
+ elif direction == 'Right':
69
+ img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
70
+ elif direction == 'Down':
71
+ img = cv2.rotate(img, cv2.ROTATE_180)
72
+ else:
73
+ raise Exception(f'No such direction (={direction}) rotation')
74
+ return img
75
+
76
+ parser=argparse.ArgumentParser()
77
+ parser.add_argument("--data_folder",type=str,required=True)
78
+ parser.add_argument("--save_dir",type=str,default=r"../example_process_data")
79
+ parser.add_argument("--debug",action="store_true",default=False)
80
+ args=parser.parse_args()
81
+
82
+ print("processing %s"%(args.data_folder))
83
+
84
+ data_folder=args.data_folder
85
+ scene_name=os.path.basename(data_folder)
86
+ save_folder=os.path.join(args.save_dir,scene_name)
87
+ os.makedirs(save_folder,exist_ok=True)
88
+ color_folder=os.path.join(data_folder,"color")
89
+ depth_folder=os.path.join(data_folder,"depth")
90
+ pose_folder=os.path.join(data_folder,"pose")
91
+
92
+ print(color_folder)
93
+
94
+ color_list=glob.glob(color_folder+"/*.jpg")
95
+ image_id_list=[os.path.basename(item)[0:-4] for item in color_list]
96
+ image_id_list.sort()
97
+
98
+ bbox_path=os.path.join(data_folder,"objects.npy")
99
+ bboxes_dict=np.load(bbox_path,allow_pickle=True).item()
100
+
101
+ intrinsic_path=os.path.join(data_folder,"intrinsic","intrinsic_color.txt")
102
+ K=np.loadtxt(intrinsic_path)
103
+
104
+ align_path=os.path.join(data_folder,"alignment_matrix.txt")
105
+ align_matrix=np.loadtxt(align_path)
106
+ if align_matrix.shape[0]==3:
107
+ new_align_matrix=np.eye(4)
108
+ new_align_matrix[0:3,0:3]=align_matrix
109
+ align_matrix=new_align_matrix
110
+
111
+ mesh_path=os.path.join(data_folder,"fused_mesh.ply")
112
+ o3d_mesh=o3d.io.read_triangle_mesh(mesh_path)
113
+ o3d_vertices = np.array(o3d_mesh.vertices)
114
+ o3d_vert_homo=np.concatenate([o3d_vertices,np.ones([o3d_vertices.shape[0],1])],axis=1)
115
+ align_o3d_vertices = np.dot(o3d_vert_homo,align_matrix)[:,0:3]
116
+ o3d_mesh.vertices = o3d.utility.Vector3dVector(align_o3d_vertices)
117
+ align_mesh_save_path=os.path.join(save_folder,"align_mesh.ply")
118
+ o3d.io.write_triangle_mesh(align_mesh_save_path,o3d_mesh)
119
+
120
+ x=np.linspace(-1,1,10)
121
+ y=np.linspace(-1,1,10)
122
+ z=np.linspace(-1,1,10)
123
+ X,Y,Z=np.meshgrid(x,y,z,indexing='ij')
124
+ vox_coor=np.concatenate([X[:,:,:,np.newaxis],Y[:,:,:,np.newaxis],Z[:,:,:,np.newaxis]],axis=-1)
125
+ vox_coor=np.reshape(vox_coor,(-1,3))
126
+ #print(np.amin(vox_coor,axis=0),np.amax(vox_coor,axis=0))
127
+
128
+ pre_proj_mates={}
129
+ obj_points_dict={}
130
+ trans_mats={}
131
+ point_save_folder=os.path.join(save_folder,"5_partial_points")
132
+ os.makedirs(point_save_folder,exist_ok=True)
133
+ tran_save_folder=os.path.join(save_folder,"10_tran_matrix")
134
+ os.makedirs(tran_save_folder,exist_ok=True)
135
+ for object_id in bboxes_dict:
136
+ object = bboxes_dict[object_id]
137
+ category = object['category']
138
+ sizes = object['size']
139
+ sizes *= 1.1
140
+ transform_matrix_t = np.array(object['transform']).reshape([4, 4])
141
+ translate = transform_matrix_t[:3, 3]
142
+ rotation = transform_matrix_t[:3, :3]
143
+
144
+ bbox_o3d = o3d.geometry.OrientedBoundingBox(translate.reshape([3, 1]),
145
+ rotation,
146
+ np.array(sizes).reshape([3, 1]))
147
+ crop_pcd = o3d_mesh.crop(bbox_o3d)
148
+ crop_vert = np.asarray(crop_pcd.vertices)
149
+ org_crop_vert = crop_vert[:, :]
150
+ crop_vert = crop_vert - translate
151
+ crop_vert = np.dot(crop_vert,np.linalg.inv(rotation).T)
152
+ crop_vert[:, 2] *= -1
153
+ bb_min, bb_max = np.amin(crop_vert, axis=0), np.amax(crop_vert, axis=0)
154
+ max_length = (bb_max - bb_min).max()
155
+ center = (bb_max + bb_min) / 2
156
+ crop_vert = (crop_vert - center) / max_length * 2
157
+
158
+ obj_points_dict[object_id]=crop_vert
159
+ crop_pcd.vertices=o3d.utility.Vector3dVector(crop_vert)
160
+ save_path=os.path.join(point_save_folder,category+"_%d.ply"%(object_id))
161
+ o3d.io.write_triangle_mesh(save_path,crop_pcd)
162
+
163
+ proj_mat = np.eye(4)
164
+ scale_tran = np.eye(4)
165
+ scale_tran[0, 0], scale_tran[1, 1], scale_tran[2, 2] = max_length / 2, max_length / 2, max_length / 2
166
+ proj_mat = np.dot(proj_mat, scale_tran)
167
+ center_tran = np.eye(4)
168
+ center_tran[0:3, 3] = center
169
+ proj_mat = np.dot(center_tran, proj_mat)
170
+ invert_mat = np.eye(4)
171
+ invert_mat[2, 2] *= -1
172
+ proj_mat = np.dot(invert_mat, proj_mat)
173
+ proj_mat[0:3, 0:3] = np.dot(rotation,proj_mat[0:3, 0:3])
174
+ translate_mat = np.eye(4)
175
+ translate_mat[0:3, 3] = translate
176
+ proj_mat = np.dot(translate_mat, proj_mat)
177
+
178
+ '''tran mat is to align output to scene space'''
179
+ tran_mat=copy.deepcopy(proj_mat)
180
+ trans_mats[object_id]=tran_mat
181
+ tran_save_path=os.path.join(tran_save_folder,category+"_%d.npy"%(object_id))
182
+ np.save(tran_save_path,tran_mat)
183
+
184
+ unalign_mat = np.linalg.inv(align_matrix)
185
+ proj_mat = np.dot(unalign_mat.T, proj_mat)
186
+ pre_proj_mates[object_id]=proj_mat
187
+
188
+ ref=np.array([
189
+ [0,1.0], #Up
190
+ [-1.0,0],#Left
191
+ [0,1.0], #Right
192
+ [0.0,-1.0] #Down
193
+ ]) #4*2
194
+ dir_list=[
195
+ "Down",
196
+ "Left",
197
+ "Right",
198
+ "Up"
199
+ ]
200
+
201
+ for image_id in image_id_list:
202
+ color_path=os.path.join(color_folder,image_id+".jpg")
203
+ depth_path=os.path.join(depth_folder,image_id+".png")
204
+ pose_path=os.path.join(pose_folder,image_id+".txt")
205
+
206
+ color=cv2.imread(color_path)
207
+ height,width=color.shape[0:2]
208
+ depth=cv2.imread(depth_path,cv2.IMREAD_ANYCOLOR|cv2.IMREAD_ANYDEPTH)/1000.0
209
+ pose=np.loadtxt(pose_path)
210
+ for object_id in bboxes_dict:
211
+ object=bboxes_dict[object_id]
212
+ category=object['category']
213
+ sizes=object['size']
214
+ object_vox_coor=vox_coor*sizes[np.newaxis,:]
215
+ #print(np.amin(object_vox_coor,axis=0),np.amax(object_vox_coor,axis=0))
216
+ #print(sizes)
217
+
218
+ prev_proj_mat=pre_proj_mates[object_id]
219
+ wrd2cam_pose = np.linalg.inv(pose)
220
+ current_proj_mat = np.dot(wrd2cam_pose, prev_proj_mat)
221
+ K4=np.eye(4)
222
+ K4[0:3,0:3]=K
223
+
224
+ '''calibrate proj_mat'''
225
+ up_vectors = np.array([[0, 0, 0, 1.0],
226
+ [0, 0.5, 0, 1.0]])
227
+ up_vec_inimg = np.dot(up_vectors, current_proj_mat.T)
228
+ up_vec_inimg = np.dot(up_vec_inimg,K4.T)
229
+ up_x = up_vec_inimg[:, 0] / up_vec_inimg[:, 2]
230
+ up_y = up_vec_inimg[:, 1] / up_vec_inimg[:, 2]
231
+ pt1 = np.array((up_x[0], up_y[0]))
232
+ pt2 = np.array((up_x[1], up_y[1]))
233
+ up_dir = pt2 - pt1
234
+ # print(up_dir)
235
+
236
+ product = np.sum(up_dir[np.newaxis, :] * ref, axis=1)
237
+ max_ind = np.argmax(product)
238
+ direction = dir_list[max_ind]
239
+ sky_rot = rotate_mat(direction)
240
+ #final_proj_mat = np.dot(K4,final_proj_mat)
241
+
242
+ vox_homo=np.concatenate([object_vox_coor,np.ones((object_vox_coor.shape[0],1))],axis=1)
243
+ vox_proj=np.dot(vox_homo,current_proj_mat.T)
244
+ vox_proj=np.dot(vox_proj,K4.T)
245
+ vox_x=vox_proj[:,0]/vox_proj[:,2]
246
+ vox_y=vox_proj[:,1]/vox_proj[:,2]
247
+
248
+ if np.mean(vox_proj[:,2])>5:
249
+ continue
250
+
251
+ inside_mask=((vox_x<width-1) &(vox_x>0) &(vox_y<height-1) &(vox_y>0)).astype(np.float32)
252
+ infrustum_ratio=np.sum(inside_mask)/vox_x.shape[0]
253
+ if infrustum_ratio < 0.4 and category in ["chair", "stool"]:
254
+ continue
255
+ elif infrustum_ratio <0.2:
256
+ continue
257
+ #print(object_id,image_id,infrustum_ratio)
258
+
259
+ '''objects visibility check for every frame'''
260
+ vox_x_inside=vox_x[inside_mask>0].astype(np.int32)
261
+ vox_y_inside=vox_y[inside_mask>0].astype(np.int32)
262
+ vox_depth=vox_proj[inside_mask>0,2]
263
+ #print(depth.shape,np.amax(vox_y_inside),np.amax(vox_x_inside))
264
+ depth_sample=depth[vox_y_inside,vox_x_inside]
265
+ depth_mask=(depth_sample>0)&(depth_sample<10.0)
266
+ depth_sample=depth_sample[depth_mask]
267
+ vox_depth=vox_depth[depth_mask]
268
+
269
+ if vox_depth.shape[0]<100:
270
+ continue
271
+
272
+ occluded_ratio=np.sum(((vox_depth-depth_sample)>0.2).astype(np.float32))/vox_depth.shape[0]
273
+ if occluded_ratio>0.6 and category in ["chair"]: #chair is easily occluded, while table is not
274
+ continue
275
+
276
+ depth_near_ratio = np.sum((np.abs(vox_depth - depth_sample) < sizes.max() * 0.5).astype(np.float32)) / \
277
+ vox_depth.shape[0]
278
+ if depth_near_ratio < 0.2:
279
+ continue
280
+
281
+ '''make sure in every image, the object is upward'''
282
+ bbox=(np.amin(vox_x_inside),np.amin(vox_y_inside),np.amax(vox_x_inside),np.amax(vox_y_inside))
283
+ rot_image=rotate_image(color,direction)
284
+ bbox = rotate_bbox(bbox, direction, height, width)
285
+ crop_image=rot_image[bbox[1]:bbox[3],bbox[0]:bbox[2]]
286
+ crop_h, crop_w = crop_image.shape[0:2]
287
+ max_length = max(crop_h, crop_w)
288
+ if max_length<100:
289
+ continue
290
+ pad_image = np.zeros((max_length, max_length, 3))
291
+ if crop_h > crop_w:
292
+ margin = crop_h - crop_w
293
+ pad_image[:, margin // 2:margin // 2 + crop_w] = crop_image[:, :, :]
294
+ x_start, x_end = bbox[0] - margin // 2, margin // 2 + bbox[2]
295
+ y_start, y_end = bbox[1], bbox[3]
296
+ else:
297
+ margin = crop_w - crop_h
298
+ pad_image[margin // 2:margin // 2 + crop_h, :] = crop_image[:, :, :]
299
+
300
+ y_start, y_end = bbox[1] - margin // 2, bbox[3] + margin // 2
301
+ x_start, x_end = bbox[0], bbox[2]
302
+
303
+ pad_image=cv2.resize(pad_image,dsize=(224,224),interpolation=cv2.INTER_LINEAR)
304
+ image_save_folder = os.path.join(save_folder, "6_images", category + "_%d" % (object_id))
305
+ os.makedirs(image_save_folder, exist_ok=True)
306
+ image_save_path=os.path.join(image_save_folder,image_id+".jpg")
307
+ #print("saving to %s"%(image_save_path))
308
+ cv2.imwrite(image_save_path,pad_image)
309
+
310
+ proj_mat=np.dot(sky_rot,current_proj_mat)
311
+ new_K4 = rotate_K(K, direction)
312
+ new_K4[0, 2] -= x_start
313
+ new_K4[1, 2] -= y_start
314
+ new_K4[0] = new_K4[0] / max_length * 224
315
+ new_K4[1] = new_K4[1] / max_length * 224
316
+ proj_mat = np.dot(new_K4, proj_mat)
317
+
318
+ proj_save_folder=os.path.join(save_folder,"8_proj_matrix",category+"_%d"%(object_id))
319
+ os.makedirs(proj_save_folder,exist_ok=True)
320
+ proj_save_path=os.path.join(proj_save_folder,image_id+".npy")
321
+ np.save(proj_save_path,proj_mat)
322
+
323
+ '''debug proj matrix'''
324
+ if args.debug:
325
+ proj_save_folder=os.path.join(save_folder,"9_proj_images",category+"_%d"%(object_id))
326
+ os.makedirs(proj_save_folder,exist_ok=True)
327
+ canvas=copy.deepcopy(pad_image)
328
+ par_points=obj_points_dict[object_id]
329
+ par_homo=np.concatenate([par_points,np.ones((par_points.shape[0],1))],axis=1)
330
+ par_inimg=np.dot(par_homo,proj_mat.T)
331
+ x=par_inimg[:,0]/par_inimg[:,2]
332
+ y=par_inimg[:,1]/par_inimg[:,2]
333
+ x=np.clip(x,a_min=0,a_max=223).astype(np.int32)
334
+ y=np.clip(y,a_min=0,a_max=223).astype(np.int32)
335
+ canvas[y,x]=np.array([[0,255,0]])
336
+ proj_save_path=os.path.join(proj_save_folder,image_id+".jpg")
337
+ cv2.imwrite(proj_save_path,canvas)
338
+
339
+
340
+
demo/simple_dataset.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils import data
4
+ import os
5
+ from PIL import Image
6
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
7
+ try:
8
+ from torchvision.transforms import InterpolationMode
9
+ BICUBIC = InterpolationMode.BICUBIC
10
+ except ImportError:
11
+ BICUBIC = Image.BICUBIC
12
+ import glob
13
+ import numpy as np
14
+ import open3d as o3d
15
+ import cv2
16
+ from datasets.taxonomy import category_map as category_ids
17
+
18
+ classname_map={
19
+ "chair":["chair","stool"],
20
+ "cabinet":["dishwasher","cabinet","oven","refrigerator",'storage'],
21
+ "sofa":["sofa"],
22
+ "table":["table"],
23
+ "bed":["bed"],
24
+ "shelf":["shelf"]
25
+ }
26
+ classname_remap={ #map small categories to six large categories
27
+ "chair":"chair",
28
+ "stool":"chair",
29
+ "dishwasher":"cabinet",
30
+ "cabinet":"cabinet",
31
+ "oven":"cabinet",
32
+ "refrigerator":"cabinet",
33
+ "storage":"cabinet",
34
+ "sofa":"sofa",
35
+ "table":"table",
36
+ "bed":"bed",
37
+ "shelf":"shelf"
38
+ }
39
+
40
+ def image_transform(n_px):
41
+ return Compose([
42
+ Resize(n_px, interpolation=BICUBIC),
43
+ CenterCrop(n_px),
44
+ ToTensor(),
45
+ Normalize((0.48145466, 0.4578275, 0.40821073),
46
+ (0.26862954, 0.26130258, 0.27577711)),
47
+ ])
48
+ class Simple_InTheWild_dataset(data.Dataset):
49
+ def __init__(self,dataset_dir="/data1/haolin/data/real_scene_process_data",scene_id="letian-310",n_px=224):
50
+ self.dataset_dir=dataset_dir
51
+ self.preprocess = image_transform(n_px)
52
+ self.image_path = []
53
+ if scene_id=="all":
54
+ scene_list=os.listdir(self.dataset_dir)
55
+ for id in scene_list:
56
+ image_folder=os.path.join(self.dataset_dir,id,"6_images")
57
+ self.image_path+=glob.glob(image_folder+"/*/*jpg")
58
+ else:
59
+ image_folder = os.path.join(self.dataset_dir, scene_id, "6_images")
60
+ self.image_path += glob.glob(image_folder + "/*/*jpg")
61
+ def __len__(self):
62
+ return len(self.image_path)
63
+
64
+ def __getitem__(self,index):
65
+ path=self.image_path[index]
66
+ basename=os.path.basename(path)[:-4]
67
+ model_id=path.split(os.sep)[-2]
68
+ scene_id=path.split(os.sep)[-4]
69
+ image=Image.open(path)
70
+ image_tensor=self.preprocess(image)
71
+
72
+ return {"images":image_tensor,"image_name":basename,"model_id":model_id,"scene_id":scene_id}
73
+
74
+ class InTheWild_Dataset(data.Dataset):
75
+ def __init__(self,data_dir="/data1/haolin/data/real_scene_process_data/letian-310",scene_id="letian-310",
76
+ par_pc_size=2048,category="chair",max_n_imgs=5):
77
+ self.par_pc_size=par_pc_size
78
+ self.data_dir=data_dir
79
+ self.category=category
80
+ self.max_n_imgs=max_n_imgs
81
+
82
+ self.models=[]
83
+ category_list=classname_map[category]
84
+ modelid_list=[]
85
+ for cat in category_list:
86
+ if scene_id=="all":
87
+ scene_list=os.listdir(self.data_dir)
88
+ for id in scene_list:
89
+ data_folder=os.path.join(self.data_dir,id)
90
+ modelid_list+=glob.glob(data_folder+"/6_images/%s*"%(cat))
91
+ else:
92
+ data_folder=os.path.join(self.data_dir,scene_id)
93
+ modelid_list+=glob.glob(data_folder+"/6_images/%s*"%(cat))
94
+ sceneid_list = [item.split("/")[-3] for item in modelid_list]
95
+ modelid_list=[item.split("/")[-1] for item in modelid_list]
96
+ for idx,modelid in enumerate(modelid_list):
97
+ scene_id=sceneid_list[idx]
98
+ image_folder=os.path.join(self.data_dir,scene_id,"6_images",modelid)
99
+ image_list=os.listdir(image_folder)
100
+ if len(image_list)==0:
101
+ continue
102
+ imageid_list=[item[0:-4] for item in image_list]
103
+ imageid_list.sort(key=lambda x:int(x))
104
+ partial_path=os.path.join(self.data_dir,scene_id,"5_partial_points",modelid+".ply")
105
+ if os.path.exists(partial_path)==False: continue
106
+ self.models+=[
107
+ {'model_id':modelid,
108
+ "scene_id":scene_id,
109
+ "partial_path":partial_path,
110
+ "imageid_list":imageid_list,
111
+ }
112
+ ]
113
+ def __len__(self):
114
+ return len(self.models)
115
+
116
+ def __getitem__(self,idx):
117
+ model = self.models[idx]['model_id']
118
+ scene_id=self.models[idx]['scene_id']
119
+ imageid_list = self.models[idx]['imageid_list']
120
+ partial_path=self.models[idx]['partial_path']
121
+ n_frames=min(len(imageid_list),self.max_n_imgs)
122
+ img_indexes=np.linspace(start=0,stop=len(imageid_list)-1,num=n_frames).astype(np.int32)
123
+
124
+ '''load partial points'''
125
+ par_point_o3d = o3d.io.read_point_cloud(partial_path)
126
+ par_points = np.asarray(par_point_o3d.points)
127
+ replace = par_points.shape[0] < self.par_pc_size
128
+ ind = np.random.default_rng().choice(par_points.shape[0], self.par_pc_size, replace=replace)
129
+ par_points=par_points[ind]
130
+ par_points=torch.from_numpy(par_points).float()
131
+
132
+ '''load image features'''
133
+ image_list=[]
134
+ valid_frames = []
135
+ image_namelist=[]
136
+ for img_index in img_indexes:
137
+ image_name = imageid_list[img_index]
138
+ image_feat_path = os.path.join(self.data_dir,scene_id, "7_img_feature", model,image_name + '.npz')
139
+ image = np.load(image_feat_path)["img_features"]
140
+ image_list.append(torch.from_numpy(image).float())
141
+ image_namelist.append(image_name)
142
+ valid_frames.append(True)
143
+ '''load original image'''
144
+ org_img_list=[]
145
+ for img_index in img_indexes:
146
+ image_name = imageid_list[img_index]
147
+ image_path = os.path.join(self.data_dir,scene_id, "6_images", model,image_name+".jpg")
148
+ org_image = cv2.imread(image_path)
149
+ org_image = cv2.resize(org_image, dsize=(224, 224), interpolation=cv2.INTER_LINEAR)
150
+ org_img_list.append(org_image)
151
+
152
+ '''load project matrix'''
153
+ proj_mat_list=[]
154
+ for img_index in img_indexes:
155
+ image_name = imageid_list[img_index]
156
+ proj_mat_path = os.path.join(self.data_dir,scene_id, "8_proj_matrix", model, image_name + ".npy")
157
+ proj_mat = np.load(proj_mat_path)
158
+ proj_mat_list.append(proj_mat)
159
+
160
+ '''load transformation matrix'''
161
+ tran_mat_path = os.path.join(self.data_dir,scene_id, "10_tran_matrix", model+".npy")
162
+ tran_mat = np.load(tran_mat_path)
163
+
164
+ '''category code, not used for category specific models'''
165
+ category_id = category_ids[self.category]
166
+ one_hot = torch.zeros((6)).float()
167
+ one_hot[category_id] = 1.0
168
+
169
+ ret_dict={
170
+ "model_id":model,
171
+ "scene_id":scene_id,
172
+ "par_points":par_points,
173
+ "proj_mat":torch.stack([torch.from_numpy(mat) for mat in proj_mat_list], dim=0),
174
+ "tran_mat":torch.from_numpy(tran_mat).float(),
175
+ "image":torch.stack(image_list,dim=0),
176
+ "org_image":org_img_list,
177
+ "valid_frames":torch.tensor(valid_frames).bool(),
178
+ "category_ids": category_id,
179
+ "category_code":one_hot,
180
+ }
181
+ return ret_dict
182
+
train_VAE.sh CHANGED
@@ -12,4 +12,4 @@ train_triplane_vae.py \
12
  --clip_grad 0.35 \
13
  --category chair \
14
  --data-pth ../data \
15
- --replica 5
 
12
  --clip_grad 0.35 \
13
  --category chair \
14
  --data-pth ../data \
15
+ --replica 5 7
util/misc.py CHANGED
@@ -15,7 +15,7 @@ from pathlib import Path
15
  import torch
16
  import torch.distributed as dist
17
  #from torch._six import inf
18
- import inf
19
  import numpy as np
20
 
21
  def log_codefiles(data_root,save_root):
 
15
  import torch
16
  import torch.distributed as dist
17
  #from torch._six import inf
18
+ import math
19
  import numpy as np
20
 
21
  def log_codefiles(data_root,save_root):
util/simple_image_loader.py CHANGED
@@ -16,12 +16,8 @@ def image_transform(n_px):
16
  Resize(n_px, interpolation=BICUBIC),
17
  CenterCrop(n_px),
18
  ToTensor(),
19
- # Normalize((123.675/255.0,116.28/255.0,103.53/255.0),
20
- # (58.395/255.0,57.12/255.0,57.375/255.0))
21
  Normalize((0.48145466, 0.4578275, 0.40821073),
22
  (0.26862954, 0.26130258, 0.27577711)),
23
- # Normalize((0.5, 0.5, 0.5),
24
- # (0.5, 0.5, 0.5)),
25
  ])
26
 
27
  class Image_dataset(data.Dataset):
 
16
  Resize(n_px, interpolation=BICUBIC),
17
  CenterCrop(n_px),
18
  ToTensor(),
 
 
19
  Normalize((0.48145466, 0.4578275, 0.40821073),
20
  (0.26862954, 0.26130258, 0.27577711)),
 
 
21
  ])
22
 
23
  class Image_dataset(data.Dataset):