first commit of codes and update readme.md
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +72 -1
- configs/config_utils.py +70 -0
- configs/finetune_triplane_diffusion.yaml +67 -0
- configs/train_triplane_diffusion.yaml +64 -0
- configs/train_triplane_vae.yaml +30 -0
- data/download_preprocess_data_here +1 -0
- datasets/SingleView_dataset.py +453 -0
- datasets/__init__.py +91 -0
- datasets/taxonomy.py +111 -0
- datasets/transforms.py +180 -0
- engine/engine_triplane_dm.py +136 -0
- engine/engine_triplane_vae.py +185 -0
- evaluation/dist_eval.sh +16 -0
- evaluation/evaluate_object_reconstruction.py +239 -0
- evaluation/pyTorchChamferDistance/.gitignore +3 -0
- evaluation/pyTorchChamferDistance/LICENSE.md +21 -0
- evaluation/pyTorchChamferDistance/README.md +23 -0
- evaluation/pyTorchChamferDistance/__init__.py +0 -0
- evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py +1 -0
- evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp +185 -0
- evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu +209 -0
- evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py +58 -0
- finetune_diffusion.sh +18 -0
- models/TriplaneVAE.py +94 -0
- models/Triplane_Diffusion.py +190 -0
- models/__init__.py +20 -0
- models/modules/PointEMB.py +34 -0
- models/modules/Positional_Embedding.py +15 -0
- models/modules/__init__.py +5 -0
- models/modules/decoder.py +121 -0
- models/modules/diffusion_sampler.py +89 -0
- models/modules/encoder.py +235 -0
- models/modules/image_sampler.py +1046 -0
- models/modules/parpoints_encoder.py +168 -0
- models/modules/point_transformer.py +442 -0
- models/modules/pointnet2_backbone.py +188 -0
- models/modules/resnet_block.py +47 -0
- models/modules/resunet.py +440 -0
- models/modules/unet.py +304 -0
- models/modules/utils.py +25 -0
- output/put_checkpoints_here +1 -0
- process_scripts/augment_arkit_partial_point.py +64 -0
- process_scripts/augment_synthetic_partial_points.py +64 -0
- process_scripts/dist_export_triplane_features.sh +8 -0
- process_scripts/dist_extract_vit.sh +6 -0
- process_scripts/export_triplane_features.py +122 -0
- process_scripts/extract_img_vit_features.py +73 -0
- process_scripts/generate_split_for_arkit.py +102 -0
- process_scripts/generate_split_for_synthetic_data.py +78 -0
- process_scripts/unzip_all_data.py +38 -0
README.md
CHANGED
@@ -8,4 +8,75 @@ Repository of LASA: Instance Reconstruction from Real Scans using A Large-scale
|
|
8 |
![292080628-a4b020dc-2673-4b1b-bfa6-ec9422625624](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/7a0dfc11-5454-428f-bfba-e8cd0d0af96e)
|
9 |
![292080638-324bbef9-c93b-4d96-b814-120204374383](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/ee07691a-8767-4701-9a32-19a70e0e240a)
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
![292080628-a4b020dc-2673-4b1b-bfa6-ec9422625624](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/7a0dfc11-5454-428f-bfba-e8cd0d0af96e)
|
9 |
![292080638-324bbef9-c93b-4d96-b814-120204374383](https://github.com/GAP-LAB-CUHK-SZ/LASA/assets/40767265/ee07691a-8767-4701-9a32-19a70e0e240a)
|
10 |
|
11 |
+
## Dataset
|
12 |
+
Complete raw data will be released soon.
|
13 |
+
|
14 |
+
## Download preprocessed data and processing
|
15 |
+
Download the preprocessed data from <a href="https://pan.baidu.com/s/1tCEGYBH0DEh8NcAURTnMbw?pwd=62ux">
|
16 |
+
BaiduYun (code: 62ux)<a/>. (These data will be updated as cleaning process continues.) Put all the downloaded data under LASA, unzip the align_mat_all.zip mannually.
|
17 |
+
You can choose to the the script ./process_scripts/unzip_all_data to unzip all the data in occ_data and other_data by following commands:
|
18 |
+
```angular2html
|
19 |
+
cd process_scripts
|
20 |
+
python unzip_all_data.py --unzip_occ --unzip_other
|
21 |
+
```
|
22 |
+
Run the following commands to generate augmented partial point cloud for synthetic dataset and LASA dataset
|
23 |
+
```angular2html
|
24 |
+
cd process_scripts
|
25 |
+
python augment_arkit_partial_point.py --cat arkit_chair arkit_stool ...
|
26 |
+
python augment_synthetic_partial_point.py --cat 03001627 future_chair ABO_chair ...
|
27 |
+
```
|
28 |
+
Run the following command to extract image features
|
29 |
+
```angular2html
|
30 |
+
cd process_scripts
|
31 |
+
bash dist_extract_vit.sh
|
32 |
+
```
|
33 |
+
Finally, run the following command to generate train/val splits:
|
34 |
+
```angular2html
|
35 |
+
cd process_scripts
|
36 |
+
python generate_split_for_arkit --cat arkit_chair arkit_stool ...
|
37 |
+
python generate_split_for_synthetic_data.py --cat 03001627 future_chair ABO_chair ...
|
38 |
+
```
|
39 |
+
|
40 |
+
## Evaluation
|
41 |
+
Download the pretrained weight for chair from <a href="https://pan.baidu.com/s/10liUOaC4CXGn7bN6SQkZsw?pwd=hlf9"> chair_checkpoint.<a/> (code:hlf9).
|
42 |
+
Put these folder under LASA/output.<br> The ae folder stores the VAE weight, dm folder stores the diffusion model trained on synthetic data.
|
43 |
+
finetune_dm folder stores the diffusion model finetuned on LASA dataset.
|
44 |
+
Run the following commands to evaluate and extract the mesh:
|
45 |
+
```angular2html
|
46 |
+
cd evaluation
|
47 |
+
bash dist_eval.sh
|
48 |
+
```
|
49 |
+
The category entries are the sub-category from arkit scenes, please see ./datasets/taxonomy.py about how they are defined.
|
50 |
+
For example, if you want to evaluate on LASA's chair, category should contain both arkit_chair and arkit_stool.
|
51 |
+
make sure the --ae-pth and --dm-pth entry points to the correct checkpoint path. If you are evaluating on LASA,
|
52 |
+
make sure the --dm-pth points to the finetuned weight in the ./output/finetune_dm folder. The result will be saved
|
53 |
+
under ./output_result.
|
54 |
+
|
55 |
+
## Training
|
56 |
+
Run the <strong>train_VAE.sh</strong> to train the VAE model. If you aims to train on one category, just specify one category from <strong> chair,
|
57 |
+
cabinet, table, sofa, bed, shelf</strong>. Inputting <strong>all</strong> will train on all categories. Makes sure to download and preprocess all
|
58 |
+
the required sub-category data. The sub-category arrangement can be found in ./datasets/taxonomy.py <br>
|
59 |
+
After finish training the VAE model, run the following commands to pre-extract the VAE features for every object:
|
60 |
+
```angular2html
|
61 |
+
cd process_scripts
|
62 |
+
bash dist_export_triplane_features.sh
|
63 |
+
```
|
64 |
+
Then, we can start training the diffusion model on the synthetic dataset by running the <strong>train_diffusion.sh</strong>.<br>
|
65 |
+
Finally, finetune the diffusion model on LASA dataset by running <strong> finetune_diffusion.sh</strong>. <br><br>
|
66 |
+
|
67 |
+
Early stopping is used by mannualy stopping the training by 150 epochs and 500 epochs for training VAE model and diffusion model respetively.
|
68 |
+
All experiments in the paper are conducted on 8 A100 GPUs with batch size = 22.
|
69 |
+
## TODO
|
70 |
+
|
71 |
+
- [ ] Object Detection Code
|
72 |
+
- [ ] Code for Demo on both arkitscene and in the wild data
|
73 |
+
|
74 |
+
## Citation
|
75 |
+
```
|
76 |
+
@article{liu2023lasa,
|
77 |
+
title={LASA: Instance Reconstruction from Real Scans using A Large-scale Aligned Shape Annotation Dataset},
|
78 |
+
author={Liu, Haolin and Ye, Chongjie and Nie, Yinyu and He, Yingfan and Han, Xiaoguang},
|
79 |
+
journal={arXiv preprint arXiv:2312.12418},
|
80 |
+
year={2023}
|
81 |
+
}
|
82 |
+
```
|
configs/config_utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import logging
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
def update_recursive(dict1, dict2):
|
7 |
+
''' Update two config dictionaries recursively.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
dict1 (dict): first dictionary to be updated
|
11 |
+
dict2 (dict): second dictionary which entries should be used
|
12 |
+
|
13 |
+
'''
|
14 |
+
for k, v in dict2.items():
|
15 |
+
if k not in dict1:
|
16 |
+
dict1[k] = dict()
|
17 |
+
if isinstance(v, dict):
|
18 |
+
update_recursive(dict1[k], v)
|
19 |
+
else:
|
20 |
+
dict1[k] = v
|
21 |
+
|
22 |
+
class CONFIG(object):
|
23 |
+
'''
|
24 |
+
Stores all configures
|
25 |
+
'''
|
26 |
+
def __init__(self, input=None):
|
27 |
+
'''
|
28 |
+
Loads config file
|
29 |
+
:param path (str): path to config file
|
30 |
+
:return:
|
31 |
+
'''
|
32 |
+
self.config = self.read_to_dict(input)
|
33 |
+
|
34 |
+
def read_to_dict(self, input):
|
35 |
+
if not input:
|
36 |
+
return dict()
|
37 |
+
if isinstance(input, str) and os.path.isfile(input):
|
38 |
+
if input.endswith('yaml'):
|
39 |
+
with open(input, 'r') as f:
|
40 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
41 |
+
else:
|
42 |
+
ValueError('Config file should be with the format of *.yaml')
|
43 |
+
elif isinstance(input, dict):
|
44 |
+
config = input
|
45 |
+
else:
|
46 |
+
raise ValueError('Unrecognized input type (i.e. not *.yaml file nor dict).')
|
47 |
+
|
48 |
+
return config
|
49 |
+
|
50 |
+
def update_config(self, *args, **kwargs):
|
51 |
+
'''
|
52 |
+
update config and corresponding logger setting
|
53 |
+
:param input: dict settings add to config file
|
54 |
+
:return:
|
55 |
+
'''
|
56 |
+
cfg1 = dict()
|
57 |
+
for item in args:
|
58 |
+
cfg1.update(self.read_to_dict(item))
|
59 |
+
|
60 |
+
cfg2 = self.read_to_dict(kwargs)
|
61 |
+
|
62 |
+
new_cfg = {**cfg1, **cfg2}
|
63 |
+
|
64 |
+
update_recursive(self.config, new_cfg)
|
65 |
+
# when update config file, the corresponding logger should also be updated.
|
66 |
+
self.__update_logger()
|
67 |
+
|
68 |
+
def write_config(self,save_path):
|
69 |
+
with open(save_path, 'w') as file:
|
70 |
+
yaml.dump(self.config, file, default_flow_style = False)
|
configs/finetune_triplane_diffusion.yaml
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
ae: #ae model is loaded to
|
3 |
+
type: TriVAE
|
4 |
+
point_emb_dim: 48
|
5 |
+
padding: 0.1
|
6 |
+
encoder:
|
7 |
+
plane_reso: 128
|
8 |
+
plane_latent_dim: 32
|
9 |
+
latent_dim: 32
|
10 |
+
unet:
|
11 |
+
depth: 4
|
12 |
+
merge_mode: concat
|
13 |
+
start_filts: 32
|
14 |
+
output_dim: 64
|
15 |
+
decoder:
|
16 |
+
plane_reso: 128
|
17 |
+
latent_dim: 32
|
18 |
+
n_blocks: 5
|
19 |
+
query_emb_dim: 48
|
20 |
+
hidden_dim: 128
|
21 |
+
unet:
|
22 |
+
depth: 4
|
23 |
+
merge_mode: concat
|
24 |
+
start_filts: 64
|
25 |
+
output_dim: 32
|
26 |
+
dm:
|
27 |
+
type: triplane_diff_multiimg_cond
|
28 |
+
backbone: resunet_multiimg_direct_atten
|
29 |
+
diff_reso: 64
|
30 |
+
input_channel: 32
|
31 |
+
output_channel: 32
|
32 |
+
triplane_padding: 0.1 #should be consistent with padding in ae
|
33 |
+
|
34 |
+
use_par: True
|
35 |
+
par_channel: 32
|
36 |
+
par_emb_dim: 48
|
37 |
+
norm: "batch"
|
38 |
+
img_in_channels: 1280
|
39 |
+
vit_reso: 16
|
40 |
+
use_cat_embedding: ???
|
41 |
+
block_type: multiview_local
|
42 |
+
par_point_encoder:
|
43 |
+
plane_reso: 64
|
44 |
+
plane_latent_dim: 32
|
45 |
+
n_blocks: 5
|
46 |
+
unet:
|
47 |
+
depth: 3
|
48 |
+
merge_mode: concat
|
49 |
+
start_filts: 32
|
50 |
+
output_dim: 32
|
51 |
+
criterion:
|
52 |
+
type: EDMLoss_MultiImgCond
|
53 |
+
use_par: True
|
54 |
+
dataset:
|
55 |
+
type: Occ_Par_MultiImg_Finetune
|
56 |
+
data_path: ???
|
57 |
+
surface_size: 20000
|
58 |
+
par_pc_size: 2048
|
59 |
+
load_proj_mat: True
|
60 |
+
load_image: True
|
61 |
+
par_point_aug: 0.5
|
62 |
+
par_prefix: "aug7_"
|
63 |
+
keyword: lowres #use lowres arkitscene or highres to train, lowres scene is more user accessible
|
64 |
+
jitter_partial_pretrain: 0.02
|
65 |
+
jitter_partial_finetune: 0.02
|
66 |
+
jitter_partial_val: 0.0
|
67 |
+
use_pretrain_data: False
|
configs/train_triplane_diffusion.yaml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
ae: #ae model is loaded to
|
3 |
+
type: TriVAE
|
4 |
+
point_emb_dim: 48
|
5 |
+
padding: 0.1
|
6 |
+
encoder:
|
7 |
+
plane_reso: 128
|
8 |
+
plane_latent_dim: 32
|
9 |
+
latent_dim: 32
|
10 |
+
unet:
|
11 |
+
depth: 4
|
12 |
+
merge_mode: concat
|
13 |
+
start_filts: 32
|
14 |
+
output_dim: 64
|
15 |
+
decoder:
|
16 |
+
plane_reso: 128
|
17 |
+
latent_dim: 32
|
18 |
+
n_blocks: 5
|
19 |
+
query_emb_dim: 48
|
20 |
+
hidden_dim: 128
|
21 |
+
unet:
|
22 |
+
depth: 4
|
23 |
+
merge_mode: concat
|
24 |
+
start_filts: 64
|
25 |
+
output_dim: 32
|
26 |
+
dm:
|
27 |
+
type: triplane_diff_multiimg_cond
|
28 |
+
backbone: resunet_multiimg_direct_atten
|
29 |
+
diff_reso: 64
|
30 |
+
input_channel: 32
|
31 |
+
output_channel: 32
|
32 |
+
triplane_padding: 0.1 #should be consistent with padding in ae
|
33 |
+
|
34 |
+
use_par: True
|
35 |
+
par_channel: 32
|
36 |
+
par_emb_dim: 48
|
37 |
+
norm: "batch"
|
38 |
+
img_in_channels: 1280
|
39 |
+
vit_reso: 16
|
40 |
+
use_cat_embedding: ???
|
41 |
+
block_type: multiview_local
|
42 |
+
par_point_encoder:
|
43 |
+
plane_reso: 64
|
44 |
+
plane_latent_dim: 32
|
45 |
+
n_blocks: 5
|
46 |
+
unet:
|
47 |
+
depth: 3
|
48 |
+
merge_mode: concat
|
49 |
+
start_filts: 32
|
50 |
+
output_dim: 32
|
51 |
+
criterion:
|
52 |
+
type: EDMLoss_MultiImgCond
|
53 |
+
use_par: True
|
54 |
+
dataset:
|
55 |
+
type: Occ_Par_MultiImg
|
56 |
+
data_path: ???
|
57 |
+
surface_size: 20000
|
58 |
+
par_pc_size: 2048
|
59 |
+
load_proj_mat: True
|
60 |
+
load_image: True
|
61 |
+
par_point_aug: 0.5
|
62 |
+
par_prefix: "aug7_" # prefix of the filenames of the partial point cloud
|
63 |
+
jitter_partial_train: 0.02
|
64 |
+
jitter_partial_val: 0.0
|
configs/train_triplane_vae.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
type: TriVAE
|
3 |
+
point_emb_dim: 48
|
4 |
+
padding: 0.1
|
5 |
+
encoder:
|
6 |
+
plane_reso: 128
|
7 |
+
plane_latent_dim: 32
|
8 |
+
latent_dim: 32
|
9 |
+
unet:
|
10 |
+
depth: 4
|
11 |
+
merge_mode: concat
|
12 |
+
start_filts: 32
|
13 |
+
output_dim: 64
|
14 |
+
decoder:
|
15 |
+
plane_reso: 128
|
16 |
+
latent_dim: 32
|
17 |
+
n_blocks: 5
|
18 |
+
query_emb_dim: 48
|
19 |
+
hidden_dim: 128
|
20 |
+
unet:
|
21 |
+
depth: 4
|
22 |
+
merge_mode: concat
|
23 |
+
start_filts: 64
|
24 |
+
output_dim: 32
|
25 |
+
dataset:
|
26 |
+
type: Occ
|
27 |
+
category: chair
|
28 |
+
data_path: ???
|
29 |
+
surface_size: 20000
|
30 |
+
num_samples: 2048
|
data/download_preprocess_data_here
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
datasets/SingleView_dataset.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import random
|
4 |
+
|
5 |
+
import yaml
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils import data
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import json
|
12 |
+
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
import h5py
|
16 |
+
import torch.distributed as dist
|
17 |
+
import open3d as o3d
|
18 |
+
o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
|
19 |
+
import pickle as p
|
20 |
+
import time
|
21 |
+
import cv2
|
22 |
+
from torchvision import transforms
|
23 |
+
import copy
|
24 |
+
from datasets.taxonomy import category_map_from_synthetic as category_ids
|
25 |
+
class Object_Occ(data.Dataset):
|
26 |
+
def __init__(self, dataset_folder, split, categories=['03001627', "future_chair", 'ABO_chair'], transform=None,
|
27 |
+
sampling=True,
|
28 |
+
num_samples=4096, return_surface=True, surface_sampling=True, surface_size=2048, replica=16):
|
29 |
+
|
30 |
+
self.pc_size = surface_size
|
31 |
+
|
32 |
+
self.transform = transform
|
33 |
+
self.num_samples = num_samples
|
34 |
+
self.sampling = sampling
|
35 |
+
self.split = split
|
36 |
+
|
37 |
+
self.dataset_folder = dataset_folder
|
38 |
+
self.return_surface = return_surface
|
39 |
+
self.surface_sampling = surface_sampling
|
40 |
+
|
41 |
+
self.dataset_folder = dataset_folder
|
42 |
+
self.point_folder = os.path.join(self.dataset_folder, 'occ_data')
|
43 |
+
self.mesh_folder = os.path.join(self.dataset_folder, 'other_data')
|
44 |
+
|
45 |
+
if categories is None:
|
46 |
+
categories = os.listdir(self.point_folder)
|
47 |
+
categories = [c for c in categories if
|
48 |
+
os.path.isdir(os.path.join(self.point_folder, c)) and c.startswith('0')]
|
49 |
+
categories.sort()
|
50 |
+
|
51 |
+
print(categories)
|
52 |
+
|
53 |
+
self.models = []
|
54 |
+
for c_idx, c in enumerate(categories):
|
55 |
+
subpath = os.path.join(self.point_folder, c)
|
56 |
+
print(subpath)
|
57 |
+
assert os.path.isdir(subpath)
|
58 |
+
|
59 |
+
split_file = os.path.join(subpath, split + '.lst')
|
60 |
+
with open(split_file, 'r') as f:
|
61 |
+
models_c = f.readlines()
|
62 |
+
models_c = [item.rstrip('\n') for item in models_c]
|
63 |
+
|
64 |
+
for m in models_c[:]:
|
65 |
+
if len(m)<=1:
|
66 |
+
continue
|
67 |
+
if m.endswith('.npz'):
|
68 |
+
model_id = m[:-4]
|
69 |
+
else:
|
70 |
+
model_id = m
|
71 |
+
self.models.append({
|
72 |
+
'category': c, 'model': model_id
|
73 |
+
})
|
74 |
+
self.replica = replica
|
75 |
+
|
76 |
+
def __getitem__(self, idx):
|
77 |
+
if self.replica >= 1:
|
78 |
+
idx = idx % len(self.models)
|
79 |
+
else:
|
80 |
+
random_segment = random.randint(0, int(1 / self.replica) - 1)
|
81 |
+
idx = int(random_segment * self.replica * len(self.models) + idx)
|
82 |
+
|
83 |
+
category = self.models[idx]['category']
|
84 |
+
model = self.models[idx]['model']
|
85 |
+
|
86 |
+
point_path = os.path.join(self.point_folder, category, model + '.npz')
|
87 |
+
# print(point_path)
|
88 |
+
try:
|
89 |
+
start_t = time.time()
|
90 |
+
with np.load(point_path) as data:
|
91 |
+
vol_points = data['vol_points']
|
92 |
+
vol_label = data['vol_label']
|
93 |
+
near_points = data['near_points']
|
94 |
+
near_label = data['near_label']
|
95 |
+
end_t = time.time()
|
96 |
+
# print("loading time %f"%(end_t-start_t))
|
97 |
+
except Exception as e:
|
98 |
+
print(e)
|
99 |
+
print(point_path)
|
100 |
+
|
101 |
+
with open(point_path.replace('.npz', '.npy'), 'rb') as f:
|
102 |
+
scale = np.load(f).item()
|
103 |
+
# scale=1.0
|
104 |
+
|
105 |
+
if self.return_surface:
|
106 |
+
pc_path = os.path.join(self.mesh_folder, category, '4_pointcloud', model + '.npz')
|
107 |
+
with np.load(pc_path) as data:
|
108 |
+
try:
|
109 |
+
surface = data['points'].astype(np.float32)
|
110 |
+
except:
|
111 |
+
print(pc_path,"has problems")
|
112 |
+
raise AttributeError
|
113 |
+
surface = surface * scale
|
114 |
+
if self.surface_sampling:
|
115 |
+
ind = np.random.default_rng().choice(surface.shape[0], self.pc_size, replace=False)
|
116 |
+
surface = surface[ind]
|
117 |
+
surface = torch.from_numpy(surface)
|
118 |
+
|
119 |
+
if self.sampling:
|
120 |
+
'''need to conduct label balancing'''
|
121 |
+
vol_ind=np.random.default_rng().choice(vol_points.shape[0], self.num_samples,
|
122 |
+
replace=(vol_points.shape[0]<self.num_samples))
|
123 |
+
near_ind=np.random.default_rng().choice(near_points.shape[0], self.num_samples,
|
124 |
+
replace=(near_points.shape[0]<self.num_samples))
|
125 |
+
vol_points=vol_points[vol_ind]
|
126 |
+
vol_label=vol_label[vol_ind]
|
127 |
+
near_points=near_points[near_ind]
|
128 |
+
near_label=near_label[near_ind]
|
129 |
+
|
130 |
+
vol_points = torch.from_numpy(vol_points)
|
131 |
+
vol_label = torch.from_numpy(vol_label).float()
|
132 |
+
|
133 |
+
if self.split == 'train':
|
134 |
+
near_points = torch.from_numpy(near_points)
|
135 |
+
near_label = torch.from_numpy(near_label).float()
|
136 |
+
|
137 |
+
points = torch.cat([vol_points, near_points], dim=0)
|
138 |
+
labels = torch.cat([vol_label, near_label], dim=0)
|
139 |
+
else:
|
140 |
+
points = vol_points
|
141 |
+
labels = vol_label
|
142 |
+
|
143 |
+
tran_mat=np.eye(4)
|
144 |
+
if self.transform:
|
145 |
+
surface, points, _,_, tran_mat = self.transform(surface, points)
|
146 |
+
|
147 |
+
data_dict = {
|
148 |
+
"points": points,
|
149 |
+
"labels": labels,
|
150 |
+
"category_ids": category_ids[category],
|
151 |
+
"model_id": model,
|
152 |
+
"tran_mat":tran_mat,
|
153 |
+
"category":category,
|
154 |
+
}
|
155 |
+
if self.return_surface:
|
156 |
+
data_dict["surface"] = surface
|
157 |
+
|
158 |
+
return data_dict
|
159 |
+
|
160 |
+
def __len__(self):
|
161 |
+
if self.split != 'train':
|
162 |
+
return len(self.models)
|
163 |
+
else:
|
164 |
+
return int(len(self.models) * self.replica)
|
165 |
+
|
166 |
+
class Object_PartialPoints_MultiImg(data.Dataset):
|
167 |
+
def __init__(self, dataset_folder, split, split_filename, categories=['03001627', 'future_chair', 'ABO_chair'],
|
168 |
+
transform=None, sampling=True, num_samples=4096,
|
169 |
+
return_surface=True, ret_sample=True,surface_sampling=True,
|
170 |
+
surface_size=20000,par_pc_size=2048, par_point_aug=None,par_prefix="aug7_",
|
171 |
+
load_proj_mat=False,load_image=False,load_org_img=False,max_img_length=5,load_triplane=True,replica=2,
|
172 |
+
eval_multiview=False,scene_id=None,num_objects=-1):
|
173 |
+
|
174 |
+
self.surface_size = surface_size
|
175 |
+
self.par_pc_size=par_pc_size
|
176 |
+
self.transform = transform
|
177 |
+
self.num_samples = num_samples
|
178 |
+
self.sampling = sampling
|
179 |
+
self.split = split
|
180 |
+
self.par_point_aug=par_point_aug
|
181 |
+
self.par_prefix=par_prefix
|
182 |
+
|
183 |
+
self.dataset_folder = dataset_folder
|
184 |
+
self.return_surface = return_surface
|
185 |
+
self.ret_sample=ret_sample
|
186 |
+
self.surface_sampling = surface_sampling
|
187 |
+
self.load_proj_mat=load_proj_mat
|
188 |
+
self.load_img=load_image
|
189 |
+
self.load_org_img=load_org_img
|
190 |
+
self.load_triplane=load_triplane
|
191 |
+
self.max_img_length=max_img_length
|
192 |
+
self.eval_multiview=eval_multiview
|
193 |
+
|
194 |
+
self.dataset_folder = dataset_folder
|
195 |
+
self.point_folder = os.path.join(self.dataset_folder, 'occ_data')
|
196 |
+
self.mesh_folder = os.path.join(self.dataset_folder, 'other_data')
|
197 |
+
|
198 |
+
if scene_id is not None:
|
199 |
+
scene_model_map_path=os.path.join(self.dataset_folder,"modelid_in_sceneid.json")
|
200 |
+
with open(scene_model_map_path,'r') as f:
|
201 |
+
scene_model_map=json.load(f)
|
202 |
+
valid_modelid=scene_model_map[scene_id]
|
203 |
+
|
204 |
+
if categories is None:
|
205 |
+
categories = os.listdir(self.point_folder)
|
206 |
+
categories = [c for c in categories if
|
207 |
+
os.path.isdir(os.path.join(self.point_folder, c)) and c.startswith('0')]
|
208 |
+
categories.sort()
|
209 |
+
|
210 |
+
print(categories)
|
211 |
+
self.models = []
|
212 |
+
self.model_images_names = {}
|
213 |
+
for c_idx, c in enumerate(categories):
|
214 |
+
cat_count=0
|
215 |
+
subpath = os.path.join(self.point_folder, c)
|
216 |
+
print(subpath)
|
217 |
+
assert os.path.isdir(subpath)
|
218 |
+
|
219 |
+
split_file = os.path.join(subpath, split_filename)
|
220 |
+
with open(split_file, 'r') as f:
|
221 |
+
splits = json.load(f)
|
222 |
+
for item in splits:
|
223 |
+
# print(item)
|
224 |
+
model_id = item['model_id']
|
225 |
+
if scene_id is not None and model_id not in valid_modelid:
|
226 |
+
continue
|
227 |
+
image_filenames = item['image_filenames']
|
228 |
+
partial_filenames = item['partial_filenames']
|
229 |
+
if len(image_filenames)==0 or len(partial_filenames)==0:
|
230 |
+
continue
|
231 |
+
self.model_images_names[model_id] = image_filenames
|
232 |
+
if split=="train":
|
233 |
+
self.models += [
|
234 |
+
{'category': c, 'model': model_id, "partial_filenames": partial_filenames,
|
235 |
+
"image_filenames": image_filenames}
|
236 |
+
]
|
237 |
+
else:
|
238 |
+
if self.eval_multiview:
|
239 |
+
for length in range(0,len(image_filenames)):
|
240 |
+
self.models+=[
|
241 |
+
{'category': c, 'model': model_id, "partial_filenames": partial_filenames[0:1],
|
242 |
+
"image_filenames": image_filenames[0:length+1]}
|
243 |
+
]
|
244 |
+
self.models += [
|
245 |
+
{'category': c, 'model': model_id, "partial_filenames": partial_filenames[0:1],
|
246 |
+
"image_filenames": image_filenames}
|
247 |
+
]
|
248 |
+
if num_objects!=-1:
|
249 |
+
indexes=np.linspace(0,len(self.models)-1,num=num_objects).astype(np.int32)
|
250 |
+
self.models = [self.models[i] for i in indexes]
|
251 |
+
|
252 |
+
self.replica = replica
|
253 |
+
|
254 |
+
def load_samples(self,point_path):
|
255 |
+
try:
|
256 |
+
start_t = time.time()
|
257 |
+
with np.load(point_path) as data:
|
258 |
+
vol_points = data['vol_points']
|
259 |
+
vol_label = data['vol_label']
|
260 |
+
near_points = data['near_points']
|
261 |
+
near_label = data['near_label']
|
262 |
+
end_t = time.time()
|
263 |
+
# print("reading time %f"%(end_t-start_t))
|
264 |
+
except Exception as e:
|
265 |
+
print(e)
|
266 |
+
print(point_path)
|
267 |
+
return vol_points,vol_label,near_points,near_label
|
268 |
+
|
269 |
+
def load_surface(self,surface_path,scale):
|
270 |
+
with np.load(surface_path) as data:
|
271 |
+
surface = data['points'].astype(np.float32)
|
272 |
+
surface = surface * scale
|
273 |
+
if self.surface_sampling:
|
274 |
+
ind = np.random.default_rng().choice(surface.shape[0], self.surface_size, replace=False)
|
275 |
+
surface = surface[ind]
|
276 |
+
surface = torch.from_numpy(surface).float()
|
277 |
+
return surface
|
278 |
+
|
279 |
+
def load_par_points(self,partial_path,scale):
|
280 |
+
# print(partial_path)
|
281 |
+
par_point_o3d = o3d.io.read_point_cloud(partial_path)
|
282 |
+
par_points = np.asarray(par_point_o3d.points)
|
283 |
+
par_points = par_points * scale
|
284 |
+
replace = par_points.shape[0] < self.par_pc_size
|
285 |
+
ind = np.random.default_rng().choice(par_points.shape[0], self.par_pc_size, replace=replace)
|
286 |
+
par_points = par_points[ind]
|
287 |
+
par_points = torch.from_numpy(par_points).float()
|
288 |
+
return par_points
|
289 |
+
|
290 |
+
def process_samples(self,vol_points,vol_label,near_points,near_label):
|
291 |
+
if self.sampling:
|
292 |
+
ind = np.random.default_rng().choice(vol_points.shape[0], self.num_samples, replace=False)
|
293 |
+
vol_points = vol_points[ind]
|
294 |
+
vol_label = vol_label[ind]
|
295 |
+
|
296 |
+
ind = np.random.default_rng().choice(near_points.shape[0], self.num_samples, replace=False)
|
297 |
+
near_points = near_points[ind]
|
298 |
+
near_label = near_label[ind]
|
299 |
+
vol_points = torch.from_numpy(vol_points)
|
300 |
+
vol_label = torch.from_numpy(vol_label).float()
|
301 |
+
if self.split == 'train':
|
302 |
+
near_points = torch.from_numpy(near_points)
|
303 |
+
near_label = torch.from_numpy(near_label).float()
|
304 |
+
|
305 |
+
points = torch.cat([vol_points, near_points], dim=0)
|
306 |
+
labels = torch.cat([vol_label, near_label], dim=0)
|
307 |
+
else:
|
308 |
+
ind = np.random.default_rng().choice(vol_points.shape[0], 100000, replace=False)
|
309 |
+
points = vol_points[ind]
|
310 |
+
labels = vol_label[ind]
|
311 |
+
return points,labels
|
312 |
+
|
313 |
+
def __getitem__(self, idx):
|
314 |
+
if self.replica >= 1:
|
315 |
+
idx = idx % len(self.models)
|
316 |
+
else:
|
317 |
+
random_segment = random.randint(0, int(1 / self.replica) - 1)
|
318 |
+
idx = int(random_segment * self.replica * len(self.models) + idx)
|
319 |
+
category = self.models[idx]['category']
|
320 |
+
model = self.models[idx]['model']
|
321 |
+
#image_filenames = self.model_images_names[model]
|
322 |
+
image_filenames = self.models[idx]["image_filenames"]
|
323 |
+
if self.split=="train":
|
324 |
+
n_frames = np.random.randint(min(2,len(image_filenames)), min(len(image_filenames) + 1, self.max_img_length + 1))
|
325 |
+
img_indexes = np.random.choice(len(image_filenames), n_frames,
|
326 |
+
replace=(n_frames > len(image_filenames))).tolist()
|
327 |
+
else:
|
328 |
+
if self.eval_multiview:
|
329 |
+
'''use all images'''
|
330 |
+
n_frames=len(image_filenames)
|
331 |
+
img_indexes=[i for i in range(n_frames)]
|
332 |
+
else:
|
333 |
+
n_frames = min(len(image_filenames),self.max_img_length)
|
334 |
+
img_indexes=np.linspace(start=0,stop=len(image_filenames)-1,num=n_frames).astype(np.int32)
|
335 |
+
|
336 |
+
partial_filenames = self.models[idx]['partial_filenames']
|
337 |
+
par_index = np.random.choice(len(partial_filenames), 1)[0]
|
338 |
+
partial_name = partial_filenames[par_index]
|
339 |
+
|
340 |
+
vol_points,vol_label,near_points,near_label=None,None,None,None
|
341 |
+
points,labels=None,None
|
342 |
+
point_path = os.path.join(self.point_folder, category, model + '.npz')
|
343 |
+
if self.ret_sample:
|
344 |
+
vol_points,vol_label,near_points,near_label=self.load_samples(point_path)
|
345 |
+
points,labels = self.process_samples(vol_points, vol_label, near_points,near_label)
|
346 |
+
|
347 |
+
with open(point_path.replace('.npz', '.npy'), 'rb') as f:
|
348 |
+
scale = np.load(f).item()
|
349 |
+
|
350 |
+
surface=None
|
351 |
+
pc_path = os.path.join(self.mesh_folder, category, '4_pointcloud', model + '.npz')
|
352 |
+
if self.return_surface:
|
353 |
+
surface=self.load_surface(pc_path,scale)
|
354 |
+
|
355 |
+
partial_path = os.path.join(self.mesh_folder, category, "5_partial_points", model, partial_name)
|
356 |
+
if self.par_point_aug is not None and random.random()<self.par_point_aug: #add augmentation
|
357 |
+
par_aug_path=os.path.join(self.mesh_folder, category, "5_partial_points", model, self.par_prefix+partial_name)
|
358 |
+
#print(par_aug_path,os.path.exists(par_aug_path))
|
359 |
+
if os.path.exists(par_aug_path):
|
360 |
+
partial_path=par_aug_path
|
361 |
+
else:
|
362 |
+
raise FileNotFoundError
|
363 |
+
par_points=self.load_par_points(partial_path,scale)
|
364 |
+
|
365 |
+
image_list=[]
|
366 |
+
valid_frames=[]
|
367 |
+
image_namelist=[]
|
368 |
+
if self.load_img:
|
369 |
+
for img_index in img_indexes:
|
370 |
+
image_name=image_filenames[img_index]
|
371 |
+
image_feat_path=os.path.join(self.mesh_folder,category,"7_img_features",model,image_name[:-4]+'.npz')
|
372 |
+
image=np.load(image_feat_path)["img_features"]
|
373 |
+
image_list.append(torch.from_numpy(image).float())
|
374 |
+
valid_frames.append(True)
|
375 |
+
image_namelist.append(image_name)
|
376 |
+
while len(image_list)<self.max_img_length:
|
377 |
+
image_list.append(torch.from_numpy(np.zeros(image_list[0].shape).astype(np.float32)).float())
|
378 |
+
valid_frames.append(False)
|
379 |
+
org_img_list=[]
|
380 |
+
if self.load_org_img:
|
381 |
+
for img_index in img_indexes:
|
382 |
+
image_name = image_filenames[img_index]
|
383 |
+
image_path = os.path.join(self.mesh_folder, category, "6_images", model,
|
384 |
+
image_name)
|
385 |
+
org_image = cv2.imread(image_path)
|
386 |
+
org_image = cv2.resize(org_image,dsize=(224,224),interpolation=cv2.INTER_LINEAR)
|
387 |
+
org_img_list.append(org_image)
|
388 |
+
|
389 |
+
proj_mat=None
|
390 |
+
proj_mat_list=[]
|
391 |
+
if self.load_proj_mat:
|
392 |
+
for img_index in img_indexes:
|
393 |
+
image_name = image_filenames[img_index]
|
394 |
+
proj_mat_path = os.path.join(self.mesh_folder, category, "8_proj_matrix", model, image_name[:-4]+".npy")
|
395 |
+
proj_mat=np.load(proj_mat_path)
|
396 |
+
proj_mat_list.append(proj_mat)
|
397 |
+
while len(proj_mat_list)<self.max_img_length:
|
398 |
+
proj_mat_list.append(np.eye(4))
|
399 |
+
tran_mat=None
|
400 |
+
if self.load_triplane:
|
401 |
+
triplane_folder=os.path.join(self.mesh_folder,category,'9_triplane_kl25_64',model)
|
402 |
+
triplane_list=os.listdir(triplane_folder)
|
403 |
+
select_index=np.random.randint(0,len(triplane_list))
|
404 |
+
triplane_path=os.path.join(triplane_folder,triplane_list[select_index])
|
405 |
+
#triplane_path=os.path.join(triplane_folder,"triplane_feat_0.npz")
|
406 |
+
triplane_content=np.load(triplane_path)
|
407 |
+
triplane_mean,triplane_logvar,tran_mat=triplane_content['mean'],triplane_content['logvar'],triplane_content['tran_mat']
|
408 |
+
tran_mat=torch.from_numpy(tran_mat).float()
|
409 |
+
|
410 |
+
if self.transform:
|
411 |
+
if not self.load_triplane:
|
412 |
+
surface, points, par_points,proj_mat,tran_mat = self.transform(surface, points, par_points,proj_mat_list)
|
413 |
+
tran_mat=torch.from_numpy(tran_mat).float()
|
414 |
+
else:
|
415 |
+
surface, points, par_points, proj_mat = self.transform(surface, points, par_points, proj_mat_list,tran_mat)
|
416 |
+
|
417 |
+
category_id=category_ids[category]
|
418 |
+
one_hot=torch.zeros((6)).float()
|
419 |
+
one_hot[category_id]=1.0
|
420 |
+
ret_dict = {
|
421 |
+
"category_ids": category_ids[category],
|
422 |
+
"category":category,
|
423 |
+
"category_code":one_hot,
|
424 |
+
"model_id": model,
|
425 |
+
"partial_name": partial_name[:-4],
|
426 |
+
"class_name": category,
|
427 |
+
}
|
428 |
+
if tran_mat is not None:
|
429 |
+
ret_dict["tran_mat"]=tran_mat
|
430 |
+
if self.ret_sample:
|
431 |
+
ret_dict["points"]=points
|
432 |
+
ret_dict["labels"]=labels
|
433 |
+
if self.return_surface:
|
434 |
+
ret_dict["surface"] = surface
|
435 |
+
ret_dict["par_points"] = par_points
|
436 |
+
if self.load_img:
|
437 |
+
ret_dict["image"] = torch.stack(image_list,dim=0)
|
438 |
+
ret_dict["valid_frames"]= torch.tensor(valid_frames).bool()
|
439 |
+
if self.load_org_img:
|
440 |
+
ret_dict["org_image"]=org_img_list
|
441 |
+
ret_dict["image_namelist"]=image_namelist
|
442 |
+
if self.load_proj_mat:
|
443 |
+
ret_dict["proj_mat"]=torch.stack([torch.from_numpy(mat) for mat in proj_mat_list],dim=0)
|
444 |
+
if self.load_triplane:
|
445 |
+
ret_dict['triplane_mean']=torch.from_numpy(triplane_mean).float()
|
446 |
+
ret_dict['triplane_logvar'] = torch.from_numpy(triplane_logvar).float()
|
447 |
+
return ret_dict
|
448 |
+
|
449 |
+
def __len__(self):
|
450 |
+
if self.split != 'train':
|
451 |
+
return len(self.models)
|
452 |
+
else:
|
453 |
+
return int(len(self.models) * self.replica)
|
datasets/__init__.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data
|
2 |
+
|
3 |
+
from .SingleView_dataset import Object_Occ,Object_PartialPoints_MultiImg
|
4 |
+
from .transforms import Scale_Shift_Rotate,Aug_with_Tran, Augment_Points
|
5 |
+
from .taxonomy import synthetic_category_combined,synthetic_arkit_category_combined,arkit_category
|
6 |
+
|
7 |
+
def build_object_occ_dataset(split,args):
|
8 |
+
transform = Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True,use_whole_scale=True)
|
9 |
+
category=args['category']
|
10 |
+
#category_list=synthetic_category_combined[category]
|
11 |
+
category_list=synthetic_arkit_category_combined[category]
|
12 |
+
replica=args['replica']
|
13 |
+
if split == "train":
|
14 |
+
return Object_Occ(args['data_path'], split=split, categories=category_list,
|
15 |
+
transform=transform, sampling=True,
|
16 |
+
num_samples=args['num_samples'], return_surface=True,
|
17 |
+
surface_sampling=True, surface_size=args['surface_size'],replica=replica)
|
18 |
+
elif split == "val":
|
19 |
+
return Object_Occ(args['data_path'], split=split,categories=category_list,
|
20 |
+
transform=transform, sampling=False,
|
21 |
+
num_samples=args['num_samples'], return_surface=True,
|
22 |
+
surface_sampling=True,surface_size=args['surface_size'], replica=1)
|
23 |
+
|
24 |
+
def build_par_multiimg_dataset(split,args):
|
25 |
+
#transform=Scale_Shift_Rotate(rot_shift_surface=False,use_scale=False,use_shift=False,use_rot=False) #fix the encoder into cannonical space
|
26 |
+
#transform=Scale_Shift_Rotate(rot_shift_surface=True)
|
27 |
+
transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_train'])
|
28 |
+
val_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_val'])
|
29 |
+
category=args['category']
|
30 |
+
category_list=synthetic_category_combined[category]
|
31 |
+
if split == "train":
|
32 |
+
return Object_PartialPoints_MultiImg(args['data_path'], split_filename="train_par_img.json",split=split,
|
33 |
+
categories=category_list,
|
34 |
+
transform=transform, sampling=True,
|
35 |
+
num_samples=1024, return_surface=False,ret_sample=False,
|
36 |
+
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
|
37 |
+
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,
|
38 |
+
par_prefix=args['par_prefix'],par_point_aug=args['par_point_aug'],replica=args['replica'],
|
39 |
+
num_objects=args['num_objects'])
|
40 |
+
elif split =="val":
|
41 |
+
return Object_PartialPoints_MultiImg(args['data_path'], split_filename="val_par_img.json",split=split,
|
42 |
+
categories=category_list,
|
43 |
+
transform=val_transform, sampling=False,
|
44 |
+
num_samples=1024, return_surface=False,ret_sample=True,
|
45 |
+
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
|
46 |
+
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,
|
47 |
+
par_prefix=args['par_prefix'],par_point_aug=None,replica=1)
|
48 |
+
|
49 |
+
def build_finetune_par_multiimg_dataset(split,args):
|
50 |
+
#transform=Scale_Shift_Rotate(rot_shift_surface=False,use_scale=False,use_shift=False,use_rot=False) #fix the encoder into cannonical space
|
51 |
+
#transform=Scale_Shift_Rotate(rot_shift_surface=True)
|
52 |
+
keyword=args['keyword']
|
53 |
+
pretrain_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_pretrain']) #add more noise to partial points
|
54 |
+
finetune_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_finetune'])
|
55 |
+
val_transform=Aug_with_Tran(par_jitter_sigma=args['jitter_partial_val'])
|
56 |
+
|
57 |
+
pretrain_cat=synthetic_category_combined[args['category']]
|
58 |
+
arkit_cat=arkit_category[args['category']]
|
59 |
+
use_pretrain_data=args["use_pretrain_data"]
|
60 |
+
#print(arkit_cat,pretrain_cat)
|
61 |
+
if split == "train":
|
62 |
+
if use_pretrain_data:
|
63 |
+
pretrain_dataset=Object_PartialPoints_MultiImg(args['data_path'], split_filename="train_par_img.json",categories=pretrain_cat,
|
64 |
+
split=split,transform=pretrain_transform, sampling=True,num_samples=1024, return_surface=False,ret_sample=False,
|
65 |
+
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
|
66 |
+
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=args['par_point_aug'],
|
67 |
+
par_prefix=args['par_prefix'],replica=1)
|
68 |
+
finetune_dataset=Object_PartialPoints_MultiImg(args['data_path'], split_filename=keyword+"_train_par_img.json",categories=arkit_cat,
|
69 |
+
split=split,transform=finetune_transform, sampling=True,num_samples=1024, return_surface=False,ret_sample=False,
|
70 |
+
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
|
71 |
+
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=None,replica=args['replica'])
|
72 |
+
if use_pretrain_data:
|
73 |
+
return torch.utils.data.ConcatDataset([pretrain_dataset,finetune_dataset])
|
74 |
+
else:
|
75 |
+
return finetune_dataset
|
76 |
+
elif split =="val":
|
77 |
+
return Object_PartialPoints_MultiImg(args['data_path'], split_filename=keyword+"_val_par_img.json",categories=arkit_cat,split=split,
|
78 |
+
transform=val_transform, sampling=False,
|
79 |
+
num_samples=1024, return_surface=False,ret_sample=True,
|
80 |
+
surface_sampling=True, par_pc_size=args['par_pc_size'],surface_size=args['surface_size'],
|
81 |
+
load_proj_mat=args['load_proj_mat'],load_image=args['load_image'],load_triplane=True,par_point_aug=None,replica=1)
|
82 |
+
|
83 |
+
def build_dataset(split,args):
|
84 |
+
if args['type']=="Occ":
|
85 |
+
return build_object_occ_dataset(split,args)
|
86 |
+
elif args['type']=="Occ_Par_MultiImg":
|
87 |
+
return build_par_multiimg_dataset(split,args)
|
88 |
+
elif args['type']=="Occ_Par_MultiImg_Finetune":
|
89 |
+
return build_finetune_par_multiimg_dataset(split,args)
|
90 |
+
else:
|
91 |
+
raise NotImplementedError
|
datasets/taxonomy.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
category_map={
|
2 |
+
"bathtub":0,
|
3 |
+
"bed":1,
|
4 |
+
"cabinet":2,
|
5 |
+
"chair":3,
|
6 |
+
"dishwasher":4,
|
7 |
+
"fireplace":5,
|
8 |
+
"oven":6,
|
9 |
+
"refrigerator":7,
|
10 |
+
"shelf":8,
|
11 |
+
"sink":9,
|
12 |
+
"sofa":10,
|
13 |
+
"stool":11,
|
14 |
+
"stove":12,
|
15 |
+
"table":13,
|
16 |
+
"toilet":14,
|
17 |
+
"washer":15
|
18 |
+
}
|
19 |
+
|
20 |
+
category_map_from_synthetic={
|
21 |
+
"03001627":0,
|
22 |
+
"future_chair":0,
|
23 |
+
"ABO_chair":0,
|
24 |
+
"arkit_chair":0,
|
25 |
+
"future_stool":0,
|
26 |
+
"arkit_stool":0,
|
27 |
+
|
28 |
+
"04256520":1,
|
29 |
+
"future_sofa":1,
|
30 |
+
"ABO_sofa":1,
|
31 |
+
"arkit_sofa":1,
|
32 |
+
|
33 |
+
"04379243":2,
|
34 |
+
"ABO_table":2,
|
35 |
+
"future_table":2,
|
36 |
+
"arkit_table":2,
|
37 |
+
|
38 |
+
"02933112":3,
|
39 |
+
"future_cabinet":3,
|
40 |
+
"ABO_cabinet":3,
|
41 |
+
"arkit_cabinet":3,
|
42 |
+
"arkit_oven":3,
|
43 |
+
"arkit_refrigerator":3,
|
44 |
+
"arkit_dishwasher":3,
|
45 |
+
"03207941":3,
|
46 |
+
|
47 |
+
"02818832":4,
|
48 |
+
"future_bed":4,
|
49 |
+
"ABO_bed":4,
|
50 |
+
"arkit_bed":4,
|
51 |
+
|
52 |
+
"02871439":5,
|
53 |
+
"future_shelf":5,
|
54 |
+
"ABO_shelf":5,
|
55 |
+
"arkit_shelf":5,
|
56 |
+
|
57 |
+
}
|
58 |
+
|
59 |
+
synthetic_category_combined={
|
60 |
+
"sofa":["future_sofa","ABO_sofa","04256520"],
|
61 |
+
"chair":["03001627","future_chair","ABO_chair",
|
62 |
+
"future_stool"],
|
63 |
+
"table":[
|
64 |
+
"04379243",
|
65 |
+
"future_table",
|
66 |
+
"ABO_table",
|
67 |
+
],
|
68 |
+
"cabinet":["02933112","03207941","future_cabinet","ABO_cabinet"],
|
69 |
+
"bed":["02818832","future_bed","ABO_bed"],
|
70 |
+
"shelf":["02871439","future_shelf","ABO_shelf"],
|
71 |
+
"all":["future_sofa","ABO_sofa","04256520",
|
72 |
+
"03001627", "future_chair", "ABO_chair",
|
73 |
+
"future_stool","04379243","future_table",
|
74 |
+
"ABO_table","02933112","03207941","future_cabinet","ABO_cabinet",
|
75 |
+
"02818832","future_bed","ABO_bed",
|
76 |
+
"02871439","future_shelf","ABO_shelf"
|
77 |
+
]
|
78 |
+
}
|
79 |
+
|
80 |
+
synthetic_arkit_category_combined={
|
81 |
+
"sofa":["future_sofa","ABO_sofa","04256520","arkit_sofa"],
|
82 |
+
"chair":["03001627","future_chair","ABO_chair",
|
83 |
+
"future_stool","arkit_chair","arkit_stool"],
|
84 |
+
"table":["04379243","ABO_table","future_table","arkit_table"],
|
85 |
+
"cabinet":["02933112","03207941","future_cabinet","ABO_cabinet","arkit_cabinet","arkit_stove","arkit_washer","arkit_dishwasher","arkit_refrigerator","arkit_oven"],
|
86 |
+
"bed":["02818832","future_bed","ABO_bed","arkit_bed"],
|
87 |
+
"shelf":["02871439","future_shelf","ABO_shelf","arkit_shelf"],
|
88 |
+
"all":[
|
89 |
+
"future_sofa","ABO_sofa","04256520","arkit_sofa",
|
90 |
+
"03001627","future_chair","ABO_chair",
|
91 |
+
"future_stool","arkit_chair","arkit_stool",
|
92 |
+
"04379243","ABO_table","future_table","arkit_table",
|
93 |
+
"02933112","03207941","future_cabinet","ABO_cabinet","arkit_cabinet","arkit_dishwasher","arkit_refrigerator","arkit_oven",
|
94 |
+
"02818832","future_bed","ABO_bed","arkit_bed",
|
95 |
+
"02871439","future_shelf","ABO_shelf","arkit_shelf"
|
96 |
+
]
|
97 |
+
}
|
98 |
+
|
99 |
+
arkit_category={
|
100 |
+
"chair":["arkit_chair","arkit_stool"],
|
101 |
+
"sofa":["arkit_sofa"],
|
102 |
+
"table":["arkit_table"],
|
103 |
+
"cabinet":["arkit_cabinet","arkit_dishwasher","arkit_refrigerator","arkit_oven"],
|
104 |
+
"bed":["arkit_bed"],
|
105 |
+
"shelf":["arkit_shelf"],
|
106 |
+
"all":["arkit_chair","arkit_stool",
|
107 |
+
"arkit_sofa","arkit_table",
|
108 |
+
"arkit_cabinet","arkit_dishwasher","arkit_refrigerator","arkit_oven",
|
109 |
+
"arkit_bed",
|
110 |
+
"arkit_shelf"],
|
111 |
+
}
|
datasets/transforms.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def get_rot_from_yaw(angle):
|
5 |
+
cy=torch.cos(angle)
|
6 |
+
sy=torch.sin(angle)
|
7 |
+
R=torch.tensor([[cy,0,-sy],
|
8 |
+
[0,1,0],
|
9 |
+
[sy,0,cy]]).float()
|
10 |
+
return R
|
11 |
+
|
12 |
+
class Aug_with_Tran(object):
|
13 |
+
def __init__(self,jitter_surface=True,jitter_partial=True,par_jitter_sigma=0.02):
|
14 |
+
self.jitter_surface=jitter_surface
|
15 |
+
self.jitter_partial=jitter_partial
|
16 |
+
self.par_jitter_sigma=par_jitter_sigma
|
17 |
+
|
18 |
+
def __call__(self,surface,point,par_points,proj_mat,tran_mat):
|
19 |
+
if surface is not None:surface=torch.mm(surface,tran_mat[0:3,0:3].transpose(0,1))+tran_mat[0:3,3]
|
20 |
+
if point is not None:point=torch.mm(point,tran_mat[0:3,0:3].transpose(0,1))+tran_mat[0:3,3]
|
21 |
+
if par_points is not None:par_points=torch.mm(par_points,tran_mat[0:3,0:3].transpose(0,1))+tran_mat[0:3,3]
|
22 |
+
if proj_mat is not None:
|
23 |
+
'''need to put the augmentation back'''
|
24 |
+
inv_tran_mat = np.linalg.inv(tran_mat)
|
25 |
+
if isinstance(proj_mat, list):
|
26 |
+
for idx, mat in enumerate(proj_mat):
|
27 |
+
mat = np.dot(mat, inv_tran_mat)
|
28 |
+
proj_mat[idx] = mat
|
29 |
+
else:
|
30 |
+
proj_mat = np.dot(proj_mat, inv_tran_mat)
|
31 |
+
|
32 |
+
if self.jitter_surface and surface is not None:
|
33 |
+
surface += 0.005 * torch.randn_like(surface)
|
34 |
+
surface.clamp_(min=-1, max=1)
|
35 |
+
if self.jitter_partial and par_points is not None:
|
36 |
+
par_points+=self.par_jitter_sigma * torch.randn_like(par_points)
|
37 |
+
|
38 |
+
|
39 |
+
return surface,point,par_points,proj_mat
|
40 |
+
|
41 |
+
|
42 |
+
#add small augmentation
|
43 |
+
class Scale_Shift_Rotate(object):
|
44 |
+
def __init__(self, interval=(0.75, 1.25), angle=(-5,5), shift=(-0.1,0.1), use_scale=True,use_whole_scale=False,use_rot=True,
|
45 |
+
use_shift=True,jitter=True,jitter_partial=True,par_jitter_sigma=0.02,rot_shift_surface=True):
|
46 |
+
assert isinstance(interval, tuple)
|
47 |
+
self.interval = interval
|
48 |
+
self.angle=angle
|
49 |
+
self.shift=shift
|
50 |
+
self.jitter = jitter
|
51 |
+
self.jitter_partial=jitter_partial
|
52 |
+
self.rot_shift_surface=rot_shift_surface
|
53 |
+
self.use_scale=use_scale
|
54 |
+
self.use_rot=use_rot
|
55 |
+
self.use_shift=use_shift
|
56 |
+
self.par_jitter_sigma=par_jitter_sigma
|
57 |
+
self.use_whole_scale=use_whole_scale
|
58 |
+
|
59 |
+
def __call__(self, surface, point, par_points=None,proj_mat=None):
|
60 |
+
if self.use_scale:
|
61 |
+
scaling = torch.rand(1, 3) * 0.5 + 0.75
|
62 |
+
else:
|
63 |
+
scaling = torch.ones((1,3)).float()
|
64 |
+
if self.use_shift:
|
65 |
+
shifting = torch.rand(1,3) *(self.shift[1]-self.shift[0])+self.shift[0]
|
66 |
+
else:
|
67 |
+
shifting=np.zeros((1,3))
|
68 |
+
if self.use_rot:
|
69 |
+
angle=torch.rand(1)*(self.angle[1]-self.angle[0])+self.angle[0]
|
70 |
+
else:
|
71 |
+
angle=torch.tensor((0))
|
72 |
+
#print(angle)
|
73 |
+
angle=angle/180*np.pi
|
74 |
+
rot_mat=get_rot_from_yaw(angle)
|
75 |
+
|
76 |
+
surface = surface * scaling
|
77 |
+
point = point * scaling
|
78 |
+
|
79 |
+
scale = (1 / torch.abs(surface).max().item()) * 0.999999
|
80 |
+
if self.use_whole_scale:
|
81 |
+
scale = scale*(np.random.random()*0.3+0.7)
|
82 |
+
surface *= scale
|
83 |
+
point *= scale
|
84 |
+
|
85 |
+
#scale = 1
|
86 |
+
|
87 |
+
if self.rot_shift_surface:
|
88 |
+
surface=torch.mm(surface,rot_mat.transpose(0,1))
|
89 |
+
surface = surface + shifting
|
90 |
+
point=torch.mm(point,rot_mat.transpose(0,1))
|
91 |
+
point=point+shifting
|
92 |
+
|
93 |
+
if par_points is not None:
|
94 |
+
par_points = par_points * scaling
|
95 |
+
par_points=torch.mm(par_points,rot_mat.transpose(0,1))
|
96 |
+
par_points+=shifting
|
97 |
+
par_points *= scale
|
98 |
+
|
99 |
+
post_scale_tran=np.eye(4)
|
100 |
+
post_scale_tran[0,0],post_scale_tran[1,1],post_scale_tran[2,2]=scale,scale,scale
|
101 |
+
shift_tran = np.eye(4)
|
102 |
+
shift_tran[0:3, 3] = shifting
|
103 |
+
rot_tran = np.eye(4)
|
104 |
+
rot_tran[0:3, 0:3] = rot_mat
|
105 |
+
scale_tran = np.eye(4)
|
106 |
+
scale_tran[0, 0], scale_tran[1, 1], scale_tran[2, 2] = scaling[0, 0], scaling[
|
107 |
+
0, 1], scaling[0, 2]
|
108 |
+
|
109 |
+
#print(post_scale_tran,np.dot(np.dot(shift_tran,np.dot(rot_tran,scale_tran))))
|
110 |
+
tran_mat=np.dot(post_scale_tran,np.dot(shift_tran,np.dot(rot_tran,scale_tran)))
|
111 |
+
#tran_mat=np.dot(post_scale_tran,tran_mat)
|
112 |
+
#print(np.linalg.norm(surface - (np.dot(org_surface,tran_mat[0:3,0:3].T)+tran_mat[0:3,3])))
|
113 |
+
if proj_mat is not None:
|
114 |
+
'''need to put the augmentation back'''
|
115 |
+
inv_tran_mat=np.linalg.inv(tran_mat)
|
116 |
+
if isinstance(proj_mat,list):
|
117 |
+
for idx,mat in enumerate(proj_mat):
|
118 |
+
mat=np.dot(mat,inv_tran_mat)
|
119 |
+
proj_mat[idx]=mat
|
120 |
+
else:
|
121 |
+
proj_mat=np.dot(proj_mat,inv_tran_mat)
|
122 |
+
|
123 |
+
|
124 |
+
if self.jitter:
|
125 |
+
surface += 0.005 * torch.randn_like(surface)
|
126 |
+
surface.clamp_(min=-1, max=1)
|
127 |
+
if self.jitter_partial and par_points is not None:
|
128 |
+
par_points+=self.par_jitter_sigma * torch.randn_like(par_points)
|
129 |
+
|
130 |
+
return surface, point, par_points, proj_mat, tran_mat
|
131 |
+
|
132 |
+
|
133 |
+
class Augment_Points(object):
|
134 |
+
def __init__(self, interval=(0.75, 1.25), angle=(-5,5), shift=(-0.1,0.1), use_scale=True,use_rot=True,
|
135 |
+
use_shift=True,jitter=True,jitter_sigma=0.02):
|
136 |
+
assert isinstance(interval, tuple)
|
137 |
+
self.interval = interval
|
138 |
+
self.angle=angle
|
139 |
+
self.shift=shift
|
140 |
+
self.jitter = jitter
|
141 |
+
self.use_scale=use_scale
|
142 |
+
self.use_rot=use_rot
|
143 |
+
self.use_shift=use_shift
|
144 |
+
self.jitter_sigma=jitter_sigma
|
145 |
+
|
146 |
+
def __call__(self, points1,points2):
|
147 |
+
if self.use_scale:
|
148 |
+
scaling = torch.rand(1, 3) * 0.5 + 0.75
|
149 |
+
else:
|
150 |
+
scaling = torch.ones((1,3)).float()
|
151 |
+
if self.use_shift:
|
152 |
+
shifting = torch.rand(1,3) *(self.shift[1]-self.shift[0])+self.shift[0]
|
153 |
+
else:
|
154 |
+
shifting=np.zeros((1,3))
|
155 |
+
if self.use_rot:
|
156 |
+
angle=torch.rand(1)*(self.angle[1]-self.angle[0])+self.angle[0]
|
157 |
+
else:
|
158 |
+
angle=torch.tensor((0))
|
159 |
+
#print(angle)
|
160 |
+
angle=angle/180*np.pi
|
161 |
+
rot_mat=get_rot_from_yaw(angle)
|
162 |
+
|
163 |
+
points1 = points1 * scaling
|
164 |
+
points2 = points2 * scaling
|
165 |
+
|
166 |
+
#scale = 1
|
167 |
+
scale = min((1 / torch.abs(points1).max().item()) * 0.999999,(1 / torch.abs(points2).max().item()) * 0.999999)
|
168 |
+
points1 *= scale
|
169 |
+
points2 *= scale
|
170 |
+
|
171 |
+
points1=torch.mm(points1,rot_mat.transpose(0,1))
|
172 |
+
points1 = points1 + shifting
|
173 |
+
points2=torch.mm(points2,rot_mat.transpose(0,1))
|
174 |
+
points2=points2+shifting
|
175 |
+
|
176 |
+
if self.jitter:
|
177 |
+
points1 += self.jitter_sigma * torch.randn_like(points1)
|
178 |
+
points2 += self.jitter_sigma * torch.randn_like(points2)
|
179 |
+
|
180 |
+
return points1,points2
|
engine/engine_triplane_dm.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# References:
|
3 |
+
# MAE: https://github.com/facebookresearch/mae
|
4 |
+
# DeiT: https://github.com/facebookresearch/deit
|
5 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
import sys
|
10 |
+
from typing import Iterable
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
import util.misc as misc
|
16 |
+
import util.lr_sched as lr_sched
|
17 |
+
import numpy as np
|
18 |
+
import os
|
19 |
+
import pickle as p
|
20 |
+
import torch.distributed as dist
|
21 |
+
import time
|
22 |
+
from models.modules.encoder import DiagonalGaussianDistribution
|
23 |
+
|
24 |
+
|
25 |
+
def train_one_epoch(model: torch.nn.Module, ae: torch.nn.Module, criterion: torch.nn.Module,
|
26 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
27 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
28 |
+
log_writer=None,log_dir=None, args=None):
|
29 |
+
model.train(True)
|
30 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
31 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
32 |
+
header = 'Epoch: [{}]'.format(epoch)
|
33 |
+
print_freq = 20
|
34 |
+
|
35 |
+
accum_iter = args.accum_iter
|
36 |
+
use_cls_free= args.use_cls_free
|
37 |
+
|
38 |
+
optimizer.zero_grad()
|
39 |
+
|
40 |
+
if log_writer is not None:
|
41 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
42 |
+
|
43 |
+
for data_iter_step, data_batch in enumerate(
|
44 |
+
metric_logger.log_every(data_loader, print_freq, header)):
|
45 |
+
|
46 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
47 |
+
if not args.constant_lr:
|
48 |
+
if data_iter_step % accum_iter == 0:
|
49 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
50 |
+
|
51 |
+
input_dict=model.module.prepare_data(data_batch)
|
52 |
+
with torch.cuda.amp.autocast(enabled=False):
|
53 |
+
loss_all = criterion(model,input_dict,classifier_free=use_cls_free)
|
54 |
+
loss=loss_all.mean()
|
55 |
+
|
56 |
+
loss_value = loss.item()
|
57 |
+
if not math.isfinite(loss_value):
|
58 |
+
print("Loss is {}, stopping training".format(loss_value))
|
59 |
+
sys.exit(1)
|
60 |
+
|
61 |
+
loss /= accum_iter
|
62 |
+
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
63 |
+
parameters=model.parameters(), create_graph=False,
|
64 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
65 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
66 |
+
optimizer.zero_grad()
|
67 |
+
|
68 |
+
torch.cuda.synchronize()
|
69 |
+
|
70 |
+
metric_logger.update(loss=loss_value)
|
71 |
+
|
72 |
+
min_lr = 10.
|
73 |
+
max_lr = 0.
|
74 |
+
for group in optimizer.param_groups:
|
75 |
+
min_lr = min(min_lr, group["lr"])
|
76 |
+
max_lr = max(max_lr, group["lr"])
|
77 |
+
|
78 |
+
metric_logger.update(lr=max_lr)
|
79 |
+
|
80 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
81 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
82 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
83 |
+
This calibrates different curves when batch size changes.
|
84 |
+
"""
|
85 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
86 |
+
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
|
87 |
+
log_writer.add_scalar('lr', max_lr, epoch_1000x)
|
88 |
+
|
89 |
+
# gather the stats from all processes
|
90 |
+
metric_logger.synchronize_between_processes()
|
91 |
+
print("Averaged stats:", metric_logger)
|
92 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def evaluate_reconstruction(data_loader, model, ae, criterion, device):
|
96 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
97 |
+
header = 'Test:'
|
98 |
+
|
99 |
+
# switch to evaluation mode
|
100 |
+
model.eval()
|
101 |
+
for data_batch in metric_logger.log_every(data_loader, 50, header):
|
102 |
+
with torch.no_grad():
|
103 |
+
input_dict=model.module.prepare_data(data_batch)
|
104 |
+
loss_all = criterion(model, input_dict,classifier_free=False)
|
105 |
+
loss = loss_all.mean()
|
106 |
+
sample_input=model.module.prepare_sample_data(data_batch)
|
107 |
+
sampled_array = model.module.sample(sample_input).float()
|
108 |
+
sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")
|
109 |
+
eval_input=model.module.prepare_eval_data(data_batch)
|
110 |
+
samples=eval_input["samples"]
|
111 |
+
labels=eval_input["labels"]
|
112 |
+
for j in range(sampled_array.shape[0]):
|
113 |
+
output = ae.decode(sampled_array[j:j + 1], samples[j:j+1]).squeeze(-1)
|
114 |
+
pred = torch.zeros_like(output)
|
115 |
+
pred[output >= 0.0] = 1
|
116 |
+
label=labels[j:j+1]
|
117 |
+
|
118 |
+
accuracy = (pred == label).float().sum(dim=1) / label.shape[1]
|
119 |
+
accuracy = accuracy.mean()
|
120 |
+
intersection = (pred * label).sum(dim=1)
|
121 |
+
union = (pred + label).gt(0).sum(dim=1)
|
122 |
+
iou = intersection * 1.0 / union + 1e-5
|
123 |
+
iou = iou.mean()
|
124 |
+
|
125 |
+
metric_logger.update(iou=iou.item())
|
126 |
+
metric_logger.update(accuracy=accuracy.item())
|
127 |
+
metric_logger.update(loss=loss.item())
|
128 |
+
metric_logger.synchronize_between_processes()
|
129 |
+
print('* iou {ious.global_avg:.3f}'
|
130 |
+
.format(ious=metric_logger.iou))
|
131 |
+
print('* accuracy {accuracies.global_avg:.3f}'
|
132 |
+
.format(accuracies=metric_logger.accuracy))
|
133 |
+
print('* loss {losses.global_avg:.3f}'
|
134 |
+
.format(losses=metric_logger.loss))
|
135 |
+
|
136 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
engine/engine_triplane_vae.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# References:
|
3 |
+
# MAE: https://github.com/facebookresearch/mae
|
4 |
+
# DeiT: https://github.com/facebookresearch/deit
|
5 |
+
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import math
|
9 |
+
import sys
|
10 |
+
sys.path.append("..")
|
11 |
+
from typing import Iterable
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
import util.misc as misc
|
17 |
+
import util.lr_sched as lr_sched
|
18 |
+
|
19 |
+
|
20 |
+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
21 |
+
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
22 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
23 |
+
log_writer=None, args=None):
|
24 |
+
model.train(True)
|
25 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
26 |
+
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
27 |
+
header = 'Epoch: [{}]'.format(epoch)
|
28 |
+
print_freq = 20
|
29 |
+
|
30 |
+
accum_iter = args.accum_iter
|
31 |
+
|
32 |
+
optimizer.zero_grad()
|
33 |
+
|
34 |
+
kl_weight = 25e-3 #TODO: try to modify this, it is 1e-3 originally, large kl ease the training of diffusion, but decrease in VAE results
|
35 |
+
|
36 |
+
if log_writer is not None:
|
37 |
+
print('log_dir: {}'.format(log_writer.log_dir))
|
38 |
+
|
39 |
+
for data_iter_step, data_batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
40 |
+
|
41 |
+
# we use a per iteration (instead of per epoch) lr scheduler
|
42 |
+
if data_iter_step % accum_iter == 0:
|
43 |
+
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
|
44 |
+
|
45 |
+
points = data_batch['points'].to(device, non_blocking=True)
|
46 |
+
labels = data_batch['labels'].to(device, non_blocking=True)
|
47 |
+
surface = data_batch['surface'].to(device, non_blocking=True)
|
48 |
+
# print(points.shape)
|
49 |
+
with torch.cuda.amp.autocast(enabled=False):
|
50 |
+
outputs = model(surface, points)
|
51 |
+
if 'kl' in outputs:
|
52 |
+
loss_kl = outputs['kl']
|
53 |
+
#print(loss_kl.shape)
|
54 |
+
loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
|
55 |
+
else:
|
56 |
+
loss_kl = None
|
57 |
+
|
58 |
+
outputs = outputs['logits']
|
59 |
+
|
60 |
+
num_samples=outputs.shape[1]//2
|
61 |
+
#print(num_samples)
|
62 |
+
loss_vol = criterion(outputs[:, :num_samples], labels[:, :num_samples])
|
63 |
+
loss_near = criterion(outputs[:, num_samples:], labels[:, num_samples:])
|
64 |
+
|
65 |
+
if loss_kl is not None:
|
66 |
+
loss = loss_vol + 0.1 * loss_near + kl_weight * loss_kl
|
67 |
+
else:
|
68 |
+
loss = loss_vol + 0.1 * loss_near
|
69 |
+
|
70 |
+
loss_value = loss.item()
|
71 |
+
|
72 |
+
threshold = 0
|
73 |
+
|
74 |
+
pred = torch.zeros_like(outputs[:, :num_samples])
|
75 |
+
pred[outputs[:, :num_samples] >= threshold] = 1
|
76 |
+
|
77 |
+
accuracy = (pred == labels[:, :num_samples]).float().sum(dim=1) / labels[:, :num_samples].shape[1]
|
78 |
+
accuracy = accuracy.mean()
|
79 |
+
intersection = (pred * labels[:, :num_samples]).sum(dim=1)
|
80 |
+
union = (pred + labels[:, :num_samples]).gt(0).sum(dim=1) + 1e-5
|
81 |
+
iou = intersection * 1.0 / union
|
82 |
+
iou = iou.mean()
|
83 |
+
|
84 |
+
if not math.isfinite(loss_value):
|
85 |
+
print("Loss is {}, stopping training".format(loss_value))
|
86 |
+
sys.exit(1)
|
87 |
+
|
88 |
+
loss /= accum_iter
|
89 |
+
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
90 |
+
parameters=model.parameters(), create_graph=False,
|
91 |
+
update_grad=(data_iter_step + 1) % accum_iter == 0)
|
92 |
+
if (data_iter_step + 1) % accum_iter == 0:
|
93 |
+
optimizer.zero_grad()
|
94 |
+
|
95 |
+
torch.cuda.synchronize()
|
96 |
+
|
97 |
+
metric_logger.update(loss=loss_value)
|
98 |
+
|
99 |
+
metric_logger.update(loss_vol=loss_vol.item())
|
100 |
+
metric_logger.update(loss_near=loss_near.item())
|
101 |
+
|
102 |
+
if loss_kl is not None:
|
103 |
+
metric_logger.update(loss_kl=loss_kl.item())
|
104 |
+
|
105 |
+
metric_logger.update(iou=iou.item())
|
106 |
+
|
107 |
+
min_lr = 10.
|
108 |
+
max_lr = 0.
|
109 |
+
for group in optimizer.param_groups:
|
110 |
+
min_lr = min(min_lr, group["lr"])
|
111 |
+
max_lr = max(max_lr, group["lr"])
|
112 |
+
|
113 |
+
metric_logger.update(lr=max_lr)
|
114 |
+
|
115 |
+
loss_value_reduce = misc.all_reduce_mean(loss_value)
|
116 |
+
iou_reduce=misc.all_reduce_mean(iou)
|
117 |
+
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
|
118 |
+
""" We use epoch_1000x as the x-axis in tensorboard.
|
119 |
+
This calibrates different curves when batch size changes.
|
120 |
+
"""
|
121 |
+
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
|
122 |
+
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
|
123 |
+
log_writer.add_scalar('iou', iou_reduce, epoch_1000x)
|
124 |
+
log_writer.add_scalar('lr', max_lr, epoch_1000x)
|
125 |
+
|
126 |
+
# gather the stats from all processes
|
127 |
+
metric_logger.synchronize_between_processes()
|
128 |
+
print("Averaged stats:", metric_logger)
|
129 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
130 |
+
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def evaluate(data_loader, model, device):
|
134 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
135 |
+
|
136 |
+
metric_logger = misc.MetricLogger(delimiter=" ")
|
137 |
+
header = 'Test:'
|
138 |
+
|
139 |
+
# switch to evaluation mode
|
140 |
+
model.eval()
|
141 |
+
|
142 |
+
for data_batch in metric_logger.log_every(data_loader, 50, header):
|
143 |
+
|
144 |
+
points = data_batch['points'].to(device, non_blocking=True)
|
145 |
+
labels = data_batch['labels'].to(device, non_blocking=True)
|
146 |
+
surface = data_batch['surface'].to(device, non_blocking=True)
|
147 |
+
# compute output
|
148 |
+
with torch.cuda.amp.autocast(enabled=False):
|
149 |
+
|
150 |
+
outputs = model(surface, points)
|
151 |
+
if 'kl' in outputs:
|
152 |
+
loss_kl = outputs['kl']
|
153 |
+
loss_kl = torch.sum(loss_kl) / loss_kl.shape[0]
|
154 |
+
else:
|
155 |
+
loss_kl = None
|
156 |
+
|
157 |
+
outputs = outputs['logits']
|
158 |
+
|
159 |
+
loss = criterion(outputs, labels)
|
160 |
+
|
161 |
+
threshold = 0
|
162 |
+
|
163 |
+
pred = torch.zeros_like(outputs)
|
164 |
+
pred[outputs >= threshold] = 1
|
165 |
+
|
166 |
+
accuracy = (pred == labels).float().sum(dim=1) / labels.shape[1]
|
167 |
+
accuracy = accuracy.mean()
|
168 |
+
intersection = (pred * labels).sum(dim=1)
|
169 |
+
union = (pred + labels).gt(0).sum(dim=1)
|
170 |
+
iou = intersection * 1.0 / union + 1e-5
|
171 |
+
iou = iou.mean()
|
172 |
+
|
173 |
+
batch_size = points.shape[0]
|
174 |
+
metric_logger.update(loss=loss.item())
|
175 |
+
metric_logger.meters['iou'].update(iou.item(), n=batch_size)
|
176 |
+
|
177 |
+
if loss_kl is not None:
|
178 |
+
metric_logger.update(loss_kl=loss_kl.item())
|
179 |
+
|
180 |
+
# gather the stats from all processes
|
181 |
+
metric_logger.synchronize_between_processes()
|
182 |
+
print('* iou {iou.global_avg:.3f} loss {losses.global_avg:.3f}'
|
183 |
+
.format(iou=metric_logger.iou, losses=metric_logger.loss))
|
184 |
+
|
185 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
evaluation/dist_eval.sh
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15002 --nproc_per_node=2 \
|
2 |
+
evaluate_object_reconstruction.py \
|
3 |
+
--configs ../configs/finetune_triplane_diffusion.yaml \
|
4 |
+
--category arkit_chair arkit_stool \
|
5 |
+
--ae-pth ../output/ae/chair/best-checkpoint.pth \
|
6 |
+
--dm-pth ../output/finetune_dm/lowres_chair/best-checkpoint.pth \
|
7 |
+
--output_folder ../output_result/chair_result \
|
8 |
+
--data-pth ../data \
|
9 |
+
--eval_cd \
|
10 |
+
--reso 256 \
|
11 |
+
--save_mesh \
|
12 |
+
--save_par_points \
|
13 |
+
--save_image \
|
14 |
+
--save_surface
|
15 |
+
|
16 |
+
#check ./datasets/taxonomy to see how sub categories are defined
|
evaluation/evaluate_object_reconstruction.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import sys
|
3 |
+
sys.path.append("..")
|
4 |
+
sys.path.append(".")
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import mcubes
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import trimesh
|
12 |
+
|
13 |
+
from datasets.SingleView_dataset import Object_PartialPoints_MultiImg
|
14 |
+
from datasets.transforms import Scale_Shift_Rotate
|
15 |
+
from models import get_model
|
16 |
+
from pathlib import Path
|
17 |
+
import open3d as o3d
|
18 |
+
from configs.config_utils import CONFIG
|
19 |
+
import cv2
|
20 |
+
from util.misc import MetricLogger
|
21 |
+
import scipy
|
22 |
+
from pyTorchChamferDistance.chamfer_distance import ChamferDistance
|
23 |
+
from util.projection_utils import draw_proj_image
|
24 |
+
from util import misc
|
25 |
+
import time
|
26 |
+
dist_chamfer=ChamferDistance()
|
27 |
+
|
28 |
+
|
29 |
+
def pc_metrics(p1, p2, space_ext=2, fscore_param=0.01, scale=.5):
|
30 |
+
""" p2: reference ponits
|
31 |
+
(B, N, 3)
|
32 |
+
"""
|
33 |
+
p1, p2, space_ext = p1 * scale, p2 * scale, space_ext * scale
|
34 |
+
f_thresh = space_ext * fscore_param
|
35 |
+
|
36 |
+
#print(p1.shape,p2.shape)
|
37 |
+
d1, d2, _, _ = dist_chamfer(p1, p2)
|
38 |
+
#print(d1.shape,d2.shape)
|
39 |
+
d1sqrt, d2sqrt = (d1 ** .5), (d2 ** .5)
|
40 |
+
chamfer_L1 = d1sqrt.mean(axis=-1) + d2sqrt.mean(axis=-1)
|
41 |
+
chamfer_L2 = d1.mean(axis=-1) + d2.mean(axis=-1)
|
42 |
+
precision = (d1sqrt < f_thresh).sum(axis=-1).float() / p1.shape[1]
|
43 |
+
recall = (d2sqrt < f_thresh).sum(axis=-1).float() / p2.shape[1]
|
44 |
+
#print(precision,recall)
|
45 |
+
fscore = 2 * torch.div(recall * precision, recall + precision)
|
46 |
+
fscore[fscore == float("inf")] = 0
|
47 |
+
return chamfer_L1,chamfer_L2,fscore
|
48 |
+
|
49 |
+
if __name__ == "__main__":
|
50 |
+
|
51 |
+
parser = argparse.ArgumentParser('this script can be used to compute iou fscore chamfer distance before icp align', add_help=False)
|
52 |
+
parser.add_argument('--configs',type=str,required=True)
|
53 |
+
parser.add_argument('--output_folder', type=str, default="../output_result/Triplane_diff_parcond_0926")
|
54 |
+
parser.add_argument('--dm-pth',type=str)
|
55 |
+
parser.add_argument('--ae-pth',type=str)
|
56 |
+
parser.add_argument('--data-pth', type=str,default="../")
|
57 |
+
parser.add_argument('--save_mesh',action="store_true",default=False)
|
58 |
+
parser.add_argument('--save_image',action="store_true",default=False)
|
59 |
+
parser.add_argument('--save_par_points', action="store_true", default=False)
|
60 |
+
parser.add_argument('--save_proj_img',action="store_true",default=False)
|
61 |
+
parser.add_argument('--save_surface',action="store_true",default=False)
|
62 |
+
parser.add_argument('--reso',default=128,type=int)
|
63 |
+
parser.add_argument('--category',nargs="+",type=str)
|
64 |
+
parser.add_argument('--eval_cd',action="store_true",default=False)
|
65 |
+
parser.add_argument('--use_augmentation',action="store_true",default=False)
|
66 |
+
|
67 |
+
parser.add_argument('--world_size', default=1, type=int,
|
68 |
+
help='number of distributed processes')
|
69 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
70 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
71 |
+
parser.add_argument('--dist_url', default='env://',
|
72 |
+
help='url used to set up distributed training')
|
73 |
+
parser.add_argument('--device', default='cuda',
|
74 |
+
help='device to use for training / testing')
|
75 |
+
args = parser.parse_args()
|
76 |
+
misc.init_distributed_mode(args)
|
77 |
+
config_path=args.configs
|
78 |
+
config=CONFIG(config_path)
|
79 |
+
dataset_config=config.config['dataset']
|
80 |
+
dataset_config['data_path']=args.data_pth
|
81 |
+
if "arkit" in args.category[0]:
|
82 |
+
split_filename=dataset_config['keyword']+'_val_par_img.json'
|
83 |
+
else:
|
84 |
+
split_filename='val_par_img.json'
|
85 |
+
|
86 |
+
transform = None
|
87 |
+
if args.use_augmentation:
|
88 |
+
transform=Scale_Shift_Rotate(jitter_partial=False,jitter=False,use_scale=False,angle=(-10,10),shift=(-0.1,0.1))
|
89 |
+
dataset_val = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename=split_filename,categories=args.category,split="val",
|
90 |
+
transform=transform, sampling=False,
|
91 |
+
num_samples=1024, return_surface=True,ret_sample=True,
|
92 |
+
surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],surface_size=100000,
|
93 |
+
load_proj_mat=True,load_image=True,load_org_img=True,load_triplane=None,par_point_aug=None,replica=1)
|
94 |
+
batch_size=1
|
95 |
+
|
96 |
+
num_tasks = misc.get_world_size()
|
97 |
+
global_rank = misc.get_rank()
|
98 |
+
val_sampler = torch.utils.data.DistributedSampler(
|
99 |
+
dataset_val, num_replicas=num_tasks, rank=global_rank,
|
100 |
+
shuffle=False) # shu
|
101 |
+
dataloader_val=torch.utils.data.DataLoader(
|
102 |
+
dataset_val,
|
103 |
+
sampler=val_sampler,
|
104 |
+
batch_size=batch_size,
|
105 |
+
num_workers=10,
|
106 |
+
shuffle=False,
|
107 |
+
)
|
108 |
+
output_folder=args.output_folder
|
109 |
+
|
110 |
+
device = torch.device('cuda')
|
111 |
+
|
112 |
+
ae_config=config.config['model']['ae']
|
113 |
+
dm_config=config.config['model']['dm']
|
114 |
+
ae_model=get_model(ae_config).to(device)
|
115 |
+
if args.category[0] == "all":
|
116 |
+
dm_config["use_cat_embedding"]=True
|
117 |
+
else:
|
118 |
+
dm_config["use_cat_embedding"] = False
|
119 |
+
dm_model=get_model(dm_config).to(device)
|
120 |
+
ae_model.eval()
|
121 |
+
dm_model.eval()
|
122 |
+
ae_model.load_state_dict(torch.load(args.ae_pth)['model'])
|
123 |
+
dm_model.load_state_dict(torch.load(args.dm_pth)['model'])
|
124 |
+
|
125 |
+
density = args.reso
|
126 |
+
gap = 2.2 / density
|
127 |
+
x = np.linspace(-1.1, 1.1, int(density + 1))
|
128 |
+
y = np.linspace(-1.1, 1.1, int(density + 1))
|
129 |
+
z = np.linspace(-1.1, 1.1, int(density + 1))
|
130 |
+
xv, yv, zv = np.meshgrid(x, y, z,indexing='ij')
|
131 |
+
grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device,non_blocking=True)
|
132 |
+
|
133 |
+
metric_logger=MetricLogger(delimiter=" ")
|
134 |
+
header = 'Test:'
|
135 |
+
|
136 |
+
with torch.no_grad():
|
137 |
+
for data_batch in metric_logger.log_every(dataloader_val,10, header):
|
138 |
+
# if data_iter_step==100:
|
139 |
+
# break
|
140 |
+
partial_name = data_batch['partial_name']
|
141 |
+
class_name = data_batch['class_name']
|
142 |
+
model_ids=data_batch['model_id']
|
143 |
+
surface=data_batch['surface']
|
144 |
+
proj_matrices=data_batch['proj_mat']
|
145 |
+
sample_points=data_batch["points"].cuda().float()
|
146 |
+
labels=data_batch["labels"].cuda().float()
|
147 |
+
sample_input=dm_model.prepare_sample_data(data_batch)
|
148 |
+
#t1 = time.time()
|
149 |
+
sampled_array = dm_model.sample(sample_input,num_steps=36).float()
|
150 |
+
#t2 = time.time()
|
151 |
+
#sample_time = t2 - t1
|
152 |
+
#print("sampling time %f" % (sample_time))
|
153 |
+
sampled_array = torch.nn.functional.interpolate(sampled_array, scale_factor=2, mode="bilinear")
|
154 |
+
for j in range(sampled_array.shape[0]):
|
155 |
+
if args.save_mesh | args.save_par_points | args.save_image:
|
156 |
+
object_folder = os.path.join(output_folder, class_name[j], model_ids[j])
|
157 |
+
Path(object_folder).mkdir(parents=True, exist_ok=True)
|
158 |
+
'''calculate iou'''
|
159 |
+
sample_point=sample_points[j:j+1]
|
160 |
+
sample_output=ae_model.decode(sampled_array[j:j + 1],sample_point)
|
161 |
+
sample_pred=torch.zeros_like(sample_output)
|
162 |
+
sample_pred[sample_output>=0.0]=1
|
163 |
+
label=labels[j:j+1]
|
164 |
+
intersection = (sample_pred * label).sum(dim=1)
|
165 |
+
union = (sample_pred + label).gt(0).sum(dim=1)
|
166 |
+
iou = intersection * 1.0 / union + 1e-5
|
167 |
+
iou = iou.mean()
|
168 |
+
metric_logger.update(iou=iou.item())
|
169 |
+
|
170 |
+
if args.use_augmentation:
|
171 |
+
tran_mat=data_batch["tran_mat"][j].numpy()
|
172 |
+
mat_save_path='{}/tran_mat.npy'.format(object_folder)
|
173 |
+
np.save(mat_save_path,tran_mat)
|
174 |
+
|
175 |
+
if args.eval_cd:
|
176 |
+
grid_list=torch.split(grid,128**3,dim=1)
|
177 |
+
output_list=[]
|
178 |
+
#t3=time.time()
|
179 |
+
for sub_grid in grid_list:
|
180 |
+
output_list.append(ae_model.decode(sampled_array[j:j + 1],sub_grid))
|
181 |
+
output=torch.cat(output_list,dim=1)
|
182 |
+
#t4=time.time()
|
183 |
+
#decoding_time=t4-t3
|
184 |
+
#print("decoding time:",decoding_time)
|
185 |
+
logits = output[j].detach()
|
186 |
+
|
187 |
+
volume = logits.view(density + 1, density + 1, density + 1).cpu().numpy()
|
188 |
+
verts, faces = mcubes.marching_cubes(volume, 0)
|
189 |
+
|
190 |
+
verts *= gap
|
191 |
+
verts -= 1.1
|
192 |
+
#print("vertice max min",np.amin(verts,axis=0),np.amax(verts,axis=0))
|
193 |
+
|
194 |
+
|
195 |
+
m = trimesh.Trimesh(verts, faces)
|
196 |
+
'''calculate fscore and chamfer distance'''
|
197 |
+
result_surface,_=trimesh.sample.sample_surface(m,100000)
|
198 |
+
gt_surface=surface[j]
|
199 |
+
assert gt_surface.shape[0]==result_surface.shape[0]
|
200 |
+
|
201 |
+
result_surface_gpu = torch.from_numpy(result_surface).float().cuda().unsqueeze(0)
|
202 |
+
gt_surface_gpu = gt_surface.float().cuda().unsqueeze(0)
|
203 |
+
_,chamfer_L2,fscore=pc_metrics(result_surface_gpu,gt_surface_gpu)
|
204 |
+
metric_logger.update(chamferl2=chamfer_L2*1000.0)
|
205 |
+
metric_logger.update(fscore=fscore)
|
206 |
+
|
207 |
+
if args.save_mesh:
|
208 |
+
m.export('{}/{}_mesh.ply'.format(object_folder, partial_name[j]))
|
209 |
+
|
210 |
+
if args.save_par_points:
|
211 |
+
par_point_input = data_batch['par_points'][j].numpy()
|
212 |
+
#print("input max min", np.amin(par_point_input, axis=0), np.amax(par_point_input, axis=0))
|
213 |
+
par_point_o3d = o3d.geometry.PointCloud()
|
214 |
+
par_point_o3d.points = o3d.utility.Vector3dVector(par_point_input[:, 0:3])
|
215 |
+
o3d.io.write_point_cloud('{}/{}.ply'.format(object_folder, partial_name[j]), par_point_o3d)
|
216 |
+
if args.save_image:
|
217 |
+
image_list=data_batch["org_image"]
|
218 |
+
for idx,image in enumerate(image_list):
|
219 |
+
image=image[0].numpy().astype(np.uint8)
|
220 |
+
if args.save_proj_img:
|
221 |
+
proj_mat=proj_matrices[j,idx].numpy()
|
222 |
+
proj_image=draw_proj_image(image,proj_mat,result_surface)
|
223 |
+
proj_save_path = '{}/proj_{}.jpg'.format(object_folder, idx)
|
224 |
+
cv2.imwrite(proj_save_path,proj_image)
|
225 |
+
save_path='{}/{}.jpg'.format(object_folder, idx)
|
226 |
+
cv2.imwrite(save_path,image)
|
227 |
+
if args.save_surface:
|
228 |
+
surface=gt_surface.numpy().astype(np.float32)
|
229 |
+
surface_o3d = o3d.geometry.PointCloud()
|
230 |
+
surface_o3d.points = o3d.utility.Vector3dVector(surface[:, 0:3])
|
231 |
+
o3d.io.write_point_cloud('{}/surface.ply'.format(object_folder), surface_o3d)
|
232 |
+
metric_logger.synchronize_between_processes()
|
233 |
+
print('* iou {ious.global_avg:.3f}'
|
234 |
+
.format(ious=metric_logger.iou))
|
235 |
+
if args.eval_cd:
|
236 |
+
print('* chamferl2 {chamferl2s.global_avg:.3f}'
|
237 |
+
.format(chamferl2s=metric_logger.chamferl2))
|
238 |
+
print('* fscore {fscores.global_avg:.3f}'
|
239 |
+
.format(fscores=metric_logger.fscore))
|
evaluation/pyTorchChamferDistance/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
._*
|
3 |
+
|
evaluation/pyTorchChamferDistance/LICENSE.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) [year] [fullname]
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
evaluation/pyTorchChamferDistance/README.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Chamfer Distance for pyTorch
|
2 |
+
|
3 |
+
This is an implementation of the Chamfer Distance as a module for pyTorch. It is written as a custom C++/CUDA extension.
|
4 |
+
|
5 |
+
As it is using pyTorch's [JIT compilation](https://pytorch.org/tutorials/advanced/cpp_extension.html), there are no additional prerequisite steps that have to be taken. Simply import the module as shown below; CUDA and C++ code will be compiled on the first run.
|
6 |
+
|
7 |
+
### Usage
|
8 |
+
```python
|
9 |
+
from chamfer_distance import ChamferDistance
|
10 |
+
chamfer_dist = ChamferDistance()
|
11 |
+
|
12 |
+
#...
|
13 |
+
# points and points_reconstructed are n_points x 3 matrices
|
14 |
+
|
15 |
+
dist1, dist2 = chamfer_dist(points, points_reconstructed)
|
16 |
+
loss = (torch.mean(dist1)) + (torch.mean(dist2))
|
17 |
+
|
18 |
+
|
19 |
+
#...
|
20 |
+
```
|
21 |
+
|
22 |
+
### Integration
|
23 |
+
This code has been integrated into the [Kaolin](https://github.com/NVIDIAGameWorks/kaolin) library for 3D Deep Learning by NVIDIAGameWorks. You should probably take a look at it if you are working on anything 3D :)
|
evaluation/pyTorchChamferDistance/__init__.py
ADDED
File without changes
|
evaluation/pyTorchChamferDistance/chamfer_distance/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .chamfer_distance import ChamferDistance
|
evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/torch.h>
|
2 |
+
|
3 |
+
// CUDA forward declarations
|
4 |
+
int ChamferDistanceKernelLauncher(
|
5 |
+
const int b, const int n,
|
6 |
+
const float* xyz,
|
7 |
+
const int m,
|
8 |
+
const float* xyz2,
|
9 |
+
float* result,
|
10 |
+
int* result_i,
|
11 |
+
float* result2,
|
12 |
+
int* result2_i);
|
13 |
+
|
14 |
+
int ChamferDistanceGradKernelLauncher(
|
15 |
+
const int b, const int n,
|
16 |
+
const float* xyz1,
|
17 |
+
const int m,
|
18 |
+
const float* xyz2,
|
19 |
+
const float* grad_dist1,
|
20 |
+
const int* idx1,
|
21 |
+
const float* grad_dist2,
|
22 |
+
const int* idx2,
|
23 |
+
float* grad_xyz1,
|
24 |
+
float* grad_xyz2);
|
25 |
+
|
26 |
+
|
27 |
+
void chamfer_distance_forward_cuda(
|
28 |
+
const at::Tensor xyz1,
|
29 |
+
const at::Tensor xyz2,
|
30 |
+
const at::Tensor dist1,
|
31 |
+
const at::Tensor dist2,
|
32 |
+
const at::Tensor idx1,
|
33 |
+
const at::Tensor idx2)
|
34 |
+
{
|
35 |
+
ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
|
36 |
+
xyz2.size(1), xyz2.data<float>(),
|
37 |
+
dist1.data<float>(), idx1.data<int>(),
|
38 |
+
dist2.data<float>(), idx2.data<int>());
|
39 |
+
}
|
40 |
+
|
41 |
+
void chamfer_distance_backward_cuda(
|
42 |
+
const at::Tensor xyz1,
|
43 |
+
const at::Tensor xyz2,
|
44 |
+
at::Tensor gradxyz1,
|
45 |
+
at::Tensor gradxyz2,
|
46 |
+
at::Tensor graddist1,
|
47 |
+
at::Tensor graddist2,
|
48 |
+
at::Tensor idx1,
|
49 |
+
at::Tensor idx2)
|
50 |
+
{
|
51 |
+
ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
|
52 |
+
xyz2.size(1), xyz2.data<float>(),
|
53 |
+
graddist1.data<float>(), idx1.data<int>(),
|
54 |
+
graddist2.data<float>(), idx2.data<int>(),
|
55 |
+
gradxyz1.data<float>(), gradxyz2.data<float>());
|
56 |
+
}
|
57 |
+
|
58 |
+
|
59 |
+
void nnsearch(
|
60 |
+
const int b, const int n, const int m,
|
61 |
+
const float* xyz1,
|
62 |
+
const float* xyz2,
|
63 |
+
float* dist,
|
64 |
+
int* idx)
|
65 |
+
{
|
66 |
+
for (int i = 0; i < b; i++) {
|
67 |
+
for (int j = 0; j < n; j++) {
|
68 |
+
const float x1 = xyz1[(i*n+j)*3+0];
|
69 |
+
const float y1 = xyz1[(i*n+j)*3+1];
|
70 |
+
const float z1 = xyz1[(i*n+j)*3+2];
|
71 |
+
double best = 0;
|
72 |
+
int besti = 0;
|
73 |
+
for (int k = 0; k < m; k++) {
|
74 |
+
const float x2 = xyz2[(i*m+k)*3+0] - x1;
|
75 |
+
const float y2 = xyz2[(i*m+k)*3+1] - y1;
|
76 |
+
const float z2 = xyz2[(i*m+k)*3+2] - z1;
|
77 |
+
const double d=x2*x2+y2*y2+z2*z2;
|
78 |
+
if (k==0 || d < best){
|
79 |
+
best = d;
|
80 |
+
besti = k;
|
81 |
+
}
|
82 |
+
}
|
83 |
+
dist[i*n+j] = best;
|
84 |
+
idx[i*n+j] = besti;
|
85 |
+
}
|
86 |
+
}
|
87 |
+
}
|
88 |
+
|
89 |
+
|
90 |
+
void chamfer_distance_forward(
|
91 |
+
const at::Tensor xyz1,
|
92 |
+
const at::Tensor xyz2,
|
93 |
+
const at::Tensor dist1,
|
94 |
+
const at::Tensor dist2,
|
95 |
+
const at::Tensor idx1,
|
96 |
+
const at::Tensor idx2)
|
97 |
+
{
|
98 |
+
const int batchsize = xyz1.size(0);
|
99 |
+
const int n = xyz1.size(1);
|
100 |
+
const int m = xyz2.size(1);
|
101 |
+
|
102 |
+
const float* xyz1_data = xyz1.data<float>();
|
103 |
+
const float* xyz2_data = xyz2.data<float>();
|
104 |
+
float* dist1_data = dist1.data<float>();
|
105 |
+
float* dist2_data = dist2.data<float>();
|
106 |
+
int* idx1_data = idx1.data<int>();
|
107 |
+
int* idx2_data = idx2.data<int>();
|
108 |
+
|
109 |
+
nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
|
110 |
+
nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
|
111 |
+
}
|
112 |
+
|
113 |
+
|
114 |
+
void chamfer_distance_backward(
|
115 |
+
const at::Tensor xyz1,
|
116 |
+
const at::Tensor xyz2,
|
117 |
+
at::Tensor gradxyz1,
|
118 |
+
at::Tensor gradxyz2,
|
119 |
+
at::Tensor graddist1,
|
120 |
+
at::Tensor graddist2,
|
121 |
+
at::Tensor idx1,
|
122 |
+
at::Tensor idx2)
|
123 |
+
{
|
124 |
+
const int b = xyz1.size(0);
|
125 |
+
const int n = xyz1.size(1);
|
126 |
+
const int m = xyz2.size(1);
|
127 |
+
|
128 |
+
const float* xyz1_data = xyz1.data<float>();
|
129 |
+
const float* xyz2_data = xyz2.data<float>();
|
130 |
+
float* gradxyz1_data = gradxyz1.data<float>();
|
131 |
+
float* gradxyz2_data = gradxyz2.data<float>();
|
132 |
+
float* graddist1_data = graddist1.data<float>();
|
133 |
+
float* graddist2_data = graddist2.data<float>();
|
134 |
+
const int* idx1_data = idx1.data<int>();
|
135 |
+
const int* idx2_data = idx2.data<int>();
|
136 |
+
|
137 |
+
for (int i = 0; i < b*n*3; i++)
|
138 |
+
gradxyz1_data[i] = 0;
|
139 |
+
for (int i = 0; i < b*m*3; i++)
|
140 |
+
gradxyz2_data[i] = 0;
|
141 |
+
for (int i = 0;i < b; i++) {
|
142 |
+
for (int j = 0; j < n; j++) {
|
143 |
+
const float x1 = xyz1_data[(i*n+j)*3+0];
|
144 |
+
const float y1 = xyz1_data[(i*n+j)*3+1];
|
145 |
+
const float z1 = xyz1_data[(i*n+j)*3+2];
|
146 |
+
const int j2 = idx1_data[i*n+j];
|
147 |
+
|
148 |
+
const float x2 = xyz2_data[(i*m+j2)*3+0];
|
149 |
+
const float y2 = xyz2_data[(i*m+j2)*3+1];
|
150 |
+
const float z2 = xyz2_data[(i*m+j2)*3+2];
|
151 |
+
const float g = graddist1_data[i*n+j]*2;
|
152 |
+
|
153 |
+
gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2);
|
154 |
+
gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2);
|
155 |
+
gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2);
|
156 |
+
gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2));
|
157 |
+
gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2));
|
158 |
+
gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2));
|
159 |
+
}
|
160 |
+
for (int j = 0; j < m; j++) {
|
161 |
+
const float x1 = xyz2_data[(i*m+j)*3+0];
|
162 |
+
const float y1 = xyz2_data[(i*m+j)*3+1];
|
163 |
+
const float z1 = xyz2_data[(i*m+j)*3+2];
|
164 |
+
const int j2 = idx2_data[i*m+j];
|
165 |
+
const float x2 = xyz1_data[(i*n+j2)*3+0];
|
166 |
+
const float y2 = xyz1_data[(i*n+j2)*3+1];
|
167 |
+
const float z2 = xyz1_data[(i*n+j2)*3+2];
|
168 |
+
const float g = graddist2_data[i*m+j]*2;
|
169 |
+
gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2);
|
170 |
+
gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2);
|
171 |
+
gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2);
|
172 |
+
gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2));
|
173 |
+
gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2));
|
174 |
+
gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2));
|
175 |
+
}
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
|
180 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
181 |
+
m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
|
182 |
+
m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
|
183 |
+
m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
|
184 |
+
m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
|
185 |
+
}
|
evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
|
3 |
+
#include <cuda.h>
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
|
6 |
+
__global__
|
7 |
+
void ChamferDistanceKernel(
|
8 |
+
int b,
|
9 |
+
int n,
|
10 |
+
const float* xyz,
|
11 |
+
int m,
|
12 |
+
const float* xyz2,
|
13 |
+
float* result,
|
14 |
+
int* result_i)
|
15 |
+
{
|
16 |
+
const int batch=512;
|
17 |
+
__shared__ float buf[batch*3];
|
18 |
+
for (int i=blockIdx.x;i<b;i+=gridDim.x){
|
19 |
+
for (int k2=0;k2<m;k2+=batch){
|
20 |
+
int end_k=min(m,k2+batch)-k2;
|
21 |
+
for (int j=threadIdx.x;j<end_k*3;j+=blockDim.x){
|
22 |
+
buf[j]=xyz2[(i*m+k2)*3+j];
|
23 |
+
}
|
24 |
+
__syncthreads();
|
25 |
+
for (int j=threadIdx.x+blockIdx.y*blockDim.x;j<n;j+=blockDim.x*gridDim.y){
|
26 |
+
float x1=xyz[(i*n+j)*3+0];
|
27 |
+
float y1=xyz[(i*n+j)*3+1];
|
28 |
+
float z1=xyz[(i*n+j)*3+2];
|
29 |
+
int best_i=0;
|
30 |
+
float best=0;
|
31 |
+
int end_ka=end_k-(end_k&3);
|
32 |
+
if (end_ka==batch){
|
33 |
+
for (int k=0;k<batch;k+=4){
|
34 |
+
{
|
35 |
+
float x2=buf[k*3+0]-x1;
|
36 |
+
float y2=buf[k*3+1]-y1;
|
37 |
+
float z2=buf[k*3+2]-z1;
|
38 |
+
float d=x2*x2+y2*y2+z2*z2;
|
39 |
+
if (k==0 || d<best){
|
40 |
+
best=d;
|
41 |
+
best_i=k+k2;
|
42 |
+
}
|
43 |
+
}
|
44 |
+
{
|
45 |
+
float x2=buf[k*3+3]-x1;
|
46 |
+
float y2=buf[k*3+4]-y1;
|
47 |
+
float z2=buf[k*3+5]-z1;
|
48 |
+
float d=x2*x2+y2*y2+z2*z2;
|
49 |
+
if (d<best){
|
50 |
+
best=d;
|
51 |
+
best_i=k+k2+1;
|
52 |
+
}
|
53 |
+
}
|
54 |
+
{
|
55 |
+
float x2=buf[k*3+6]-x1;
|
56 |
+
float y2=buf[k*3+7]-y1;
|
57 |
+
float z2=buf[k*3+8]-z1;
|
58 |
+
float d=x2*x2+y2*y2+z2*z2;
|
59 |
+
if (d<best){
|
60 |
+
best=d;
|
61 |
+
best_i=k+k2+2;
|
62 |
+
}
|
63 |
+
}
|
64 |
+
{
|
65 |
+
float x2=buf[k*3+9]-x1;
|
66 |
+
float y2=buf[k*3+10]-y1;
|
67 |
+
float z2=buf[k*3+11]-z1;
|
68 |
+
float d=x2*x2+y2*y2+z2*z2;
|
69 |
+
if (d<best){
|
70 |
+
best=d;
|
71 |
+
best_i=k+k2+3;
|
72 |
+
}
|
73 |
+
}
|
74 |
+
}
|
75 |
+
}else{
|
76 |
+
for (int k=0;k<end_ka;k+=4){
|
77 |
+
{
|
78 |
+
float x2=buf[k*3+0]-x1;
|
79 |
+
float y2=buf[k*3+1]-y1;
|
80 |
+
float z2=buf[k*3+2]-z1;
|
81 |
+
float d=x2*x2+y2*y2+z2*z2;
|
82 |
+
if (k==0 || d<best){
|
83 |
+
best=d;
|
84 |
+
best_i=k+k2;
|
85 |
+
}
|
86 |
+
}
|
87 |
+
{
|
88 |
+
float x2=buf[k*3+3]-x1;
|
89 |
+
float y2=buf[k*3+4]-y1;
|
90 |
+
float z2=buf[k*3+5]-z1;
|
91 |
+
float d=x2*x2+y2*y2+z2*z2;
|
92 |
+
if (d<best){
|
93 |
+
best=d;
|
94 |
+
best_i=k+k2+1;
|
95 |
+
}
|
96 |
+
}
|
97 |
+
{
|
98 |
+
float x2=buf[k*3+6]-x1;
|
99 |
+
float y2=buf[k*3+7]-y1;
|
100 |
+
float z2=buf[k*3+8]-z1;
|
101 |
+
float d=x2*x2+y2*y2+z2*z2;
|
102 |
+
if (d<best){
|
103 |
+
best=d;
|
104 |
+
best_i=k+k2+2;
|
105 |
+
}
|
106 |
+
}
|
107 |
+
{
|
108 |
+
float x2=buf[k*3+9]-x1;
|
109 |
+
float y2=buf[k*3+10]-y1;
|
110 |
+
float z2=buf[k*3+11]-z1;
|
111 |
+
float d=x2*x2+y2*y2+z2*z2;
|
112 |
+
if (d<best){
|
113 |
+
best=d;
|
114 |
+
best_i=k+k2+3;
|
115 |
+
}
|
116 |
+
}
|
117 |
+
}
|
118 |
+
}
|
119 |
+
for (int k=end_ka;k<end_k;k++){
|
120 |
+
float x2=buf[k*3+0]-x1;
|
121 |
+
float y2=buf[k*3+1]-y1;
|
122 |
+
float z2=buf[k*3+2]-z1;
|
123 |
+
float d=x2*x2+y2*y2+z2*z2;
|
124 |
+
if (k==0 || d<best){
|
125 |
+
best=d;
|
126 |
+
best_i=k+k2;
|
127 |
+
}
|
128 |
+
}
|
129 |
+
if (k2==0 || result[(i*n+j)]>best){
|
130 |
+
result[(i*n+j)]=best;
|
131 |
+
result_i[(i*n+j)]=best_i;
|
132 |
+
}
|
133 |
+
}
|
134 |
+
__syncthreads();
|
135 |
+
}
|
136 |
+
}
|
137 |
+
}
|
138 |
+
|
139 |
+
void ChamferDistanceKernelLauncher(
|
140 |
+
const int b, const int n,
|
141 |
+
const float* xyz,
|
142 |
+
const int m,
|
143 |
+
const float* xyz2,
|
144 |
+
float* result,
|
145 |
+
int* result_i,
|
146 |
+
float* result2,
|
147 |
+
int* result2_i)
|
148 |
+
{
|
149 |
+
ChamferDistanceKernel<<<dim3(32,16,1),512>>>(b, n, xyz, m, xyz2, result, result_i);
|
150 |
+
ChamferDistanceKernel<<<dim3(32,16,1),512>>>(b, m, xyz2, n, xyz, result2, result2_i);
|
151 |
+
|
152 |
+
cudaError_t err = cudaGetLastError();
|
153 |
+
if (err != cudaSuccess)
|
154 |
+
printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err));
|
155 |
+
}
|
156 |
+
|
157 |
+
|
158 |
+
__global__
|
159 |
+
void ChamferDistanceGradKernel(
|
160 |
+
int b, int n,
|
161 |
+
const float* xyz1,
|
162 |
+
int m,
|
163 |
+
const float* xyz2,
|
164 |
+
const float* grad_dist1,
|
165 |
+
const int* idx1,
|
166 |
+
float* grad_xyz1,
|
167 |
+
float* grad_xyz2)
|
168 |
+
{
|
169 |
+
for (int i = blockIdx.x; i<b; i += gridDim.x) {
|
170 |
+
for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x*gridDim.y) {
|
171 |
+
float x1=xyz1[(i*n+j)*3+0];
|
172 |
+
float y1=xyz1[(i*n+j)*3+1];
|
173 |
+
float z1=xyz1[(i*n+j)*3+2];
|
174 |
+
int j2=idx1[i*n+j];
|
175 |
+
float x2=xyz2[(i*m+j2)*3+0];
|
176 |
+
float y2=xyz2[(i*m+j2)*3+1];
|
177 |
+
float z2=xyz2[(i*m+j2)*3+2];
|
178 |
+
float g=grad_dist1[i*n+j]*2;
|
179 |
+
atomicAdd(&(grad_xyz1[(i*n+j)*3+0]),g*(x1-x2));
|
180 |
+
atomicAdd(&(grad_xyz1[(i*n+j)*3+1]),g*(y1-y2));
|
181 |
+
atomicAdd(&(grad_xyz1[(i*n+j)*3+2]),g*(z1-z2));
|
182 |
+
atomicAdd(&(grad_xyz2[(i*m+j2)*3+0]),-(g*(x1-x2)));
|
183 |
+
atomicAdd(&(grad_xyz2[(i*m+j2)*3+1]),-(g*(y1-y2)));
|
184 |
+
atomicAdd(&(grad_xyz2[(i*m+j2)*3+2]),-(g*(z1-z2)));
|
185 |
+
}
|
186 |
+
}
|
187 |
+
}
|
188 |
+
|
189 |
+
void ChamferDistanceGradKernelLauncher(
|
190 |
+
const int b, const int n,
|
191 |
+
const float* xyz1,
|
192 |
+
const int m,
|
193 |
+
const float* xyz2,
|
194 |
+
const float* grad_dist1,
|
195 |
+
const int* idx1,
|
196 |
+
const float* grad_dist2,
|
197 |
+
const int* idx2,
|
198 |
+
float* grad_xyz1,
|
199 |
+
float* grad_xyz2)
|
200 |
+
{
|
201 |
+
cudaMemset(grad_xyz1, 0, b*n*3*4);
|
202 |
+
cudaMemset(grad_xyz2, 0, b*m*3*4);
|
203 |
+
ChamferDistanceGradKernel<<<dim3(1,16,1), 256>>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2);
|
204 |
+
ChamferDistanceGradKernel<<<dim3(1,16,1), 256>>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1);
|
205 |
+
|
206 |
+
cudaError_t err = cudaGetLastError();
|
207 |
+
if (err != cudaSuccess)
|
208 |
+
printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err));
|
209 |
+
}
|
evaluation/pyTorchChamferDistance/chamfer_distance/chamfer_distance.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from torch.utils.cpp_extension import load
|
5 |
+
cd = load(name="build",
|
6 |
+
sources=["pyTorchChamferDistance/chamfer_distance/chamfer_distance.cpp",
|
7 |
+
"pyTorchChamferDistance/chamfer_distance/chamfer_distance.cu"],
|
8 |
+
build_directory="pyTorchChamferDistance/build")
|
9 |
+
|
10 |
+
class ChamferDistanceFunction(torch.autograd.Function):
|
11 |
+
@staticmethod
|
12 |
+
def forward(ctx, xyz1, xyz2):
|
13 |
+
batchsize, n, _ = xyz1.size()
|
14 |
+
_, m, _ = xyz2.size()
|
15 |
+
xyz1 = xyz1.contiguous()
|
16 |
+
xyz2 = xyz2.contiguous()
|
17 |
+
dist1 = torch.zeros(batchsize, n)
|
18 |
+
dist2 = torch.zeros(batchsize, m)
|
19 |
+
|
20 |
+
idx1 = torch.zeros(batchsize, n, dtype=torch.int)
|
21 |
+
idx2 = torch.zeros(batchsize, m, dtype=torch.int)
|
22 |
+
|
23 |
+
if not xyz1.is_cuda:
|
24 |
+
cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
25 |
+
else:
|
26 |
+
dist1 = dist1.cuda()
|
27 |
+
dist2 = dist2.cuda()
|
28 |
+
idx1 = idx1.cuda()
|
29 |
+
idx2 = idx2.cuda()
|
30 |
+
cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2)
|
31 |
+
|
32 |
+
ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
|
33 |
+
|
34 |
+
return dist1, dist2, idx1, idx2
|
35 |
+
|
36 |
+
@staticmethod
|
37 |
+
def backward(ctx, graddist1, graddist2, *args):
|
38 |
+
xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
|
39 |
+
|
40 |
+
graddist1 = graddist1.contiguous()
|
41 |
+
graddist2 = graddist2.contiguous()
|
42 |
+
|
43 |
+
gradxyz1 = torch.zeros(xyz1.size())
|
44 |
+
gradxyz2 = torch.zeros(xyz2.size())
|
45 |
+
|
46 |
+
if not graddist1.is_cuda:
|
47 |
+
cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
|
48 |
+
else:
|
49 |
+
gradxyz1 = gradxyz1.cuda()
|
50 |
+
gradxyz2 = gradxyz2.cuda()
|
51 |
+
cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2)
|
52 |
+
|
53 |
+
return gradxyz1, gradxyz2
|
54 |
+
|
55 |
+
|
56 |
+
class ChamferDistance(torch.nn.Module):
|
57 |
+
def forward(self, xyz1, xyz2):
|
58 |
+
return ChamferDistanceFunction.apply(xyz1, xyz2)
|
finetune_diffusion.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cd scripts
|
2 |
+
CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' torchrun --master_port 15003 --nproc_per_node=8 \
|
3 |
+
train_triplane_diffusion.py \
|
4 |
+
--configs ../configs/finetune_triplane_diffusion.yaml \
|
5 |
+
--accum_iter 2 \
|
6 |
+
--output_dir ../output/finetune_dm/lowres_chair \
|
7 |
+
--log_dir ../output/finetune_dm/lowres_chair --num_workers 8 \
|
8 |
+
--batch_size 22 \
|
9 |
+
--blr 1e-4 \
|
10 |
+
--epochs 500 \
|
11 |
+
--dist_eval \
|
12 |
+
--warmup_epochs 20 \
|
13 |
+
--ae-pth ../output/ae/chair/best-checkpoint.pth \
|
14 |
+
--category chair \
|
15 |
+
--finetune \
|
16 |
+
--finetune-pth ../output/dm/chair/best-checkpoint.pth \
|
17 |
+
--data-pth ../data \
|
18 |
+
--replica 5
|
models/TriplaneVAE.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import sys,os
|
3 |
+
sys.path.append("..")
|
4 |
+
import torch
|
5 |
+
from datasets import build_dataset
|
6 |
+
from configs.config_utils import CONFIG
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from models.modules import PointEmbed
|
9 |
+
from models.modules import ConvPointnet_Encoder,ConvPointnet_Decoder
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
class TriplaneVAE(nn.Module):
|
13 |
+
def __init__(self,opt):
|
14 |
+
super().__init__()
|
15 |
+
self.point_embedder=PointEmbed(hidden_dim=opt['point_emb_dim'])
|
16 |
+
|
17 |
+
encoder_args=opt['encoder']
|
18 |
+
decoder_args=opt['decoder']
|
19 |
+
self.encoder=ConvPointnet_Encoder(c_dim=encoder_args['plane_latent_dim'],dim=opt['point_emb_dim'],latent_dim=encoder_args['latent_dim'],
|
20 |
+
plane_resolution=encoder_args['plane_reso'],unet_kwargs=encoder_args['unet'],unet=True,padding=opt['padding'])
|
21 |
+
self.decoder=ConvPointnet_Decoder(latent_dim=decoder_args['latent_dim'],query_emb_dim=decoder_args['query_emb_dim'],
|
22 |
+
hidden_dim=decoder_args['hidden_dim'],unet_kwargs=decoder_args['unet'],n_blocks=decoder_args['n_blocks'],
|
23 |
+
plane_resolution=decoder_args['plane_reso'],padding=opt['padding'])
|
24 |
+
|
25 |
+
def forward(self,p,query):
|
26 |
+
'''
|
27 |
+
:param p: surface points cloud of shape B,N,3
|
28 |
+
:param query: sample points of shape B,N,3
|
29 |
+
:return:
|
30 |
+
'''
|
31 |
+
point_emb=self.point_embedder(p)
|
32 |
+
query_emb=self.point_embedder(query)
|
33 |
+
kl,plane_feat,means,logvars=self.encoder(p,point_emb)
|
34 |
+
if self.training:
|
35 |
+
if np.random.random()<0.5:
|
36 |
+
'''randomly sacle the triplane, and conduct triplane diffusion on 64x64x64 plane, promote robustness'''
|
37 |
+
plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode="bilinear")
|
38 |
+
plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode="bilinear")
|
39 |
+
# if self.training:
|
40 |
+
# if np.random.random()<0.5:
|
41 |
+
# means = torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
|
42 |
+
# vars=torch.exp(logvars)
|
43 |
+
# vars = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")
|
44 |
+
# new_logvars=torch.log(vars)
|
45 |
+
# posterior = DiagonalGaussianDistribution(means, new_logvars)
|
46 |
+
# plane_feat=posterior.sample()
|
47 |
+
# plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=2,mode='bilinear')
|
48 |
+
|
49 |
+
# mean_scale=torch.nn.functional.interpolate(means, scale_factor=0.5, mode="bilinear")
|
50 |
+
# vars = torch.exp(logvars)
|
51 |
+
# vars_scale = torch.nn.functional.interpolate(vars, scale_factor=0.5, mode="bilinear")/4
|
52 |
+
# logvars_scale=torch.log(vars_scale)
|
53 |
+
# scale_noise=torch.randn(mean_scale.shape).to(mean_scale.device)
|
54 |
+
# plane_feat_scale2=mean_scale+torch.exp(0.5*logvars_scale)*scale_noise
|
55 |
+
# plane_feat=torch.nn.functional.interpolate(plane_feat_scale2,scale_factor=2,mode='bilinear')
|
56 |
+
o=self.decoder(plane_feat,query,query_emb)
|
57 |
+
|
58 |
+
return {'logits':o,'kl':kl}
|
59 |
+
|
60 |
+
|
61 |
+
def decode(self,plane_feature,query):
|
62 |
+
query_embedding=self.point_embedder(query)
|
63 |
+
o=self.decoder(plane_feature,query,query_embedding)
|
64 |
+
|
65 |
+
return o
|
66 |
+
|
67 |
+
def encode(self,p):
|
68 |
+
point_emb = self.point_embedder(p)
|
69 |
+
kl, plane_feat,mean,logvar = self.encoder(p, point_emb)
|
70 |
+
'''p is point cloud of B,N,3'''
|
71 |
+
return plane_feat,kl,mean,logvar
|
72 |
+
|
73 |
+
if __name__=="__main__":
|
74 |
+
configs=CONFIG("../configs/train_triplane_vae_64.yaml")
|
75 |
+
config=configs.config
|
76 |
+
dataset_config=config['datasets']
|
77 |
+
model_config=config["model"]
|
78 |
+
dataset=build_dataset("train",dataset_config)
|
79 |
+
dataset.__getitem__(0)
|
80 |
+
dataloader=DataLoader(
|
81 |
+
dataset=dataset,
|
82 |
+
batch_size=10,
|
83 |
+
shuffle=True,
|
84 |
+
num_workers=2,
|
85 |
+
)
|
86 |
+
net=TriplaneVAE(model_config).float().cuda()
|
87 |
+
for idx,data_batch in enumerate(dataloader):
|
88 |
+
if idx==1:
|
89 |
+
break
|
90 |
+
surface=data_batch['surface'].float().cuda()
|
91 |
+
query=data_batch['points'].float().cuda()
|
92 |
+
net(surface,query)
|
93 |
+
|
94 |
+
|
models/Triplane_Diffusion.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from models.modules.resunet import ResUnet_DirectAttenMultiImg_Cond
|
4 |
+
from models.modules.parpoints_encoder import ParPoint_Encoder
|
5 |
+
from models.modules.PointEMB import PointEmbed
|
6 |
+
from models.modules.utils import StackedRandomGenerator
|
7 |
+
from models.modules.diffusion_sampler import edm_sampler
|
8 |
+
from models.modules.encoder import DiagonalGaussianDistribution
|
9 |
+
import numpy as np
|
10 |
+
class EDMLoss_MultiImgCond:
|
11 |
+
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5,use_par=False):
|
12 |
+
self.P_mean = P_mean
|
13 |
+
self.P_std = P_std
|
14 |
+
self.sigma_data = sigma_data
|
15 |
+
self.use_par=use_par
|
16 |
+
|
17 |
+
def __call__(self, net, data_batch, classifier_free=False):
|
18 |
+
inputs = data_batch['input']
|
19 |
+
image=data_batch['image']
|
20 |
+
proj_mat=data_batch['proj_mat']
|
21 |
+
valid_frames=data_batch['valid_frames']
|
22 |
+
par_points=data_batch["par_points"]
|
23 |
+
category_code=data_batch["category_code"]
|
24 |
+
rnd_normal = torch.randn([inputs.shape[0], 1, 1, 1], device=inputs.device)
|
25 |
+
|
26 |
+
sigma = (rnd_normal * self.P_std + self.P_mean).exp() #B,1,1,1
|
27 |
+
weight = (sigma ** 2 + self.sigma_data ** 2) / (self.sigma_data * sigma) ** 2
|
28 |
+
y=inputs
|
29 |
+
|
30 |
+
n = torch.randn_like(y) * sigma
|
31 |
+
|
32 |
+
# if classifier_free and np.random.random()<0.5:
|
33 |
+
# net.par_feat=torch.zeros((inputs.shape[0],32,inputs.shape[2],inputs.shape[3])).float().to(inputs.device)
|
34 |
+
if classifier_free and np.random.random()<0.5:
|
35 |
+
image=torch.zeros_like(image).float().cuda()
|
36 |
+
net.module.extract_img_feat(image)
|
37 |
+
net.module.set_proj_matrix(proj_mat)
|
38 |
+
net.module.set_valid_frames(valid_frames)
|
39 |
+
net.module.set_category_code(category_code)
|
40 |
+
if self.use_par:
|
41 |
+
net.module.extract_point_feat(par_points)
|
42 |
+
|
43 |
+
D_yn = net(y + n,sigma)
|
44 |
+
loss = weight * ((D_yn - y) ** 2)
|
45 |
+
return loss
|
46 |
+
|
47 |
+
class Triplane_Diff_MultiImgCond_EDM(nn.Module):
|
48 |
+
def __init__(self,opt):
|
49 |
+
super().__init__()
|
50 |
+
self.diff_reso=opt['diff_reso']
|
51 |
+
self.diff_dim=opt['output_channel']
|
52 |
+
self.use_cat_embedding=opt['use_cat_embedding']
|
53 |
+
self.use_fp16=False
|
54 |
+
self.sigma_data=0.5
|
55 |
+
self.sigma_max=float("inf")
|
56 |
+
self.sigma_min=0
|
57 |
+
self.use_par=opt['use_par']
|
58 |
+
self.triplane_padding=opt['triplane_padding']
|
59 |
+
self.block_type=opt['block_type']
|
60 |
+
#self.use_bn=opt['use_bn']
|
61 |
+
if opt['backbone']=="resunet_multiimg_direct_atten":
|
62 |
+
self.denoise_model=ResUnet_DirectAttenMultiImg_Cond(channel=opt['input_channel'],
|
63 |
+
output_channel=opt['output_channel'],use_par=opt['use_par'],par_channel=opt['par_channel'],
|
64 |
+
img_in_channels=opt['img_in_channels'],vit_reso=opt['vit_reso'],triplane_padding=self.triplane_padding,
|
65 |
+
norm=opt['norm'],use_cat_embedding=self.use_cat_embedding,block_type=self.block_type)
|
66 |
+
else:
|
67 |
+
raise NotImplementedError
|
68 |
+
if opt['use_par']: #use partial point cloud as inputs
|
69 |
+
par_emb_dim = opt['par_emb_dim']
|
70 |
+
par_args = opt['par_point_encoder']
|
71 |
+
self.point_embedder = PointEmbed(hidden_dim=par_emb_dim)
|
72 |
+
self.par_points_encoder = ParPoint_Encoder(c_dim=par_args['plane_latent_dim'], dim=par_emb_dim,
|
73 |
+
plane_resolution=par_args['plane_reso'],
|
74 |
+
unet_kwargs=par_args['unet'])
|
75 |
+
self.unflatten = torch.nn.Unflatten(1, (16, 16))
|
76 |
+
def prepare_data(self,data_batch):
|
77 |
+
#par_points = data_batch['par_points'].to(device, non_blocking=True)
|
78 |
+
device=torch.device("cuda")
|
79 |
+
means, logvars = data_batch['triplane_mean'].to(device, non_blocking=True), data_batch['triplane_logvar'].to(
|
80 |
+
device, non_blocking=True)
|
81 |
+
distribution = DiagonalGaussianDistribution(means, logvars)
|
82 |
+
plane_feat = distribution.sample()
|
83 |
+
|
84 |
+
image=data_batch["image"].to(device)
|
85 |
+
proj_mat = data_batch['proj_mat'].to(device, non_blocking=True)
|
86 |
+
valid_frames=data_batch["valid_frames"].to(device,non_blocking=True)
|
87 |
+
par_points=data_batch["par_points"].to(device,non_blocking=True)
|
88 |
+
category_code=data_batch["category_code"].to(device,non_blocking=True)
|
89 |
+
input_dict = {"input": plane_feat.float(),
|
90 |
+
"image": image.float(),
|
91 |
+
"par_points":par_points.float(),
|
92 |
+
"proj_mat":proj_mat.float(),
|
93 |
+
"category_code":category_code.float(),
|
94 |
+
"valid_frames":valid_frames.float()} # TODO: add image and proj matrix
|
95 |
+
|
96 |
+
return input_dict
|
97 |
+
|
98 |
+
def prepare_sample_data(self,data_batch):
|
99 |
+
device=torch.device("cuda")
|
100 |
+
image=data_batch['image'].to(device, non_blocking=True)
|
101 |
+
proj_mat = data_batch['proj_mat'].to(device, non_blocking=True)
|
102 |
+
valid_frames = data_batch["valid_frames"].to(device, non_blocking=True)
|
103 |
+
par_points = data_batch["par_points"].to(device, non_blocking=True)
|
104 |
+
category_code=data_batch["category_code"].to(device,non_blocking=True)
|
105 |
+
sample_dict={
|
106 |
+
"image":image.float(),
|
107 |
+
"proj_mat":proj_mat.float(),
|
108 |
+
"valid_frames":valid_frames.float(),
|
109 |
+
"category_code":category_code.float(),
|
110 |
+
"par_points":par_points.float(),
|
111 |
+
}
|
112 |
+
return sample_dict
|
113 |
+
|
114 |
+
def prepare_eval_data(self,data_batch):
|
115 |
+
device=torch.device("cuda")
|
116 |
+
samples=data_batch["points"].to(device, non_blocking=True)
|
117 |
+
labels=data_batch['labels'].to(device,non_blocking=True)
|
118 |
+
|
119 |
+
eval_dict={
|
120 |
+
"samples":samples,
|
121 |
+
"labels":labels,
|
122 |
+
}
|
123 |
+
return eval_dict
|
124 |
+
|
125 |
+
def extract_point_feat(self,par_points):
|
126 |
+
par_emb=self.point_embedder(par_points)
|
127 |
+
self.par_feat=self.par_points_encoder(par_points,par_emb)
|
128 |
+
|
129 |
+
def extract_img_feat(self,image):
|
130 |
+
self.image_emb=image
|
131 |
+
|
132 |
+
def set_proj_matrix(self,proj_matrix):
|
133 |
+
self.proj_matrix=proj_matrix
|
134 |
+
|
135 |
+
def set_valid_frames(self,valid_frames):
|
136 |
+
self.valid_frames=valid_frames
|
137 |
+
|
138 |
+
def set_category_code(self,category_code):
|
139 |
+
self.category_code=category_code
|
140 |
+
|
141 |
+
def forward(self, x, sigma,force_fp32=False):
|
142 |
+
x = x.to(torch.float32)
|
143 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) #B,1,1,1
|
144 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
145 |
+
|
146 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
147 |
+
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
148 |
+
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
|
149 |
+
c_noise = sigma.log() / 4 #B,1,1,1, need to check how to add embedding into unet
|
150 |
+
|
151 |
+
if self.use_par:
|
152 |
+
F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(), self.image_emb, self.proj_matrix,
|
153 |
+
self.valid_frames,self.category_code,self.par_feat)
|
154 |
+
else:
|
155 |
+
F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(),self.image_emb,self.proj_matrix,
|
156 |
+
self.valid_frames,self.category_code)
|
157 |
+
assert F_x.dtype == dtype
|
158 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
159 |
+
return D_x
|
160 |
+
|
161 |
+
def round_sigma(self, sigma):
|
162 |
+
return torch.as_tensor(sigma)
|
163 |
+
|
164 |
+
@torch.no_grad()
|
165 |
+
def sample(self, input_batch, batch_seeds=None,ret_all=False,num_steps=18):
|
166 |
+
img_cond=input_batch['image']
|
167 |
+
proj_mat=input_batch['proj_mat']
|
168 |
+
valid_frames=input_batch["valid_frames"]
|
169 |
+
category_code=input_batch["category_code"]
|
170 |
+
if img_cond is not None:
|
171 |
+
batch_size, device = img_cond.shape[0], img_cond.device
|
172 |
+
if batch_seeds is None:
|
173 |
+
batch_seeds = torch.arange(batch_size)
|
174 |
+
else:
|
175 |
+
device = batch_seeds.device
|
176 |
+
batch_size = batch_seeds.shape[0]
|
177 |
+
|
178 |
+
self.extract_img_feat(img_cond)
|
179 |
+
self.set_proj_matrix(proj_mat)
|
180 |
+
self.set_valid_frames(valid_frames)
|
181 |
+
self.set_category_code(category_code)
|
182 |
+
if self.use_par:
|
183 |
+
par_points=input_batch["par_points"]
|
184 |
+
self.extract_point_feat(par_points)
|
185 |
+
rnd = StackedRandomGenerator(device, batch_seeds)
|
186 |
+
latents = rnd.randn([batch_size, self.diff_dim, self.diff_reso*3,self.diff_reso], device=device)
|
187 |
+
|
188 |
+
return edm_sampler(self, latents, randn_like=rnd.randn_like,ret_all=ret_all,sigma_min=0.002, sigma_max=80,num_steps=num_steps)
|
189 |
+
|
190 |
+
|
models/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .TriplaneVAE import TriplaneVAE
|
2 |
+
from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM
|
3 |
+
from .Triplane_Diffusion import EDMLoss_MultiImgCond
|
4 |
+
#from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug
|
5 |
+
|
6 |
+
def get_model(model_args):
|
7 |
+
if model_args['type']=="TriVAE":
|
8 |
+
model=TriplaneVAE(model_args)
|
9 |
+
elif model_args['type']=="triplane_diff_multiimg_cond":
|
10 |
+
model=Triplane_Diff_MultiImgCond_EDM(model_args)
|
11 |
+
else:
|
12 |
+
raise NotImplementedError
|
13 |
+
return model
|
14 |
+
|
15 |
+
def get_criterion(cri_args):
|
16 |
+
if cri_args['type']=="EDMLoss_MultiImgCond":
|
17 |
+
criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par'])
|
18 |
+
else:
|
19 |
+
raise NotImplementedError
|
20 |
+
return criterion
|
models/modules/PointEMB.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class PointEmbed(nn.Module):
|
6 |
+
def __init__(self, hidden_dim=48):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
assert hidden_dim % 6 == 0
|
10 |
+
|
11 |
+
self.embedding_dim = hidden_dim
|
12 |
+
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
|
13 |
+
e = torch.stack([
|
14 |
+
torch.cat([e, torch.zeros(self.embedding_dim // 6),
|
15 |
+
torch.zeros(self.embedding_dim // 6)]),
|
16 |
+
torch.cat([torch.zeros(self.embedding_dim // 6), e,
|
17 |
+
torch.zeros(self.embedding_dim // 6)]),
|
18 |
+
torch.cat([torch.zeros(self.embedding_dim // 6),
|
19 |
+
torch.zeros(self.embedding_dim // 6), e]),
|
20 |
+
])
|
21 |
+
self.register_buffer('basis', e) # 3 x 24
|
22 |
+
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def embed(input, basis):
|
26 |
+
projections = torch.einsum(
|
27 |
+
'bnd,de->bne', input, basis) # N,24
|
28 |
+
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
|
29 |
+
return embeddings
|
30 |
+
|
31 |
+
def forward(self, input):
|
32 |
+
# input: B x N x 3
|
33 |
+
embed = self.embed(input, self.basis)
|
34 |
+
return embed
|
models/modules/Positional_Embedding.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
class PositionalEmbedding(torch.nn.Module):
|
3 |
+
def __init__(self, num_channels, max_positions=10000, endpoint=False):
|
4 |
+
super().__init__()
|
5 |
+
self.num_channels = num_channels
|
6 |
+
self.max_positions = max_positions
|
7 |
+
self.endpoint = endpoint
|
8 |
+
|
9 |
+
def forward(self, x):
|
10 |
+
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
|
11 |
+
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
|
12 |
+
freqs = (1 / self.max_positions) ** freqs
|
13 |
+
x = x.ger(freqs.to(x.dtype))
|
14 |
+
x = torch.cat([x.cos(), x.sin()], dim=1)
|
15 |
+
return x
|
models/modules/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .encoder import ConvPointnet_Encoder
|
2 |
+
from .resnet_block import ResnetBlockFC
|
3 |
+
from .unet import UNet,RollOut_Conv
|
4 |
+
from .PointEMB import PointEmbed
|
5 |
+
from .decoder import ConvPointnet_Decoder
|
models/modules/decoder.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch_scatter import scatter_mean, scatter_max
|
5 |
+
from .unet import UNet
|
6 |
+
from .resnet_block import ResnetBlockFC
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class ConvPointnet_Decoder(nn.Module):
|
10 |
+
''' PointNet-based encoder network with ResNet blocks for each point.
|
11 |
+
Number of input points are fixed.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
c_dim (int): dimension of latent code c
|
15 |
+
dim (int): input points dimension
|
16 |
+
hidden_dim (int): hidden dimension of the network
|
17 |
+
scatter_type (str): feature aggregation when doing local pooling
|
18 |
+
unet (bool): weather to use U-Net
|
19 |
+
unet_kwargs (str): U-Net parameters
|
20 |
+
plane_resolution (int): defined resolution for plane feature
|
21 |
+
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
|
22 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
23 |
+
n_blocks (int): number of blocks ResNetBlockFC layers
|
24 |
+
'''
|
25 |
+
|
26 |
+
def __init__(self, latent_dim=32,query_emb_dim=51,hidden_dim=128, unet_kwargs=None,
|
27 |
+
plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.latent_dim=32
|
31 |
+
self.actvn = nn.ReLU()
|
32 |
+
|
33 |
+
self.unet = UNet(unet_kwargs['output_dim'], in_channels=latent_dim, **unet_kwargs)
|
34 |
+
|
35 |
+
self.fc_c=nn.ModuleList
|
36 |
+
self.reso_plane = plane_resolution
|
37 |
+
self.plane_type = plane_type
|
38 |
+
self.padding = padding
|
39 |
+
self.n_blocks=n_blocks
|
40 |
+
|
41 |
+
self.fc_c = nn.ModuleList([
|
42 |
+
nn.Linear(latent_dim*3, hidden_dim) for i in range(n_blocks)
|
43 |
+
])
|
44 |
+
self.fc_p=nn.Linear(query_emb_dim,hidden_dim)
|
45 |
+
self.fc_out=nn.Linear(hidden_dim,1)
|
46 |
+
|
47 |
+
self.blocks = nn.ModuleList([
|
48 |
+
ResnetBlockFC(hidden_dim) for i in range(n_blocks)
|
49 |
+
])
|
50 |
+
|
51 |
+
def forward(self, plane_features,query,query_emb): # , query2):
|
52 |
+
plane_feature=self.unet(plane_features)
|
53 |
+
H,W=plane_feature.shape[2:4]
|
54 |
+
xz_feat,xy_feat,yz_feat=torch.split(plane_feature,dim=2,split_size_or_sections=H//3)
|
55 |
+
xz_sample_feat=self.sample_plane_feature(query,xz_feat,'xz')
|
56 |
+
xy_sample_feat=self.sample_plane_feature(query,xy_feat,'xy')
|
57 |
+
yz_sample_feat=self.sample_plane_feature(query,yz_feat,'yz')
|
58 |
+
|
59 |
+
sample_feat=torch.cat([xz_sample_feat,xy_sample_feat,yz_sample_feat],dim=1)
|
60 |
+
sample_feat=sample_feat.transpose(1,2)
|
61 |
+
|
62 |
+
net=self.fc_p(query_emb)
|
63 |
+
for i in range(self.n_blocks):
|
64 |
+
net=net+self.fc_c[i](sample_feat)
|
65 |
+
net=self.blocks[i](net)
|
66 |
+
out=self.fc_out(self.actvn(net)).squeeze(-1)
|
67 |
+
return out
|
68 |
+
|
69 |
+
|
70 |
+
def normalize_coordinate(self, p, padding=0.1, plane='xz'):
|
71 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments
|
72 |
+
|
73 |
+
Args:
|
74 |
+
p (tensor): point
|
75 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
76 |
+
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
77 |
+
'''
|
78 |
+
if plane == 'xz':
|
79 |
+
xy = p[:, :, [0, 2]]
|
80 |
+
elif plane == 'xy':
|
81 |
+
xy = p[:, :, [0, 1]]
|
82 |
+
else:
|
83 |
+
xy = p[:, :, [1, 2]]
|
84 |
+
#print("origin",torch.amin(xy), torch.amax(xy))
|
85 |
+
xy=xy/2 #xy is originally -1 ~ 1
|
86 |
+
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
|
87 |
+
xy_new = xy_new + 0.5 # range (0, 1)
|
88 |
+
#print("scale",torch.amin(xy_new),torch.amax(xy_new))
|
89 |
+
|
90 |
+
# f there are outliers out of the range
|
91 |
+
if xy_new.max() >= 1:
|
92 |
+
xy_new[xy_new >= 1] = 1 - 10e-6
|
93 |
+
if xy_new.min() < 0:
|
94 |
+
xy_new[xy_new < 0] = 0.0
|
95 |
+
return xy_new
|
96 |
+
|
97 |
+
def coordinate2index(self, x, reso):
|
98 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
99 |
+
Corresponds to our 3D model
|
100 |
+
|
101 |
+
Args:
|
102 |
+
x (tensor): coordinate
|
103 |
+
reso (int): defined resolution
|
104 |
+
coord_type (str): coordinate type
|
105 |
+
'''
|
106 |
+
x = (x * reso).long()
|
107 |
+
index = x[:, :, 0] + reso * x[:, :, 1]
|
108 |
+
index = index[:, None, :]
|
109 |
+
return index
|
110 |
+
|
111 |
+
# uses values from plane_feature and pixel locations from vgrid to interpolate feature
|
112 |
+
def sample_plane_feature(self, query, plane_feature, plane):
|
113 |
+
xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
|
114 |
+
xy = xy[:, :, None].float()
|
115 |
+
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
|
116 |
+
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True,
|
117 |
+
mode='bilinear').squeeze(-1)
|
118 |
+
return sampled_feat
|
119 |
+
|
120 |
+
|
121 |
+
|
models/modules/diffusion_sampler.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def edm_sampler(
|
5 |
+
net, latents, randn_like=torch.randn_like,
|
6 |
+
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
7 |
+
# S_churn=40, S_min=0.05, S_max=50, S_noise=1.003,
|
8 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False
|
9 |
+
):
|
10 |
+
# Adjust noise levels based on what's supported by the network.
|
11 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
12 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
13 |
+
|
14 |
+
# Time step discretization.
|
15 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
16 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
17 |
+
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
18 |
+
|
19 |
+
# Main sampling loop.
|
20 |
+
x_next = latents.to(torch.float64) * t_steps[0]
|
21 |
+
all_x=[]
|
22 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
23 |
+
x_cur = x_next
|
24 |
+
|
25 |
+
# Increase noise temporarily.
|
26 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
27 |
+
t_hat = net.round_sigma(t_cur + gamma * t_cur)
|
28 |
+
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
29 |
+
|
30 |
+
# Euler step.
|
31 |
+
denoised = net(x_hat, t_hat).to(torch.float64)
|
32 |
+
d_cur = (x_hat - denoised) / t_hat
|
33 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
34 |
+
|
35 |
+
# Apply 2nd order correction.
|
36 |
+
if i < num_steps - 1:
|
37 |
+
denoised = net(x_next, t_next).to(torch.float64)
|
38 |
+
d_prime = (x_next - denoised) / t_next
|
39 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
40 |
+
all_x.append(x_next.clone()/(t_next**2+1).sqrt())
|
41 |
+
|
42 |
+
if ret_all:
|
43 |
+
return x_next,all_x
|
44 |
+
|
45 |
+
return x_next
|
46 |
+
|
47 |
+
def edm_sampler_cond(
|
48 |
+
net, latents,cond_points, randn_like=torch.randn_like,
|
49 |
+
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
|
50 |
+
# S_churn=40, S_min=0.05, S_max=50, S_noise=1.003,
|
51 |
+
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ret_all=False
|
52 |
+
):
|
53 |
+
# Adjust noise levels based on what's supported by the network.
|
54 |
+
sigma_min = max(sigma_min, net.sigma_min)
|
55 |
+
sigma_max = min(sigma_max, net.sigma_max)
|
56 |
+
|
57 |
+
# Time step discretization.
|
58 |
+
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
|
59 |
+
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
|
60 |
+
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
|
61 |
+
|
62 |
+
# Main sampling loop.
|
63 |
+
x_next = latents.to(torch.float64) * t_steps[0]
|
64 |
+
all_x=[]
|
65 |
+
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
|
66 |
+
x_cur = x_next
|
67 |
+
|
68 |
+
# Increase noise temporarily.
|
69 |
+
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
|
70 |
+
t_hat = net.round_sigma(t_cur + gamma * t_cur)
|
71 |
+
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
|
72 |
+
|
73 |
+
# Euler step.
|
74 |
+
denoised = net(x_hat, t_hat,cond_points).to(torch.float64)
|
75 |
+
d_cur = (x_hat - denoised) / t_hat
|
76 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
77 |
+
|
78 |
+
# Apply 2nd order correction.
|
79 |
+
if i < num_steps - 1:
|
80 |
+
denoised = net(x_next, t_next,cond_points).to(torch.float64)
|
81 |
+
d_prime = (x_next - denoised) / t_next
|
82 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
83 |
+
all_x.append(x_next.clone()/(t_next**2+1).sqrt())
|
84 |
+
|
85 |
+
if ret_all:
|
86 |
+
return x_next,all_x
|
87 |
+
|
88 |
+
return x_next
|
89 |
+
|
models/modules/encoder.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch_scatter import scatter_mean, scatter_max
|
5 |
+
from .unet import UNet
|
6 |
+
from .resnet_block import ResnetBlockFC
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
class DiagonalGaussianDistribution(object):
|
10 |
+
def __init__(self, mean, logvar, deterministic=False):
|
11 |
+
self.mean = mean
|
12 |
+
self.logvar = logvar
|
13 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
14 |
+
self.deterministic = deterministic
|
15 |
+
self.std = torch.exp(0.5 * self.logvar)
|
16 |
+
self.var = torch.exp(self.logvar)
|
17 |
+
if self.deterministic:
|
18 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device)
|
19 |
+
|
20 |
+
def sample(self):
|
21 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.mean.device)
|
22 |
+
return x
|
23 |
+
|
24 |
+
def kl(self, other=None):
|
25 |
+
if self.deterministic:
|
26 |
+
return torch.Tensor([0.])
|
27 |
+
else:
|
28 |
+
if other is None:
|
29 |
+
return 0.5 * torch.mean(torch.pow(self.mean, 2)
|
30 |
+
+ self.var - 1.0 - self.logvar,
|
31 |
+
dim=[1, 2,3])
|
32 |
+
else:
|
33 |
+
return 0.5 * torch.mean(
|
34 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
35 |
+
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
36 |
+
dim=[1, 2, 3])
|
37 |
+
|
38 |
+
def nll(self, sample, dims=[1,2,3]):
|
39 |
+
if self.deterministic:
|
40 |
+
return torch.Tensor([0.])
|
41 |
+
logtwopi = np.log(2.0 * np.pi)
|
42 |
+
return 0.5 * torch.sum(
|
43 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
44 |
+
dim=dims)
|
45 |
+
|
46 |
+
def mode(self):
|
47 |
+
return self.mean
|
48 |
+
|
49 |
+
class ConvPointnet_Encoder(nn.Module):
|
50 |
+
''' PointNet-based encoder network with ResNet blocks for each point.
|
51 |
+
Number of input points are fixed.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
c_dim (int): dimension of latent code c
|
55 |
+
dim (int): input points dimension
|
56 |
+
hidden_dim (int): hidden dimension of the network
|
57 |
+
scatter_type (str): feature aggregation when doing local pooling
|
58 |
+
unet (bool): weather to use U-Net
|
59 |
+
unet_kwargs (str): U-Net parameters
|
60 |
+
plane_resolution (int): defined resolution for plane feature
|
61 |
+
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
|
62 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
63 |
+
n_blocks (int): number of blocks ResNetBlockFC layers
|
64 |
+
'''
|
65 |
+
|
66 |
+
def __init__(self, c_dim=128, dim=3, hidden_dim=128,latent_dim=32, scatter_type='max',
|
67 |
+
unet=False, unet_kwargs=None,
|
68 |
+
plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
|
69 |
+
super().__init__()
|
70 |
+
self.c_dim = c_dim
|
71 |
+
|
72 |
+
self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
|
73 |
+
self.blocks = nn.ModuleList([
|
74 |
+
ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)
|
75 |
+
])
|
76 |
+
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
77 |
+
|
78 |
+
self.actvn = nn.ReLU()
|
79 |
+
self.hidden_dim = hidden_dim
|
80 |
+
|
81 |
+
if unet:
|
82 |
+
self.unet = UNet(unet_kwargs['output_dim'], in_channels=c_dim, **unet_kwargs)
|
83 |
+
else:
|
84 |
+
self.unet = None
|
85 |
+
|
86 |
+
self.reso_plane = plane_resolution
|
87 |
+
self.plane_type = plane_type
|
88 |
+
self.padding = padding
|
89 |
+
|
90 |
+
if scatter_type == 'max':
|
91 |
+
self.scatter = scatter_max
|
92 |
+
elif scatter_type == 'mean':
|
93 |
+
self.scatter = scatter_mean
|
94 |
+
|
95 |
+
self.mean_fc = nn.Conv2d(unet_kwargs['output_dim'], latent_dim,kernel_size=1)
|
96 |
+
self.logvar_fc = nn.Conv2d(unet_kwargs['output_dim'], latent_dim,kernel_size=1)
|
97 |
+
|
98 |
+
# takes in "p": point cloud and "query": sdf_xyz
|
99 |
+
# sample plane features for unlabeled_query as well
|
100 |
+
def forward(self, p,point_emb): # , query2):
|
101 |
+
batch_size, T, D = p.size()
|
102 |
+
#print('origin',torch.amin(p[0],dim=0),torch.amax(p[0],dim=0))
|
103 |
+
# acquire the index for each point
|
104 |
+
coord = {}
|
105 |
+
index = {}
|
106 |
+
if 'xz' in self.plane_type:
|
107 |
+
coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
|
108 |
+
index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
|
109 |
+
if 'xy' in self.plane_type:
|
110 |
+
coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
|
111 |
+
index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
|
112 |
+
if 'yz' in self.plane_type:
|
113 |
+
coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
|
114 |
+
index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
|
115 |
+
net = self.fc_pos(point_emb)
|
116 |
+
|
117 |
+
net = self.blocks[0](net)
|
118 |
+
for block in self.blocks[1:]:
|
119 |
+
pooled = self.pool_local(coord, index, net)
|
120 |
+
net = torch.cat([net, pooled], dim=2)
|
121 |
+
net = block(net)
|
122 |
+
|
123 |
+
c = self.fc_c(net)
|
124 |
+
#print(c.shape)
|
125 |
+
|
126 |
+
fea = {}
|
127 |
+
plane_feat_sum = 0
|
128 |
+
# second_sum = 0
|
129 |
+
if 'xz' in self.plane_type:
|
130 |
+
fea['xz'] = self.generate_plane_features(p, c,
|
131 |
+
plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
|
132 |
+
if 'xy' in self.plane_type:
|
133 |
+
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
|
134 |
+
if 'yz' in self.plane_type:
|
135 |
+
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
|
136 |
+
cat_feature = torch.cat([fea['xz'], fea['xy'], fea['yz']],
|
137 |
+
dim=2) # concat at row dimension
|
138 |
+
#print(cat_feature.shape)
|
139 |
+
plane_feat=self.unet(cat_feature)
|
140 |
+
|
141 |
+
mean=self.mean_fc(plane_feat)
|
142 |
+
logvar=self.logvar_fc(plane_feat)
|
143 |
+
|
144 |
+
posterior = DiagonalGaussianDistribution(mean, logvar)
|
145 |
+
x = posterior.sample()
|
146 |
+
kl = posterior.kl()
|
147 |
+
|
148 |
+
return kl, x, mean, logvar
|
149 |
+
|
150 |
+
|
151 |
+
def normalize_coordinate(self, p, padding=0.1, plane='xz'):
|
152 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments
|
153 |
+
|
154 |
+
Args:
|
155 |
+
p (tensor): point
|
156 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
157 |
+
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
158 |
+
'''
|
159 |
+
if plane == 'xz':
|
160 |
+
xy = p[:, :, [0, 2]]
|
161 |
+
elif plane == 'xy':
|
162 |
+
xy = p[:, :, [0, 1]]
|
163 |
+
else:
|
164 |
+
xy = p[:, :, [1, 2]]
|
165 |
+
#print("origin",torch.amin(xy), torch.amax(xy))
|
166 |
+
xy=xy/2 #xy is originally -1 ~ 1
|
167 |
+
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
|
168 |
+
xy_new = xy_new + 0.5 # range (0, 1)
|
169 |
+
#print("scale",torch.amin(xy_new),torch.amax(xy_new))
|
170 |
+
|
171 |
+
# f there are outliers out of the range
|
172 |
+
if xy_new.max() >= 1:
|
173 |
+
xy_new[xy_new >= 1] = 1 - 10e-6
|
174 |
+
if xy_new.min() < 0:
|
175 |
+
xy_new[xy_new < 0] = 0.0
|
176 |
+
return xy_new
|
177 |
+
|
178 |
+
def coordinate2index(self, x, reso):
|
179 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
180 |
+
Corresponds to our 3D model
|
181 |
+
|
182 |
+
Args:
|
183 |
+
x (tensor): coordinate
|
184 |
+
reso (int): defined resolution
|
185 |
+
coord_type (str): coordinate type
|
186 |
+
'''
|
187 |
+
x = (x * reso).long()
|
188 |
+
index = x[:, :, 0] + reso * x[:, :, 1]
|
189 |
+
index = index[:, None, :]
|
190 |
+
return index
|
191 |
+
|
192 |
+
# xy is the normalized coordinates of the point cloud of each plane
|
193 |
+
# I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
|
194 |
+
def pool_local(self, xy, index, c):
|
195 |
+
bs, fea_dim = c.size(0), c.size(2)
|
196 |
+
keys = xy.keys()
|
197 |
+
|
198 |
+
c_out = 0
|
199 |
+
for key in keys:
|
200 |
+
# scatter plane features from points
|
201 |
+
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2)
|
202 |
+
if self.scatter == scatter_max:
|
203 |
+
fea = fea[0]
|
204 |
+
# gather feature back to points
|
205 |
+
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
|
206 |
+
c_out += fea
|
207 |
+
return c_out.permute(0, 2, 1)
|
208 |
+
|
209 |
+
def generate_plane_features(self, p, c, plane='xz'):
|
210 |
+
# acquire indices of features in plane
|
211 |
+
xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
|
212 |
+
index = self.coordinate2index(xy, self.reso_plane)
|
213 |
+
|
214 |
+
# scatter plane features from points
|
215 |
+
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2)
|
216 |
+
c = c.permute(0, 2, 1) # B x 512 x T
|
217 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
218 |
+
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane,
|
219 |
+
self.reso_plane) # sparce matrix (B x 512 x reso x reso)
|
220 |
+
#print(fea_plane.shape)
|
221 |
+
|
222 |
+
return fea_plane
|
223 |
+
|
224 |
+
# sample_plane_feature function copied from /src/conv_onet/models/decoder.py
|
225 |
+
# uses values from plane_feature and pixel locations from vgrid to interpolate feature
|
226 |
+
def sample_plane_feature(self, query, plane_feature, plane):
|
227 |
+
xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
|
228 |
+
xy = xy[:, :, None].float()
|
229 |
+
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
|
230 |
+
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True,
|
231 |
+
mode='bilinear').squeeze(-1)
|
232 |
+
return sampled_feat
|
233 |
+
|
234 |
+
|
235 |
+
|
models/modules/image_sampler.py
ADDED
@@ -0,0 +1,1046 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('../..')
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import math
|
6 |
+
from models.modules.unet import RollOut_Conv
|
7 |
+
from einops import rearrange, reduce
|
8 |
+
MB =1024.0*1024.0
|
9 |
+
def mask_kernel(x, sigma=1):
|
10 |
+
return torch.abs(x) < sigma #if the distance is smaller than the kernel size, return True
|
11 |
+
|
12 |
+
def mask_kernel_close_false(x, sigma=1):
|
13 |
+
return torch.abs(x) > sigma #if the distance is smaller than the kernel size, return False
|
14 |
+
|
15 |
+
class Image_Local_Sampler(nn.Module):
|
16 |
+
def __init__(self,reso,padding=0.1,in_channels=1280,out_channels=512):
|
17 |
+
super().__init__()
|
18 |
+
self.triplane_reso=reso
|
19 |
+
self.padding=padding
|
20 |
+
self.get_triplane_coord()
|
21 |
+
self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1)
|
22 |
+
def get_triplane_coord(self):
|
23 |
+
'''xz plane firstly, z is at the '''
|
24 |
+
x=torch.arange(self.triplane_reso)
|
25 |
+
z=torch.arange(self.triplane_reso)
|
26 |
+
X,Z=torch.meshgrid(x,z,indexing='xy')
|
27 |
+
xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order
|
28 |
+
|
29 |
+
'''xy plane'''
|
30 |
+
x = torch.arange(self.triplane_reso)
|
31 |
+
y = torch.arange(self.triplane_reso)
|
32 |
+
X, Y = torch.meshgrid(x, y, indexing='xy')
|
33 |
+
xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order
|
34 |
+
|
35 |
+
'''yz plane'''
|
36 |
+
y = torch.arange(self.triplane_reso)
|
37 |
+
z = torch.arange(self.triplane_reso)
|
38 |
+
Y,Z = torch.meshgrid(y,z,indexing='xy')
|
39 |
+
yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1)
|
40 |
+
|
41 |
+
triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0)
|
42 |
+
triplane_coords=triplane_coords/(self.triplane_reso-1)
|
43 |
+
triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6)
|
44 |
+
self.triplane_coords=triplane_coords.float().cuda()
|
45 |
+
|
46 |
+
def forward(self,image_feat,proj_mat):
|
47 |
+
image_feat=self.img_proj(image_feat)
|
48 |
+
batch_size=image_feat.shape[0]
|
49 |
+
triplane_coords=self.triplane_coords.unsqueeze(0).expand(batch_size,-1,-1,-1) #B,192,64,3
|
50 |
+
#print(torch.amin(triplane_coords),torch.amax(triplane_coords))
|
51 |
+
coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,triplane_coords.shape[1],triplane_coords.shape[2],1)).float().cuda()],dim=-1)
|
52 |
+
coord_inimg=torch.einsum('bhwc,bck->bhwk',coord_homo,proj_mat.transpose(1,2))
|
53 |
+
x=coord_inimg[:,:,:,0]/coord_inimg[:,:,:,2]
|
54 |
+
y=coord_inimg[:,:,:,1]/coord_inimg[:,:,:,2]
|
55 |
+
x=(x/(224.0-1.0)-0.5)*2 #-1~1
|
56 |
+
y=(y/(224.0-1.0)-0.5)*2 #-1~1
|
57 |
+
dist=coord_inimg[:,:,:,2]
|
58 |
+
|
59 |
+
xy=torch.cat([x[:,:,:,None],y[:,:,:,None]],dim=-1)
|
60 |
+
#print(image_feat.shape,xy.shape)
|
61 |
+
sample_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear')
|
62 |
+
return sample_feat
|
63 |
+
|
64 |
+
def position_encoding(d_model, length):
|
65 |
+
if d_model % 2 != 0:
|
66 |
+
raise ValueError("Cannot use sin/cos positional encoding with "
|
67 |
+
"odd dim (got dim={:d})".format(d_model))
|
68 |
+
pe = torch.zeros(length, d_model)
|
69 |
+
position = torch.arange(0, length).unsqueeze(1) #length,1
|
70 |
+
div_term = torch.exp((torch.arange(0, d_model, 2, dtype=torch.float) *
|
71 |
+
-(math.log(10000.0) / d_model))) #d_model//2, this is the frequency
|
72 |
+
pe[:, 0::2] = torch.sin(position.float() * div_term) #length*(d_model//2)
|
73 |
+
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
74 |
+
|
75 |
+
return pe
|
76 |
+
|
77 |
+
class Image_Vox_Local_Sampler(nn.Module):
|
78 |
+
def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
|
79 |
+
super().__init__()
|
80 |
+
self.triplane_reso=reso
|
81 |
+
self.padding=padding
|
82 |
+
self.get_vox_coord()
|
83 |
+
self.out_channels=out_channels
|
84 |
+
self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1)
|
85 |
+
|
86 |
+
self.vox_process=nn.Sequential(
|
87 |
+
nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1,),
|
88 |
+
)
|
89 |
+
self.k=nn.Linear(in_features=inner_channel,out_features=inner_channel)
|
90 |
+
self.q=nn.Linear(in_features=inner_channel,out_features=inner_channel)
|
91 |
+
self.v=nn.Linear(in_features=inner_channel,out_features=inner_channel)
|
92 |
+
self.attn = torch.nn.MultiheadAttention(
|
93 |
+
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
|
94 |
+
|
95 |
+
self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
|
96 |
+
self.condition_pe = position_encoding(inner_channel, self.triplane_reso).unsqueeze(0)
|
97 |
+
def get_vox_coord(self):
|
98 |
+
x = torch.arange(self.triplane_reso)
|
99 |
+
y = torch.arange(self.triplane_reso)
|
100 |
+
z = torch.arange(self.triplane_reso)
|
101 |
+
|
102 |
+
X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
|
103 |
+
vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
|
104 |
+
vox_coor=vox_coor/(self.triplane_reso-1)
|
105 |
+
vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6)
|
106 |
+
self.vox_coor=vox_coor.view(-1,3).float().cuda()
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self,triplane_feat,image_feat,proj_mat):
|
110 |
+
xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64
|
111 |
+
image_feat=self.img_proj(image_feat)
|
112 |
+
batch_size=image_feat.shape[0]
|
113 |
+
vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3
|
114 |
+
vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1)
|
115 |
+
coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2))
|
116 |
+
x=coord_inimg[:,:,0]/coord_inimg[:,:,2]
|
117 |
+
y=coord_inimg[:,:,1]/coord_inimg[:,:,2]
|
118 |
+
x=(x/(224.0-1.0)-0.5)*2 #-1~1
|
119 |
+
y=(y/(224.0-1.0)-0.5)*2 #-1~1
|
120 |
+
|
121 |
+
xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2
|
122 |
+
#print(image_feat.shape,xy.shape)
|
123 |
+
grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\
|
124 |
+
view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3
|
125 |
+
|
126 |
+
grid_feat=self.vox_process(grid_feat)
|
127 |
+
xzy_grid=grid_feat.permute(0,4,2,3,1)
|
128 |
+
xz_as_query=xz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
|
129 |
+
xz_as_key=xzy_grid.reshape(batch_size*self.triplane_reso**2,self.triplane_reso,-1)
|
130 |
+
|
131 |
+
xyz_grid=grid_feat.permute(0,3,2,4,1)
|
132 |
+
xy_as_query=xy_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
|
133 |
+
xy_as_key = xyz_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1)
|
134 |
+
|
135 |
+
yzx_grid = grid_feat.permute(0, 4, 3, 2, 1)
|
136 |
+
yz_as_query = yz_feat.permute(0,2,3,1).reshape(batch_size*self.triplane_reso**2,1,-1)
|
137 |
+
yz_as_key = yzx_grid.reshape(batch_size * self.triplane_reso ** 2, self.triplane_reso, -1)
|
138 |
+
|
139 |
+
query=self.q(torch.cat([xz_as_query,xy_as_query,yz_as_query],dim=0))
|
140 |
+
key=self.k(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device)
|
141 |
+
value=self.v(torch.cat([xz_as_key,xy_as_key,yz_as_key],dim=0))+self.condition_pe.to(xz_as_key.device)
|
142 |
+
|
143 |
+
attn,_=self.attn(query,key,value)
|
144 |
+
xz_plane,xy_plane,yz_plane=torch.split(attn,dim=0,split_size_or_sections=batch_size*self.triplane_reso**2)
|
145 |
+
xz_plane=xz_plane.reshape(batch_size,self.triplane_reso,self.triplane_reso,-1).permute(0,3,1,2)
|
146 |
+
xy_plane = xy_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2)
|
147 |
+
yz_plane = yz_plane.reshape(batch_size, self.triplane_reso, self.triplane_reso, -1).permute(0, 3, 1, 2)
|
148 |
+
|
149 |
+
triplane_wImg=torch.cat([xz_plane,xy_plane,yz_plane],dim=2)
|
150 |
+
triplane_wImg=self.proj_out(triplane_wImg)
|
151 |
+
#print(triplane_wImg.shape)
|
152 |
+
|
153 |
+
return triplane_wImg
|
154 |
+
|
155 |
+
class Image_Direct_AttenwMask_Sampler(nn.Module):
|
156 |
+
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
|
157 |
+
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
|
158 |
+
super().__init__()
|
159 |
+
self.triplane_reso=reso
|
160 |
+
self.vit_reso=vit_reso
|
161 |
+
self.padding=padding
|
162 |
+
self.n_heads=n_heads
|
163 |
+
self.get_plane_expand_coord()
|
164 |
+
self.get_vit_coords()
|
165 |
+
self.out_channels=out_channels
|
166 |
+
self.kernel_func=mask_kernel
|
167 |
+
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
168 |
+
self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
|
169 |
+
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
170 |
+
self.attn = torch.nn.MultiheadAttention(
|
171 |
+
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
|
172 |
+
|
173 |
+
self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
|
174 |
+
self.image_pe = position_encoding(inner_channel, self.vit_reso**2+1).unsqueeze(0).cuda().float() #1,n_img*reso*reso,channel
|
175 |
+
self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float()
|
176 |
+
def get_plane_expand_coord(self):
|
177 |
+
x = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
|
178 |
+
y = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
|
179 |
+
z = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
|
180 |
+
|
181 |
+
first,second,third=torch.meshgrid(x,y,z,indexing='xy')
|
182 |
+
xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3
|
183 |
+
xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz
|
184 |
+
xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy
|
185 |
+
yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx
|
186 |
+
|
187 |
+
# print(xyz_coords[0,0,0],xyz_coords[0,0,1],xyz_coords[1,0,0],xyz_coords[0,1,0])
|
188 |
+
# print(xzy_coords[0, 0, 0], xzy_coords[0, 0, 1], xzy_coords[1, 0, 0], xzy_coords[0, 1, 0])
|
189 |
+
# print(yzx_coords[0, 0, 0], yzx_coords[0, 0, 1], yzx_coords[1, 0, 0], yzx_coords[0, 1, 0])
|
190 |
+
|
191 |
+
xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1)
|
192 |
+
xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1)
|
193 |
+
yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1)
|
194 |
+
|
195 |
+
coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0)
|
196 |
+
self.plane_coords=coords.cuda().float()
|
197 |
+
# self.xzy_coords=xzy_coords.cuda().float() #reso**3,3
|
198 |
+
# self.xyz_coords=xyz_coords.cuda().float() #reso**3,3
|
199 |
+
# self.yzx_coords=yzx_coords.cuda().float() #reso**3,3
|
200 |
+
|
201 |
+
def get_vit_coords(self):
|
202 |
+
x=torch.arange(self.vit_reso)
|
203 |
+
y=torch.arange(self.vit_reso)
|
204 |
+
|
205 |
+
X,Y=torch.meshgrid(x,y,indexing='xy')
|
206 |
+
vit_coords=torch.stack([X,Y],dim=-1)
|
207 |
+
self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
|
208 |
+
|
209 |
+
def get_attn_mask(self,coords_proj,vit_coords,kernel_size=1.0):
|
210 |
+
'''
|
211 |
+
:param coords_proj: B,reso**3,2, in range of 0~1
|
212 |
+
:param vit_coords: B,vit_reso**2,2, in range of 0~vit_reso
|
213 |
+
:param kernel_size: 0.5, so that only one pixel will be select
|
214 |
+
:return:
|
215 |
+
'''
|
216 |
+
bs=coords_proj.shape[0]
|
217 |
+
coords_proj=coords_proj*(self.vit_reso-1)
|
218 |
+
#print(torch.amin(coords_proj[0,0:self.triplane_reso**3]),torch.amax(coords_proj[0,0:self.triplane_reso**3]))
|
219 |
+
dist=torch.cdist(coords_proj.float(),vit_coords.float())
|
220 |
+
mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B,3*reso**3,vit_reso**2
|
221 |
+
mask=mask.reshape(bs,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2)
|
222 |
+
mask=torch.sum(mask,dim=2)
|
223 |
+
attn_mask=(mask==0)
|
224 |
+
return attn_mask
|
225 |
+
|
226 |
+
def forward(self,triplane_feat,image_feat,proj_mat):
|
227 |
+
#xz_feat,xy_feat,yz_feat=torch.split(triplane_feat,triplane_feat.shape[2]//3,dim=2) #B,C,64,64
|
228 |
+
batch_size=image_feat.shape[0]
|
229 |
+
#print(self.plane_coords.shape)
|
230 |
+
coords=self.plane_coords.unsqueeze(0).expand(batch_size,-1,-1)
|
231 |
+
|
232 |
+
coords_homo=torch.cat([coords,torch.ones(batch_size,self.triplane_reso**3*3,1).float().cuda()],dim=-1)
|
233 |
+
coords_inimg=torch.einsum('bhc,bck->bhk',coords_homo,proj_mat.transpose(1,2))
|
234 |
+
coords_x=coords_inimg[:,:,0]/coords_inimg[:,:,2]/(224.0-1) #0~1
|
235 |
+
coords_y=coords_inimg[:,:,1]/coords_inimg[:,:,2]/(224.0-1) #0~1
|
236 |
+
coords_x=torch.clamp(coords_x,min=0.0,max=1.0)
|
237 |
+
coords_y=torch.clamp(coords_y,min=0.0,max=1.0)
|
238 |
+
#print(torch.amin(coords_x),torch.amax(coords_x))
|
239 |
+
coords_proj=torch.stack([coords_x,coords_y],dim=-1)
|
240 |
+
vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1)
|
241 |
+
attn_mask=torch.repeat_interleave(
|
242 |
+
self.get_attn_mask(coords_proj,vit_coords,kernel_size=1.0),self.n_heads, 0
|
243 |
+
)
|
244 |
+
attn_mask = torch.cat([torch.zeros([attn_mask.shape[0], attn_mask.shape[1], 1]).cuda().bool(), attn_mask],
|
245 |
+
dim=-1) # add global token
|
246 |
+
#print(attn_mask.shape,torch.sum(attn_mask.float()))
|
247 |
+
triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1)
|
248 |
+
#print(triplane_feat.shape,self.triplane_pe.shape)
|
249 |
+
query=self.q(triplane_feat)+self.triplane_pe
|
250 |
+
key=self.k(image_feat)+self.image_pe
|
251 |
+
value=self.v(image_feat)+self.image_pe
|
252 |
+
#print(query.shape,key.shape,value.shape)
|
253 |
+
attn,_=self.attn(query,key,value,attn_mask=attn_mask)
|
254 |
+
#print(attn.shape)
|
255 |
+
output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso)
|
256 |
+
|
257 |
+
return output
|
258 |
+
|
259 |
+
class MultiImage_Direct_AttenwMask_Sampler(nn.Module):
|
260 |
+
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
|
261 |
+
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
|
262 |
+
super().__init__()
|
263 |
+
self.triplane_reso=reso
|
264 |
+
self.vit_reso=vit_reso
|
265 |
+
self.padding=padding
|
266 |
+
self.n_heads=n_heads
|
267 |
+
self.get_plane_expand_coord()
|
268 |
+
self.get_vit_coords()
|
269 |
+
self.out_channels=out_channels
|
270 |
+
self.kernel_func=mask_kernel
|
271 |
+
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
272 |
+
self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
|
273 |
+
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
274 |
+
self.attn = torch.nn.MultiheadAttention(
|
275 |
+
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
|
276 |
+
|
277 |
+
self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
|
278 |
+
self.image_pe = position_encoding(inner_channel, max_nimg*(self.vit_reso**2+1)).unsqueeze(0).cuda().float()
|
279 |
+
self.triplane_pe = position_encoding(inner_channel, 3*self.triplane_reso**2).unsqueeze(0).cuda().float()
|
280 |
+
def get_plane_expand_coord(self):
|
281 |
+
x = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
|
282 |
+
y = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
|
283 |
+
z = torch.arange(self.triplane_reso)/(self.triplane_reso-1)
|
284 |
+
|
285 |
+
first,second,third=torch.meshgrid(x,y,z,indexing='xy')
|
286 |
+
xyz_coords=torch.stack([first,second,third],dim=-1)#reso,reso,reso,3
|
287 |
+
xyz_coords=(xyz_coords-0.5)*2*(1+self.padding+10e-6) #ordering yxz ->xyz
|
288 |
+
xzy_coords=xyz_coords.clone().permute(2,1,0,3) #ordering zxy ->xzy
|
289 |
+
yzx_coords=xyz_coords.clone().permute(2,0,1,3) #ordering zyx ->yzx
|
290 |
+
|
291 |
+
xyz_coords=xyz_coords.reshape(self.triplane_reso**3,-1)
|
292 |
+
xzy_coords=xzy_coords.reshape(self.triplane_reso**3,-1)
|
293 |
+
yzx_coords=yzx_coords.reshape(self.triplane_reso**3,-1)
|
294 |
+
|
295 |
+
coords=torch.cat([xzy_coords,xyz_coords,yzx_coords],dim=0)
|
296 |
+
self.plane_coords=coords.cuda().float()
|
297 |
+
# self.xzy_coords=xzy_coords.cuda().float() #reso**3,3
|
298 |
+
# self.xyz_coords=xyz_coords.cuda().float() #reso**3,3
|
299 |
+
# self.yzx_coords=yzx_coords.cuda().float() #reso**3,3
|
300 |
+
|
301 |
+
def get_vit_coords(self):
|
302 |
+
x=torch.arange(self.vit_reso)
|
303 |
+
y=torch.arange(self.vit_reso)
|
304 |
+
|
305 |
+
X,Y=torch.meshgrid(x,y,indexing='xy')
|
306 |
+
vit_coords=torch.stack([X,Y],dim=-1)
|
307 |
+
self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
|
308 |
+
|
309 |
+
def get_attn_mask(self,coords_proj,vit_coords,valid_frames,kernel_size=1.0):
|
310 |
+
'''
|
311 |
+
:param coords_proj: B,n_img,3*reso**3,2, in range of 0~vit_reso
|
312 |
+
:param vit_coords: B,n_img,vit_reso**2,2, in range of 0~vit_reso
|
313 |
+
:param kernel_size: 0.5, so that only one pixel will be select
|
314 |
+
:return:
|
315 |
+
'''
|
316 |
+
bs,n_img=coords_proj.shape[0],coords_proj.shape[1]
|
317 |
+
coords_proj_flat=coords_proj.reshape(bs*n_img,3*self.triplane_reso**3,2)
|
318 |
+
vit_coords_flat=vit_coords.reshape(bs*n_img,self.vit_reso**2,2)
|
319 |
+
dist=torch.cdist(coords_proj_flat.float(),vit_coords_flat.float())
|
320 |
+
mask=self.kernel_func(dist,sigma=kernel_size).float() #True if valid, B*n_img,3*reso**3,vit_reso**2
|
321 |
+
mask=mask.reshape(bs,n_img,3*self.triplane_reso**2,self.triplane_reso,self.vit_reso**2)
|
322 |
+
mask=torch.sum(mask,dim=3) #B,n_img,3*reso**2,vit_reso**2
|
323 |
+
mask=torch.cat([torch.ones(size=mask.shape[0:3]).unsqueeze(3).float().cuda(),mask],dim=-1) #B,n_img,3*reso**2,vit_reso**2+1, add global mask
|
324 |
+
mask[valid_frames == 0, :, :] = False
|
325 |
+
mask=mask.permute(0,2,1,3).reshape(bs,3*self.triplane_reso**2,-1) #B,3*reso**2,n_img*(vit_resso**2+1)
|
326 |
+
attn_mask=(mask==0) #invert the mask, False indicates valid, True indicates invalid
|
327 |
+
return attn_mask
|
328 |
+
|
329 |
+
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
|
330 |
+
'''image feat is bs,n_img,length,channel'''
|
331 |
+
batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
|
332 |
+
img_length=image_feat.shape[2]
|
333 |
+
image_feat_flat=image_feat.view(batch_size,n_img*img_length,-1)
|
334 |
+
coords=self.plane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
|
335 |
+
|
336 |
+
coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3*3,1).float().cuda()],dim=-1)
|
337 |
+
#print(coord_homo.shape,proj_mat.shape)
|
338 |
+
coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
|
339 |
+
x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
|
340 |
+
y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
|
341 |
+
x = x/(224.0-1)
|
342 |
+
y = y/(224.0-1)
|
343 |
+
coords_x=torch.clamp(x,min=0.0,max=1.0)*(self.vit_reso-1)
|
344 |
+
coords_y=torch.clamp(y,min=0.0,max=1.0)*(self.vit_reso-1)
|
345 |
+
coords_proj=torch.stack([coords_x,coords_y],dim=-1)
|
346 |
+
vit_coords=self.vit_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
|
347 |
+
attn_mask=torch.repeat_interleave(
|
348 |
+
self.get_attn_mask(coords_proj,vit_coords,valid_frames,kernel_size=1.0),self.n_heads, 0
|
349 |
+
)
|
350 |
+
triplane_feat=triplane_feat.permute(0,2,3,1).view(batch_size,3*self.triplane_reso**2,-1)
|
351 |
+
query=self.q(triplane_feat)+self.triplane_pe
|
352 |
+
key=self.k(image_feat_flat)+self.image_pe
|
353 |
+
value=self.v(image_feat_flat)+self.image_pe
|
354 |
+
attn,_=self.attn(query,key,value,attn_mask=attn_mask)
|
355 |
+
output=self.proj_out(attn).transpose(1,2).reshape(batch_size,-1,3*self.triplane_reso,self.triplane_reso)
|
356 |
+
|
357 |
+
return output
|
358 |
+
|
359 |
+
class MultiImage_Fuse_Sampler(nn.Module):
|
360 |
+
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
|
361 |
+
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
|
362 |
+
super().__init__()
|
363 |
+
self.triplane_reso=reso
|
364 |
+
self.vit_reso=vit_reso
|
365 |
+
self.inner_channel=inner_channel
|
366 |
+
self.padding=padding
|
367 |
+
self.n_heads=n_heads
|
368 |
+
self.get_vox_coord()
|
369 |
+
self.get_vit_coords()
|
370 |
+
self.out_channels=out_channels
|
371 |
+
self.kernel_func=mask_kernel
|
372 |
+
self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso))
|
373 |
+
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
374 |
+
self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel)
|
375 |
+
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
376 |
+
|
377 |
+
#self.cross_attn=CrossAttention(query_dim=inner_channel,heads=8,dim_head=inner_channel//8)
|
378 |
+
self.cross_attn = torch.nn.MultiheadAttention(
|
379 |
+
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
|
380 |
+
self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
|
381 |
+
self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].cuda().float() #1,1,length,channel
|
382 |
+
#self.image_pe = self.image_pe.reshape(1,max_nimg,self.vit_reso,self.vit_reso,inner_channel)
|
383 |
+
self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float()
|
384 |
+
|
385 |
+
def get_vit_coords(self):
|
386 |
+
x = torch.arange(self.vit_reso)
|
387 |
+
y = torch.arange(self.vit_reso)
|
388 |
+
|
389 |
+
X, Y = torch.meshgrid(x, y, indexing='xy')
|
390 |
+
vit_coords = torch.stack([X, Y], dim=-1)
|
391 |
+
self.vit_coords = vit_coords.cuda().float() #reso,reso,2
|
392 |
+
|
393 |
+
def get_vox_coord(self):
|
394 |
+
x = torch.arange(self.triplane_reso)
|
395 |
+
y = torch.arange(self.triplane_reso)
|
396 |
+
z = torch.arange(self.triplane_reso)
|
397 |
+
|
398 |
+
X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
|
399 |
+
vox_coor = torch.cat([X[:, :, :, None], Y[:, :, :, None], Z[:, :, :, None]], dim=-1)
|
400 |
+
self.vox_index = vox_coor.view(-1, 3).long().cuda()
|
401 |
+
|
402 |
+
vox_coor = self.vox_index.float() / (self.triplane_reso - 1)
|
403 |
+
vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6)
|
404 |
+
self.vox_coor = vox_coor.view(-1, 3).float().cuda()
|
405 |
+
|
406 |
+
def get_attn_mask(self,valid_frames):
|
407 |
+
'''
|
408 |
+
:param valid_frames: of shape B,n_img
|
409 |
+
'''
|
410 |
+
#print(valid_frames)
|
411 |
+
#bs,n_img=valid_frames.shape[0:2]
|
412 |
+
attn_mask=(valid_frames.float()==0)
|
413 |
+
#attn_mask=attn_mask.unsqueeze(1).unsqueeze(2).expand(-1,self.triplane_reso**3,-1,-1) #B,1,n_img
|
414 |
+
#attn_mask=attn_mask.reshape(bs*self.triplane_reso**3,-1,n_img).bool()
|
415 |
+
attn_mask=torch.repeat_interleave(attn_mask.unsqueeze(1),self.triplane_reso**3,0)
|
416 |
+
# print(attn_mask[self.triplane_reso**3*1+10])
|
417 |
+
# print(attn_mask[self.triplane_reso ** 3 * 2+10])
|
418 |
+
# print(attn_mask[self.triplane_reso ** 3 * 3+10])
|
419 |
+
return attn_mask
|
420 |
+
|
421 |
+
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
|
422 |
+
'''image feat is bs,n_img,length,channel'''
|
423 |
+
batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
|
424 |
+
image_feat=image_feat[:,:,1:,:] #discard global feature
|
425 |
+
|
426 |
+
#image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c
|
427 |
+
image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c
|
428 |
+
image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c
|
429 |
+
image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c
|
430 |
+
unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso
|
431 |
+
#unflat_k_v=image_k_v.permute(0,4,1,2,3)
|
432 |
+
#vit_coords=self.vit_coords[None,None].expand(batch_size,n_img,-1,-1,-1) #Bs,n_img,reso,reso,2
|
433 |
+
|
434 |
+
coords=self.vox_coor.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
|
435 |
+
coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**3,1).float().cuda()],dim=-1)
|
436 |
+
coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
|
437 |
+
x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
|
438 |
+
y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
|
439 |
+
x = x/(224.0-1) #0~1
|
440 |
+
y = y/(224.0-1)
|
441 |
+
coords_proj=torch.stack([x,y],dim=-1)
|
442 |
+
coords_proj=(coords_proj-0.5)*2
|
443 |
+
img_index=((torch.arange(n_img)[None,:,None,None].expand(
|
444 |
+
batch_size,-1,self.triplane_reso**3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1
|
445 |
+
|
446 |
+
# img_index_feat=torch.arange(n_img)[None,:,None,None,None].expand(
|
447 |
+
# batch_size,-1,self.vit_reso,self.vit_reso,-1).float().cuda() #Bs,n_img,reso,reso,1
|
448 |
+
#coords_feat=torch.cat([vit_coords,img_index_feat],dim=-1).permute(0,4,1,2,3)#Bs,n_img,reso,reso,3
|
449 |
+
grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index
|
450 |
+
grid=torch.clamp(grid,min=-1.0,max=1.0)
|
451 |
+
sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3
|
452 |
+
xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3,
|
453 |
+
dim=2) # B,C,64,64
|
454 |
+
xz_vox_feat=xz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,4,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zxy
|
455 |
+
xz_vox_feat=rearrange(xz_vox_feat, 'b c z x y -> b (x y z) c')
|
456 |
+
xy_vox_feat=xy_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,3,2,4).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #yxz
|
457 |
+
xy_vox_feat=rearrange(xy_vox_feat, 'b c y x z -> b (x y z) c')
|
458 |
+
yz_vox_feat=yz_feat.unsqueeze(4).expand(-1,-1,-1,-1,self.triplane_reso)#.permute(0,1,4,3,2).reshape(batch_size,-1,self.triplane_reso**3).transpose(1,2) #zyx
|
459 |
+
yz_vox_feat=rearrange(yz_vox_feat, 'b c z y x -> b (x y z) c')
|
460 |
+
#xz_vox_feat = xz_feat[:, :, vox_index[:, 2], vox_index[:, 0]].transpose(1, 2) # B,C,64*64*64
|
461 |
+
#xy_vox_feat = xy_feat[:, :, vox_index[:, 1], vox_index[:, 0]].transpose(1, 2)
|
462 |
+
#yz_vox_feat = yz_feat[:, :, vox_index[:, 2], vox_index[:, 1]].transpose(1, 2)
|
463 |
+
|
464 |
+
triplane_expand_feat = torch.cat([xz_vox_feat, xy_vox_feat, yz_vox_feat], dim=-1) # B,64*64*64,3*C
|
465 |
+
triplane_query = self.q(triplane_expand_feat) + self.triplane_pe
|
466 |
+
k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c')
|
467 |
+
#k_v=sample_k_v.permute(0,3,2,1).reshape(batch_size*self.triplane_reso**3,n_img,-1) #B*64**3,n_img,C
|
468 |
+
k=k_v[:,:,0:self.inner_channel]
|
469 |
+
v=k_v[:,:,self.inner_channel:]
|
470 |
+
q=rearrange(triplane_query,'b k c -> (b k) 1 c')
|
471 |
+
#q=triplane_query.view(batch_size*self.triplane_reso**3,1,-1)
|
472 |
+
#k,v is of shape, B*reso**3,k,channel, q is of shape B*reso**3,1,channel
|
473 |
+
#attn mask should be B*reso**3*n_heads,1,k
|
474 |
+
#attn_mask=torch.repeat_interleave(self.get_attn_mask(valid_frames),self.n_heads,0)
|
475 |
+
#print(q.shape,k.shape,v.shape)
|
476 |
+
attn_out,_=self.cross_attn(q,k,v)#attn_mask=attn_mask) #fuse multi-view feature
|
477 |
+
#volume=attn_out.view(batch_size,self.triplane_reso,self.triplane_reso,self.triplane_reso,-1) #B,reso,reso,reso,channel
|
478 |
+
#print(attn_out.shape)
|
479 |
+
volume=rearrange(attn_out,'(b x y z) 1 c -> b x y z c',x=self.triplane_reso,y=self.triplane_reso,z=self.triplane_reso)
|
480 |
+
#xz_feat = torch.mean(volume, dim=2).transpose(1,2) #B,reso,reso,C
|
481 |
+
xz_feat = reduce(volume, "b x y z c -> b z x c", 'mean')
|
482 |
+
#xy_feat = torch.mean(volume, dim=3).transpose(1,2) #B,reso,reso,C
|
483 |
+
xy_feat= reduce(volume, 'b x y z c -> b y x c', 'mean')
|
484 |
+
#yz_feat = torch.mean(volume, dim=1).transpose(1,2) #B,reso,reso,C
|
485 |
+
yz_feat=reduce(volume, 'b x y z c -> b z y c', 'mean')
|
486 |
+
triplane_out = torch.cat([xz_feat, xy_feat, yz_feat], dim=1) #B,reso*3,reso,C
|
487 |
+
#print(triplane_out.shape)
|
488 |
+
triplane_out = self.proj_out(triplane_out)
|
489 |
+
triplane_out = triplane_out.permute(0,3,1,2)
|
490 |
+
#print(triplane_out.shape)
|
491 |
+
return triplane_out
|
492 |
+
|
493 |
+
class MultiImage_TriFuse_Sampler(nn.Module):
|
494 |
+
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
|
495 |
+
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
|
496 |
+
super().__init__()
|
497 |
+
self.triplane_reso=reso
|
498 |
+
self.vit_reso=vit_reso
|
499 |
+
self.inner_channel=inner_channel
|
500 |
+
self.padding=padding
|
501 |
+
self.n_heads=n_heads
|
502 |
+
self.get_triplane_coord()
|
503 |
+
self.get_vit_coords()
|
504 |
+
self.out_channels=out_channels
|
505 |
+
self.kernel_func=mask_kernel
|
506 |
+
self.image_unflatten=nn.Unflatten(2,(vit_reso,vit_reso))
|
507 |
+
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
508 |
+
self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
|
509 |
+
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
510 |
+
|
511 |
+
self.cross_attn = torch.nn.MultiheadAttention(
|
512 |
+
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
|
513 |
+
self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
|
514 |
+
self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel
|
515 |
+
self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 2*3).unsqueeze(0).cuda().float()
|
516 |
+
|
517 |
+
def get_vit_coords(self):
|
518 |
+
x = torch.arange(self.vit_reso)
|
519 |
+
y = torch.arange(self.vit_reso)
|
520 |
+
|
521 |
+
X, Y = torch.meshgrid(x, y, indexing='xy')
|
522 |
+
vit_coords = torch.stack([X, Y], dim=-1)
|
523 |
+
self.vit_coords = vit_coords.cuda().float() #reso,reso,2
|
524 |
+
|
525 |
+
def get_triplane_coord(self):
|
526 |
+
'''xz plane firstly, z is at the '''
|
527 |
+
x = torch.arange(self.triplane_reso)
|
528 |
+
z = torch.arange(self.triplane_reso)
|
529 |
+
X, Z = torch.meshgrid(x, z, indexing='xy')
|
530 |
+
xz_coords = torch.cat(
|
531 |
+
[X[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2, Z[:, :, None]],
|
532 |
+
dim=-1) # in xyz order
|
533 |
+
|
534 |
+
'''xy plane'''
|
535 |
+
x = torch.arange(self.triplane_reso)
|
536 |
+
y = torch.arange(self.triplane_reso)
|
537 |
+
X, Y = torch.meshgrid(x, y, indexing='xy')
|
538 |
+
xy_coords = torch.cat(
|
539 |
+
[X[:, :, None], Y[:, :, None], torch.ones_like(X[:, :, None]) * (self.triplane_reso - 1) / 2],
|
540 |
+
dim=-1) # in xyz order
|
541 |
+
|
542 |
+
'''yz plane'''
|
543 |
+
y = torch.arange(self.triplane_reso)
|
544 |
+
z = torch.arange(self.triplane_reso)
|
545 |
+
Y, Z = torch.meshgrid(y, z, indexing='xy')
|
546 |
+
yz_coords = torch.cat(
|
547 |
+
[torch.ones_like(Y[:, :, None]) * (self.triplane_reso - 1) / 2, Y[:, :, None], Z[:, :, None]], dim=-1)
|
548 |
+
|
549 |
+
triplane_coords = torch.cat([xz_coords, xy_coords, yz_coords], dim=0)
|
550 |
+
triplane_coords = triplane_coords / (self.triplane_reso - 1)
|
551 |
+
triplane_coords = (triplane_coords - 0.5) * 2 * (1 + self.padding + 10e-6)
|
552 |
+
self.triplane_coords = triplane_coords.view(-1,3).float().cuda()
|
553 |
+
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
|
554 |
+
'''image feat is bs,n_img,length,channel'''
|
555 |
+
batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
|
556 |
+
image_feat=image_feat[:,:,1:,:] #discard global feature
|
557 |
+
#print(image_feat.shape)
|
558 |
+
|
559 |
+
#image_feat=image_feat.permute(0,1,3,4,2) #B,n_img,h,w,c
|
560 |
+
image_k=self.k(image_feat)+self.image_pe #B,n_img,h,w,c
|
561 |
+
image_v=self.v(image_feat)+self.image_pe #B,n_img,h,w,c
|
562 |
+
image_k_v=torch.cat([image_k,image_v],dim=-1) #B,n_img,h,w,c
|
563 |
+
unflat_k_v=self.image_unflatten(image_k_v).permute(0,4,1,2,3) #Bs,channel,n_img,reso,reso
|
564 |
+
|
565 |
+
coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,n_img,-1,-1)
|
566 |
+
coord_homo=torch.cat([coords,torch.ones(batch_size,n_img,self.triplane_reso**2*3,1).float().cuda()],dim=-1)
|
567 |
+
coord_inimg = torch.einsum('bjnc,bjck->bjnk', coord_homo, proj_mat.transpose(2, 3))
|
568 |
+
x = coord_inimg[:, :, :, 0] / coord_inimg[:, :, :, 2]
|
569 |
+
y = coord_inimg[:, :, :, 1] / coord_inimg[:, :, :, 2]
|
570 |
+
x = x/(224.0-1) #0~1
|
571 |
+
y = y/(224.0-1)
|
572 |
+
coords_proj=torch.stack([x,y],dim=-1)
|
573 |
+
coords_proj=(coords_proj-0.5)*2
|
574 |
+
img_index=((torch.arange(n_img)[None,:,None,None].expand(
|
575 |
+
batch_size,-1,self.triplane_reso**2*3,-1).float().cuda()/(n_img-1))-0.5)*2 #Bs,n_img,64**3,1
|
576 |
+
|
577 |
+
grid=torch.cat([coords_proj,img_index],dim=-1) #x,y,index
|
578 |
+
grid=torch.clamp(grid,min=-1.0,max=1.0)
|
579 |
+
sample_k_v = torch.nn.functional.grid_sample(unflat_k_v, grid.unsqueeze(1), align_corners=True, mode='bilinear').squeeze(2) #B,C,n_img,64**3
|
580 |
+
|
581 |
+
triplane_flat_feat=rearrange(triplane_feat,'b c h w -> b (h w) c')
|
582 |
+
triplane_query = self.q(triplane_flat_feat) + self.triplane_pe
|
583 |
+
|
584 |
+
k_v=rearrange(sample_k_v, 'b c n k -> (b k) n c')
|
585 |
+
k=k_v[:,:,0:self.inner_channel]
|
586 |
+
v=k_v[:,:,self.inner_channel:]
|
587 |
+
q=rearrange(triplane_query,'b k c -> (b k) 1 c')
|
588 |
+
attn_out,_=self.cross_attn(q,k,v)
|
589 |
+
triplane_out=rearrange(attn_out,'(b h w) 1 c -> b c h w',b=batch_size,h=self.triplane_reso*3,w=self.triplane_reso)
|
590 |
+
triplane_out = self.proj_out(triplane_out)
|
591 |
+
return triplane_out
|
592 |
+
|
593 |
+
|
594 |
+
class MultiImage_Global_Sampler(nn.Module):
|
595 |
+
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,
|
596 |
+
img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8,max_nimg=5):
|
597 |
+
super().__init__()
|
598 |
+
self.triplane_reso=reso
|
599 |
+
self.vit_reso=vit_reso
|
600 |
+
self.inner_channel=inner_channel
|
601 |
+
self.padding=padding
|
602 |
+
self.n_heads=n_heads
|
603 |
+
self.out_channels=out_channels
|
604 |
+
self.k=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
605 |
+
self.q=nn.Linear(in_features=triplane_in_channels,out_features=inner_channel)
|
606 |
+
self.v=nn.Linear(in_features=img_in_channels,out_features=inner_channel)
|
607 |
+
|
608 |
+
self.cross_attn = torch.nn.MultiheadAttention(
|
609 |
+
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
|
610 |
+
self.proj_out=nn.Linear(in_features=inner_channel,out_features=out_channels)
|
611 |
+
self.image_pe = position_encoding(inner_channel, self.vit_reso**2)[None,None,:,:].expand(-1,max_nimg,-1,-1).cuda().float() #B,n_img,length,channel
|
612 |
+
self.triplane_pe = position_encoding(inner_channel, self.triplane_reso**2*3).unsqueeze(0).cuda().float()
|
613 |
+
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
|
614 |
+
'''image feat is bs,n_img,length,channel
|
615 |
+
triplane feat is bs,C,H*3,W
|
616 |
+
'''
|
617 |
+
batch_size,n_img=image_feat.shape[0],image_feat.shape[1]
|
618 |
+
L=image_feat.shape[2]-1
|
619 |
+
image_feat=image_feat[:,:,1:,:] #discard global feature
|
620 |
+
|
621 |
+
image_k=self.k(image_feat)+self.image_pe #B,n_img,h*w,c
|
622 |
+
image_v=self.v(image_feat)+self.image_pe #B,n_img,h*w,c
|
623 |
+
image_k=image_k.view(batch_size,n_img*L,-1)
|
624 |
+
image_v=image_v.view(batch_size,n_img*L,-1)
|
625 |
+
|
626 |
+
triplane_flat_feat=rearrange(triplane_feat,"b c h w -> b (h w) c")
|
627 |
+
triplane_query = self.q(triplane_flat_feat) + self.triplane_pe
|
628 |
+
#print(triplane_query.shape,image_k.shape,image_v.shape)
|
629 |
+
attn_out,_=self.cross_attn(triplane_query,image_k,image_v)
|
630 |
+
triplane_flat_out = self.proj_out(attn_out)
|
631 |
+
triplane_out=rearrange(triplane_flat_out,"b (h w) c -> b c h w",h=self.triplane_reso*3,w=self.triplane_reso)
|
632 |
+
|
633 |
+
return triplane_out
|
634 |
+
|
635 |
+
class CrossAttention(nn.Module):
|
636 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
637 |
+
super().__init__()
|
638 |
+
inner_dim = dim_head * heads
|
639 |
+
|
640 |
+
if context_dim is None:
|
641 |
+
context_dim = query_dim
|
642 |
+
|
643 |
+
self.scale = dim_head ** -0.5
|
644 |
+
self.heads = heads
|
645 |
+
|
646 |
+
self.to_out = nn.Sequential(
|
647 |
+
nn.Linear(inner_dim, query_dim),
|
648 |
+
nn.Dropout(dropout)
|
649 |
+
)
|
650 |
+
|
651 |
+
def forward(self, q,k,v):
|
652 |
+
h = self.heads
|
653 |
+
|
654 |
+
q, k, v = map(lambda t: rearrange(
|
655 |
+
t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
656 |
+
|
657 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
658 |
+
|
659 |
+
# attention, what we cannot get enough of
|
660 |
+
attn = sim.softmax(dim=-1)
|
661 |
+
|
662 |
+
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
663 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
664 |
+
return self.to_out(out)
|
665 |
+
|
666 |
+
class Image_Vox_Local_Sampler_Pooling(nn.Module):
|
667 |
+
def __init__(self,reso,padding=0.1,in_channels=1280,inner_channel=128,out_channels=64,stride=4):
|
668 |
+
super().__init__()
|
669 |
+
self.triplane_reso=reso
|
670 |
+
self.padding=padding
|
671 |
+
self.get_vox_coord()
|
672 |
+
self.out_channels=out_channels
|
673 |
+
self.img_proj=nn.Conv2d(in_channels=in_channels,out_channels=inner_channel,kernel_size=1)
|
674 |
+
|
675 |
+
self.vox_process=nn.Sequential(
|
676 |
+
nn.Conv3d(in_channels=inner_channel,out_channels=inner_channel,kernel_size=3,padding=1)
|
677 |
+
)
|
678 |
+
self.xz_conv=nn.Sequential(
|
679 |
+
nn.BatchNorm3d(inner_channel),
|
680 |
+
nn.ReLU(),
|
681 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
682 |
+
nn.AvgPool3d((1,stride,1),stride=(1,stride,1)), #8
|
683 |
+
nn.BatchNorm3d(inner_channel),
|
684 |
+
nn.ReLU(),
|
685 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
686 |
+
nn.AvgPool3d((1,stride,1), stride=(1,stride,1)), #2
|
687 |
+
nn.BatchNorm3d(inner_channel),
|
688 |
+
nn.ReLU(),
|
689 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
690 |
+
)
|
691 |
+
self.xy_conv = nn.Sequential(
|
692 |
+
nn.BatchNorm3d(inner_channel),
|
693 |
+
nn.ReLU(),
|
694 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
695 |
+
nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 8
|
696 |
+
nn.BatchNorm3d(inner_channel),
|
697 |
+
nn.ReLU(),
|
698 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
699 |
+
nn.AvgPool3d((1, 1, stride), stride=(1, 1, stride)), # 2
|
700 |
+
nn.BatchNorm3d(inner_channel),
|
701 |
+
nn.ReLU(),
|
702 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
703 |
+
)
|
704 |
+
self.yz_conv = nn.Sequential(
|
705 |
+
nn.BatchNorm3d(inner_channel),
|
706 |
+
nn.ReLU(),
|
707 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
708 |
+
nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 8
|
709 |
+
nn.BatchNorm3d(inner_channel),
|
710 |
+
nn.ReLU(),
|
711 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
712 |
+
nn.AvgPool3d((stride, 1, 1), stride=(stride, 1, 1)), # 2
|
713 |
+
nn.BatchNorm3d(inner_channel),
|
714 |
+
nn.ReLU(),
|
715 |
+
nn.Conv3d(in_channels=inner_channel, out_channels=inner_channel, kernel_size=3, padding=1),
|
716 |
+
)
|
717 |
+
self.roll_out_conv=RollOut_Conv(in_channels=inner_channel,out_channels=out_channels)
|
718 |
+
#self.proj_out=nn.Conv2d(in_channels=inner_channel,out_channels=out_channels,kernel_size=1)
|
719 |
+
def get_vox_coord(self):
|
720 |
+
x = torch.arange(self.triplane_reso)
|
721 |
+
y = torch.arange(self.triplane_reso)
|
722 |
+
z = torch.arange(self.triplane_reso)
|
723 |
+
|
724 |
+
X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
|
725 |
+
vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
|
726 |
+
vox_coor=vox_coor/(self.triplane_reso-1)
|
727 |
+
vox_coor=(vox_coor-0.5)*2*(1+self.padding+10e-6)
|
728 |
+
self.vox_coor=vox_coor.view(-1,3).float().cuda()
|
729 |
+
|
730 |
+
|
731 |
+
def forward(self,image_feat,proj_mat):
|
732 |
+
image_feat=self.img_proj(image_feat)
|
733 |
+
batch_size=image_feat.shape[0]
|
734 |
+
vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #B,64*64*64,3
|
735 |
+
vox_homo=torch.cat([vox_coords,torch.ones((batch_size,self.triplane_reso**3,1)).float().cuda()],dim=-1)
|
736 |
+
coord_inimg=torch.einsum('bhc,bck->bhk',vox_homo,proj_mat.transpose(1,2))
|
737 |
+
x=coord_inimg[:,:,0]/coord_inimg[:,:,2]
|
738 |
+
y=coord_inimg[:,:,1]/coord_inimg[:,:,2]
|
739 |
+
x=(x/(224.0-1.0)-0.5)*2 #-1~1
|
740 |
+
y=(y/(224.0-1.0)-0.5)*2 #-1~1
|
741 |
+
|
742 |
+
xy=torch.cat([x[:,:,None],y[:,:,None]],dim=-1).unsqueeze(1).contiguous() #B, 1,64**3,2
|
743 |
+
#print(image_feat.shape,xy.shape)
|
744 |
+
grid_feat=torch.nn.functional.grid_sample(image_feat,xy,align_corners=True,mode='bilinear').squeeze(2).\
|
745 |
+
view(batch_size,-1,self.triplane_reso,self.triplane_reso,self.triplane_reso) #B,C,1,64**3
|
746 |
+
|
747 |
+
grid_feat=self.vox_process(grid_feat)
|
748 |
+
xz_feat=torch.mean(self.xz_conv(grid_feat),dim=3).transpose(2,3)
|
749 |
+
xy_feat=torch.mean(self.xy_conv(grid_feat),dim=4).transpose(2,3)
|
750 |
+
yz_feat=torch.mean(self.yz_conv(grid_feat),dim=2).transpose(2,3)
|
751 |
+
triplane_wImg=torch.cat([xz_feat,xy_feat,yz_feat],dim=2)
|
752 |
+
#print(triplane_wImg.shape)
|
753 |
+
|
754 |
+
return self.roll_out_conv(triplane_wImg)
|
755 |
+
|
756 |
+
class Image_ExpandVox_attn_Sampler(nn.Module):
|
757 |
+
def __init__(self,reso,vit_reso=16,padding=0.1,triplane_in_channels=64,img_in_channels=1280,inner_channel=128,out_channels=64,n_heads=8):
|
758 |
+
super().__init__()
|
759 |
+
self.triplane_reso=reso
|
760 |
+
self.padding=padding
|
761 |
+
self.vit_reso=vit_reso
|
762 |
+
self.get_vox_coord()
|
763 |
+
self.get_vit_coords()
|
764 |
+
self.out_channels=out_channels
|
765 |
+
self.n_heads=n_heads
|
766 |
+
|
767 |
+
self.kernel_func = mask_kernel_close_false
|
768 |
+
self.k = nn.Linear(in_features=img_in_channels, out_features=inner_channel)
|
769 |
+
# self.q_xz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
|
770 |
+
# self.q_xy = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
|
771 |
+
# self.q_yz = nn.Conv2d(in_channels=triplane_in_channels,out_channels=inner_channel,kernel_size=1)
|
772 |
+
self.q=nn.Linear(in_features=triplane_in_channels*3,out_features=inner_channel)
|
773 |
+
|
774 |
+
self.v = nn.Linear(in_features=img_in_channels, out_features=inner_channel)
|
775 |
+
self.attn = torch.nn.MultiheadAttention(
|
776 |
+
embed_dim=inner_channel, num_heads=n_heads, batch_first=True)
|
777 |
+
self.out_proj=nn.Linear(in_features=inner_channel,out_features=out_channels)
|
778 |
+
|
779 |
+
self.triplane_pe = position_encoding(inner_channel, self.triplane_reso ** 3).unsqueeze(0).cuda().float()
|
780 |
+
self.image_pe = position_encoding(inner_channel, self.vit_reso ** 2+1).unsqueeze(0).cuda().float()
|
781 |
+
def get_vox_coord(self):
|
782 |
+
x = torch.arange(self.triplane_reso)
|
783 |
+
y = torch.arange(self.triplane_reso)
|
784 |
+
z = torch.arange(self.triplane_reso)
|
785 |
+
|
786 |
+
X,Y,Z=torch.meshgrid(x,y,z,indexing='ij')
|
787 |
+
vox_coor=torch.cat([X[:,:,:,None],Y[:,:,:,None],Z[:,:,:,None]],dim=-1)
|
788 |
+
self.vox_index=vox_coor.view(-1,3).long().cuda()
|
789 |
+
|
790 |
+
|
791 |
+
vox_coor = self.vox_index.float() / (self.triplane_reso - 1)
|
792 |
+
vox_coor = (vox_coor - 0.5) * 2 * (1 + self.padding + 10e-6)
|
793 |
+
self.vox_coor = vox_coor.view(-1, 3).float().cuda()
|
794 |
+
# print(self.vox_coor[0])
|
795 |
+
# print(self.vox_coor[self.triplane_reso**2])#x should increase
|
796 |
+
# print(self.vox_coor[self.triplane_reso]) #y should increase
|
797 |
+
# print(self.vox_coor[1])#z should increase
|
798 |
+
|
799 |
+
def get_vit_coords(self):
|
800 |
+
x=torch.arange(self.vit_reso)
|
801 |
+
y=torch.arange(self.vit_reso)
|
802 |
+
|
803 |
+
X,Y=torch.meshgrid(x,y,indexing='xy')
|
804 |
+
vit_coords=torch.stack([X,Y],dim=-1)
|
805 |
+
self.vit_coords=vit_coords.view(self.vit_reso**2,2).cuda().float()
|
806 |
+
|
807 |
+
def compute_attn_mask(self,proj_coords,vit_coords,kernel_size=1.0):
|
808 |
+
dist = torch.cdist(proj_coords.float(), vit_coords.float())
|
809 |
+
mask = self.kernel_func(dist, sigma=kernel_size) # True if valid, B,reso**3,vit_reso**2
|
810 |
+
return mask
|
811 |
+
|
812 |
+
|
813 |
+
def forward(self,triplane_feat,image_feat,proj_mat):
|
814 |
+
xz_feat, xy_feat, yz_feat = torch.split(triplane_feat, split_size_or_sections=triplane_feat.shape[2] // 3, dim=2) # B,C,64,64
|
815 |
+
#xz_feat=self.q_xz(xz_feat)
|
816 |
+
#xy_feat=self.q_xy(xy_feat)
|
817 |
+
#yz_feat=self.q_yz(yz_feat)
|
818 |
+
batch_size=image_feat.shape[0]
|
819 |
+
vox_index=self.vox_index #64*64*64,3
|
820 |
+
xz_vox_feat=xz_feat[:,:,vox_index[:,2],vox_index[:,0]].transpose(1,2) #B,C,64*64*64
|
821 |
+
xy_vox_feat=xy_feat[:,:,vox_index[:,1],vox_index[:,0]].transpose(1,2)
|
822 |
+
yz_vox_feat=yz_feat[:,:,vox_index[:,2],vox_index[:,1]].transpose(1,2)
|
823 |
+
triplane_expand_feat=torch.cat([xz_vox_feat,xy_vox_feat,yz_vox_feat],dim=-1)#B,C,64*64*64,3
|
824 |
+
triplane_query=self.q(triplane_expand_feat)+self.triplane_pe
|
825 |
+
|
826 |
+
'''compute projection'''
|
827 |
+
vox_coords=self.vox_coor.unsqueeze(0).expand(batch_size,-1,-1) #
|
828 |
+
vox_homo = torch.cat([vox_coords, torch.ones((batch_size, self.triplane_reso ** 3, 1)).float().cuda()], dim=-1)
|
829 |
+
coord_inimg = torch.einsum('bhc,bck->bhk', vox_homo, proj_mat.transpose(1, 2))
|
830 |
+
x = coord_inimg[:, :, 0] / coord_inimg[:, :, 2]
|
831 |
+
y = coord_inimg[:, :, 1] / coord_inimg[:, :, 2]
|
832 |
+
#
|
833 |
+
x = x / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1
|
834 |
+
y = y / (224.0 - 1.0) * (self.vit_reso-1) # 0~self.vit_reso-1 #B,N
|
835 |
+
xy=torch.stack([x,y],dim=-1) #B,64*64*64,2
|
836 |
+
xy=torch.clamp(xy,min=0,max=self.vit_reso-1)
|
837 |
+
vit_coords=self.vit_coords.unsqueeze(0).expand(batch_size,-1,-1) #B, 16*16,2
|
838 |
+
attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,vit_coords,kernel_size=0.5),
|
839 |
+
self.n_heads,0) #B*n_heads, reso**3, vit_reso**2
|
840 |
+
|
841 |
+
k=self.k(image_feat)+self.image_pe
|
842 |
+
v=self.v(image_feat)+self.image_pe
|
843 |
+
attn_mask=torch.cat([torch.zeros([attn_mask.shape[0],attn_mask.shape[1],1]).cuda().bool(),attn_mask],dim=-1) #add empty token to each key and value
|
844 |
+
vox_feat,_=self.attn(triplane_query,k,v,attn_mask=attn_mask) #B,reso**3,C
|
845 |
+
feat_volume=self.out_proj(vox_feat).transpose(1,2).reshape(batch_size,-1,self.triplane_reso,
|
846 |
+
self.triplane_reso,self.triplane_reso)
|
847 |
+
xz_feat=torch.mean(feat_volume,dim=3).transpose(2,3)
|
848 |
+
xy_feat=torch.mean(feat_volume,dim=4).transpose(2,3)
|
849 |
+
yz_feat=torch.mean(feat_volume,dim=2).transpose(2,3)
|
850 |
+
triplane_out=torch.cat([xz_feat,xy_feat,yz_feat],dim=2)
|
851 |
+
return triplane_out
|
852 |
+
|
853 |
+
class Multi_Image_Fusion(nn.Module):
|
854 |
+
def __init__(self,reso,image_reso=16,padding=0.1,img_channels=1280,triplane_channel=64,inner_channels=128,output_channel=64,n_heads=8):
|
855 |
+
super().__init__()
|
856 |
+
self.triplane_reso=reso
|
857 |
+
self.image_reso=image_reso
|
858 |
+
self.padding=padding
|
859 |
+
self.get_triplane_coord()
|
860 |
+
self.get_vit_coords()
|
861 |
+
self.img_proj=nn.Conv3d(in_channels=img_channels,out_channels=512,kernel_size=1)
|
862 |
+
self.kernel_func=mask_kernel
|
863 |
+
|
864 |
+
self.q = nn.Linear(in_features=triplane_channel, out_features=inner_channels, bias=False)
|
865 |
+
self.k = nn.Linear(in_features=512, out_features=inner_channels)
|
866 |
+
self.v = nn.Linear(in_features=512, out_features=inner_channels)
|
867 |
+
|
868 |
+
self.attn = torch.nn.MultiheadAttention(
|
869 |
+
embed_dim=inner_channels, num_heads=n_heads, batch_first=True)
|
870 |
+
self.out_proj=nn.Linear(in_features=inner_channels,out_features=output_channel)
|
871 |
+
self.n_heads=n_heads
|
872 |
+
|
873 |
+
def get_triplane_coord(self):
|
874 |
+
'''xz plane firstly, z is at the '''
|
875 |
+
x=torch.arange(self.triplane_reso)
|
876 |
+
z=torch.arange(self.triplane_reso)
|
877 |
+
X,Z=torch.meshgrid(x,z,indexing='xy')
|
878 |
+
xz_coords=torch.cat([X[:,:,None],torch.ones_like(X[:,:,None])*(self.triplane_reso-1)/2,Z[:,:,None]],dim=-1) #in xyz order
|
879 |
+
|
880 |
+
'''xy plane'''
|
881 |
+
x = torch.arange(self.triplane_reso)
|
882 |
+
y = torch.arange(self.triplane_reso)
|
883 |
+
X, Y = torch.meshgrid(x, y, indexing='xy')
|
884 |
+
xy_coords = torch.cat([X[:, :, None], Y[:, :, None],torch.ones_like(X[:, :, None])*(self.triplane_reso-1)/2], dim=-1) # in xyz order
|
885 |
+
|
886 |
+
'''yz plane'''
|
887 |
+
y = torch.arange(self.triplane_reso)
|
888 |
+
z = torch.arange(self.triplane_reso)
|
889 |
+
Y,Z = torch.meshgrid(y,z,indexing='xy')
|
890 |
+
yz_coords= torch.cat([torch.ones_like(Y[:, :, None])*(self.triplane_reso-1)/2,Y[:,:,None],Z[:,:,None]], dim=-1)
|
891 |
+
|
892 |
+
triplane_coords=torch.cat([xz_coords,xy_coords,yz_coords],dim=0)
|
893 |
+
triplane_coords=triplane_coords/(self.triplane_reso-1)
|
894 |
+
triplane_coords=(triplane_coords-0.5)*2*(1 + self.padding + 10e-6)
|
895 |
+
self.triplane_coords=triplane_coords.float().cuda()
|
896 |
+
|
897 |
+
def get_vit_coords(self):
|
898 |
+
x=torch.arange(self.image_reso)
|
899 |
+
y=torch.arange(self.image_reso)
|
900 |
+
X,Y=torch.meshgrid(x,y,indexing='xy')
|
901 |
+
vit_coords=torch.cat([X[:,:,None],Y[:,:,None]],dim=-1)
|
902 |
+
self.vit_coords=vit_coords.float().cuda() #in x,y order
|
903 |
+
|
904 |
+
def compute_attn_mask(self,proj_coord,vit_coords,valid_frames,kernel_size=2.0):
|
905 |
+
'''
|
906 |
+
:param proj_coord: B,K,H,W,2
|
907 |
+
:param vit_coords: H,W,2
|
908 |
+
:return:
|
909 |
+
'''
|
910 |
+
B,K=proj_coord.shape[0:2]
|
911 |
+
vit_coords_expand=vit_coords[None,None,:,:,:].expand(B,K,-1,-1,-1)
|
912 |
+
|
913 |
+
proj_coord=proj_coord.view(B*K,proj_coord.shape[2]*proj_coord.shape[3],proj_coord.shape[4])
|
914 |
+
vit_coords_expand=vit_coords_expand.view(B*K,self.image_reso*self.image_reso,2)
|
915 |
+
attn_mask=self.kernel_func(torch.cdist(proj_coord,vit_coords_expand),sigma=float(kernel_size))
|
916 |
+
attn_mask=attn_mask.reshape(B,K,proj_coord.shape[1],vit_coords_expand.shape[1])
|
917 |
+
valid_expand=valid_frames[:,:,None,None]
|
918 |
+
attn_mask[valid_frames>0,:,:]=True
|
919 |
+
attn_mask=attn_mask.permute(0,2,1,3)
|
920 |
+
attn_mask=attn_mask.reshape(B,proj_coord.shape[1],K*vit_coords_expand.shape[1])
|
921 |
+
atten_index=torch.where(attn_mask[0,0]==False)
|
922 |
+
return attn_mask
|
923 |
+
|
924 |
+
|
925 |
+
def forward(self,triplane_feat,image_feat,proj_mat,valid_frames):
|
926 |
+
'''
|
927 |
+
:param image_feat: B,C,K,16,16
|
928 |
+
:param proj_mat: B,K,4,4
|
929 |
+
:param valid_frames: B,K, true if have image, used to compute attn_mask for transformer
|
930 |
+
:return:
|
931 |
+
'''
|
932 |
+
image_feat=self.img_proj(image_feat)
|
933 |
+
batch_size=image_feat.shape[0] #K is number of frames
|
934 |
+
K=image_feat.shape[2]
|
935 |
+
triplane_coords=self.triplane_coords.unsqueeze(0).unsqueeze(1).expand(batch_size,K,-1,-1,-1) #B,K,192,64,3
|
936 |
+
#print(torch.amin(triplane_coords),torch.amax(triplane_coords))
|
937 |
+
coord_homo=torch.cat([triplane_coords,torch.ones((batch_size,K,triplane_coords.shape[2],triplane_coords.shape[3],1)).float().cuda()],dim=-1)
|
938 |
+
#print(coord_homo.shape,proj_mat.shape)
|
939 |
+
coord_inimg=torch.einsum('bjhwc,bjck->bjhwk',coord_homo,proj_mat.transpose(2,3))
|
940 |
+
x=coord_inimg[:,:,:,:,0]/coord_inimg[:,:,:,:,2]
|
941 |
+
y=coord_inimg[:,:,:,:,1]/coord_inimg[:,:,:,:,2]
|
942 |
+
x=x/(224.0-1.0)*(self.image_reso-1)
|
943 |
+
y=y/(224.0-1.0)*(self.image_reso-1)
|
944 |
+
|
945 |
+
xy=torch.cat([x[...,None],y[...,None]],dim=-1) #B,K,H,W,2
|
946 |
+
image_value=image_feat.view(image_feat.shape[0],image_feat.shape[1],-1).transpose(1,2)
|
947 |
+
triplane_query=triplane_feat.view(triplane_feat.shape[0],triplane_feat.shape[1],-1).transpose(1,2)
|
948 |
+
valid_frames=1.0-valid_frames.float()
|
949 |
+
attn_mask=torch.repeat_interleave(self.compute_attn_mask(xy,self.vit_coords,valid_frames),
|
950 |
+
self.n_heads,dim=0)
|
951 |
+
|
952 |
+
q=self.q(triplane_query)
|
953 |
+
k=self.k(image_value)
|
954 |
+
v=self.v(image_value)
|
955 |
+
#print(q.shape,k.shape,v.shape)
|
956 |
+
|
957 |
+
attn,_=self.attn(q,k,v,attn_mask=attn_mask)
|
958 |
+
#print(attn.shape)
|
959 |
+
output=self.out_proj(attn).transpose(1,2).reshape(batch_size,-1,triplane_feat.shape[2],triplane_feat.shape[3])
|
960 |
+
#print(output.shape)
|
961 |
+
return output
|
962 |
+
|
963 |
+
|
964 |
+
if __name__=="__main__":
|
965 |
+
# import sys
|
966 |
+
# sys.path.append("../..")
|
967 |
+
# from datasets.SingleView_dataset import Object_PartialPoints_Img
|
968 |
+
# from datasets.transforms import Aug_with_Tran
|
969 |
+
# #sampler=#Image_Vox_Local_Sampler_Pooling(reso=64,padding=0.1,out_channels=64,stride=4).cuda().float()
|
970 |
+
# sampler=Image_ExpandVox_attn_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=64,inner_channel=64
|
971 |
+
# ,out_channels=64,n_heads=8).cuda().float()
|
972 |
+
# # sampler=Image_Direct_AttenwMask_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
|
973 |
+
# # ,out_channels=64,n_heads=8).cuda().float()
|
974 |
+
# dataset_config = {
|
975 |
+
# "data_path": "/data1/haolin/datasets",
|
976 |
+
# "surface_size": 20000,
|
977 |
+
# "par_pc_size": 4096,
|
978 |
+
# "load_proj_mat": True,
|
979 |
+
# }
|
980 |
+
# transform = Aug_with_Tran()
|
981 |
+
# datasets = Object_PartialPoints_Img(dataset_config['data_path'], split_filename="val_par_img.json", split='val',
|
982 |
+
# transform=transform, sampling=False,
|
983 |
+
# num_samples=1024, return_surface=True, ret_sample=True,
|
984 |
+
# surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],
|
985 |
+
# surface_size=dataset_config['surface_size'],
|
986 |
+
# load_proj_mat=dataset_config['load_proj_mat'], load_image=True,
|
987 |
+
# load_org_img=False, load_triplane=True, replica=1)
|
988 |
+
#
|
989 |
+
# dataloader = torch.utils.data.DataLoader(
|
990 |
+
# datasets=datasets,
|
991 |
+
# batch_size=64,
|
992 |
+
# shuffle=True
|
993 |
+
# )
|
994 |
+
# iterator = dataloader.__iter__()
|
995 |
+
# data_batch = iterator.next()
|
996 |
+
# unflatten = torch.nn.Unflatten(1, (16, 16))
|
997 |
+
# image = data_batch['image'][:,:,:].cuda().float()
|
998 |
+
# #image=unflatten(image).permute(0,3,1,2)
|
999 |
+
# proj_mat = data_batch['proj_mat'].cuda().float()
|
1000 |
+
# triplane_feat=torch.randn((64,64,32*3,32)).cuda().float()
|
1001 |
+
# sampler(triplane_feat,image,proj_mat)
|
1002 |
+
# memory_usage=torch.cuda.max_memory_allocated() / MB
|
1003 |
+
# print("memory usage %f mb"%(memory_usage))
|
1004 |
+
|
1005 |
+
|
1006 |
+
import sys
|
1007 |
+
sys.path.append("../..")
|
1008 |
+
from datasets.SingleView_dataset import Object_PartialPoints_MultiImg
|
1009 |
+
from datasets.transforms import Aug_with_Tran
|
1010 |
+
|
1011 |
+
dataset_config = {
|
1012 |
+
"data_path": "/data1/haolin/datasets",
|
1013 |
+
"surface_size": 20000,
|
1014 |
+
"par_pc_size": 4096,
|
1015 |
+
"load_proj_mat": True,
|
1016 |
+
}
|
1017 |
+
transform = Aug_with_Tran()
|
1018 |
+
dataset = Object_PartialPoints_MultiImg(dataset_config['data_path'], split_filename="train_par_img.json", split='train',
|
1019 |
+
transform=transform, sampling=False,
|
1020 |
+
num_samples=1024, return_surface=True, ret_sample=True,
|
1021 |
+
surface_sampling=True, par_pc_size=dataset_config['par_pc_size'],
|
1022 |
+
surface_size=dataset_config['surface_size'],
|
1023 |
+
load_proj_mat=dataset_config['load_proj_mat'], load_image=True,
|
1024 |
+
load_org_img=False, load_triplane=True, replica=1)
|
1025 |
+
|
1026 |
+
dataloader = torch.utils.data.DataLoader(
|
1027 |
+
dataset=dataset,
|
1028 |
+
batch_size=10,
|
1029 |
+
shuffle=False
|
1030 |
+
)
|
1031 |
+
iterator = dataloader.__iter__()
|
1032 |
+
data_batch = iterator.next()
|
1033 |
+
#unflatten = torch.nn.Unflatten(2, (16, 16))
|
1034 |
+
image = data_batch['image'][:,:,:,:].cuda().float()
|
1035 |
+
#image=unflatten(image).permute(0,4,1,2,3)
|
1036 |
+
proj_mat = data_batch['proj_mat'].cuda().float()
|
1037 |
+
valid_frames = data_batch['valid_frames'].cuda().float()
|
1038 |
+
triplane_feat=torch.randn((10,128,32*3,32)).cuda().float()
|
1039 |
+
|
1040 |
+
# fusion_module=MultiImage_Fuse_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
|
1041 |
+
# ,out_channels=64,n_heads=8).cuda().float()
|
1042 |
+
fusion_module=MultiImage_Global_Sampler(reso=32,vit_reso=16,padding=0.1,img_in_channels=1280,triplane_in_channels=128,inner_channel=128
|
1043 |
+
,out_channels=64,n_heads=8).cuda().float()
|
1044 |
+
fusion_module(triplane_feat,image,proj_mat,valid_frames)
|
1045 |
+
memory_usage=torch.cuda.max_memory_allocated() / MB
|
1046 |
+
print("memory usage %f mb"%(memory_usage))
|
models/modules/parpoints_encoder.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch_scatter import scatter_mean, scatter_max
|
5 |
+
from .unet import UNet
|
6 |
+
from .resnet_block import ResnetBlockFC
|
7 |
+
from .PointEMB import PointEmbed
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
class ParPoint_Encoder(nn.Module):
|
11 |
+
''' PointNet-based encoder network with ResNet blocks for each point.
|
12 |
+
Number of input points are fixed.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
c_dim (int): dimension of latent code c
|
16 |
+
dim (int): input points dimension
|
17 |
+
hidden_dim (int): hidden dimension of the network
|
18 |
+
scatter_type (str): feature aggregation when doing local pooling
|
19 |
+
unet (bool): weather to use U-Net
|
20 |
+
unet_kwargs (str): U-Net parameters
|
21 |
+
plane_resolution (int): defined resolution for plane feature
|
22 |
+
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
|
23 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
24 |
+
n_blocks (int): number of blocks ResNetBlockFC layers
|
25 |
+
'''
|
26 |
+
|
27 |
+
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max', unet_kwargs=None,
|
28 |
+
plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
|
29 |
+
super().__init__()
|
30 |
+
self.c_dim = c_dim
|
31 |
+
|
32 |
+
self.fc_pos = nn.Linear(dim, 2 * hidden_dim)
|
33 |
+
self.blocks = nn.ModuleList([
|
34 |
+
ResnetBlockFC(2 * hidden_dim, hidden_dim) for i in range(n_blocks)
|
35 |
+
])
|
36 |
+
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
37 |
+
|
38 |
+
self.actvn = nn.ReLU()
|
39 |
+
self.hidden_dim = hidden_dim
|
40 |
+
|
41 |
+
self.unet = UNet(unet_kwargs['output_dim'], in_channels=c_dim, **unet_kwargs)
|
42 |
+
|
43 |
+
self.reso_plane = plane_resolution
|
44 |
+
self.plane_type = plane_type
|
45 |
+
self.padding = padding
|
46 |
+
|
47 |
+
if scatter_type == 'max':
|
48 |
+
self.scatter = scatter_max
|
49 |
+
elif scatter_type == 'mean':
|
50 |
+
self.scatter = scatter_mean
|
51 |
+
|
52 |
+
# takes in "p": point cloud and "query": sdf_xyz
|
53 |
+
# sample plane features for unlabeled_query as well
|
54 |
+
def forward(self, p,point_emb): # , query2):
|
55 |
+
batch_size, T, D = p.size()
|
56 |
+
#print('origin',torch.amin(p[0],dim=0),torch.amax(p[0],dim=0))
|
57 |
+
# acquire the index for each point
|
58 |
+
coord = {}
|
59 |
+
index = {}
|
60 |
+
if 'xz' in self.plane_type:
|
61 |
+
coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
|
62 |
+
index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
|
63 |
+
if 'xy' in self.plane_type:
|
64 |
+
coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
|
65 |
+
index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
|
66 |
+
if 'yz' in self.plane_type:
|
67 |
+
coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
|
68 |
+
index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
|
69 |
+
net = self.fc_pos(point_emb)
|
70 |
+
|
71 |
+
net = self.blocks[0](net)
|
72 |
+
for block in self.blocks[1:]:
|
73 |
+
pooled = self.pool_local(coord, index, net)
|
74 |
+
net = torch.cat([net, pooled], dim=2)
|
75 |
+
net = block(net)
|
76 |
+
|
77 |
+
c = self.fc_c(net)
|
78 |
+
#print(c.shape)
|
79 |
+
|
80 |
+
fea = {}
|
81 |
+
# second_sum = 0
|
82 |
+
if 'xz' in self.plane_type:
|
83 |
+
fea['xz'] = self.generate_plane_features(p, c,
|
84 |
+
plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
|
85 |
+
if 'xy' in self.plane_type:
|
86 |
+
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
|
87 |
+
if 'yz' in self.plane_type:
|
88 |
+
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
|
89 |
+
cat_feature = torch.cat([fea['xz'], fea['xy'], fea['yz']],
|
90 |
+
dim=2) # concat at row dimension
|
91 |
+
#print(cat_feature.shape)
|
92 |
+
plane_feat=self.unet(cat_feature)
|
93 |
+
|
94 |
+
return plane_feat
|
95 |
+
|
96 |
+
|
97 |
+
def normalize_coordinate(self, p, padding=0.1, plane='xz'):
|
98 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments
|
99 |
+
|
100 |
+
Args:
|
101 |
+
p (tensor): point
|
102 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
103 |
+
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
104 |
+
'''
|
105 |
+
if plane == 'xz':
|
106 |
+
xy = p[:, :, [0, 2]]
|
107 |
+
elif plane == 'xy':
|
108 |
+
xy = p[:, :, [0, 1]]
|
109 |
+
else:
|
110 |
+
xy = p[:, :, [1, 2]]
|
111 |
+
#print("origin",torch.amin(xy), torch.amax(xy))
|
112 |
+
xy=xy/2 #xy is originally -1 ~ 1
|
113 |
+
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
|
114 |
+
xy_new = xy_new + 0.5 # range (0, 1)
|
115 |
+
#print("scale",torch.amin(xy_new),torch.amax(xy_new))
|
116 |
+
|
117 |
+
# f there are outliers out of the range
|
118 |
+
if xy_new.max() >= 1:
|
119 |
+
xy_new[xy_new >= 1] = 1 - 10e-6
|
120 |
+
if xy_new.min() < 0:
|
121 |
+
xy_new[xy_new < 0] = 0.0
|
122 |
+
return xy_new
|
123 |
+
|
124 |
+
def coordinate2index(self, x, reso):
|
125 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
126 |
+
Corresponds to our 3D model
|
127 |
+
|
128 |
+
Args:
|
129 |
+
x (tensor): coordinate
|
130 |
+
reso (int): defined resolution
|
131 |
+
coord_type (str): coordinate type
|
132 |
+
'''
|
133 |
+
x = (x * reso).long()
|
134 |
+
index = x[:, :, 0] + reso * x[:, :, 1]
|
135 |
+
index = index[:, None, :]
|
136 |
+
return index
|
137 |
+
|
138 |
+
# xy is the normalized coordinates of the point cloud of each plane
|
139 |
+
# I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
|
140 |
+
def pool_local(self, xy, index, c):
|
141 |
+
bs, fea_dim = c.size(0), c.size(2)
|
142 |
+
keys = xy.keys()
|
143 |
+
|
144 |
+
c_out = 0
|
145 |
+
for key in keys:
|
146 |
+
# scatter plane features from points
|
147 |
+
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane ** 2)
|
148 |
+
if self.scatter == scatter_max:
|
149 |
+
fea = fea[0]
|
150 |
+
# gather feature back to points
|
151 |
+
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
|
152 |
+
c_out += fea
|
153 |
+
return c_out.permute(0, 2, 1)
|
154 |
+
|
155 |
+
def generate_plane_features(self, p, c, plane='xz'):
|
156 |
+
# acquire indices of features in plane
|
157 |
+
xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
|
158 |
+
index = self.coordinate2index(xy, self.reso_plane)
|
159 |
+
|
160 |
+
# scatter plane features from points
|
161 |
+
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane ** 2)
|
162 |
+
c = c.permute(0, 2, 1) # B x 512 x T
|
163 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
164 |
+
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane,
|
165 |
+
self.reso_plane) # sparce matrix (B x 512 x reso x reso)
|
166 |
+
#print(fea_plane.shape)
|
167 |
+
|
168 |
+
return fea_plane
|
models/modules/point_transformer.py
ADDED
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn, einsum
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import rearrange,repeat
|
5 |
+
from timm.models.layers import DropPath
|
6 |
+
from torch_cluster import fps
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
def zero_module(module):
|
10 |
+
"""
|
11 |
+
Zero out the parameters of a module and return it.
|
12 |
+
"""
|
13 |
+
for p in module.parameters():
|
14 |
+
p.detach().zero_()
|
15 |
+
return module
|
16 |
+
|
17 |
+
class PositionalEmbedding(torch.nn.Module):
|
18 |
+
def __init__(self, num_channels, max_positions=10000, endpoint=False):
|
19 |
+
super().__init__()
|
20 |
+
self.num_channels = num_channels
|
21 |
+
self.max_positions = max_positions
|
22 |
+
self.endpoint = endpoint
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
freqs = torch.arange(start=0, end=self.num_channels//2, dtype=torch.float32, device=x.device)
|
26 |
+
freqs = freqs / (self.num_channels // 2 - (1 if self.endpoint else 0))
|
27 |
+
freqs = (1 / self.max_positions) ** freqs
|
28 |
+
x = x.ger(freqs.to(x.dtype))
|
29 |
+
x = torch.cat([x.cos(), x.sin()], dim=1)
|
30 |
+
return x
|
31 |
+
|
32 |
+
class CrossAttention(nn.Module):
|
33 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
34 |
+
super().__init__()
|
35 |
+
inner_dim = dim_head * heads
|
36 |
+
|
37 |
+
if context_dim is None:
|
38 |
+
context_dim = query_dim
|
39 |
+
|
40 |
+
self.scale = dim_head ** -0.5
|
41 |
+
self.heads = heads
|
42 |
+
|
43 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
44 |
+
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
45 |
+
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
46 |
+
|
47 |
+
self.to_out = nn.Sequential(
|
48 |
+
nn.Linear(inner_dim, query_dim),
|
49 |
+
nn.Dropout(dropout)
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x, context=None, mask=None):
|
53 |
+
h = self.heads
|
54 |
+
|
55 |
+
q = self.to_q(x)
|
56 |
+
|
57 |
+
if context is None:
|
58 |
+
context = x
|
59 |
+
|
60 |
+
k = self.to_k(context)
|
61 |
+
v = self.to_v(context)
|
62 |
+
|
63 |
+
q, k, v = map(lambda t: rearrange(
|
64 |
+
t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
65 |
+
|
66 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
67 |
+
|
68 |
+
# attention, what we cannot get enough of
|
69 |
+
attn = sim.softmax(dim=-1)
|
70 |
+
|
71 |
+
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
72 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
73 |
+
return self.to_out(out)
|
74 |
+
|
75 |
+
|
76 |
+
class LayerScale(nn.Module):
|
77 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
78 |
+
super().__init__()
|
79 |
+
self.inplace = inplace
|
80 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
84 |
+
|
85 |
+
class GEGLU(nn.Module):
|
86 |
+
def __init__(self, dim_in, dim_out):
|
87 |
+
super().__init__()
|
88 |
+
self.proj = nn.Linear(dim_in, dim_out * 2)
|
89 |
+
|
90 |
+
def forward(self, x):
|
91 |
+
x, gate = self.proj(x).chunk(2, dim=-1)
|
92 |
+
return x * F.gelu(gate)
|
93 |
+
|
94 |
+
class FeedForward(nn.Module):
|
95 |
+
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
96 |
+
super().__init__()
|
97 |
+
inner_dim = int(dim * mult)
|
98 |
+
if dim_out is None:
|
99 |
+
dim_out = dim
|
100 |
+
|
101 |
+
project_in = nn.Sequential(
|
102 |
+
nn.Linear(dim, inner_dim),
|
103 |
+
nn.GELU()
|
104 |
+
) if not glu else GEGLU(dim, inner_dim)
|
105 |
+
|
106 |
+
self.net = nn.Sequential(
|
107 |
+
project_in,
|
108 |
+
nn.Dropout(dropout),
|
109 |
+
nn.Linear(inner_dim, dim_out)
|
110 |
+
)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return self.net(x)
|
114 |
+
|
115 |
+
class AdaLayerNorm(nn.Module):
|
116 |
+
def __init__(self, n_embd):
|
117 |
+
super().__init__()
|
118 |
+
|
119 |
+
self.silu = nn.SiLU()
|
120 |
+
self.linear = nn.Linear(n_embd, n_embd*2)
|
121 |
+
self.layernorm = nn.LayerNorm(n_embd, elementwise_affine=False)
|
122 |
+
|
123 |
+
def forward(self, x, timestep):
|
124 |
+
emb = self.linear(timestep)
|
125 |
+
scale, shift = torch.chunk(emb, 2, dim=2)
|
126 |
+
x = self.layernorm(x) * (1 + scale) + shift
|
127 |
+
return x
|
128 |
+
|
129 |
+
class BasicTransformerBlock(nn.Module):
|
130 |
+
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
|
131 |
+
super().__init__()
|
132 |
+
self.attn1 = CrossAttention(
|
133 |
+
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
|
134 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
135 |
+
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
136 |
+
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
137 |
+
self.norm1 = AdaLayerNorm(dim)
|
138 |
+
self.norm2 = AdaLayerNorm(dim)
|
139 |
+
self.norm3 = AdaLayerNorm(dim)
|
140 |
+
self.checkpoint = checkpoint
|
141 |
+
|
142 |
+
init_values = 0
|
143 |
+
drop_path = 0.0
|
144 |
+
|
145 |
+
|
146 |
+
self.ls1 = LayerScale(
|
147 |
+
dim, init_values=init_values) if init_values else nn.Identity()
|
148 |
+
self.drop_path1 = DropPath(
|
149 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
150 |
+
|
151 |
+
self.ls2 = LayerScale(
|
152 |
+
dim, init_values=init_values) if init_values else nn.Identity()
|
153 |
+
self.drop_path2 = DropPath(
|
154 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
155 |
+
|
156 |
+
self.ls3 = LayerScale(
|
157 |
+
dim, init_values=init_values) if init_values else nn.Identity()
|
158 |
+
self.drop_path3 = DropPath(
|
159 |
+
drop_path) if drop_path > 0. else nn.Identity()
|
160 |
+
|
161 |
+
def forward(self, x, t, context=None):
|
162 |
+
x = self.drop_path1(self.ls1(self.attn1(self.norm1(x, t)))) + x
|
163 |
+
x = self.drop_path2(self.ls2(self.attn2(self.norm2(x, t), context=context))) + x
|
164 |
+
x = self.drop_path3(self.ls3(self.ff(self.norm3(x, t)))) + x
|
165 |
+
return x
|
166 |
+
|
167 |
+
class LatentArrayTransformer(nn.Module):
|
168 |
+
"""
|
169 |
+
Transformer block for image-like data.
|
170 |
+
First, project the input (aka embedding)
|
171 |
+
and reshape to b, t, d.
|
172 |
+
Then apply standard transformer action.
|
173 |
+
Finally, reshape to image
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, in_channels, t_channels, n_heads, d_head,
|
177 |
+
depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None,
|
178 |
+
block=BasicTransformerBlock):
|
179 |
+
super().__init__()
|
180 |
+
self.in_channels = in_channels
|
181 |
+
inner_dim = n_heads * d_head
|
182 |
+
|
183 |
+
self.t_channels = t_channels
|
184 |
+
|
185 |
+
self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
|
186 |
+
|
187 |
+
self.transformer_blocks = nn.ModuleList(
|
188 |
+
[block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
189 |
+
for _ in range(depth)]
|
190 |
+
)
|
191 |
+
|
192 |
+
self.norm = nn.LayerNorm(inner_dim)
|
193 |
+
|
194 |
+
if out_channels is None:
|
195 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
|
196 |
+
else:
|
197 |
+
self.num_cls = out_channels
|
198 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
|
199 |
+
|
200 |
+
self.context_dim = context_dim
|
201 |
+
|
202 |
+
self.map_noise = PositionalEmbedding(t_channels)
|
203 |
+
|
204 |
+
self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
|
205 |
+
self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
|
206 |
+
|
207 |
+
# ###
|
208 |
+
# self.pos_emb = nn.Embedding(512, inner_dim)
|
209 |
+
# ###
|
210 |
+
|
211 |
+
def forward(self, x, t, cond, class_emb):
|
212 |
+
|
213 |
+
t_emb = self.map_noise(t)[:, None]
|
214 |
+
t_emb = F.silu(self.map_layer0(t_emb))
|
215 |
+
t_emb = F.silu(self.map_layer1(t_emb))
|
216 |
+
|
217 |
+
x = self.proj_in(x)
|
218 |
+
#print(class_emb.shape,t_emb.shape)
|
219 |
+
for block in self.transformer_blocks:
|
220 |
+
x = block(x, t_emb+class_emb[:,None,:], context=cond)
|
221 |
+
|
222 |
+
x = self.norm(x)
|
223 |
+
|
224 |
+
x = self.proj_out(x)
|
225 |
+
return x
|
226 |
+
|
227 |
+
class PointTransformer(nn.Module):
|
228 |
+
"""
|
229 |
+
Transformer block for image-like data.
|
230 |
+
First, project the input (aka embedding)
|
231 |
+
and reshape to b, t, d.
|
232 |
+
Then apply standard transformer action.
|
233 |
+
Finally, reshape to image
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(self, in_channels, t_channels, n_heads, d_head,
|
237 |
+
depth=1, dropout=0., context_dim=None, out_channels=None, context_dim2=None,
|
238 |
+
block=BasicTransformerBlock):
|
239 |
+
super().__init__()
|
240 |
+
self.in_channels = in_channels
|
241 |
+
inner_dim = n_heads * d_head
|
242 |
+
|
243 |
+
self.t_channels = t_channels
|
244 |
+
|
245 |
+
self.proj_in = nn.Linear(in_channels, inner_dim, bias=False)
|
246 |
+
|
247 |
+
self.transformer_blocks = nn.ModuleList(
|
248 |
+
[block(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
249 |
+
for _ in range(depth)]
|
250 |
+
)
|
251 |
+
|
252 |
+
self.norm = nn.LayerNorm(inner_dim)
|
253 |
+
|
254 |
+
if out_channels is None:
|
255 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels, bias=False))
|
256 |
+
else:
|
257 |
+
self.num_cls = out_channels
|
258 |
+
self.proj_out = zero_module(nn.Linear(inner_dim, out_channels, bias=False))
|
259 |
+
|
260 |
+
self.context_dim = context_dim
|
261 |
+
|
262 |
+
self.map_noise = PositionalEmbedding(t_channels)
|
263 |
+
|
264 |
+
self.map_layer0 = nn.Linear(in_features=t_channels, out_features=inner_dim)
|
265 |
+
self.map_layer1 = nn.Linear(in_features=inner_dim, out_features=inner_dim)
|
266 |
+
|
267 |
+
# ###
|
268 |
+
# self.pos_emb = nn.Embedding(512, inner_dim)
|
269 |
+
# ###
|
270 |
+
|
271 |
+
def forward(self, x, t, cond=None):
|
272 |
+
|
273 |
+
t_emb = self.map_noise(t)[:, None]
|
274 |
+
t_emb = F.silu(self.map_layer0(t_emb))
|
275 |
+
t_emb = F.silu(self.map_layer1(t_emb))
|
276 |
+
|
277 |
+
x = self.proj_in(x)
|
278 |
+
|
279 |
+
for block in self.transformer_blocks:
|
280 |
+
x = block(x, t_emb, context=cond)
|
281 |
+
|
282 |
+
x = self.norm(x)
|
283 |
+
|
284 |
+
x = self.proj_out(x)
|
285 |
+
return x
|
286 |
+
def exists(val):
|
287 |
+
return val is not None
|
288 |
+
|
289 |
+
def default(val, d):
|
290 |
+
return val if exists(val) else d
|
291 |
+
|
292 |
+
def cache_fn(f):
|
293 |
+
cache = None
|
294 |
+
@wraps(f)
|
295 |
+
def cached_fn(*args, _cache = True, **kwargs):
|
296 |
+
if not _cache:
|
297 |
+
return f(*args, **kwargs)
|
298 |
+
nonlocal cache
|
299 |
+
if cache is not None:
|
300 |
+
return cache
|
301 |
+
cache = f(*args, **kwargs)
|
302 |
+
return cache
|
303 |
+
return cached_fn
|
304 |
+
|
305 |
+
class PreNorm(nn.Module):
|
306 |
+
def __init__(self, dim, fn, context_dim = None):
|
307 |
+
super().__init__()
|
308 |
+
self.fn = fn
|
309 |
+
self.norm = nn.LayerNorm(dim)
|
310 |
+
self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None
|
311 |
+
|
312 |
+
def forward(self, x, **kwargs):
|
313 |
+
x = self.norm(x)
|
314 |
+
|
315 |
+
if exists(self.norm_context):
|
316 |
+
context = kwargs['context']
|
317 |
+
normed_context = self.norm_context(context)
|
318 |
+
kwargs.update(context = normed_context)
|
319 |
+
|
320 |
+
return self.fn(x, **kwargs)
|
321 |
+
|
322 |
+
class Attention(nn.Module):
|
323 |
+
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
|
324 |
+
super().__init__()
|
325 |
+
inner_dim = dim_head * heads
|
326 |
+
context_dim = default(context_dim, query_dim)
|
327 |
+
self.scale = dim_head ** -0.5
|
328 |
+
self.heads = heads
|
329 |
+
|
330 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
|
331 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
|
332 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
333 |
+
|
334 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
335 |
+
|
336 |
+
def forward(self, x, context = None, mask = None):
|
337 |
+
h = self.heads
|
338 |
+
|
339 |
+
q = self.to_q(x)
|
340 |
+
context = default(context, x)
|
341 |
+
k, v = self.to_kv(context).chunk(2, dim = -1)
|
342 |
+
|
343 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
|
344 |
+
|
345 |
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
346 |
+
|
347 |
+
if exists(mask):
|
348 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
349 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
350 |
+
mask = repeat(mask, 'b j -> (b h) () j', h = h)
|
351 |
+
sim.masked_fill_(~mask, max_neg_value)
|
352 |
+
|
353 |
+
# attention, what we cannot get enough of
|
354 |
+
attn = sim.softmax(dim = -1)
|
355 |
+
|
356 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
357 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
|
358 |
+
return self.drop_path(self.to_out(out))
|
359 |
+
|
360 |
+
|
361 |
+
class PointEmbed(nn.Module):
|
362 |
+
def __init__(self, hidden_dim=48, dim=128):
|
363 |
+
super().__init__()
|
364 |
+
|
365 |
+
assert hidden_dim % 6 == 0
|
366 |
+
|
367 |
+
self.embedding_dim = hidden_dim
|
368 |
+
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
|
369 |
+
e = torch.stack([
|
370 |
+
torch.cat([e, torch.zeros(self.embedding_dim // 6),
|
371 |
+
torch.zeros(self.embedding_dim // 6)]),
|
372 |
+
torch.cat([torch.zeros(self.embedding_dim // 6), e,
|
373 |
+
torch.zeros(self.embedding_dim // 6)]),
|
374 |
+
torch.cat([torch.zeros(self.embedding_dim // 6),
|
375 |
+
torch.zeros(self.embedding_dim // 6), e]),
|
376 |
+
])
|
377 |
+
self.register_buffer('basis', e) # 3 x 16
|
378 |
+
|
379 |
+
self.mlp = nn.Linear(self.embedding_dim + 3, dim)
|
380 |
+
|
381 |
+
@staticmethod
|
382 |
+
def embed(input, basis):
|
383 |
+
projections = torch.einsum(
|
384 |
+
'bnd,de->bne', input, basis)
|
385 |
+
embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
|
386 |
+
return embeddings
|
387 |
+
|
388 |
+
def forward(self, input):
|
389 |
+
# input: B x N x 3
|
390 |
+
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
|
391 |
+
return embed
|
392 |
+
|
393 |
+
|
394 |
+
class PointEncoder(nn.Module):
|
395 |
+
def __init__(self,
|
396 |
+
dim=512,
|
397 |
+
num_inputs = 2048,
|
398 |
+
num_latents = 512,
|
399 |
+
latent_dim = 512):
|
400 |
+
super().__init__()
|
401 |
+
|
402 |
+
self.num_inputs = num_inputs
|
403 |
+
self.num_latents = num_latents
|
404 |
+
|
405 |
+
self.cross_attend_blocks = nn.ModuleList([
|
406 |
+
PreNorm(dim, Attention(dim, dim, heads=1, dim_head=dim), context_dim=dim),
|
407 |
+
PreNorm(dim, FeedForward(dim))
|
408 |
+
])
|
409 |
+
|
410 |
+
self.point_embed = PointEmbed(dim=dim)
|
411 |
+
self.proj=nn.Linear(dim,latent_dim)
|
412 |
+
def encode(self, pc):
|
413 |
+
# pc: B x N x 3
|
414 |
+
B, N, D = pc.shape
|
415 |
+
assert N == self.num_inputs
|
416 |
+
|
417 |
+
###### fps
|
418 |
+
flattened = pc.view(B * N, D)
|
419 |
+
|
420 |
+
batch = torch.arange(B).to(pc.device)
|
421 |
+
batch = torch.repeat_interleave(batch, N)
|
422 |
+
|
423 |
+
pos = flattened
|
424 |
+
|
425 |
+
ratio = 1.0 * self.num_latents / self.num_inputs
|
426 |
+
|
427 |
+
idx = fps(pos, batch, ratio=ratio)
|
428 |
+
|
429 |
+
sampled_pc = pos[idx]
|
430 |
+
sampled_pc = sampled_pc.view(B, -1, 3)
|
431 |
+
######
|
432 |
+
|
433 |
+
sampled_pc_embeddings = self.point_embed(sampled_pc)
|
434 |
+
|
435 |
+
pc_embeddings = self.point_embed(pc)
|
436 |
+
|
437 |
+
cross_attn, cross_ff = self.cross_attend_blocks
|
438 |
+
|
439 |
+
x = cross_attn(sampled_pc_embeddings, context=pc_embeddings, mask=None) + sampled_pc_embeddings
|
440 |
+
x = cross_ff(x) + x
|
441 |
+
|
442 |
+
return self.proj(x)
|
models/modules/pointnet2_backbone.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
from external.pointnet2.pointnet2_modules import PointnetSAModuleVotes, PointnetFPModule
|
8 |
+
from .utils import zero_module
|
9 |
+
from .Positional_Embedding import PositionalEmbedding
|
10 |
+
|
11 |
+
class Pointnet2Encoder(nn.Module):
|
12 |
+
def __init__(self,input_feature_dim=0,npoints=[2048,1024,512,256],radius=[0.2,0.4,0.6,1.2],nsample=[64,32,16,8]):
|
13 |
+
super().__init__()
|
14 |
+
self.sa1 = PointnetSAModuleVotes(
|
15 |
+
npoint=npoints[0],
|
16 |
+
radius=radius[0],
|
17 |
+
nsample=nsample[0],
|
18 |
+
mlp=[input_feature_dim, 64, 64, 128],
|
19 |
+
use_xyz=True,
|
20 |
+
normalize_xyz=True
|
21 |
+
)
|
22 |
+
|
23 |
+
self.sa2 = PointnetSAModuleVotes(
|
24 |
+
npoint=npoints[1],
|
25 |
+
radius=radius[1],
|
26 |
+
nsample=nsample[1],
|
27 |
+
mlp=[128, 128, 128, 256],
|
28 |
+
use_xyz=True,
|
29 |
+
normalize_xyz=True
|
30 |
+
)
|
31 |
+
|
32 |
+
self.sa3 = PointnetSAModuleVotes(
|
33 |
+
npoint=npoints[2],
|
34 |
+
radius=radius[2],
|
35 |
+
nsample=nsample[2],
|
36 |
+
mlp=[256, 256, 256, 512],
|
37 |
+
use_xyz=True,
|
38 |
+
normalize_xyz=True
|
39 |
+
)
|
40 |
+
|
41 |
+
self.sa4 = PointnetSAModuleVotes(
|
42 |
+
npoint=npoints[3],
|
43 |
+
radius=radius[3],
|
44 |
+
nsample=nsample[3],
|
45 |
+
mlp=[512, 512, 512, 512],
|
46 |
+
use_xyz=True,
|
47 |
+
normalize_xyz=True
|
48 |
+
)
|
49 |
+
def _break_up_pc(self, pc):
|
50 |
+
xyz = pc[..., 0:3].contiguous()
|
51 |
+
features = (
|
52 |
+
pc[..., 3:].transpose(1, 2).contiguous()
|
53 |
+
if pc.size(-1) > 3 else None
|
54 |
+
)
|
55 |
+
|
56 |
+
return xyz, features
|
57 |
+
def forward(self,pointcloud,end_points=None):
|
58 |
+
if not end_points: end_points = {}
|
59 |
+
batch_size = pointcloud.shape[0]
|
60 |
+
|
61 |
+
xyz, features = self._break_up_pc(pointcloud)
|
62 |
+
|
63 |
+
end_points['org_xyz'] = xyz
|
64 |
+
# --------- 4 SET ABSTRACTION LAYERS ---------
|
65 |
+
xyz1, features1, _ = self.sa1(xyz, features)
|
66 |
+
end_points['sa1_xyz'] = xyz1
|
67 |
+
end_points['sa1_features'] = features1
|
68 |
+
|
69 |
+
xyz2, features2, _ = self.sa2(xyz1, features1) # this fps_inds is just 0,1,...,1023
|
70 |
+
end_points['sa2_xyz'] = xyz2
|
71 |
+
end_points['sa2_features'] = features2
|
72 |
+
|
73 |
+
xyz3, features3, _ = self.sa3(xyz2, features2) # this fps_inds is just 0,1,...,511
|
74 |
+
end_points['sa3_xyz'] = xyz3
|
75 |
+
end_points['sa3_features'] = features3
|
76 |
+
#print(xyz3.shape,features3.shape)
|
77 |
+
xyz4, features4, _ = self.sa4(xyz3, features3) # this fps_inds is just 0,1,...,255
|
78 |
+
end_points['sa4_xyz'] = xyz4
|
79 |
+
end_points['sa4_features'] = features4
|
80 |
+
#print(xyz4.shape,features4.shape)
|
81 |
+
return end_points
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
class PointUNet(nn.Module):
|
86 |
+
r"""
|
87 |
+
Backbone network for point cloud feature learning.
|
88 |
+
Based on Pointnet++ single-scale grouping network.
|
89 |
+
|
90 |
+
Parameters
|
91 |
+
----------
|
92 |
+
input_feature_dim: int
|
93 |
+
Number of input channels in the feature descriptor for each point.
|
94 |
+
e.g. 3 for RGB.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.noisy_encoder=Pointnet2Encoder()
|
101 |
+
self.cond_encoder=Pointnet2Encoder()
|
102 |
+
self.fp1_cross = PointnetFPModule(mlp=[512 + 512, 512, 512])
|
103 |
+
self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512])
|
104 |
+
#self.fp1 = PointnetFPModule(mlp=[512 + 512, 512, 512])
|
105 |
+
self.fp2_cross = PointnetFPModule(mlp=[512 + 512, 512, 256])
|
106 |
+
self.fp2 = PointnetFPModule(mlp=[256 + 256, 512, 256])
|
107 |
+
#self.fp2=PointnetFPModule(mlp=[512 + 256, 512, 256])
|
108 |
+
self.fp3_cross= PointnetFPModule(mlp=[256 + 256, 256, 128])
|
109 |
+
self.fp3 = PointnetFPModule(mlp=[128 + 128, 256, 128])
|
110 |
+
#self.fp3 = PointnetFPModule(mlp=[256 + 128, 256, 128])
|
111 |
+
self.fp4_cross=PointnetFPModule(mlp=[128+128, 128, 128])
|
112 |
+
self.fp4 = PointnetFPModule(mlp=[128, 128, 128])
|
113 |
+
#self.fp4 = PointnetFPModule(mlp=[128, 128, 128])
|
114 |
+
|
115 |
+
self.output_layer=nn.Sequential(
|
116 |
+
nn.LayerNorm(128),
|
117 |
+
zero_module(nn.Linear(in_features=128,out_features=3,bias=False))
|
118 |
+
)
|
119 |
+
self.t_emb_layer = PositionalEmbedding(256)
|
120 |
+
self.map_layer0 = nn.Linear(in_features=256, out_features=512)
|
121 |
+
self.map_layer1 = nn.Linear(in_features=512, out_features=512)
|
122 |
+
|
123 |
+
def forward(self, noise_points, t,cond_points):
|
124 |
+
r"""
|
125 |
+
Forward pass of the network
|
126 |
+
|
127 |
+
Parameters
|
128 |
+
----------
|
129 |
+
pointcloud: Variable(torch.cuda.FloatTensor)
|
130 |
+
(B, N, 3 + input_feature_dim) tensor
|
131 |
+
Point cloud to run predicts on
|
132 |
+
Each point in the point-cloud MUST
|
133 |
+
be formated as (x, y, z, features...)
|
134 |
+
|
135 |
+
Returns
|
136 |
+
----------
|
137 |
+
end_points: {XXX_xyz, XXX_features, XXX_inds}
|
138 |
+
XXX_xyz: float32 Tensor of shape (B,K,3)
|
139 |
+
XXX_features: float32 Tensor of shape (B,K,D)
|
140 |
+
XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1]
|
141 |
+
"""
|
142 |
+
t_emb = self.t_emb_layer(t)
|
143 |
+
t_emb = F.silu(self.map_layer0(t_emb))
|
144 |
+
t_emb = F.silu(self.map_layer1(t_emb))#B,512
|
145 |
+
t_emb = t_emb[:, :, None] #B,512,K
|
146 |
+
noise_end_points=self.noisy_encoder(noise_points)
|
147 |
+
cond=self.cond_encoder(cond_points)
|
148 |
+
# --------- 2 FEATURE UPSAMPLING LAYERS --------
|
149 |
+
features = self.fp1_cross(noise_end_points['sa4_xyz'],cond['sa4_xyz'],noise_end_points['sa4_features']+t_emb,
|
150 |
+
cond['sa4_features'])
|
151 |
+
features = self.fp1(noise_end_points['sa3_xyz'], noise_end_points['sa4_xyz'], noise_end_points['sa3_features'],
|
152 |
+
features)
|
153 |
+
features = self.fp2_cross(noise_end_points['sa3_xyz'],cond['sa3_xyz'],features,
|
154 |
+
cond["sa3_features"])
|
155 |
+
features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'],
|
156 |
+
features)
|
157 |
+
features = self.fp3_cross(noise_end_points['sa2_xyz'],cond['sa2_xyz'],features,
|
158 |
+
cond['sa2_features'])
|
159 |
+
features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features)
|
160 |
+
features = self.fp4_cross(noise_end_points['sa1_xyz'],cond['sa1_xyz'],features,
|
161 |
+
cond['sa1_features'])
|
162 |
+
features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features)
|
163 |
+
features=features.transpose(1,2)
|
164 |
+
|
165 |
+
# features = self.fp1_cross(noise_end_points['sa4_xyz'], cond_end_points['sa4_xyz'],
|
166 |
+
# noise_end_points['sa4_features']+t_emb, cond_end_points['sa4_features'])
|
167 |
+
# features = self.fp1(noise_end_points['sa3_xyz'].clone(), noise_end_points['sa4_xyz'].clone(), noise_end_points['sa3_features'],
|
168 |
+
# features)
|
169 |
+
# features = self.fp2(noise_end_points['sa2_xyz'], noise_end_points['sa3_xyz'], noise_end_points['sa2_features'],
|
170 |
+
# features)
|
171 |
+
# features = self.fp3(noise_end_points['sa1_xyz'],noise_end_points['sa2_xyz'],noise_end_points['sa1_features'],features)
|
172 |
+
# features = self.fp4(noise_end_points['org_xyz'], noise_end_points['sa1_xyz'], None, features)
|
173 |
+
# features = features.transpose(1,2)
|
174 |
+
output_points=self.output_layer(features)
|
175 |
+
|
176 |
+
return output_points
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == '__main__':
|
180 |
+
net=PointUNet().cuda().float()
|
181 |
+
net=net.eval()
|
182 |
+
noise_points=torch.randn(16,4096,3).cuda().float()
|
183 |
+
cond_points=torch.randn(16,4096,3).cuda().float()
|
184 |
+
t=torch.randn(16).cuda().float()
|
185 |
+
cond_encoder=Pointnet2Encoder().cuda().float()
|
186 |
+
|
187 |
+
out = net(noise_points,cond_points)
|
188 |
+
print(out.shape)
|
models/modules/resnet_block.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
# Resnet Blocks
|
6 |
+
class ResnetBlockFC(nn.Module):
|
7 |
+
''' Fully connected ResNet Block class.
|
8 |
+
Args:
|
9 |
+
size_in (int): input dimension
|
10 |
+
size_out (int): output dimension
|
11 |
+
size_h (int): hidden dimension
|
12 |
+
'''
|
13 |
+
|
14 |
+
def __init__(self, size_in, size_out=None, size_h=None):
|
15 |
+
super().__init__()
|
16 |
+
# Attributes
|
17 |
+
if size_out is None:
|
18 |
+
size_out = size_in
|
19 |
+
|
20 |
+
if size_h is None:
|
21 |
+
size_h = min(size_in, size_out)
|
22 |
+
|
23 |
+
self.size_in = size_in
|
24 |
+
self.size_h = size_h
|
25 |
+
self.size_out = size_out
|
26 |
+
# Submodules
|
27 |
+
self.fc_0 = nn.Linear(size_in, size_h)
|
28 |
+
self.fc_1 = nn.Linear(size_h, size_out)
|
29 |
+
self.actvn = nn.ReLU()
|
30 |
+
|
31 |
+
if size_in == size_out:
|
32 |
+
self.shortcut = None
|
33 |
+
else:
|
34 |
+
self.shortcut = nn.Linear(size_in, size_out, bias=False)
|
35 |
+
# Initialization
|
36 |
+
nn.init.zeros_(self.fc_1.weight)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
net = self.fc_0(self.actvn(x))
|
40 |
+
dx = self.fc_1(self.actvn(net))
|
41 |
+
|
42 |
+
if self.shortcut is not None:
|
43 |
+
x_s = self.shortcut(x)
|
44 |
+
else:
|
45 |
+
x_s = x
|
46 |
+
|
47 |
+
return x_s + dx
|
models/modules/resunet.py
ADDED
@@ -0,0 +1,440 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from .unet import RollOut_Conv
|
4 |
+
from .Positional_Embedding import PositionalEmbedding
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from .utils import zero_module
|
7 |
+
from .image_sampler import MultiImage_Fuse_Sampler, MultiImage_Global_Sampler,MultiImage_TriFuse_Sampler
|
8 |
+
|
9 |
+
class ResidualConv_MultiImgAtten(nn.Module):
|
10 |
+
def __init__(self, input_dim, output_dim, stride, padding, reso=64,
|
11 |
+
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
|
12 |
+
norm="batch"):
|
13 |
+
super(ResidualConv_MultiImgAtten, self).__init__()
|
14 |
+
self.use_attn=use_attn
|
15 |
+
|
16 |
+
if norm=="batch":
|
17 |
+
norm_layer=nn.BatchNorm2d
|
18 |
+
elif norm==None:
|
19 |
+
norm_layer=nn.Identity
|
20 |
+
|
21 |
+
self.conv_block = nn.Sequential(
|
22 |
+
norm_layer(input_dim),
|
23 |
+
nn.ReLU(),
|
24 |
+
nn.Conv2d(
|
25 |
+
input_dim, output_dim, kernel_size=3, padding=padding
|
26 |
+
)
|
27 |
+
)
|
28 |
+
self.out_layer=nn.Sequential(
|
29 |
+
norm_layer(output_dim),
|
30 |
+
nn.ReLU(),
|
31 |
+
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
32 |
+
)
|
33 |
+
self.conv_skip = nn.Sequential(
|
34 |
+
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
|
35 |
+
norm_layer(output_dim),
|
36 |
+
)
|
37 |
+
self.roll_out_conv=nn.Sequential(
|
38 |
+
norm_layer(output_dim),
|
39 |
+
nn.ReLU(),
|
40 |
+
RollOut_Conv(output_dim, output_dim),
|
41 |
+
)
|
42 |
+
if self.use_attn:
|
43 |
+
self.img_sampler = MultiImage_Fuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
|
44 |
+
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
|
45 |
+
out_channels=output_dim,padding=triplane_padding)
|
46 |
+
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
|
47 |
+
|
48 |
+
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
|
49 |
+
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
|
50 |
+
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
|
51 |
+
t_emb = F.silu(self.map_layer0(t_emb))
|
52 |
+
t_emb = F.silu(self.map_layer1(t_emb))
|
53 |
+
t_emb = t_emb[:,:,None,None]
|
54 |
+
|
55 |
+
out=self.conv_block(x)+t_emb
|
56 |
+
out=self.out_layer(out)
|
57 |
+
feature=out+self.conv_skip(x)
|
58 |
+
feature = self.roll_out_conv(feature)
|
59 |
+
if self.use_attn:
|
60 |
+
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
|
61 |
+
feature=self.down_conv(feature)
|
62 |
+
|
63 |
+
return feature
|
64 |
+
|
65 |
+
class ResidualConv_TriMultiImgAtten(nn.Module):
|
66 |
+
def __init__(self, input_dim, output_dim, stride, padding, reso=64,
|
67 |
+
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
|
68 |
+
norm="batch"):
|
69 |
+
super(ResidualConv_TriMultiImgAtten, self).__init__()
|
70 |
+
self.use_attn=use_attn
|
71 |
+
|
72 |
+
if norm=="batch":
|
73 |
+
norm_layer=nn.BatchNorm2d
|
74 |
+
elif norm==None:
|
75 |
+
norm_layer=nn.Identity
|
76 |
+
|
77 |
+
self.conv_block = nn.Sequential(
|
78 |
+
norm_layer(input_dim),
|
79 |
+
nn.ReLU(),
|
80 |
+
nn.Conv2d(
|
81 |
+
input_dim, output_dim, kernel_size=3, padding=padding
|
82 |
+
)
|
83 |
+
)
|
84 |
+
self.out_layer=nn.Sequential(
|
85 |
+
norm_layer(output_dim),
|
86 |
+
nn.ReLU(),
|
87 |
+
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
88 |
+
)
|
89 |
+
self.conv_skip = nn.Sequential(
|
90 |
+
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
|
91 |
+
norm_layer(output_dim),
|
92 |
+
)
|
93 |
+
self.roll_out_conv=nn.Sequential(
|
94 |
+
norm_layer(output_dim),
|
95 |
+
nn.ReLU(),
|
96 |
+
RollOut_Conv(output_dim, output_dim),
|
97 |
+
)
|
98 |
+
if self.use_attn:
|
99 |
+
self.img_sampler = MultiImage_TriFuse_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
|
100 |
+
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
|
101 |
+
out_channels=output_dim,max_nimg=5,padding=triplane_padding)
|
102 |
+
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
|
103 |
+
|
104 |
+
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
|
105 |
+
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
|
106 |
+
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
|
107 |
+
t_emb = F.silu(self.map_layer0(t_emb))
|
108 |
+
t_emb = F.silu(self.map_layer1(t_emb))
|
109 |
+
t_emb = t_emb[:,:,None,None]
|
110 |
+
|
111 |
+
out=self.conv_block(x)+t_emb
|
112 |
+
out=self.out_layer(out)
|
113 |
+
feature=out+self.conv_skip(x)
|
114 |
+
feature = self.roll_out_conv(feature)
|
115 |
+
if self.use_attn:
|
116 |
+
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
|
117 |
+
feature=self.down_conv(feature)
|
118 |
+
|
119 |
+
return feature
|
120 |
+
|
121 |
+
|
122 |
+
class ResidualConv_GlobalAtten(nn.Module):
|
123 |
+
def __init__(self, input_dim, output_dim, stride, padding, reso=64,
|
124 |
+
vit_reso=16,t_input_dim=256,img_in_channels=1280,use_attn=True,triplane_padding=0.1,
|
125 |
+
norm="batch"):
|
126 |
+
super(ResidualConv_GlobalAtten, self).__init__()
|
127 |
+
self.use_attn=use_attn
|
128 |
+
|
129 |
+
if norm=="batch":
|
130 |
+
norm_layer=nn.BatchNorm2d
|
131 |
+
elif norm==None:
|
132 |
+
norm_layer=nn.Identity
|
133 |
+
|
134 |
+
self.conv_block = nn.Sequential(
|
135 |
+
norm_layer(input_dim),
|
136 |
+
nn.ReLU(),
|
137 |
+
nn.Conv2d(
|
138 |
+
input_dim, output_dim, kernel_size=3, padding=padding
|
139 |
+
)
|
140 |
+
)
|
141 |
+
self.out_layer=nn.Sequential(
|
142 |
+
norm_layer(output_dim),
|
143 |
+
nn.ReLU(),
|
144 |
+
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
145 |
+
)
|
146 |
+
self.conv_skip = nn.Sequential(
|
147 |
+
nn.Conv2d(input_dim, output_dim, kernel_size=3, padding=1),
|
148 |
+
norm_layer(output_dim),
|
149 |
+
)
|
150 |
+
self.roll_out_conv=nn.Sequential(
|
151 |
+
norm_layer(output_dim),
|
152 |
+
nn.ReLU(),
|
153 |
+
RollOut_Conv(output_dim, output_dim),
|
154 |
+
)
|
155 |
+
if self.use_attn:
|
156 |
+
self.img_sampler = MultiImage_Global_Sampler(inner_channel=output_dim, triplane_in_channels=output_dim,
|
157 |
+
img_in_channels=img_in_channels,reso=reso,vit_reso=vit_reso,
|
158 |
+
out_channels=output_dim,max_nimg=5,padding=triplane_padding)
|
159 |
+
self.down_conv=nn.Conv2d(output_dim, output_dim, kernel_size=3, stride=stride, padding=padding)
|
160 |
+
|
161 |
+
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
|
162 |
+
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
|
163 |
+
def forward(self, x,t_emb,img_feat,proj_mat,valid_frames):
|
164 |
+
t_emb = F.silu(self.map_layer0(t_emb))
|
165 |
+
t_emb = F.silu(self.map_layer1(t_emb))
|
166 |
+
t_emb = t_emb[:,:,None,None]
|
167 |
+
|
168 |
+
out=self.conv_block(x)+t_emb
|
169 |
+
out=self.out_layer(out)
|
170 |
+
feature=out+self.conv_skip(x)
|
171 |
+
feature = self.roll_out_conv(feature)
|
172 |
+
if self.use_attn:
|
173 |
+
feature=self.img_sampler(feature,img_feat,proj_mat,valid_frames)+feature #skip connect
|
174 |
+
feature=self.down_conv(feature)
|
175 |
+
|
176 |
+
return feature
|
177 |
+
|
178 |
+
class ResidualConv(nn.Module):
|
179 |
+
def __init__(self, input_dim, output_dim, stride, padding, t_input_dim=256):
|
180 |
+
super(ResidualConv, self).__init__()
|
181 |
+
|
182 |
+
self.conv_block = nn.Sequential(
|
183 |
+
nn.BatchNorm2d(input_dim),
|
184 |
+
nn.ReLU(),
|
185 |
+
nn.Conv2d(
|
186 |
+
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
|
187 |
+
),
|
188 |
+
nn.BatchNorm2d(output_dim),
|
189 |
+
nn.ReLU(),
|
190 |
+
RollOut_Conv(output_dim,output_dim),
|
191 |
+
)
|
192 |
+
self.out_layer=nn.Sequential(
|
193 |
+
nn.BatchNorm2d(output_dim),
|
194 |
+
nn.ReLU(),
|
195 |
+
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
196 |
+
)
|
197 |
+
self.conv_skip = nn.Sequential(
|
198 |
+
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
|
199 |
+
nn.BatchNorm2d(output_dim),
|
200 |
+
)
|
201 |
+
|
202 |
+
self.map_layer0 = nn.Linear(in_features=t_input_dim, out_features=output_dim)
|
203 |
+
self.map_layer1 = nn.Linear(in_features=output_dim, out_features=output_dim)
|
204 |
+
def forward(self, x,t_emb):
|
205 |
+
t_emb = F.silu(self.map_layer0(t_emb))
|
206 |
+
t_emb = F.silu(self.map_layer1(t_emb))
|
207 |
+
t_emb = t_emb[:,:,None,None]
|
208 |
+
|
209 |
+
out=self.conv_block(x)+t_emb
|
210 |
+
out=self.out_layer(out)
|
211 |
+
|
212 |
+
return out + self.conv_skip(x)
|
213 |
+
|
214 |
+
class Upsample(nn.Module):
|
215 |
+
def __init__(self, input_dim, output_dim, kernel, stride):
|
216 |
+
super(Upsample, self).__init__()
|
217 |
+
|
218 |
+
self.upsample = nn.ConvTranspose2d(
|
219 |
+
input_dim, output_dim, kernel_size=kernel, stride=stride
|
220 |
+
)
|
221 |
+
|
222 |
+
def forward(self, x):
|
223 |
+
return self.upsample(x)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
class ResUnet_Par_cond(nn.Module):
|
228 |
+
def __init__(self, channel, filters=[64, 128, 256, 512, 1024],output_channel=32,par_channel=32):
|
229 |
+
super(ResUnet_Par_cond, self).__init__()
|
230 |
+
|
231 |
+
self.input_layer = nn.Sequential(
|
232 |
+
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
|
233 |
+
nn.BatchNorm2d(filters[0]),
|
234 |
+
nn.ReLU(),
|
235 |
+
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
|
236 |
+
)
|
237 |
+
self.input_skip = nn.Sequential(
|
238 |
+
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
|
239 |
+
)
|
240 |
+
|
241 |
+
self.residual_conv_1 = ResidualConv(filters[0]+par_channel, filters[1], 2, 1)
|
242 |
+
self.residual_conv_2 = ResidualConv(filters[1], filters[2], 2, 1)
|
243 |
+
self.residual_conv_3 = ResidualConv(filters[2], filters[3], 2, 1)
|
244 |
+
self.bridge = ResidualConv(filters[3],filters[4],2,1)
|
245 |
+
|
246 |
+
|
247 |
+
self.upsample_1 = Upsample(filters[4], filters[4], 2, 2)
|
248 |
+
self.up_residual_conv1 = ResidualConv(filters[4] + filters[3], filters[3], 1, 1)
|
249 |
+
|
250 |
+
self.upsample_2 = Upsample(filters[3], filters[3], 2, 2)
|
251 |
+
self.up_residual_conv2 = ResidualConv(filters[3] + filters[2], filters[2], 1, 1)
|
252 |
+
|
253 |
+
self.upsample_3 = Upsample(filters[2], filters[2], 2, 2)
|
254 |
+
self.up_residual_conv3 = ResidualConv(filters[2] + filters[1], filters[1], 1, 1)
|
255 |
+
|
256 |
+
self.upsample_4 = Upsample(filters[1], filters[1], 2, 2)
|
257 |
+
self.up_residual_conv4 = ResidualConv(filters[1] + filters[0]+par_channel, filters[0], 1, 1)
|
258 |
+
|
259 |
+
self.output_layer = nn.Sequential(
|
260 |
+
#nn.LayerNorm(filters[0]),
|
261 |
+
nn.LayerNorm(64),#normalize along width dimension, usually it should normalize along channel dimension,
|
262 |
+
# I don't know why, but the finetuning performance increase significantly
|
263 |
+
zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)),
|
264 |
+
)
|
265 |
+
self.par_channel=par_channel
|
266 |
+
self.par_conv=nn.Sequential(
|
267 |
+
nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1),
|
268 |
+
)
|
269 |
+
self.t_emb_layer=PositionalEmbedding(256)
|
270 |
+
self.cat_emb=nn.Linear(
|
271 |
+
in_features=6,
|
272 |
+
out_features=256,
|
273 |
+
)
|
274 |
+
|
275 |
+
def forward(self, x,t,category_code,par_point_feat):
|
276 |
+
# Encode
|
277 |
+
t_emb=self.t_emb_layer(t)
|
278 |
+
cat_emb=self.cat_emb(category_code)
|
279 |
+
t_emb=t_emb+cat_emb
|
280 |
+
#print(t_emb.shape)
|
281 |
+
x1 = self.input_layer(x) + self.input_skip(x)
|
282 |
+
if par_point_feat is not None:
|
283 |
+
par_point_feat=self.par_conv(par_point_feat)
|
284 |
+
else:
|
285 |
+
bs,_,H,W=x1.shape
|
286 |
+
#print(x1.shape)
|
287 |
+
par_point_feat=torch.zeros((bs,self.par_channel,H,W)).float().to(x1.device)
|
288 |
+
x1 = torch.cat([x1, par_point_feat], dim=1)
|
289 |
+
x2 = self.residual_conv_1(x1,t_emb)
|
290 |
+
x3 = self.residual_conv_2(x2,t_emb)
|
291 |
+
# Bridge
|
292 |
+
x4 = self.residual_conv_3(x3,t_emb)
|
293 |
+
x5 = self.bridge(x4,t_emb)
|
294 |
+
|
295 |
+
x6=self.upsample_1(x5)
|
296 |
+
x6=torch.cat([x6,x4],dim=1)
|
297 |
+
x7=self.up_residual_conv1(x6,t_emb)
|
298 |
+
|
299 |
+
x7=self.upsample_2(x7)
|
300 |
+
x7=torch.cat([x7,x3],dim=1)
|
301 |
+
x8=self.up_residual_conv2(x7,t_emb)
|
302 |
+
|
303 |
+
x8 = self.upsample_3(x8)
|
304 |
+
x8 = torch.cat([x8, x2], dim=1)
|
305 |
+
#print(x8.shape)
|
306 |
+
x9 = self.up_residual_conv3(x8,t_emb)
|
307 |
+
|
308 |
+
x9 = self.upsample_4(x9)
|
309 |
+
x9 = torch.cat([x9, x1], dim=1)
|
310 |
+
x10 = self.up_residual_conv4(x9,t_emb)
|
311 |
+
|
312 |
+
output=self.output_layer(x10)
|
313 |
+
|
314 |
+
return output
|
315 |
+
|
316 |
+
class ResUnet_DirectAttenMultiImg_Cond(nn.Module):
|
317 |
+
def __init__(self, channel, filters=[64, 128, 256, 512, 1024],
|
318 |
+
img_in_channels=1024,vit_reso=16,output_channel=32,
|
319 |
+
use_par=False,par_channel=32,triplane_padding=0.1,norm='batch',
|
320 |
+
use_cat_embedding=False,
|
321 |
+
block_type="multiview_local"):
|
322 |
+
super(ResUnet_DirectAttenMultiImg_Cond, self).__init__()
|
323 |
+
|
324 |
+
if block_type == "multiview_local":
|
325 |
+
block=ResidualConv_MultiImgAtten
|
326 |
+
elif block_type =="multiview_global":
|
327 |
+
block=ResidualConv_GlobalAtten
|
328 |
+
elif block_type =="multiview_tri":
|
329 |
+
block=ResidualConv_TriMultiImgAtten
|
330 |
+
else:
|
331 |
+
raise NotImplementedError
|
332 |
+
|
333 |
+
if norm=="batch":
|
334 |
+
norm_layer=nn.BatchNorm2d
|
335 |
+
elif norm==None:
|
336 |
+
norm_layer=nn.Identity
|
337 |
+
|
338 |
+
self.use_cat_embedding=use_cat_embedding
|
339 |
+
self.input_layer = nn.Sequential(
|
340 |
+
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
|
341 |
+
norm_layer(filters[0]),
|
342 |
+
nn.ReLU(),
|
343 |
+
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
|
344 |
+
)
|
345 |
+
self.input_skip = nn.Sequential(
|
346 |
+
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
|
347 |
+
)
|
348 |
+
self.use_par=use_par
|
349 |
+
input_1_channels=filters[0]
|
350 |
+
if self.use_par:
|
351 |
+
self.par_conv = nn.Sequential(
|
352 |
+
nn.Conv2d(par_channel, par_channel, kernel_size=3, padding=1),
|
353 |
+
)
|
354 |
+
input_1_channels=filters[0]+par_channel
|
355 |
+
self.residual_conv_1 = block(input_1_channels, filters[1], 2, 1,reso=64
|
356 |
+
,use_attn=False,triplane_padding=triplane_padding,norm=norm)
|
357 |
+
self.residual_conv_2 = block(filters[1], filters[2], 2, 1, reso=32,
|
358 |
+
use_attn=False,triplane_padding=triplane_padding,norm=norm)
|
359 |
+
self.residual_conv_3 = block(filters[2], filters[3], 2, 1,reso=16,
|
360 |
+
use_attn=False,triplane_padding=triplane_padding,norm=norm)
|
361 |
+
self.bridge = block(filters[3] , filters[4], 2, 1, reso=8
|
362 |
+
,use_attn=False,triplane_padding=triplane_padding,norm=norm) #input reso is 8, output reso is 4
|
363 |
+
|
364 |
+
|
365 |
+
self.upsample_1 = Upsample(filters[4], filters[4], 2, 2)
|
366 |
+
self.up_residual_conv1 = block(filters[4] + filters[3], filters[3], 1, 1,reso=8,img_in_channels=img_in_channels,vit_reso=vit_reso,
|
367 |
+
use_attn=True,triplane_padding=triplane_padding,norm=norm)
|
368 |
+
|
369 |
+
self.upsample_2 = Upsample(filters[3], filters[3], 2, 2)
|
370 |
+
self.up_residual_conv2 = block(filters[3] + filters[2], filters[2], 1, 1,reso=16,img_in_channels=img_in_channels,vit_reso=vit_reso,
|
371 |
+
use_attn=True,triplane_padding=triplane_padding,norm=norm)
|
372 |
+
|
373 |
+
self.upsample_3 = Upsample(filters[2], filters[2], 2, 2)
|
374 |
+
self.up_residual_conv3 = block(filters[2] + filters[1], filters[1], 1, 1,reso=32,img_in_channels=img_in_channels,vit_reso=vit_reso,
|
375 |
+
use_attn=True,triplane_padding=triplane_padding,norm=norm)
|
376 |
+
|
377 |
+
self.upsample_4 = Upsample(filters[1], filters[1], 2, 2)
|
378 |
+
self.up_residual_conv4 = block(filters[1] + input_1_channels, filters[0], 1, 1, reso=64,
|
379 |
+
use_attn=False,triplane_padding=triplane_padding,norm=norm)
|
380 |
+
|
381 |
+
self.output_layer = nn.Sequential(
|
382 |
+
nn.LayerNorm(64), #normalize along width dimension, usually it should normalize along channel dimension,
|
383 |
+
# I don't know why, but the finetuning performance increase significantly
|
384 |
+
#nn.LayerNorm([filters[0], 192, 64]),
|
385 |
+
zero_module(nn.Conv2d(filters[0], output_channel, 1, 1,bias=False)),
|
386 |
+
)
|
387 |
+
self.t_emb_layer=PositionalEmbedding(256)
|
388 |
+
if use_cat_embedding:
|
389 |
+
self.cat_emb = nn.Linear(
|
390 |
+
in_features=6,
|
391 |
+
out_features=256,
|
392 |
+
)
|
393 |
+
|
394 |
+
def forward(self, x,t,image_emb,proj_mat,valid_frames,category_code,par_point_feat=None):
|
395 |
+
# Encode
|
396 |
+
t_emb=self.t_emb_layer(t)
|
397 |
+
if self.use_cat_embedding:
|
398 |
+
cat_emb=self.cat_emb(category_code)
|
399 |
+
t_emb=t_emb+cat_emb
|
400 |
+
x1 = self.input_layer(x) + self.input_skip(x)
|
401 |
+
if self.use_par:
|
402 |
+
par_point_feat=self.par_conv(par_point_feat)
|
403 |
+
x1 = torch.cat([x1, par_point_feat], dim=1)
|
404 |
+
x2 = self.residual_conv_1(x1,t_emb,image_emb,proj_mat,valid_frames)
|
405 |
+
x3 = self.residual_conv_2(x2,t_emb,image_emb,proj_mat,valid_frames)
|
406 |
+
x4 = self.residual_conv_3(x3,t_emb,image_emb,proj_mat,valid_frames)
|
407 |
+
x5 = self.bridge(x4,t_emb,image_emb,proj_mat,valid_frames)
|
408 |
+
|
409 |
+
x6=self.upsample_1(x5)
|
410 |
+
x6=torch.cat([x6,x4],dim=1)
|
411 |
+
x7=self.up_residual_conv1(x6,t_emb,image_emb,proj_mat,valid_frames)
|
412 |
+
|
413 |
+
x7=self.upsample_2(x7)
|
414 |
+
x7=torch.cat([x7,x3],dim=1)
|
415 |
+
x8=self.up_residual_conv2(x7,t_emb,image_emb,proj_mat,valid_frames)
|
416 |
+
|
417 |
+
x8 = self.upsample_3(x8)
|
418 |
+
x8 = torch.cat([x8, x2], dim=1)
|
419 |
+
#print(x8.shape)
|
420 |
+
x9 = self.up_residual_conv3(x8,t_emb,image_emb,proj_mat,valid_frames)
|
421 |
+
|
422 |
+
x9 = self.upsample_4(x9)
|
423 |
+
x9 = torch.cat([x9, x1], dim=1)
|
424 |
+
x10 = self.up_residual_conv4(x9,t_emb,image_emb,proj_mat,valid_frames)
|
425 |
+
|
426 |
+
output=self.output_layer(x10)
|
427 |
+
|
428 |
+
return output
|
429 |
+
|
430 |
+
|
431 |
+
if __name__=="__main__":
|
432 |
+
net=ResUnet(32,output_channel=32).float().cuda()
|
433 |
+
n_parameters = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
434 |
+
print("Model = %s" % str(net))
|
435 |
+
print('number of params (M): %.2f' % (n_parameters / 1.e6))
|
436 |
+
par_point_feat=torch.randn((10,32,64*3,64)).float().cuda()
|
437 |
+
input=torch.randn((10,32,64*3,64)).float().cuda()
|
438 |
+
t=torch.randn((10,1,1,1)).float().cuda()
|
439 |
+
output=net(input,t.flatten(),par_point_feat)
|
440 |
+
#print(output.shape)
|
models/modules/unet.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Codes are from:
|
3 |
+
https://github.com/jaxony/unet-pytorch/blob/master/model.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch.autograd import Variable
|
10 |
+
from collections import OrderedDict
|
11 |
+
from torch.nn import init
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
|
15 |
+
def conv3x3(in_channels, out_channels, stride=1,
|
16 |
+
padding=1, bias=True, groups=1):
|
17 |
+
return nn.Conv2d(
|
18 |
+
in_channels,
|
19 |
+
out_channels,
|
20 |
+
kernel_size=3,
|
21 |
+
stride=stride,
|
22 |
+
padding=padding,
|
23 |
+
bias=bias,
|
24 |
+
groups=groups)
|
25 |
+
|
26 |
+
|
27 |
+
def upconv2x2(in_channels, out_channels, mode='transpose'):
|
28 |
+
if mode == 'transpose':
|
29 |
+
return nn.ConvTranspose2d(
|
30 |
+
in_channels,
|
31 |
+
out_channels,
|
32 |
+
kernel_size=2,
|
33 |
+
stride=2)
|
34 |
+
else:
|
35 |
+
# out_channels is always going to be the same
|
36 |
+
# as in_channels
|
37 |
+
return nn.Sequential(
|
38 |
+
nn.Upsample(mode='bilinear', scale_factor=2),
|
39 |
+
conv1x1(in_channels, out_channels))
|
40 |
+
|
41 |
+
|
42 |
+
def conv1x1(in_channels, out_channels, groups=1):
|
43 |
+
return nn.Conv2d(
|
44 |
+
in_channels,
|
45 |
+
out_channels,
|
46 |
+
kernel_size=1,
|
47 |
+
groups=groups,
|
48 |
+
stride=1)
|
49 |
+
|
50 |
+
class RollOut_Conv(nn.Module):
|
51 |
+
def __init__(self,in_channels,out_channels):
|
52 |
+
super(RollOut_Conv,self).__init__()
|
53 |
+
#pass
|
54 |
+
self.in_channels=in_channels
|
55 |
+
self.out_channels=out_channels
|
56 |
+
self.conv = conv3x3(self.in_channels*3, self.out_channels)
|
57 |
+
|
58 |
+
def forward(self,row_features):
|
59 |
+
H,W=row_features.shape[2],row_features.shape[3]
|
60 |
+
H_per=H//3
|
61 |
+
xz_feature,xy_feature,yz_feature=torch.split(row_features,dim=2,split_size_or_sections=H_per)
|
62 |
+
xy_row_pool=torch.mean(xy_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
|
63 |
+
yz_col_pool=torch.mean(yz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
|
64 |
+
cat_xz_feat=torch.cat([xz_feature,xy_row_pool,yz_col_pool],dim=1)
|
65 |
+
|
66 |
+
xz_row_pool=torch.mean(xz_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
|
67 |
+
zy_feature=yz_feature.transpose(2,3) #switch z y axis, for reduced confusion
|
68 |
+
zy_col_pool=torch.mean(zy_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
|
69 |
+
cat_xy_feat=torch.cat([xy_feature,xz_row_pool,zy_col_pool],dim=1)
|
70 |
+
|
71 |
+
xz_col_pool=torch.mean(xz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
|
72 |
+
yx_feature=xy_feature.transpose(2,3)
|
73 |
+
yx_row_pool=torch.mean(yx_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
|
74 |
+
cat_yz_feat=torch.cat([yz_feature,yx_row_pool,xz_col_pool],dim=1)
|
75 |
+
|
76 |
+
fuse_row_feat=torch.cat([cat_xz_feat,cat_xy_feat,cat_yz_feat],dim=2) #concat at row dimension
|
77 |
+
|
78 |
+
x = self.conv(fuse_row_feat)
|
79 |
+
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class DownConv(nn.Module):
|
84 |
+
"""
|
85 |
+
A helper Module that performs 2 convolutions and 1 MaxPool.
|
86 |
+
A ReLU activation follows each convolution.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(self, in_channels, out_channels, pooling=True):
|
90 |
+
super(DownConv, self).__init__()
|
91 |
+
|
92 |
+
self.in_channels = in_channels
|
93 |
+
self.out_channels = out_channels
|
94 |
+
self.pooling = pooling
|
95 |
+
|
96 |
+
self.conv1 = conv3x3(self.in_channels, self.out_channels)
|
97 |
+
self.Rollout_conv=RollOut_Conv(self.out_channels,self.out_channels)
|
98 |
+
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
99 |
+
|
100 |
+
if self.pooling:
|
101 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
x = F.relu(self.conv1(x))
|
105 |
+
x = F.relu(self.Rollout_conv(x))
|
106 |
+
x = F.relu(self.conv2(x))
|
107 |
+
before_pool = x
|
108 |
+
if self.pooling:
|
109 |
+
x = self.pool(x)
|
110 |
+
return x, before_pool
|
111 |
+
|
112 |
+
|
113 |
+
class UpConv(nn.Module):
|
114 |
+
"""
|
115 |
+
A helper Module that performs 2 convolutions and 1 UpConvolution.
|
116 |
+
A ReLU activation follows each convolution.
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, in_channels, out_channels,
|
120 |
+
merge_mode='concat', up_mode='transpose'):
|
121 |
+
super(UpConv, self).__init__()
|
122 |
+
|
123 |
+
self.in_channels = in_channels
|
124 |
+
self.out_channels = out_channels
|
125 |
+
self.merge_mode = merge_mode
|
126 |
+
self.up_mode = up_mode
|
127 |
+
|
128 |
+
self.upconv = upconv2x2(self.in_channels, self.out_channels,
|
129 |
+
mode=self.up_mode)
|
130 |
+
|
131 |
+
if self.merge_mode == 'concat':
|
132 |
+
self.conv1 = conv3x3(
|
133 |
+
2 * self.out_channels, self.out_channels)
|
134 |
+
else:
|
135 |
+
# num of input channels to conv2 is same
|
136 |
+
self.conv1 = conv3x3(self.out_channels, self.out_channels)
|
137 |
+
self.Rollout_conv = RollOut_Conv(self.out_channels, self.out_channels)
|
138 |
+
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
139 |
+
|
140 |
+
def forward(self, from_down, from_up):
|
141 |
+
""" Forward pass
|
142 |
+
Arguments:
|
143 |
+
from_down: tensor from the encoder pathway
|
144 |
+
from_up: upconv'd tensor from the decoder pathway
|
145 |
+
"""
|
146 |
+
from_up = self.upconv(from_up)
|
147 |
+
if self.merge_mode == 'concat':
|
148 |
+
x = torch.cat((from_up, from_down), 1)
|
149 |
+
else:
|
150 |
+
x = from_up + from_down
|
151 |
+
x = F.relu(self.conv1(x))
|
152 |
+
x = F.relu(self.Rollout_conv(x))
|
153 |
+
x = F.relu(self.conv2(x))
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class UNet(nn.Module):
|
158 |
+
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
|
159 |
+
|
160 |
+
The U-Net is a convolutional encoder-decoder neural network.
|
161 |
+
Contextual spatial information (from the decoding,
|
162 |
+
expansive pathway) about an input tensor is merged with
|
163 |
+
information representing the localization of details
|
164 |
+
(from the encoding, compressive pathway).
|
165 |
+
|
166 |
+
Modifications to the original paper:
|
167 |
+
(1) padding is used in 3x3 convolutions to prevent loss
|
168 |
+
of border pixels
|
169 |
+
(2) merging outputs does not require cropping due to (1)
|
170 |
+
(3) residual connections can be used by specifying
|
171 |
+
UNet(merge_mode='add')
|
172 |
+
(4) if non-parametric upsampling is used in the decoder
|
173 |
+
pathway (specified by upmode='upsample'), then an
|
174 |
+
additional 1x1 2d convolution occurs after upsampling
|
175 |
+
to reduce channel dimensionality by a factor of 2.
|
176 |
+
This channel halving happens with the convolution in
|
177 |
+
the tranpose convolution (specified by upmode='transpose')
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(self, num_classes, in_channels=3, depth=5,
|
181 |
+
start_filts=64, up_mode='transpose',
|
182 |
+
merge_mode='concat', **kwargs):
|
183 |
+
"""
|
184 |
+
Arguments:
|
185 |
+
in_channels: int, number of channels in the input tensor.
|
186 |
+
Default is 3 for RGB images.
|
187 |
+
depth: int, number of MaxPools in the U-Net.
|
188 |
+
start_filts: int, number of convolutional filters for the
|
189 |
+
first conv.
|
190 |
+
up_mode: string, type of upconvolution. Choices: 'transpose'
|
191 |
+
for transpose convolution or 'upsample' for nearest neighbour
|
192 |
+
upsampling.
|
193 |
+
"""
|
194 |
+
super(UNet, self).__init__()
|
195 |
+
|
196 |
+
if up_mode in ('transpose', 'upsample'):
|
197 |
+
self.up_mode = up_mode
|
198 |
+
else:
|
199 |
+
raise ValueError("\"{}\" is not a valid mode for "
|
200 |
+
"upsampling. Only \"transpose\" and "
|
201 |
+
"\"upsample\" are allowed.".format(up_mode))
|
202 |
+
|
203 |
+
if merge_mode in ('concat', 'add'):
|
204 |
+
self.merge_mode = merge_mode
|
205 |
+
else:
|
206 |
+
raise ValueError("\"{}\" is not a valid mode for"
|
207 |
+
"merging up and down paths. "
|
208 |
+
"Only \"concat\" and "
|
209 |
+
"\"add\" are allowed.".format(up_mode))
|
210 |
+
|
211 |
+
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
|
212 |
+
if self.up_mode == 'upsample' and self.merge_mode == 'add':
|
213 |
+
raise ValueError("up_mode \"upsample\" is incompatible "
|
214 |
+
"with merge_mode \"add\" at the moment "
|
215 |
+
"because it doesn't make sense to use "
|
216 |
+
"nearest neighbour to reduce "
|
217 |
+
"depth channels (by half).")
|
218 |
+
|
219 |
+
self.num_classes = num_classes
|
220 |
+
self.in_channels = in_channels
|
221 |
+
self.start_filts = start_filts
|
222 |
+
self.depth = depth
|
223 |
+
|
224 |
+
self.down_convs = []
|
225 |
+
self.up_convs = []
|
226 |
+
|
227 |
+
# create the encoder pathway and add to a list
|
228 |
+
for i in range(depth):
|
229 |
+
ins = self.in_channels if i == 0 else outs
|
230 |
+
outs = self.start_filts * (2 ** i)
|
231 |
+
pooling = True if i < depth - 1 else False
|
232 |
+
|
233 |
+
down_conv = DownConv(ins, outs, pooling=pooling)
|
234 |
+
self.down_convs.append(down_conv)
|
235 |
+
|
236 |
+
# create the decoder pathway and add to a list
|
237 |
+
# - careful! decoding only requires depth-1 blocks
|
238 |
+
for i in range(depth - 1):
|
239 |
+
ins = outs
|
240 |
+
outs = ins // 2
|
241 |
+
up_conv = UpConv(ins, outs, up_mode=up_mode,
|
242 |
+
merge_mode=merge_mode)
|
243 |
+
self.up_convs.append(up_conv)
|
244 |
+
|
245 |
+
# add the list of modules to current module
|
246 |
+
self.down_convs = nn.ModuleList(self.down_convs)
|
247 |
+
self.up_convs = nn.ModuleList(self.up_convs)
|
248 |
+
self.conv_final = conv1x1(outs, self.num_classes)
|
249 |
+
|
250 |
+
self.reset_params()
|
251 |
+
|
252 |
+
@staticmethod
|
253 |
+
def weight_init(m):
|
254 |
+
if isinstance(m, nn.Conv2d):
|
255 |
+
init.xavier_normal_(m.weight)
|
256 |
+
init.constant_(m.bias, 0)
|
257 |
+
|
258 |
+
def reset_params(self):
|
259 |
+
for i, m in enumerate(self.modules()):
|
260 |
+
self.weight_init(m)
|
261 |
+
|
262 |
+
def forward(self, feature_plane):
|
263 |
+
#cat_feature=torch.cat([feature_plane['xz'],feature_plane['xy'],feature_plane,feature_plane['yz']],dim=2) #concat at row dimension
|
264 |
+
x=feature_plane
|
265 |
+
encoder_outs = []
|
266 |
+
# encoder pathway, save outputs for merging
|
267 |
+
for i, module in enumerate(self.down_convs):
|
268 |
+
x, before_pool = module(x)
|
269 |
+
encoder_outs.append(before_pool)
|
270 |
+
for i, module in enumerate(self.up_convs):
|
271 |
+
before_pool = encoder_outs[-(i + 2)]
|
272 |
+
x = module(before_pool, x)
|
273 |
+
|
274 |
+
# No softmax is used. This means you need to use
|
275 |
+
# nn.CrossEntropyLoss is your training script,
|
276 |
+
# as this module includes a softmax already.
|
277 |
+
x = self.conv_final(x)
|
278 |
+
return x
|
279 |
+
|
280 |
+
|
281 |
+
if __name__ == "__main__":
|
282 |
+
# """
|
283 |
+
# testing
|
284 |
+
# """
|
285 |
+
# model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
|
286 |
+
# print(model)
|
287 |
+
# print(sum(p.numel() for p in model.parameters()))
|
288 |
+
#
|
289 |
+
# reso = 176
|
290 |
+
# x = np.zeros((1, 1, reso, reso))
|
291 |
+
# x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
|
292 |
+
# x = torch.FloatTensor(x)
|
293 |
+
#
|
294 |
+
# out = model(x)
|
295 |
+
# print('%f' % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso)))
|
296 |
+
#
|
297 |
+
# # loss = torch.sum(out)
|
298 |
+
# # loss.backward()
|
299 |
+
#roll_out_conv=RollOut_Conv(in_channels=32,out_channels=32).cuda().float()
|
300 |
+
model=UNet(32, depth=5, merge_mode='concat', in_channels=32, start_filts=32).cuda().float()
|
301 |
+
row_feature=torch.randn((10,32,128*3,128)).cuda().float()
|
302 |
+
output=model(row_feature)
|
303 |
+
#output_feature=roll_out_conv(row_feature)
|
304 |
+
#print(output_feature.shape)
|
models/modules/utils.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
def zero_module(module):
|
4 |
+
"""
|
5 |
+
Zero out the parameters of a module and return it.
|
6 |
+
"""
|
7 |
+
for p in module.parameters():
|
8 |
+
p.detach().zero_()
|
9 |
+
return module
|
10 |
+
|
11 |
+
class StackedRandomGenerator:
|
12 |
+
def __init__(self, device, seeds):
|
13 |
+
super().__init__()
|
14 |
+
self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
|
15 |
+
|
16 |
+
def randn(self, size, **kwargs):
|
17 |
+
assert size[0] == len(self.generators)
|
18 |
+
return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
|
19 |
+
|
20 |
+
def randn_like(self, input):
|
21 |
+
return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
|
22 |
+
|
23 |
+
def randint(self, *args, size, **kwargs):
|
24 |
+
assert size[0] == len(self.generators)
|
25 |
+
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
|
output/put_checkpoints_here
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
1
|
process_scripts/augment_arkit_partial_point.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
import os
|
4 |
+
import trimesh
|
5 |
+
from sklearn.cluster import KMeans
|
6 |
+
import random
|
7 |
+
import glob
|
8 |
+
import tqdm
|
9 |
+
import argparse
|
10 |
+
import multiprocessing as mp
|
11 |
+
import sys
|
12 |
+
sys.path.append("..")
|
13 |
+
from datasets.taxonomy import arkit_category
|
14 |
+
|
15 |
+
parser=argparse.ArgumentParser()
|
16 |
+
parser.add_argument('--category',nargs="+",type=str)
|
17 |
+
parser.add_argument("--keyword",type=str,default="lowres") #augment only the low resolution points
|
18 |
+
parser.add_argument("--data_root",type=str,default="../data/other_data")
|
19 |
+
args=parser.parse_args()
|
20 |
+
category=args.category
|
21 |
+
if category[0]=="all":
|
22 |
+
category=arkit_category["all"]
|
23 |
+
kmeans=KMeans(
|
24 |
+
init="random",
|
25 |
+
n_clusters=20,
|
26 |
+
n_init=10,
|
27 |
+
max_iter=300,
|
28 |
+
random_state=42
|
29 |
+
)
|
30 |
+
|
31 |
+
def process_data(src_point_path,save_folder,keyword):
|
32 |
+
src_point_tri = trimesh.load(src_point_path)
|
33 |
+
src_point = np.asarray(src_point_tri.vertices)
|
34 |
+
kmeans.fit(src_point)
|
35 |
+
point_cluster_index = kmeans.labels_
|
36 |
+
|
37 |
+
'''choose 10~19 clusters to form the augmented new point'''
|
38 |
+
for i in range(10):
|
39 |
+
n_cluster = random.randint(14, 19) # 14,19 for lowres, 10,19 for highres
|
40 |
+
choose_cluster = np.random.choice(20, n_cluster, replace=False)
|
41 |
+
aug_point_list = []
|
42 |
+
for cluster_index in choose_cluster:
|
43 |
+
cluster_point = src_point[point_cluster_index == cluster_index]
|
44 |
+
aug_point_list.append(cluster_point)
|
45 |
+
aug_point = np.concatenate(aug_point_list, axis=0)
|
46 |
+
save_path = os.path.join(save_folder, "%s_partial_points_%d.ply" % (keyword, i + 1))
|
47 |
+
print("saving to %s"%(save_path))
|
48 |
+
aug_point_tri = trimesh.PointCloud(vertices=aug_point)
|
49 |
+
aug_point_tri.export(save_path)
|
50 |
+
|
51 |
+
pool=mp.Pool(10)
|
52 |
+
for cat in category[0:]:
|
53 |
+
keyword=args.keyword
|
54 |
+
point_dir = os.path.join(args.data_root,cat,"5_partial_points")
|
55 |
+
folder_list=os.listdir(point_dir)
|
56 |
+
for folder in tqdm.tqdm(folder_list[0:]):
|
57 |
+
folder_path=os.path.join(point_dir,folder)
|
58 |
+
src_point_path=os.path.join(point_dir,folder,"%s_partial_points_0.ply"%(keyword))
|
59 |
+
if os.path.exists(src_point_path)==False:
|
60 |
+
continue
|
61 |
+
save_folder=folder_path
|
62 |
+
pool.apply_async(process_data,(src_point_path,save_folder,keyword))
|
63 |
+
pool.close()
|
64 |
+
pool.join()
|
process_scripts/augment_synthetic_partial_points.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
import os
|
4 |
+
import trimesh
|
5 |
+
from sklearn.cluster import KMeans
|
6 |
+
import random
|
7 |
+
import glob
|
8 |
+
import tqdm
|
9 |
+
import multiprocessing as mp
|
10 |
+
import sys
|
11 |
+
sys.path.append("..")
|
12 |
+
from datasets.taxonomy import synthetic_category_combined
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
parser=argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--category",nargs="+",type=str)
|
17 |
+
parser.add_argument("--root_dir",type=str,default="../data/other_data")
|
18 |
+
args=parser.parse_args()
|
19 |
+
categories=args.category
|
20 |
+
if categories[0]=="all":
|
21 |
+
categories=synthetic_category_combined["all"]
|
22 |
+
|
23 |
+
kmeans=KMeans(
|
24 |
+
init="random",
|
25 |
+
n_clusters=7,
|
26 |
+
n_init=10,
|
27 |
+
max_iter=300,
|
28 |
+
random_state=42
|
29 |
+
)
|
30 |
+
|
31 |
+
def process_data(src_filepath,save_path):
|
32 |
+
#print("processing %s"%(src_filepath))
|
33 |
+
src_point_tri = trimesh.load(src_filepath)
|
34 |
+
src_point = np.asarray(src_point_tri.vertices)
|
35 |
+
kmeans.fit(src_point)
|
36 |
+
point_cluster_index = kmeans.labels_
|
37 |
+
|
38 |
+
n_cluster = random.randint(3, 6)
|
39 |
+
choose_cluster = np.random.choice(7, n_cluster, replace=False)
|
40 |
+
aug_point_list = []
|
41 |
+
for cluster_index in choose_cluster:
|
42 |
+
cluster_point = src_point[point_cluster_index == cluster_index]
|
43 |
+
aug_point_list.append(cluster_point)
|
44 |
+
aug_point = np.concatenate(aug_point_list, axis=0)
|
45 |
+
aug_point_tri = trimesh.PointCloud(vertices=aug_point)
|
46 |
+
print("saving to %s"%(save_path))
|
47 |
+
aug_point_tri.export(save_path)
|
48 |
+
|
49 |
+
pool=mp.Pool(10)
|
50 |
+
for cat in categories:
|
51 |
+
print("processing %s"%cat)
|
52 |
+
point_dir=os.path.join(args.root_dir,cat,"5_partial_points")
|
53 |
+
folder_list=os.listdir(point_dir)
|
54 |
+
for folder in folder_list[:]:
|
55 |
+
folder_path=os.path.join(point_dir,folder)
|
56 |
+
src_filelist=glob.glob(folder_path+"/partial_points_*.ply")
|
57 |
+
for src_filepath in src_filelist:
|
58 |
+
basename=os.path.basename(src_filepath)
|
59 |
+
save_path = os.path.join(point_dir, folder, "aug7_" + basename)
|
60 |
+
pool.apply_async(process_data,(src_filepath,save_path))
|
61 |
+
pool.close()
|
62 |
+
pool.join()
|
63 |
+
|
64 |
+
|
process_scripts/dist_export_triplane_features.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15002 --nproc_per_node=2 \
|
2 |
+
export_triplane_features.py \
|
3 |
+
--configs ../configs/train_triplane_vae.yaml \
|
4 |
+
--batch_size 10 \
|
5 |
+
--ae-pth ../output/ae/chair/best-checkpoint.pth \
|
6 |
+
--data-pth ../data \
|
7 |
+
--category arkit_chair 03001627 future_chair arkit_stool future_stool ABO_chair
|
8 |
+
#sub category
|
process_scripts/dist_extract_vit.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES='0,1,2,3' torchrun --master_port 15000 --nproc_per_node=2 \
|
2 |
+
extract_img_vit_features.py \
|
3 |
+
--batch_size 24 \
|
4 |
+
--ckpt_path ../data/open_clip_pytorch_model.bin \
|
5 |
+
--category arkit_chair 03001627 future_chair arkit_stool future_stool ABO_chair #sub category
|
6 |
+
#--category 02871439 future_shelf ABO_shelf arkit_shelf \
|
process_scripts/export_triplane_features.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import math
|
3 |
+
import sys
|
4 |
+
sys.path.append("..")
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import trimesh
|
10 |
+
|
11 |
+
from datasets import Object_Occ,Scale_Shift_Rotate
|
12 |
+
from models import get_model
|
13 |
+
from pathlib import Path
|
14 |
+
import open3d as o3d
|
15 |
+
from configs.config_utils import CONFIG
|
16 |
+
import tqdm
|
17 |
+
from util import misc
|
18 |
+
from datasets.taxonomy import synthetic_arkit_category_combined
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
|
22 |
+
parser = argparse.ArgumentParser('', add_help=False)
|
23 |
+
parser.add_argument('--configs',type=str,required=True)
|
24 |
+
parser.add_argument('--ae-pth',type=str)
|
25 |
+
parser.add_argument("--category",nargs='+', type=str)
|
26 |
+
parser.add_argument('--world_size', default=1, type=int,
|
27 |
+
help='number of distributed processes')
|
28 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
29 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
30 |
+
parser.add_argument('--dist_url', default='env://',
|
31 |
+
help='url used to set up distributed training')
|
32 |
+
parser.add_argument('--device', default='cuda',
|
33 |
+
help='device to use for training / testing')
|
34 |
+
parser.add_argument("--batch_size", default=1, type=int)
|
35 |
+
parser.add_argument("--data-pth",default="../data",type=str)
|
36 |
+
|
37 |
+
args = parser.parse_args()
|
38 |
+
misc.init_distributed_mode(args)
|
39 |
+
device = torch.device(args.device)
|
40 |
+
|
41 |
+
config_path=args.configs
|
42 |
+
config=CONFIG(config_path)
|
43 |
+
dataset_config=config.config['dataset']
|
44 |
+
dataset_config['data_path']=args.data_pth
|
45 |
+
#transform = AxisScaling((0.75, 1.25), True)
|
46 |
+
transform=Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True)
|
47 |
+
if len(args.category)==1 and args.category[0]=="all":
|
48 |
+
category=synthetic_arkit_category_combined["all"]
|
49 |
+
else:
|
50 |
+
category=args.category
|
51 |
+
train_dataset = Object_Occ(dataset_config['data_path'], split="train",
|
52 |
+
categories=category,
|
53 |
+
transform=transform, sampling=True,
|
54 |
+
num_samples=1024, return_surface=True,
|
55 |
+
surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1)
|
56 |
+
val_dataset = Object_Occ(dataset_config['data_path'], split="val",
|
57 |
+
categories=category,
|
58 |
+
transform=transform, sampling=True,
|
59 |
+
num_samples=1024, return_surface=True,
|
60 |
+
surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1)
|
61 |
+
num_tasks = misc.get_world_size()
|
62 |
+
global_rank = misc.get_rank()
|
63 |
+
train_sampler = torch.utils.data.DistributedSampler(
|
64 |
+
train_dataset, num_replicas=num_tasks, rank=global_rank,
|
65 |
+
shuffle=False) # shuffle=True to reduce monitor bias
|
66 |
+
val_sampler=torch.utils.data.DistributedSampler(
|
67 |
+
val_dataset, num_replicas=num_tasks, rank=global_rank,
|
68 |
+
shuffle=False) # shu
|
69 |
+
#dataset=val_dataset
|
70 |
+
batch_size=args.batch_size
|
71 |
+
train_dataloader=torch.utils.data.DataLoader(
|
72 |
+
train_dataset,sampler=train_sampler,
|
73 |
+
batch_size=batch_size,
|
74 |
+
num_workers=10,
|
75 |
+
shuffle=False,
|
76 |
+
drop_last=False,
|
77 |
+
)
|
78 |
+
val_dataloader = torch.utils.data.DataLoader(
|
79 |
+
val_dataset, sampler=val_sampler,
|
80 |
+
batch_size=batch_size,
|
81 |
+
num_workers=10,
|
82 |
+
shuffle=False,
|
83 |
+
drop_last=False,
|
84 |
+
)
|
85 |
+
dataloader_list=[train_dataloader,val_dataloader]
|
86 |
+
#dataloader_list=[val_dataloader]
|
87 |
+
output_dir=os.path.join(dataset_config['data_path'],"other_data")
|
88 |
+
#output_dir="/data1/haolin/datasets/ShapeNetV2_watertight"
|
89 |
+
|
90 |
+
model_config=config.config['model']
|
91 |
+
model=get_model(model_config)
|
92 |
+
model.load_state_dict(torch.load(args.ae_pth)['model'])
|
93 |
+
model.eval().float().to(device)
|
94 |
+
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
|
95 |
+
|
96 |
+
with torch.no_grad():
|
97 |
+
for e in range(5):
|
98 |
+
for dataloader in dataloader_list:
|
99 |
+
for data_iter_step, data_batch in tqdm.tqdm(enumerate(dataloader)):
|
100 |
+
surface = data_batch['surface'].to(device, non_blocking=True)
|
101 |
+
model_ids=data_batch['model_id']
|
102 |
+
tran_mats=data_batch['tran_mat']
|
103 |
+
categories=data_batch['category']
|
104 |
+
with torch.no_grad():
|
105 |
+
plane_feat,_,means,logvars=model.encode(surface)
|
106 |
+
plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode='bilinear')
|
107 |
+
vars=torch.exp(logvars)
|
108 |
+
means=torch.nn.functional.interpolate(means,scale_factor=0.5,mode="bilinear")
|
109 |
+
vars=torch.nn.functional.interpolate(vars,scale_factor=0.5,mode="bilinear")/4
|
110 |
+
sample_logvars=torch.log(vars)
|
111 |
+
|
112 |
+
for j in range(means.shape[0]):
|
113 |
+
#plane_dist=plane_feat[j].float().cpu().numpy()
|
114 |
+
mean=means[j].float().cpu().numpy()
|
115 |
+
logvar=sample_logvars[j].float().cpu().numpy()
|
116 |
+
tran_mat=tran_mats[j].float().cpu().numpy()
|
117 |
+
|
118 |
+
output_folder=os.path.join(output_dir,categories[j],'9_triplane_kl25_64',model_ids[j])
|
119 |
+
Path(output_folder).mkdir(parents=True, exist_ok=True)
|
120 |
+
exist_len=len(os.listdir(output_folder))
|
121 |
+
save_filepath=os.path.join(output_folder,"triplane_feat_%d.npz"%(exist_len))
|
122 |
+
np.savez_compressed(save_filepath,mean=mean,logvar=logvar,tran_mat=tran_mat)
|
process_scripts/extract_img_vit_features.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os,sys
|
2 |
+
sys.path.append("..")
|
3 |
+
from util.simple_image_loader import Image_dataset
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import timm
|
6 |
+
import torch
|
7 |
+
from tqdm import tqdm
|
8 |
+
import numpy as np
|
9 |
+
from transformers import DPTForDepthEstimation, DPTFeatureExtractor
|
10 |
+
import argparse
|
11 |
+
from util import misc
|
12 |
+
from datasets.taxonomy import synthetic_arkit_category_combined
|
13 |
+
parser=argparse.ArgumentParser()
|
14 |
+
|
15 |
+
parser.add_argument("--category",nargs="+",type=str)
|
16 |
+
parser.add_argument("--root_dir",type=str, default="../data")
|
17 |
+
parser.add_argument("--ckpt_path",type=str,default="../open_clip_pytorch_model.bin")
|
18 |
+
parser.add_argument("--batch_size",type=int,default=24)
|
19 |
+
parser.add_argument('--world_size', default=1, type=int,
|
20 |
+
help='number of distributed processes')
|
21 |
+
parser.add_argument('--local_rank', default=-1, type=int)
|
22 |
+
parser.add_argument('--dist_on_itp', action='store_true')
|
23 |
+
parser.add_argument('--dist_url', default='env://',
|
24 |
+
help='url used to set up distributed training')
|
25 |
+
args= parser.parse_args()
|
26 |
+
misc.init_distributed_mode(args)
|
27 |
+
category=args.category
|
28 |
+
|
29 |
+
#dataset=Image_dataset(categories=['03001627','ABO_chair','future_chair'])
|
30 |
+
if args.category[0]=="all":
|
31 |
+
category=synthetic_arkit_category_combined["all"]
|
32 |
+
print("loading dataset")
|
33 |
+
dataset=Image_dataset(dataset_folder=args.root_dir,categories=category,n_px=224)
|
34 |
+
num_tasks = misc.get_world_size()
|
35 |
+
global_rank = misc.get_rank()
|
36 |
+
sampler = torch.utils.data.DistributedSampler(
|
37 |
+
dataset, num_replicas=num_tasks, rank=global_rank,
|
38 |
+
shuffle=False) # shuffle=True to reduce monitor bias
|
39 |
+
|
40 |
+
dataloader=DataLoader(
|
41 |
+
dataset,
|
42 |
+
sampler=sampler,
|
43 |
+
batch_size=args.batch_size,
|
44 |
+
num_workers=4,
|
45 |
+
pin_memory=True,
|
46 |
+
drop_last=False
|
47 |
+
)
|
48 |
+
print("loading model")
|
49 |
+
VIT_MODEL = 'vit_huge_patch14_224_clip_laion2b'
|
50 |
+
model=timm.create_model(VIT_MODEL, pretrained=True,pretrained_cfg_overlay=dict(file=args.ckpt_path))
|
51 |
+
model=model.eval().float().cuda()
|
52 |
+
save_dir=os.path.join(args.root_dir,"other_data")
|
53 |
+
for idx,data_batch in enumerate(dataloader):
|
54 |
+
if idx%50==0:
|
55 |
+
print("{}/{}".format(dataloader.__len__(),idx))
|
56 |
+
images = data_batch["images"].cuda().float()
|
57 |
+
model_id= data_batch["model_id"]
|
58 |
+
image_name=data_batch["image_name"]
|
59 |
+
category=data_batch["category"]
|
60 |
+
with torch.no_grad():
|
61 |
+
#output=model(images,output_hidden_states=True)
|
62 |
+
output_features=model.forward_features(images)
|
63 |
+
#predict_depth=output.predicted_depth
|
64 |
+
#print(predict_depth.shape)
|
65 |
+
for j in range(output_features.shape[0]):
|
66 |
+
save_folder=os.path.join(save_dir,category[j],"7_img_features",model_id[j])
|
67 |
+
os.makedirs(save_folder,exist_ok=True)
|
68 |
+
save_path=os.path.join(save_folder,image_name[j]+".npz")
|
69 |
+
#print("saving to",save_path)
|
70 |
+
np.savez_compressed(save_path,img_features=output_features[j].detach().cpu().numpy().astype(np.float32))
|
71 |
+
|
72 |
+
|
73 |
+
|
process_scripts/generate_split_for_arkit.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import glob
|
4 |
+
import open3d as o3d
|
5 |
+
import json
|
6 |
+
import argparse
|
7 |
+
import glob
|
8 |
+
|
9 |
+
parser=argparse.ArgumentParser()
|
10 |
+
parser.add_argument("--cat",required=True,type=str,nargs="+")
|
11 |
+
parser.add_argument("--keyword",default="lowres",type=str)
|
12 |
+
parser.add_argument("--root_dir",type=str,default="../data")
|
13 |
+
args=parser.parse_args()
|
14 |
+
|
15 |
+
keyword=args.keyword
|
16 |
+
sdf_folder="occ_data"
|
17 |
+
other_folder="other_data"
|
18 |
+
data_dir=args.root_dir
|
19 |
+
|
20 |
+
align_dir=os.path.join(args.root_dir,"align_mat_all") # this alignment matrix is aligned from highres scan to lowres scan
|
21 |
+
# the alignment matrix is still under cleaning, not all the data have proper alignment matrix yet.
|
22 |
+
align_filelist=glob.glob(align_dir+"/*/*.txt")
|
23 |
+
valid_model_list=[]
|
24 |
+
for align_filepath in align_filelist:
|
25 |
+
if "-v" in align_filepath:
|
26 |
+
align_mat=np.loadtxt(align_filepath)
|
27 |
+
if align_mat.shape[0]!=4:
|
28 |
+
continue
|
29 |
+
model_id=os.path.basename(align_filepath).split("-")[0]
|
30 |
+
valid_model_list.append(model_id)
|
31 |
+
|
32 |
+
print("there are %d valid lowres models"%(len(valid_model_list)))
|
33 |
+
|
34 |
+
category_list=args.cat
|
35 |
+
for category in category_list:
|
36 |
+
train_path=os.path.join(data_dir,sdf_folder,category,"train.lst")
|
37 |
+
with open(train_path,'r') as f:
|
38 |
+
train_list=f.readlines()
|
39 |
+
train_list=[item.rstrip() for item in train_list]
|
40 |
+
if ".npz" in train_list[0]:
|
41 |
+
train_list=[item[:-4] for item in train_list]
|
42 |
+
val_path=os.path.join(data_dir,sdf_folder,category,"val.lst")
|
43 |
+
with open(val_path,'r') as f:
|
44 |
+
val_list=f.readlines()
|
45 |
+
val_list=[item.rstrip() for item in val_list]
|
46 |
+
if ".npz" in val_list[0]:
|
47 |
+
val_list=[item[:-4] for item in val_list]
|
48 |
+
|
49 |
+
|
50 |
+
sdf_dir=os.path.join(data_dir,sdf_folder,category)
|
51 |
+
filelist=os.listdir(sdf_dir)
|
52 |
+
model_id_list=[item[:-4] for item in filelist if ".npz" in item]
|
53 |
+
|
54 |
+
train_par_img_list=[]
|
55 |
+
val_par_img_list=[]
|
56 |
+
for model_id in model_id_list:
|
57 |
+
if model_id not in valid_model_list:
|
58 |
+
continue
|
59 |
+
image_dir=os.path.join(data_dir,other_folder,category,"6_images",model_id)
|
60 |
+
partial_dir=os.path.join(data_dir,other_folder,category,"5_partial_points",model_id)
|
61 |
+
if os.path.exists(image_dir)==False and os.path.exists(partial_dir)==False:
|
62 |
+
continue
|
63 |
+
if os.path.exists(image_dir):
|
64 |
+
image_list=glob.glob(image_dir+"/*.jpg")+glob.glob(image_dir+"/*.png")
|
65 |
+
image_list=[os.path.basename(image_path) for image_path in image_list]
|
66 |
+
else:
|
67 |
+
image_list=[]
|
68 |
+
|
69 |
+
if os.path.exists(partial_dir):
|
70 |
+
partial_list=glob.glob(partial_dir+"/%s_partial_points_*.ply"%(keyword))
|
71 |
+
else:
|
72 |
+
partial_list=[]
|
73 |
+
partial_valid_list=[]
|
74 |
+
for partial_filepath in partial_list:
|
75 |
+
par_o3d=o3d.io.read_point_cloud(partial_filepath)
|
76 |
+
par_xyz=np.asarray(par_o3d.points)
|
77 |
+
if par_xyz.shape[0]>2048:
|
78 |
+
partial_valid_list.append(os.path.basename(partial_filepath))
|
79 |
+
if model_id in val_list:
|
80 |
+
if "%s_partial_points_0.ply"%(keyword) in partial_valid_list:
|
81 |
+
partial_valid_list=["%s_partial_points_0.ply"%(keyword)]
|
82 |
+
else:
|
83 |
+
partial_valid_list=[]
|
84 |
+
if len(image_list)==0 and len(partial_valid_list)==0:
|
85 |
+
continue
|
86 |
+
ret_dict={
|
87 |
+
"model_id":model_id,
|
88 |
+
"image_filenames":image_list[:],
|
89 |
+
"partial_filenames":partial_valid_list[:]
|
90 |
+
}
|
91 |
+
if model_id in train_list:
|
92 |
+
train_par_img_list.append(ret_dict)
|
93 |
+
elif model_id in val_list:
|
94 |
+
val_par_img_list.append(ret_dict)
|
95 |
+
|
96 |
+
train_save_path=os.path.join(sdf_dir,"%s_train_par_img.json"%(keyword))
|
97 |
+
with open(train_save_path,'w') as f:
|
98 |
+
json.dump(train_par_img_list,f,indent=4)
|
99 |
+
|
100 |
+
val_save_path=os.path.join(sdf_dir,"%s_val_par_img.json"%(keyword))
|
101 |
+
with open(val_save_path,'w') as f:
|
102 |
+
json.dump(val_par_img_list,f,indent=4)
|
process_scripts/generate_split_for_synthetic_data.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os,sys
|
2 |
+
import numpy as np
|
3 |
+
import glob
|
4 |
+
import open3d as o3d
|
5 |
+
import json
|
6 |
+
import argparse
|
7 |
+
|
8 |
+
parser=argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--cat",required=True,type=str,nargs="+")
|
10 |
+
parser.add_argument("--root_dir",type=str,default="../data")
|
11 |
+
args=parser.parse_args()
|
12 |
+
|
13 |
+
sdf_folder="occ_data"
|
14 |
+
other_folder="other_folder"
|
15 |
+
data_dir=args.root_dir
|
16 |
+
category=args.cat
|
17 |
+
train_path=os.path.join(data_dir,sdf_folder,category,"train.lst")
|
18 |
+
with open(train_path,'r') as f:
|
19 |
+
train_list=f.readlines()
|
20 |
+
train_list=[item.rstrip() for item in train_list]
|
21 |
+
if ".npz" in train_list[0]:
|
22 |
+
train_list=[item[:-4] for item in train_list]
|
23 |
+
val_path=os.path.join(data_dir,sdf_folder,category,"val.lst")
|
24 |
+
with open(val_path,'r') as f:
|
25 |
+
val_list=f.readlines()
|
26 |
+
val_list=[item.rstrip() for item in val_list]
|
27 |
+
if ".npz" in val_list[0]:
|
28 |
+
val_list=[item[:-4] for item in val_list]
|
29 |
+
|
30 |
+
category_list=args.cat
|
31 |
+
for category in category_list:
|
32 |
+
sdf_dir=os.path.join(data_dir,sdf_folder,category)
|
33 |
+
filelist=os.listdir(sdf_dir)
|
34 |
+
model_id_list=[item[:-4] for item in filelist if ".npz" in item]
|
35 |
+
|
36 |
+
train_par_img_list=[]
|
37 |
+
val_par_img_list=[]
|
38 |
+
for model_id in model_id_list:
|
39 |
+
image_dir=os.path.join(data_dir,other_folder,category,"6_images",model_id)
|
40 |
+
partial_dir=os.path.join(data_dir,other_folder,category,"5_partial_points",model_id)
|
41 |
+
if os.path.exists(image_dir)==False and os.path.exists(partial_dir)==False:
|
42 |
+
continue
|
43 |
+
if os.path.exists(image_dir):
|
44 |
+
image_list=glob.glob(image_dir+"/*.jpg")+glob.glob(image_dir+"/*.png")
|
45 |
+
image_list=[os.path.basename(image_path) for image_path in image_list]
|
46 |
+
else:
|
47 |
+
image_list=[]
|
48 |
+
|
49 |
+
if os.path.exists(partial_dir):
|
50 |
+
partial_list=glob.glob(partial_dir+"/partial_points_*.ply")
|
51 |
+
else:
|
52 |
+
partial_list=[]
|
53 |
+
partial_valid_list=[]
|
54 |
+
for partial_filepath in partial_list:
|
55 |
+
par_o3d=o3d.io.read_point_cloud(partial_filepath)
|
56 |
+
par_xyz=np.asarray(par_o3d.points)
|
57 |
+
if par_xyz.shape[0]>2048:
|
58 |
+
partial_valid_list.append(os.path.basename(partial_filepath))
|
59 |
+
if len(image_list)==0 and len(partial_valid_list)==0:
|
60 |
+
continue
|
61 |
+
ret_dict={
|
62 |
+
"model_id":model_id,
|
63 |
+
"image_filenames":image_list[:],
|
64 |
+
"partial_filenames":partial_valid_list[:]
|
65 |
+
}
|
66 |
+
if model_id in train_list:
|
67 |
+
train_par_img_list.append(ret_dict)
|
68 |
+
elif model_id in val_list:
|
69 |
+
val_par_img_list.append(ret_dict)
|
70 |
+
|
71 |
+
#print(train_par_img_list)
|
72 |
+
train_save_path=os.path.join(sdf_dir,"train_par_img.json")
|
73 |
+
with open(train_save_path,'w') as f:
|
74 |
+
json.dump(train_par_img_list,f,indent=4)
|
75 |
+
|
76 |
+
val_save_path=os.path.join(sdf_dir,"val_par_img.json")
|
77 |
+
with open(val_save_path,'w') as f:
|
78 |
+
json.dump(val_par_img_list,f,indent=4)
|
process_scripts/unzip_all_data.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
parser = argparse.ArgumentParser("unzip the prepared data")
|
6 |
+
parser.add_argument("--occ_root", type=str, default="../data/occ_data")
|
7 |
+
parser.add_argument("--other_root", type=str,default="../data/other_data")
|
8 |
+
parser.add_argument("--unzip_occ",default=False,action="store_true")
|
9 |
+
parser.add_argument("--unzip_other",default=False,action="store_true")
|
10 |
+
|
11 |
+
args=parser.parse_args()
|
12 |
+
if args.unzip_occ:
|
13 |
+
filelist=os.listdir(args.occ_root)
|
14 |
+
for filename in filelist:
|
15 |
+
filepath=os.path.join(args.occ_root,filename)
|
16 |
+
if ".rar" in filename:
|
17 |
+
unrar_command="unrar x %s %s"%(filepath,args.occ_root)
|
18 |
+
os.system(unrar_command)
|
19 |
+
elif ".zip" in filename:
|
20 |
+
unzip_command="7z x %s -o%s"%(filepath,args.occ_root)
|
21 |
+
os.system(unzip_command)
|
22 |
+
|
23 |
+
|
24 |
+
if args.unzip_other:
|
25 |
+
category_list=os.listdir(args.other_root)
|
26 |
+
for category in category_list:
|
27 |
+
category_folder=os.path.join(args.other_root,category)
|
28 |
+
#print(category_folder)
|
29 |
+
rar_filelist=glob.glob(category_folder+"/*.rar")
|
30 |
+
zip_filelist=glob.glob(category_folder+"/*.zip")
|
31 |
+
|
32 |
+
for rar_filepath in rar_filelist:
|
33 |
+
unrar_command="unrar x %s %s"%(rar_filepath,category_folder)
|
34 |
+
os.system(unrar_command)
|
35 |
+
for zip_filepath in zip_filelist:
|
36 |
+
unzip_command="7z x %s -o%s"%(zip_filepath,category_folder)
|
37 |
+
os.system(unzip_command)
|
38 |
+
|