Spaces:
Running
on
Zero
Running
on
Zero
daidedou
commited on
Commit
·
e321b92
1
Parent(s):
458efe2
forgot a few things lol
Browse files- config/diffusion/dfaust_fmap.yaml +65 -0
- config/matching/diff_mask.yaml +28 -0
- config/matching/lap_mask.yaml +23 -0
- config/matching/resol_mask.yaml +23 -0
- config/matching/sds.yaml +35 -0
- config/matching/sds_dt4d.yaml +35 -0
- config/matching/sds_slow.yaml +40 -0
- config/matching/sds_smal.yaml +35 -0
- config/matching/snk.yaml +27 -0
- diffu_models/basis_dataset.py +314 -0
- diffu_models/dit_models.py +383 -0
- diffu_models/losses.py +96 -0
- diffu_models/precond.py +152 -0
- diffu_models/sds.py +77 -0
- shape_data/__init__.py +74 -0
- shape_data/data_utils.py +270 -0
- shape_data/dt4dinter.py +50 -0
- shape_data/dt4dintra.py +57 -0
- shape_data/faust.py +408 -0
- shape_data/scape.py +17 -0
- shape_data/shrec19.py +42 -0
- shape_data/smalr.py +31 -0
- shape_data/tosca.py +46 -0
- snk/__init__.py +0 -0
- snk/loss.py +119 -0
- snk/prism_decoder.py +86 -0
config/diffusion/dfaust_fmap.yaml
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# misc
|
| 2 |
+
misc:
|
| 3 |
+
cuda: True
|
| 4 |
+
device: 0
|
| 5 |
+
checkpoint_interval: 1
|
| 6 |
+
log_interval: 812
|
| 7 |
+
desc: null
|
| 8 |
+
precond: False
|
| 9 |
+
dry_run: False
|
| 10 |
+
|
| 11 |
+
data:
|
| 12 |
+
root_dir: "data_cache"
|
| 13 |
+
name: DFAUST_fmap_30
|
| 14 |
+
n_fmap: 30
|
| 15 |
+
out: "fmap_exps"
|
| 16 |
+
cond: False
|
| 17 |
+
template_path: "data/template.ply"
|
| 18 |
+
normalize: True
|
| 19 |
+
pairs: False
|
| 20 |
+
abs: True
|
| 21 |
+
|
| 22 |
+
add_name:
|
| 23 |
+
do: False
|
| 24 |
+
name: "bis"
|
| 25 |
+
|
| 26 |
+
architecture:
|
| 27 |
+
model: "DiT"
|
| 28 |
+
name_arch: "DiT-S/4"
|
| 29 |
+
input_type: "img"
|
| 30 |
+
cond: False # Conditioning with 3D-CODED
|
| 31 |
+
|
| 32 |
+
## loss params
|
| 33 |
+
#loss:
|
| 34 |
+
# w_gt: False # if one wants to train as a supervised method, one should set w_gt=True
|
| 35 |
+
# w_ortho: 1 # orthogonal loss for functional map (default: 1)
|
| 36 |
+
# w_Qortho: 0 # orthogonal loss for complex functional map (default: 1)
|
| 37 |
+
# w_bij: 1
|
| 38 |
+
# w_res: 1 # residual loss for functional map (default: 1)
|
| 39 |
+
# w_rank: -0.1
|
| 40 |
+
# w_srnf: 1
|
| 41 |
+
# min_alpha: 1
|
| 42 |
+
# max_alpha: 100
|
| 43 |
+
#
|
| 44 |
+
|
| 45 |
+
hyper_params:
|
| 46 |
+
iterations: 200
|
| 47 |
+
batch_size: 256
|
| 48 |
+
lr: 0.001
|
| 49 |
+
lr_rampup_kimg: 10000 # Learning rate ramp-up duration
|
| 50 |
+
ema_halflife_nshape : 500 # ema half-life of the exponential moving average (EMA) of model weights.
|
| 51 |
+
ema_rampup_ratio : 0.05 # EMA ramp-up coefficient, None = no rampup.
|
| 52 |
+
dropout: 0
|
| 53 |
+
loss_name: 'VPLoss'
|
| 54 |
+
ls : 1 #loss scaling
|
| 55 |
+
|
| 56 |
+
perfs:
|
| 57 |
+
fp16: False
|
| 58 |
+
workers: 1
|
| 59 |
+
|
| 60 |
+
resume:
|
| 61 |
+
pkl: null
|
| 62 |
+
transfer: null
|
| 63 |
+
kimg_per_tick: 5
|
| 64 |
+
snapshot_ticks: 50
|
| 65 |
+
state_dump_ticks: 50
|
config/matching/diff_mask.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu: 0
|
| 2 |
+
cache: "cache/fmaps"
|
| 3 |
+
|
| 4 |
+
sds: True
|
| 5 |
+
optimize: False
|
| 6 |
+
|
| 7 |
+
sds_conf:
|
| 8 |
+
train_dir: pretrained
|
| 9 |
+
diff_num_exp: 53
|
| 10 |
+
zoomout: 40
|
| 11 |
+
|
| 12 |
+
deepfeat_conf:
|
| 13 |
+
fmap:
|
| 14 |
+
feat: "xyz"
|
| 15 |
+
n_fmap: 30
|
| 16 |
+
C_in: 3
|
| 17 |
+
n_feat: 128 ## Doesn't change
|
| 18 |
+
lambda_: 2
|
| 19 |
+
use_diff: True
|
| 20 |
+
diffusion:
|
| 21 |
+
abs: True
|
| 22 |
+
normalize: False
|
| 23 |
+
time: 1
|
| 24 |
+
batch_sds: 32
|
| 25 |
+
batch_mask: 200
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
zo_shot: 150
|
config/matching/lap_mask.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu: 0
|
| 2 |
+
cache: "cache/fmaps"
|
| 3 |
+
|
| 4 |
+
sds: True
|
| 5 |
+
optimize: False
|
| 6 |
+
|
| 7 |
+
sds_conf:
|
| 8 |
+
train_dir: pretrained
|
| 9 |
+
diff_num_exp: 53
|
| 10 |
+
zoomout: 40
|
| 11 |
+
|
| 12 |
+
deepfeat_conf:
|
| 13 |
+
fmap:
|
| 14 |
+
feat: "xyz"
|
| 15 |
+
n_fmap: 30
|
| 16 |
+
C_in: 3
|
| 17 |
+
n_feat: 128 ## Doesn't change
|
| 18 |
+
lambda_: 1e-3
|
| 19 |
+
use_resolvent: False ## Don't forget to change lambda values if you want to use the resolvent mask (around 100)
|
| 20 |
+
resolvent_gamma: 0.5
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
zo_shot: 150
|
config/matching/resol_mask.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu: 0
|
| 2 |
+
cache: "cache/fmaps"
|
| 3 |
+
|
| 4 |
+
sds: True
|
| 5 |
+
optimize: False
|
| 6 |
+
|
| 7 |
+
sds_conf:
|
| 8 |
+
train_dir: pretrained
|
| 9 |
+
diff_num_exp: 53
|
| 10 |
+
zoomout: 40
|
| 11 |
+
|
| 12 |
+
deepfeat_conf:
|
| 13 |
+
fmap:
|
| 14 |
+
feat: "xyz"
|
| 15 |
+
n_fmap: 30
|
| 16 |
+
C_in: 3
|
| 17 |
+
n_feat: 128 ## Doesn't change
|
| 18 |
+
lambda_: 100
|
| 19 |
+
use_resolvent: True ## Don't forget to change lambda values if you want to use the resolvent mask (around 100)
|
| 20 |
+
resolvent_gamma: 0.5
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
zo_shot: 150
|
config/matching/sds.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu: 0
|
| 2 |
+
cache: "cache/fmaps"
|
| 3 |
+
|
| 4 |
+
sds: True
|
| 5 |
+
refine: True
|
| 6 |
+
optimize: True
|
| 7 |
+
oriented: True
|
| 8 |
+
|
| 9 |
+
sds_conf:
|
| 10 |
+
train_dir: pretrained
|
| 11 |
+
diff_num_exp: 53
|
| 12 |
+
zoomout: 40
|
| 13 |
+
|
| 14 |
+
deepfeat_conf:
|
| 15 |
+
fmap:
|
| 16 |
+
feat: "xyz"
|
| 17 |
+
n_fmap: 30
|
| 18 |
+
C_in: 3
|
| 19 |
+
n_feat: 128 ## Doesn't change
|
| 20 |
+
lambda_: 1e-1
|
| 21 |
+
use_diff: True
|
| 22 |
+
diffusion:
|
| 23 |
+
abs: True
|
| 24 |
+
normalize: False
|
| 25 |
+
time: 1
|
| 26 |
+
batch_sds: 32
|
| 27 |
+
batch_mask: 200
|
| 28 |
+
|
| 29 |
+
opt:
|
| 30 |
+
n_loop: 300
|
| 31 |
+
soft_p2p: False
|
| 32 |
+
|
| 33 |
+
loss:
|
| 34 |
+
sds: 1.
|
| 35 |
+
proper: 1.
|
config/matching/sds_dt4d.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu: 0
|
| 2 |
+
cache: "cache/fmaps"
|
| 3 |
+
|
| 4 |
+
sds: True
|
| 5 |
+
refine: True
|
| 6 |
+
optimize: True
|
| 7 |
+
rotate: True
|
| 8 |
+
|
| 9 |
+
sds_conf:
|
| 10 |
+
train_dir: fmap_exps
|
| 11 |
+
diff_num_exp: 53
|
| 12 |
+
zoomout: 40
|
| 13 |
+
|
| 14 |
+
deepfeat_conf:
|
| 15 |
+
fmap:
|
| 16 |
+
feat: "xyz"
|
| 17 |
+
n_fmap: 30
|
| 18 |
+
C_in: 3
|
| 19 |
+
n_feat: 128 ## Doesn't change
|
| 20 |
+
lambda_: 1e-3
|
| 21 |
+
use_diff: True
|
| 22 |
+
diffusion:
|
| 23 |
+
abs: True
|
| 24 |
+
normalize: False
|
| 25 |
+
time: 1
|
| 26 |
+
batch_sds: 32
|
| 27 |
+
batch_mask: 200
|
| 28 |
+
|
| 29 |
+
opt:
|
| 30 |
+
n_loop: 1000
|
| 31 |
+
soft_p2p: False
|
| 32 |
+
|
| 33 |
+
loss:
|
| 34 |
+
sds: 0.1
|
| 35 |
+
proper: 1.
|
config/matching/sds_slow.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu: 0
|
| 2 |
+
cache: "cache/fmaps"
|
| 3 |
+
|
| 4 |
+
sds: True
|
| 5 |
+
refine: True
|
| 6 |
+
optimize: True
|
| 7 |
+
oriented: True
|
| 8 |
+
|
| 9 |
+
diff_model:
|
| 10 |
+
train_dir: pretrained
|
| 11 |
+
|
| 12 |
+
# diff_model:
|
| 13 |
+
# train_dir: fmap_exps
|
| 14 |
+
# diff_num_exp: 53
|
| 15 |
+
|
| 16 |
+
sds_conf:
|
| 17 |
+
zoomout: 40
|
| 18 |
+
|
| 19 |
+
deepfeat_conf:
|
| 20 |
+
fmap:
|
| 21 |
+
feat: "xyz"
|
| 22 |
+
n_fmap: 30
|
| 23 |
+
C_in: 3
|
| 24 |
+
n_feat: 128 ## Doesn't change
|
| 25 |
+
lambda_: 1e-3
|
| 26 |
+
use_diff: True
|
| 27 |
+
diffusion:
|
| 28 |
+
abs: True
|
| 29 |
+
normalize: False
|
| 30 |
+
time: 1
|
| 31 |
+
batch_sds: 32
|
| 32 |
+
batch_mask: 200
|
| 33 |
+
|
| 34 |
+
opt:
|
| 35 |
+
n_loop: 1000
|
| 36 |
+
soft_p2p: False
|
| 37 |
+
|
| 38 |
+
loss:
|
| 39 |
+
sds: 0.1
|
| 40 |
+
proper: 1.
|
config/matching/sds_smal.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu: 0
|
| 2 |
+
cache: "cache/fmaps"
|
| 3 |
+
|
| 4 |
+
sds: True
|
| 5 |
+
refine: True
|
| 6 |
+
optimize: True
|
| 7 |
+
oriented: True
|
| 8 |
+
|
| 9 |
+
sds_conf:
|
| 10 |
+
train_dir: pretrained
|
| 11 |
+
diff_num_exp: 53
|
| 12 |
+
zoomout: 40
|
| 13 |
+
|
| 14 |
+
deepfeat_conf:
|
| 15 |
+
fmap:
|
| 16 |
+
feat: "xyz"
|
| 17 |
+
n_fmap: 30
|
| 18 |
+
C_in: 3
|
| 19 |
+
n_feat: 128 ## Doesn't change
|
| 20 |
+
lambda_: 1e-1
|
| 21 |
+
use_diff: True
|
| 22 |
+
diffusion:
|
| 23 |
+
abs: True
|
| 24 |
+
normalize: False
|
| 25 |
+
time: 1
|
| 26 |
+
batch_sds: 32
|
| 27 |
+
batch_mask: 200
|
| 28 |
+
|
| 29 |
+
opt:
|
| 30 |
+
n_loop: 1000
|
| 31 |
+
soft_p2p: False
|
| 32 |
+
|
| 33 |
+
loss:
|
| 34 |
+
sds: 0.1
|
| 35 |
+
proper: 1.
|
config/matching/snk.yaml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gpu: 0
|
| 2 |
+
cache: "cache/fmaps"
|
| 3 |
+
|
| 4 |
+
snk: True
|
| 5 |
+
refine: True
|
| 6 |
+
optimize: True
|
| 7 |
+
|
| 8 |
+
deepfeat_conf:
|
| 9 |
+
fmap:
|
| 10 |
+
feat: "xyz"
|
| 11 |
+
n_fmap: 30
|
| 12 |
+
C_in: 3
|
| 13 |
+
n_feat: 128 ## Doesn't change
|
| 14 |
+
lambda_: 100
|
| 15 |
+
use_resolvent: True ## Don't forget to change lambda values if you want to use the resolvent mask (around 100)
|
| 16 |
+
resolvent_gamma: 0.5
|
| 17 |
+
|
| 18 |
+
opt:
|
| 19 |
+
n_loop: 1000
|
| 20 |
+
soft_p2p: True
|
| 21 |
+
|
| 22 |
+
loss:
|
| 23 |
+
bij: 1.
|
| 24 |
+
ortho: 1.
|
| 25 |
+
cycle: 1
|
| 26 |
+
mse_rec: 1
|
| 27 |
+
prism_rec: 1
|
diffu_models/basis_dataset.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
|
| 7 |
+
N_POSES = 21
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AMASSDataset(torch.utils.data.Dataset):
|
| 11 |
+
def __init__(self, root_path, version='version0', subset='train', basis_path='base_amass.npy',
|
| 12 |
+
sample_interval=None, num_coeffs=100, return_shape=False,
|
| 13 |
+
normalize=True, min_max=False):
|
| 14 |
+
|
| 15 |
+
self.root_path = root_path
|
| 16 |
+
self.version = version
|
| 17 |
+
assert subset in ['train', 'valid', 'test']
|
| 18 |
+
self.subset = subset
|
| 19 |
+
self.sample_interval = sample_interval
|
| 20 |
+
self.return_shape = return_shape
|
| 21 |
+
self.normalize = normalize
|
| 22 |
+
self.min_max = min_max
|
| 23 |
+
self.num_coeffs = num_coeffs
|
| 24 |
+
self.poses, self.shapes = self.read_data()
|
| 25 |
+
|
| 26 |
+
if self.sample_interval:
|
| 27 |
+
self._sample(sample_interval)
|
| 28 |
+
if self.normalize:
|
| 29 |
+
if self.min_max:
|
| 30 |
+
self.min_poses, self.max_poses, self.min_shapes, self.max_shapes = self.Normalize()
|
| 31 |
+
else:
|
| 32 |
+
self.mean_poses, self.std_poses, self.mean_shapes, self.std_shapes = self.Normalize()
|
| 33 |
+
|
| 34 |
+
self.real_data_len = len(self.poses)
|
| 35 |
+
|
| 36 |
+
def __getitem__(self, idx):
|
| 37 |
+
"""
|
| 38 |
+
Return:
|
| 39 |
+
[21, 3] or [21, 6] for poses including body and root orient
|
| 40 |
+
[10] for shapes (betas) [Optimal]
|
| 41 |
+
"""
|
| 42 |
+
data_poses = self.poses[idx % self.real_data_len]
|
| 43 |
+
#coeffs = data_poses}
|
| 44 |
+
if self.return_shape:
|
| 45 |
+
return data_poses, self.shapes[idx % self.real_data_len]
|
| 46 |
+
return data_poses
|
| 47 |
+
|
| 48 |
+
def __len__(self, ):
|
| 49 |
+
return len(self.poses)
|
| 50 |
+
|
| 51 |
+
def _sample(self, sample_interval):
|
| 52 |
+
print(f'Class AMASSDataset({self.subset}): sample dataset every {sample_interval} frame')
|
| 53 |
+
self.poses = self.poses[::sample_interval]
|
| 54 |
+
|
| 55 |
+
def read_data(self):
|
| 56 |
+
data_path = os.path.join(self.root_path, self.subset)
|
| 57 |
+
# root_orient = torch.load(os.path.join(data_path, 'root_orient.pt'))
|
| 58 |
+
coeffs = torch.load(os.path.join(data_path, 'train_coeffs.pt'))
|
| 59 |
+
shapes = torch.load(os.path.join(data_path, 'betas.pt')) if self.return_shape else None
|
| 60 |
+
# poses = torch.cat([root_orient, pose_body], dim=1)
|
| 61 |
+
data_len = len(coeffs)
|
| 62 |
+
if self.num_coeffs < 300:
|
| 63 |
+
coeffs = coeffs[:, -self.num_coeffs:]
|
| 64 |
+
|
| 65 |
+
return coeffs, shapes
|
| 66 |
+
|
| 67 |
+
def Normalize(self):
|
| 68 |
+
# Use train dataset for normalize computing, Z_score or min-max Normalize
|
| 69 |
+
if self.min_max:
|
| 70 |
+
normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize1.pt')
|
| 71 |
+
else:
|
| 72 |
+
normalize_path = os.path.join(self.root_path, 'train', 'coeffs_' + str(self.num_coeffs) + '_normalize2.pt')
|
| 73 |
+
|
| 74 |
+
if os.path.exists(normalize_path):
|
| 75 |
+
normalize_params = torch.load(normalize_path)
|
| 76 |
+
if self.min_max:
|
| 77 |
+
min_poses, max_poses, min_shapes, max_shapes = (
|
| 78 |
+
normalize_params['min_poses'],
|
| 79 |
+
normalize_params['max_poses'],
|
| 80 |
+
normalize_params['min_shapes'],
|
| 81 |
+
normalize_params['max_shapes']
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
mean_poses, std_poses, mean_shapes, std_shapes = (
|
| 85 |
+
normalize_params['mean_poses'],
|
| 86 |
+
normalize_params['std_poses'],
|
| 87 |
+
normalize_params['mean_shapes'],
|
| 88 |
+
normalize_params['std_shapes']
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
if self.min_max:
|
| 92 |
+
min_poses = torch.min(self.poses, dim=0)[0]
|
| 93 |
+
max_poses = torch.max(self.poses, dim=0)[0]
|
| 94 |
+
|
| 95 |
+
min_shapes = torch.min(self.shapes, dim=0)[0] if self.return_shape else None
|
| 96 |
+
max_shapes = torch.max(self.shapes, dim=0)[0] if self.return_shape else None
|
| 97 |
+
|
| 98 |
+
torch.save({
|
| 99 |
+
'min_poses': min_poses,
|
| 100 |
+
'max_poses': max_poses,
|
| 101 |
+
'min_shapes': min_shapes,
|
| 102 |
+
'max_shapes': max_shapes
|
| 103 |
+
}, normalize_path)
|
| 104 |
+
else:
|
| 105 |
+
mean_poses = torch.mean(self.poses, dim=0)
|
| 106 |
+
std_poses = torch.std(self.poses, dim=0)
|
| 107 |
+
|
| 108 |
+
mean_shapes = torch.mean(self.shapes, dim=0) if self.return_shape else None
|
| 109 |
+
std_shapes = torch.std(self.shapes, dim=0) if self.return_shape else None
|
| 110 |
+
|
| 111 |
+
torch.save({
|
| 112 |
+
'mean_poses': mean_poses,
|
| 113 |
+
'std_poses': std_poses,
|
| 114 |
+
'mean_shapes': mean_shapes,
|
| 115 |
+
'std_shapes': std_shapes
|
| 116 |
+
}, normalize_path)
|
| 117 |
+
|
| 118 |
+
if self.min_max:
|
| 119 |
+
self.poses = 2 * (self.poses - min_poses) / (max_poses - min_poses) - 1
|
| 120 |
+
if self.return_shape:
|
| 121 |
+
self.shapes = 2 * (self.shapes - min_shapes) / (max_shapes - min_shapes) - 1
|
| 122 |
+
return min_poses, max_poses, min_shapes, max_shapes
|
| 123 |
+
|
| 124 |
+
else:
|
| 125 |
+
self.poses = (self.poses - mean_poses) / std_poses
|
| 126 |
+
if self.return_shape:
|
| 127 |
+
self.shapes = (self.shapes - mean_shapes) / std_shapes
|
| 128 |
+
return mean_poses, std_poses, mean_shapes, std_shapes
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def Denormalize(self, poses, shapes=None):
|
| 132 |
+
assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim]
|
| 133 |
+
|
| 134 |
+
if self.min_max:
|
| 135 |
+
min_poses = self.min_poses.view(1, -1).to(poses.device)
|
| 136 |
+
max_poses = self.max_poses.view(1, -1).to(poses.device)
|
| 137 |
+
|
| 138 |
+
if len(poses.shape) == 3: # [t, b, data_dim]
|
| 139 |
+
min_poses = min_poses.unsqueeze(0)
|
| 140 |
+
max_poses = max_poses.unsqueeze(0)
|
| 141 |
+
|
| 142 |
+
normalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses)
|
| 143 |
+
|
| 144 |
+
if shapes is not None and self.min_shapes is not None:
|
| 145 |
+
min_shapes = self.min_shapes.view(1, -1).to(shapes.device)
|
| 146 |
+
max_shapes = self.max_shapes.view(1, -1).to(shapes.device)
|
| 147 |
+
|
| 148 |
+
if len(shapes.shape) == 3:
|
| 149 |
+
min_shapes = min_shapes.unsqueeze(0)
|
| 150 |
+
max_shapes = max_shapes.unsqueeze(0)
|
| 151 |
+
|
| 152 |
+
normalized_shapes = 0.5 * ((shapes + 1) * (max_shapes - min_shapes) + 2 * min_shapes)
|
| 153 |
+
return normalized_poses, normalized_shapes
|
| 154 |
+
else:
|
| 155 |
+
return normalized_poses
|
| 156 |
+
else:
|
| 157 |
+
mean_poses = self.mean_poses.view(1, -1).to(poses.device)
|
| 158 |
+
std_poses = self.std_poses.view(1, -1).to(poses.device)
|
| 159 |
+
|
| 160 |
+
if len(poses.shape) == 3: # [t, b, data_dim]
|
| 161 |
+
mean_poses = mean_poses.unsqueeze(0)
|
| 162 |
+
std_poses = std_poses.unsqueeze(0)
|
| 163 |
+
|
| 164 |
+
normalized_poses = poses * std_poses + mean_poses
|
| 165 |
+
|
| 166 |
+
if shapes is not None and self.mean_shapes is not None:
|
| 167 |
+
mean_shapes = self.mean_shapes.view(1, -1)
|
| 168 |
+
std_shapes = self.std_shapes.view(1, -1)
|
| 169 |
+
|
| 170 |
+
if len(shapes.shape) == 3:
|
| 171 |
+
mean_shapes = mean_shapes.unsqueeze(0)
|
| 172 |
+
std_shapes = std_shapes.unsqueeze(0)
|
| 173 |
+
|
| 174 |
+
normalized_shapes = shapes * std_shapes + mean_shapes
|
| 175 |
+
return normalized_poses, normalized_shapes
|
| 176 |
+
else:
|
| 177 |
+
return normalized_poses
|
| 178 |
+
|
| 179 |
+
def eval(self, preds):
|
| 180 |
+
pass
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class Posenormalizer:
|
| 184 |
+
def __init__(self, data_path, device='cuda:0', normalize=True, min_max=True, rot_rep=None):
|
| 185 |
+
assert rot_rep in ['rot6d', 'axis']
|
| 186 |
+
self.normalize = normalize
|
| 187 |
+
self.min_max = min_max
|
| 188 |
+
self.rot_rep = rot_rep
|
| 189 |
+
normalize_params = torch.load(os.path.join(data_path, '{}_normalize1.pt'.format(rot_rep)))
|
| 190 |
+
self.min_poses, self.max_poses = normalize_params['min_poses'].to(device), normalize_params['max_poses'].to(device)
|
| 191 |
+
normalize_params = torch.load(os.path.join(data_path, '{}_normalize2.pt'.format(rot_rep)))
|
| 192 |
+
self.mean_poses, self.std_poses = normalize_params['mean_poses'].to(device), normalize_params['std_poses'].to(device)
|
| 193 |
+
|
| 194 |
+
def offline_normalize(self, poses, from_axis=False):
|
| 195 |
+
assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim]
|
| 196 |
+
pose_shape = poses.shape
|
| 197 |
+
|
| 198 |
+
if not self.normalize:
|
| 199 |
+
return poses
|
| 200 |
+
|
| 201 |
+
if self.min_max:
|
| 202 |
+
min_poses = self.min_poses.view(1, -1)
|
| 203 |
+
max_poses = self.max_poses.view(1, -1)
|
| 204 |
+
|
| 205 |
+
if len(poses.shape) == 3: # [t, b, data_dim]
|
| 206 |
+
min_poses = min_poses.unsqueeze(0)
|
| 207 |
+
max_poses = max_poses.unsqueeze(0)
|
| 208 |
+
|
| 209 |
+
normalized_poses = 2 * (poses - min_poses) / (max_poses - min_poses) - 1
|
| 210 |
+
|
| 211 |
+
else:
|
| 212 |
+
mean_poses = self.mean_poses.view(1, -1)
|
| 213 |
+
std_poses = self.std_poses.view(1, -1)
|
| 214 |
+
|
| 215 |
+
if len(poses.shape) == 3: # [t, b, data_dim]
|
| 216 |
+
mean_poses = mean_poses.unsqueeze(0)
|
| 217 |
+
std_poses = std_poses.unsqueeze(0)
|
| 218 |
+
|
| 219 |
+
normalized_poses = (poses - mean_poses) / std_poses
|
| 220 |
+
|
| 221 |
+
return normalized_poses
|
| 222 |
+
|
| 223 |
+
def offline_denormalize(self, poses, to_axis=False):
|
| 224 |
+
assert len(poses.shape) == 2 or len(poses.shape) == 3 # [b, data_dim] or [t, b, data_dim]
|
| 225 |
+
|
| 226 |
+
if not self.normalize:
|
| 227 |
+
denormalized_poses = poses
|
| 228 |
+
else:
|
| 229 |
+
if self.min_max:
|
| 230 |
+
min_poses = self.min_poses.view(1, -1)
|
| 231 |
+
max_poses = self.max_poses.view(1, -1)
|
| 232 |
+
|
| 233 |
+
if len(poses.shape) == 3: # [t, b, data_dim]
|
| 234 |
+
min_poses = min_poses.unsqueeze(0)
|
| 235 |
+
max_poses = max_poses.unsqueeze(0)
|
| 236 |
+
|
| 237 |
+
denormalized_poses = 0.5 * ((poses + 1) * (max_poses - min_poses) + 2 * min_poses)
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
mean_poses = self.mean_poses.view(1, -1)
|
| 241 |
+
std_poses = self.std_poses.view(1, -1)
|
| 242 |
+
|
| 243 |
+
if len(poses.shape) == 3: # [t, b, data_dim]
|
| 244 |
+
mean_poses = mean_poses.unsqueeze(0)
|
| 245 |
+
std_poses = std_poses.unsqueeze(0)
|
| 246 |
+
|
| 247 |
+
denormalized_poses = poses * std_poses + mean_poses
|
| 248 |
+
|
| 249 |
+
return denormalized_poses
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# a simple eval process for completion task
|
| 253 |
+
class Evaler:
|
| 254 |
+
def __init__(self, body_model, part=None):
|
| 255 |
+
self.body_model = body_model
|
| 256 |
+
self.part = part
|
| 257 |
+
|
| 258 |
+
if self.part is not None:
|
| 259 |
+
self.joint_idx = np.array(getattr(BodyPartIndices, self.part)) + 1 # skip pelvis
|
| 260 |
+
self.vert_idx = np.array(getattr(BodySegIndices, self.part))
|
| 261 |
+
else:
|
| 262 |
+
self.joint_idx = slice(None)
|
| 263 |
+
self.vert_idx = slice(None)
|
| 264 |
+
|
| 265 |
+
def eval_bodys(self, outs, gts):
|
| 266 |
+
'''
|
| 267 |
+
:param outs: [b, j*3] axis-angle results of body poses
|
| 268 |
+
:param gts: [b, j*3] axis-angle groundtruth of body poses
|
| 269 |
+
:return: result dict for every sample
|
| 270 |
+
'''
|
| 271 |
+
sample_num = len(outs)
|
| 272 |
+
eval_result = {'mpvpe_all': [], 'mpjpe_body': []}
|
| 273 |
+
body_gt = self.body_model(pose_body=gts)
|
| 274 |
+
body_out = self.body_model(pose_body=outs)
|
| 275 |
+
|
| 276 |
+
for n in range(sample_num):
|
| 277 |
+
# MPVPE from all vertices
|
| 278 |
+
mesh_gt = body_gt.v.detach().cpu().numpy()[n, self.vert_idx]
|
| 279 |
+
mesh_out = body_out.v.detach().cpu().numpy()[n, self.vert_idx]
|
| 280 |
+
eval_result['mpvpe_all'].append(np.sqrt(np.sum((mesh_out - mesh_gt) ** 2, 1)).mean() * 1000)
|
| 281 |
+
|
| 282 |
+
joint_gt_body = body_gt.Jtr.detach().cpu().numpy()[n, self.joint_idx]
|
| 283 |
+
joint_out_body = body_out.Jtr.detach().cpu().numpy()[n, self.joint_idx]
|
| 284 |
+
|
| 285 |
+
eval_result['mpjpe_body'].append(
|
| 286 |
+
np.sqrt(np.sum((joint_out_body - joint_gt_body) ** 2, 1)).mean() * 1000)
|
| 287 |
+
|
| 288 |
+
return eval_result
|
| 289 |
+
|
| 290 |
+
def multi_eval_bodys(self, outs, gts):
|
| 291 |
+
'''
|
| 292 |
+
:param outs: [b, hypo, j*3] axis-angle results of body poses, multiple hypothesis
|
| 293 |
+
:param gts: [b, j*3] axis-angle groundtruth of body poses
|
| 294 |
+
:return: result dict
|
| 295 |
+
'''
|
| 296 |
+
hypo_num = outs.shape[1]
|
| 297 |
+
eval_result = {f'mpvpe_all': [], f'mpjpe_body': []}
|
| 298 |
+
for hypo in range(hypo_num):
|
| 299 |
+
result = self.eval_bodys(outs[:, hypo], gts)
|
| 300 |
+
eval_result['mpvpe_all'].append(result['mpvpe_all'])
|
| 301 |
+
eval_result['mpjpe_body'].append(result['mpjpe_body'])
|
| 302 |
+
|
| 303 |
+
eval_result['mpvpe_all'] = np.min(eval_result['mpvpe_all'], axis=0)
|
| 304 |
+
eval_result['mpjpe_body'] = np.min(eval_result['mpjpe_body'], axis=0)
|
| 305 |
+
|
| 306 |
+
return eval_result
|
| 307 |
+
|
| 308 |
+
def print_eval_result(self, eval_result):
|
| 309 |
+
print('MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
|
| 310 |
+
print('MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body']))
|
| 311 |
+
|
| 312 |
+
def print_multi_eval_result(self, eval_result, hypo_num):
|
| 313 |
+
print(f'multihypo {hypo_num} MPVPE (All): %.2f mm' % np.mean(eval_result['mpvpe_all']))
|
| 314 |
+
print(f'multihypo {hypo_num} MPJPE (Body): %.2f mm' % np.mean(eval_result['mpjpe_body']))
|
diffu_models/dit_models.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
# --------------------------------------------------------
|
| 7 |
+
# References:
|
| 8 |
+
# GLIDE: https://github.com/openai/glide-text2im
|
| 9 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
|
| 10 |
+
# --------------------------------------------------------
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import numpy as np
|
| 15 |
+
import math
|
| 16 |
+
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def modulate(x, shift, scale):
|
| 20 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
#################################################################################
|
| 24 |
+
# Embedding Layers for Timesteps and Class Labels #
|
| 25 |
+
#################################################################################
|
| 26 |
+
|
| 27 |
+
class TimestepEmbedder(nn.Module):
|
| 28 |
+
"""
|
| 29 |
+
Embeds scalar timesteps into vector representations.
|
| 30 |
+
"""
|
| 31 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.mlp = nn.Sequential(
|
| 34 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
| 35 |
+
nn.SiLU(),
|
| 36 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 37 |
+
)
|
| 38 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 42 |
+
"""
|
| 43 |
+
Create sinusoidal timestep embeddings.
|
| 44 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 45 |
+
These may be fractional.
|
| 46 |
+
:param dim: the dimension of the output.
|
| 47 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 48 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 49 |
+
"""
|
| 50 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 51 |
+
half = dim // 2
|
| 52 |
+
freqs = torch.exp(
|
| 53 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 54 |
+
).to(device=t.device)
|
| 55 |
+
args = t[:, None].float() * freqs[None]
|
| 56 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 57 |
+
if dim % 2:
|
| 58 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 59 |
+
return embedding
|
| 60 |
+
|
| 61 |
+
def forward(self, t):
|
| 62 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 63 |
+
t_emb = self.mlp(t_freq)
|
| 64 |
+
return t_emb
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LabelEmbedder(nn.Module):
|
| 68 |
+
"""
|
| 69 |
+
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
| 70 |
+
"""
|
| 71 |
+
def __init__(self, num_classes, hidden_size, dropout_prob):
|
| 72 |
+
super().__init__()
|
| 73 |
+
use_cfg_embedding = dropout_prob > 0
|
| 74 |
+
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
| 75 |
+
self.num_classes = num_classes
|
| 76 |
+
self.dropout_prob = dropout_prob
|
| 77 |
+
|
| 78 |
+
def token_drop(self, labels, force_drop_ids=None):
|
| 79 |
+
"""
|
| 80 |
+
Drops labels to enable classifier-free guidance.
|
| 81 |
+
"""
|
| 82 |
+
if force_drop_ids is None:
|
| 83 |
+
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
| 84 |
+
else:
|
| 85 |
+
drop_ids = force_drop_ids == 1
|
| 86 |
+
labels = torch.where(drop_ids, self.num_classes, labels)
|
| 87 |
+
return labels
|
| 88 |
+
|
| 89 |
+
def forward(self, labels, train, force_drop_ids=None):
|
| 90 |
+
use_dropout = self.dropout_prob > 0
|
| 91 |
+
if (train and use_dropout) or (force_drop_ids is not None):
|
| 92 |
+
labels = self.token_drop(labels, force_drop_ids)
|
| 93 |
+
embeddings = self.embedding_table(labels)
|
| 94 |
+
return embeddings
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
#################################################################################
|
| 98 |
+
# Core DiT Model #
|
| 99 |
+
#################################################################################
|
| 100 |
+
|
| 101 |
+
class DiTBlock(nn.Module):
|
| 102 |
+
"""
|
| 103 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 104 |
+
"""
|
| 105 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 108 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
| 109 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 110 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 111 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 112 |
+
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 113 |
+
self.adaLN_modulation = nn.Sequential(
|
| 114 |
+
nn.SiLU(),
|
| 115 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
def forward(self, x, c):
|
| 119 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
|
| 120 |
+
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
|
| 121 |
+
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
| 122 |
+
return x
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class FinalLayer(nn.Module):
|
| 126 |
+
"""
|
| 127 |
+
The final layer of DiT.
|
| 128 |
+
"""
|
| 129 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 132 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 133 |
+
self.adaLN_modulation = nn.Sequential(
|
| 134 |
+
nn.SiLU(),
|
| 135 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def forward(self, x, c):
|
| 139 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 140 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 141 |
+
x = self.linear(x)
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class DiT(nn.Module):
|
| 146 |
+
"""
|
| 147 |
+
Diffusion model with a Transformer backbone.
|
| 148 |
+
"""
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
input_size=32,
|
| 152 |
+
patch_size=2,
|
| 153 |
+
in_channels=4,
|
| 154 |
+
hidden_size=1152,
|
| 155 |
+
depth=28,
|
| 156 |
+
num_heads=16,
|
| 157 |
+
mlp_ratio=4.0,
|
| 158 |
+
class_dropout_prob=0.1,
|
| 159 |
+
num_classes=1000,
|
| 160 |
+
learn_sigma=True,
|
| 161 |
+
conditioning=False
|
| 162 |
+
):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.learn_sigma = learn_sigma
|
| 165 |
+
self.in_channels = in_channels
|
| 166 |
+
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
| 167 |
+
self.patch_size = patch_size
|
| 168 |
+
self.num_heads = num_heads
|
| 169 |
+
|
| 170 |
+
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
|
| 171 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 172 |
+
self.conditioning = conditioning
|
| 173 |
+
if conditioning:
|
| 174 |
+
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
|
| 175 |
+
num_patches = self.x_embedder.num_patches
|
| 176 |
+
# Will use fixed sin-cos embedding:
|
| 177 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
|
| 178 |
+
|
| 179 |
+
self.blocks = nn.ModuleList([
|
| 180 |
+
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
|
| 181 |
+
])
|
| 182 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 183 |
+
self.initialize_weights()
|
| 184 |
+
|
| 185 |
+
def initialize_weights(self):
|
| 186 |
+
# Initialize transformer layers:
|
| 187 |
+
def _basic_init(module):
|
| 188 |
+
if isinstance(module, nn.Linear):
|
| 189 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 190 |
+
if module.bias is not None:
|
| 191 |
+
nn.init.constant_(module.bias, 0)
|
| 192 |
+
self.apply(_basic_init)
|
| 193 |
+
|
| 194 |
+
# Initialize (and freeze) pos_embed by sin-cos embedding:
|
| 195 |
+
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
|
| 196 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 197 |
+
|
| 198 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 199 |
+
w = self.x_embedder.proj.weight.data
|
| 200 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 201 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 202 |
+
|
| 203 |
+
# Initialize label embedding table:
|
| 204 |
+
if self.conditioning:
|
| 205 |
+
nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
|
| 206 |
+
|
| 207 |
+
# Initialize timestep embedding MLP:
|
| 208 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 209 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 210 |
+
|
| 211 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 212 |
+
for block in self.blocks:
|
| 213 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
| 214 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
| 215 |
+
|
| 216 |
+
# Zero-out output layers:
|
| 217 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 218 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 219 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 220 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 221 |
+
|
| 222 |
+
def unpatchify(self, x):
|
| 223 |
+
"""
|
| 224 |
+
x: (N, T, patch_size**2 * C)
|
| 225 |
+
imgs: (N, H, W, C)
|
| 226 |
+
"""
|
| 227 |
+
c = self.out_channels
|
| 228 |
+
p = self.x_embedder.patch_size[0]
|
| 229 |
+
h = w = int(x.shape[1] ** 0.5)
|
| 230 |
+
assert h * w == x.shape[1]
|
| 231 |
+
|
| 232 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 233 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 234 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
|
| 235 |
+
return imgs
|
| 236 |
+
|
| 237 |
+
def forward(self, x, t, y=None):
|
| 238 |
+
"""
|
| 239 |
+
Forward pass of DiT.
|
| 240 |
+
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 241 |
+
t: (N,) tensor of diffusion timesteps
|
| 242 |
+
y: (N,) tensor of class labels
|
| 243 |
+
"""
|
| 244 |
+
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
| 245 |
+
t = self.t_embedder(t) # (N, D)
|
| 246 |
+
|
| 247 |
+
c = t # (N, D)
|
| 248 |
+
if self.conditioning:
|
| 249 |
+
y = self.y_embedder(y, self.training) # (N, D)
|
| 250 |
+
c += t
|
| 251 |
+
for block in self.blocks:
|
| 252 |
+
x = block(x, c) # (N, T, D)
|
| 253 |
+
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
|
| 254 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 255 |
+
return x
|
| 256 |
+
|
| 257 |
+
def forward_with_cfg(self, x, t, y, cfg_scale):
|
| 258 |
+
"""
|
| 259 |
+
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
|
| 260 |
+
"""
|
| 261 |
+
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
| 262 |
+
half = x[: len(x) // 2]
|
| 263 |
+
combined = torch.cat([half, half], dim=0)
|
| 264 |
+
model_out = self.forward(combined, t, y)
|
| 265 |
+
# For exact reproducibility reasons, we apply classifier-free guidance on only
|
| 266 |
+
# three channels by default. The standard approach to cfg applies it to all channels.
|
| 267 |
+
# This can be done by uncommenting the following line and commenting-out the line following that.
|
| 268 |
+
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
|
| 269 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
| 270 |
+
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
| 271 |
+
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
|
| 272 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
| 273 |
+
return torch.cat([eps, rest], dim=1)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
#################################################################################
|
| 277 |
+
# Sine/Cosine Positional Embedding Functions #
|
| 278 |
+
#################################################################################
|
| 279 |
+
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
| 280 |
+
|
| 281 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
| 282 |
+
"""
|
| 283 |
+
grid_size: int of the grid height and width
|
| 284 |
+
return:
|
| 285 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 286 |
+
"""
|
| 287 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
| 288 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
| 289 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
| 290 |
+
grid = np.stack(grid, axis=0)
|
| 291 |
+
|
| 292 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
| 293 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 294 |
+
if cls_token and extra_tokens > 0:
|
| 295 |
+
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
| 296 |
+
return pos_embed
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 300 |
+
assert embed_dim % 2 == 0
|
| 301 |
+
|
| 302 |
+
# use half of dimensions to encode grid_h
|
| 303 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
| 304 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
| 305 |
+
|
| 306 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
| 307 |
+
return emb
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 311 |
+
"""
|
| 312 |
+
embed_dim: output dimension for each position
|
| 313 |
+
pos: a list of positions to be encoded: size (M,)
|
| 314 |
+
out: (M, D)
|
| 315 |
+
"""
|
| 316 |
+
assert embed_dim % 2 == 0
|
| 317 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 318 |
+
omega /= embed_dim / 2.
|
| 319 |
+
omega = 1. / 10000**omega # (D/2,)
|
| 320 |
+
|
| 321 |
+
pos = pos.reshape(-1) # (M,)
|
| 322 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
| 323 |
+
|
| 324 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 325 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 326 |
+
|
| 327 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 328 |
+
return emb
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
#################################################################################
|
| 332 |
+
# DiT Configs #
|
| 333 |
+
#################################################################################
|
| 334 |
+
|
| 335 |
+
def DiT_XL_2(**kwargs):
|
| 336 |
+
return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
|
| 337 |
+
|
| 338 |
+
def DiT_XL_4(**kwargs):
|
| 339 |
+
return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
|
| 340 |
+
|
| 341 |
+
def DiT_XL_8(**kwargs):
|
| 342 |
+
return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
|
| 343 |
+
|
| 344 |
+
def DiT_L_2(**kwargs):
|
| 345 |
+
return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
|
| 346 |
+
|
| 347 |
+
def DiT_L_4(**kwargs):
|
| 348 |
+
return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
|
| 349 |
+
|
| 350 |
+
def DiT_L_5(**kwargs):
|
| 351 |
+
return DiT(depth=24, hidden_size=1024, patch_size=5, num_heads=16, **kwargs)
|
| 352 |
+
|
| 353 |
+
def DiT_L_8(**kwargs):
|
| 354 |
+
return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
|
| 355 |
+
|
| 356 |
+
def DiT_B_2(**kwargs):
|
| 357 |
+
return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
|
| 358 |
+
|
| 359 |
+
def DiT_B_4(**kwargs):
|
| 360 |
+
return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
|
| 361 |
+
|
| 362 |
+
def DiT_B_5(**kwargs):
|
| 363 |
+
return DiT(depth=12, hidden_size=768, patch_size=5, num_heads=12, **kwargs)
|
| 364 |
+
|
| 365 |
+
def DiT_B_8(**kwargs):
|
| 366 |
+
return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
|
| 367 |
+
|
| 368 |
+
def DiT_S_2(**kwargs):
|
| 369 |
+
return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
|
| 370 |
+
|
| 371 |
+
def DiT_S_4(**kwargs):
|
| 372 |
+
return DiT(depth=12, hidden_size=384, patch_size=5, num_heads=6, **kwargs)
|
| 373 |
+
|
| 374 |
+
def DiT_S_8(**kwargs):
|
| 375 |
+
return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
DiT_models = {
|
| 379 |
+
'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
|
| 380 |
+
'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8, 'DiT-L/5': DiT_L_5,
|
| 381 |
+
'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8, 'DiT-B/5': DiT_B_5,
|
| 382 |
+
'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
|
| 383 |
+
}
|
diffu_models/losses.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This work is licensed under a Creative Commons
|
| 4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
| 5 |
+
# You should have received a copy of the license along with this
|
| 6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
| 7 |
+
|
| 8 |
+
"""Loss functions used in the paper
|
| 9 |
+
"Elucidating the Design Space of Diffusion-Based Generative Models"."""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from edm.torch_utils import persistence
|
| 13 |
+
import pdb
|
| 14 |
+
#----------------------------------------------------------------------------
|
| 15 |
+
# Loss function corresponding to the variance preserving (VP) formulation
|
| 16 |
+
# from the paper "Score-Based Generative Modeling through Stochastic
|
| 17 |
+
# Differential Equations".
|
| 18 |
+
|
| 19 |
+
@persistence.persistent_class
|
| 20 |
+
class VPLoss:
|
| 21 |
+
def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5):
|
| 22 |
+
self.beta_d = beta_d
|
| 23 |
+
self.beta_min = beta_min
|
| 24 |
+
self.epsilon_t = epsilon_t
|
| 25 |
+
|
| 26 |
+
def noise_and_weight(self, shape, device, sds=False):
|
| 27 |
+
rnd_uniform = torch.rand([shape, 1, 1, 1], device=device)
|
| 28 |
+
if sds:
|
| 29 |
+
rnd_uniform = 0.02 + rnd_uniform*0.96 #Between O.O2 and 0.98, see https://github.com/ashawkey/stable-dreamfusion/blob/5550b91862a3af7842bb04875b7f1211e5095a63/guidance/sd_utils.py#L180
|
| 30 |
+
sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1))
|
| 31 |
+
weight = 1 / sigma ** 2
|
| 32 |
+
return sigma, weight
|
| 33 |
+
|
| 34 |
+
def __call__(self, net, x, latents, augment_pipe=None):
|
| 35 |
+
sigma, weight = self.noise_and_weight(x.shape[0], x.device)
|
| 36 |
+
n = torch.randn_like(x) * sigma
|
| 37 |
+
D_xn = net(x + n, sigma, latents)
|
| 38 |
+
loss = weight * ((D_xn - x) ** 2)
|
| 39 |
+
return loss
|
| 40 |
+
|
| 41 |
+
def sigma(self, t):
|
| 42 |
+
t = torch.as_tensor(t)
|
| 43 |
+
return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt()
|
| 44 |
+
|
| 45 |
+
#----------------------------------------------------------------------------
|
| 46 |
+
# Loss function corresponding to the variance exploding (VE) formulation
|
| 47 |
+
# from the paper "Score-Based Generative Modeling through Stochastic
|
| 48 |
+
# Differential Equations".
|
| 49 |
+
|
| 50 |
+
@persistence.persistent_class
|
| 51 |
+
class VELoss:
|
| 52 |
+
def __init__(self, sigma_min=0.02, sigma_max=100):
|
| 53 |
+
self.sigma_min = sigma_min
|
| 54 |
+
self.sigma_max = sigma_max
|
| 55 |
+
|
| 56 |
+
def noise_and_weight(self, shape, device, sds=False):
|
| 57 |
+
rnd_uniform = torch.rand([x.shape[0], 1], device=x.device)
|
| 58 |
+
sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)
|
| 59 |
+
weight = 1 / sigma ** 2
|
| 60 |
+
return sigma, weight
|
| 61 |
+
|
| 62 |
+
def __call__(self, net, x, latents, augment_pipe=None):
|
| 63 |
+
sigma, weight = self.noise_and_weight(x.shape[0], x.device)
|
| 64 |
+
n = torch.randn_like(x) * sigma
|
| 65 |
+
D_xn = net(x + n, sigma, latents)
|
| 66 |
+
loss = weight * ((D_xn - x) ** 2)
|
| 67 |
+
return loss
|
| 68 |
+
|
| 69 |
+
#----------------------------------------------------------------------------
|
| 70 |
+
# Improved loss function proposed in the paper "Elucidating the Design Space
|
| 71 |
+
# of Diffusion-Based Generative Models" (EDM).
|
| 72 |
+
|
| 73 |
+
@persistence.persistent_class
|
| 74 |
+
class EDMLoss:
|
| 75 |
+
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5):
|
| 76 |
+
self.P_mean = P_mean
|
| 77 |
+
self.P_std = P_std
|
| 78 |
+
self.sigma_data = sigma_data
|
| 79 |
+
self.sigma_min = 0.4
|
| 80 |
+
self.sigma_max = 10
|
| 81 |
+
self.rho=3
|
| 82 |
+
|
| 83 |
+
def noise_and_weight(self, shape, device, sds=False):
|
| 84 |
+
rnd_normal = torch.randn([shape, 1, 1, 1], device=device)
|
| 85 |
+
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
|
| 86 |
+
weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
|
| 87 |
+
return sigma.float(), weight.float()
|
| 88 |
+
|
| 89 |
+
def __call__(self, net, x, latents, augment_pipe=None):
|
| 90 |
+
sigma, weight = self.noise_and_weight(x.shape[0], x.device)
|
| 91 |
+
n = torch.randn_like(x) * sigma
|
| 92 |
+
D_xn = net(x + n, sigma, latents)
|
| 93 |
+
loss = weight * ((D_xn - x) ** 2)
|
| 94 |
+
return loss
|
| 95 |
+
|
| 96 |
+
#----------------------------------------------------------------------------
|
diffu_models/precond.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
#----------------------------------------------------------------------------
|
| 5 |
+
# Preconditioning corresponding to the variance exploding (VE) formulation
|
| 6 |
+
# from the paper "Score-Based Generative Modeling through Stochastic
|
| 7 |
+
# Differential Equations".
|
| 8 |
+
|
| 9 |
+
class VEPrecond(torch.nn.Module):
|
| 10 |
+
def __init__(self,
|
| 11 |
+
model,
|
| 12 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 13 |
+
use_fp16 = False, # Execute the underlying model at FP16 precision?
|
| 14 |
+
sigma_min = 0.02, # Minimum supported noise level.
|
| 15 |
+
sigma_max = 100, # Maximum supported noise level.
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.label_dim = label_dim
|
| 19 |
+
self.use_fp16 = use_fp16
|
| 20 |
+
self.sigma_min = sigma_min
|
| 21 |
+
self.sigma_max = sigma_max
|
| 22 |
+
self.model = model
|
| 23 |
+
|
| 24 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
|
| 25 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 26 |
+
x = x.to(torch.float32)
|
| 27 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
| 28 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
| 29 |
+
|
| 30 |
+
c_skip = 1
|
| 31 |
+
c_out = sigma
|
| 32 |
+
c_in = 1
|
| 33 |
+
c_noise = (0.5 * sigma).log()
|
| 34 |
+
|
| 35 |
+
if class_labels is not None:
|
| 36 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
|
| 37 |
+
else:
|
| 38 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
|
| 39 |
+
assert F_x.dtype == dtype
|
| 40 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
| 41 |
+
return D_x
|
| 42 |
+
|
| 43 |
+
def round_sigma(self, sigma):
|
| 44 |
+
return torch.as_tensor(sigma)
|
| 45 |
+
|
| 46 |
+
#----------------------------------------------------------------------------
|
| 47 |
+
# Preconditioning corresponding to improved DDPM (iDDPM) formulation from
|
| 48 |
+
# the paper "Improved Denoising Diffusion Probabilistic Models".
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class iDDPMPrecond(torch.nn.Module):
|
| 52 |
+
def __init__(self,
|
| 53 |
+
model,
|
| 54 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 55 |
+
use_fp16 = False, # Execute the underlying model at FP16 precision?
|
| 56 |
+
C_1 = 0.001, # Timestep adjustment at low noise levels.
|
| 57 |
+
C_2 = 0.008, # Timestep adjustment at high noise levels.
|
| 58 |
+
M = 1000, # Original number of timesteps in the DDPM formulation.
|
| 59 |
+
):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.label_dim = label_dim
|
| 62 |
+
self.use_fp16 = use_fp16
|
| 63 |
+
self.C_1 = C_1
|
| 64 |
+
self.C_2 = C_2
|
| 65 |
+
self.M = M
|
| 66 |
+
self.model = model
|
| 67 |
+
u = torch.zeros(M + 1)
|
| 68 |
+
for j in range(M, 0, -1): # M, ..., 1
|
| 69 |
+
u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt()
|
| 70 |
+
self.register_buffer('u', u)
|
| 71 |
+
self.sigma_min = float(u[M - 1])
|
| 72 |
+
self.sigma_max = float(u[0])
|
| 73 |
+
|
| 74 |
+
def forward(self, x, sigma, class_labels=None, lamb=None, force_fp32=False, **model_kwargs):
|
| 75 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 76 |
+
x = x.to(torch.float32)
|
| 77 |
+
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
| 78 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
| 79 |
+
|
| 80 |
+
c_skip = 1
|
| 81 |
+
c_out = -sigma
|
| 82 |
+
c_in = 1 / (sigma ** 2 + 1).sqrt()
|
| 83 |
+
c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
|
| 84 |
+
# if class_labels is not None:
|
| 85 |
+
# F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
|
| 86 |
+
# else:
|
| 87 |
+
if lamb is not None:
|
| 88 |
+
F_x = self.model((c_in * x).to(dtype), lamb, c_noise.flatten(), **model_kwargs)
|
| 89 |
+
else:
|
| 90 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
|
| 91 |
+
assert F_x.dtype == dtype
|
| 92 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
| 93 |
+
return D_x
|
| 94 |
+
|
| 95 |
+
def alpha_bar(self, j):
|
| 96 |
+
j = torch.as_tensor(j)
|
| 97 |
+
return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
|
| 98 |
+
|
| 99 |
+
def round_sigma(self, sigma, return_index=False):
|
| 100 |
+
sigma = torch.as_tensor(sigma)
|
| 101 |
+
index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2)
|
| 102 |
+
result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
|
| 103 |
+
return result.reshape(sigma.shape).to(sigma.device)
|
| 104 |
+
|
| 105 |
+
#----------------------------------------------------------------------------
|
| 106 |
+
# Improved preconditioning proposed in the paper "Elucidating the Design
|
| 107 |
+
# Space of Diffusion-Based Generative Models" (EDM).
|
| 108 |
+
|
| 109 |
+
class EDMPrecond(torch.nn.Module):
|
| 110 |
+
def __init__(self,
|
| 111 |
+
model,
|
| 112 |
+
label_dim = 0, # Number of class labels, 0 = unconditional.
|
| 113 |
+
use_fp16 = False, # Execute the underlying model at FP16 precision?
|
| 114 |
+
sigma_min = 0, # Minimum supported noise level.
|
| 115 |
+
sigma_max = float('inf'), # Maximum supported noise level.
|
| 116 |
+
sigma_data = 0.5, # Expected standard deviation of the training data.
|
| 117 |
+
):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.label_dim = label_dim
|
| 120 |
+
self.use_fp16 = use_fp16
|
| 121 |
+
self.sigma_min = sigma_min
|
| 122 |
+
self.sigma_max = sigma_max
|
| 123 |
+
self.sigma_data = sigma_data
|
| 124 |
+
self.model = model
|
| 125 |
+
|
| 126 |
+
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
|
| 127 |
+
x = x.to(torch.float32)
|
| 128 |
+
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
|
| 129 |
+
if class_labels is not None:
|
| 130 |
+
if self.label_dim == 0:
|
| 131 |
+
class_labels = None
|
| 132 |
+
else:
|
| 133 |
+
class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim)
|
| 134 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
| 135 |
+
|
| 136 |
+
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
|
| 137 |
+
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
|
| 138 |
+
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
|
| 139 |
+
c_in = c_in.to(x.device)
|
| 140 |
+
c_noise = sigma.log() / 4
|
| 141 |
+
if class_labels is not None:
|
| 142 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), c_latent=class_labels, **model_kwargs)
|
| 143 |
+
else:
|
| 144 |
+
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
|
| 145 |
+
assert F_x.dtype == dtype
|
| 146 |
+
D_x = c_skip * x + c_out * F_x.to(torch.float32)
|
| 147 |
+
return D_x
|
| 148 |
+
|
| 149 |
+
def round_sigma(self, sigma):
|
| 150 |
+
return torch.as_tensor(sigma)
|
| 151 |
+
|
| 152 |
+
#----------------------------------------------------------------------------
|
diffu_models/sds.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch.autograd import grad
|
| 3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class WarmupCosineDecayScheduler(_LRScheduler):
|
| 8 |
+
def __init__(self, optimizer, warmup_steps, total_steps, warmup_start_lr=1e-9, max_lr=1e-4, min_lr=1e-6, last_epoch=-1):
|
| 9 |
+
self.warmup_steps = warmup_steps
|
| 10 |
+
self.total_steps = total_steps
|
| 11 |
+
self.warmup_start_lr = warmup_start_lr
|
| 12 |
+
self.max_lr = max_lr
|
| 13 |
+
self.min_lr = min_lr
|
| 14 |
+
super(WarmupCosineDecayScheduler, self).__init__(optimizer, last_epoch)
|
| 15 |
+
|
| 16 |
+
def get_lr(self):
|
| 17 |
+
if self.last_epoch < self.warmup_steps:
|
| 18 |
+
# Linear warmup
|
| 19 |
+
lr = self.max_lr * self.last_epoch/self.warmup_steps + (1-self.last_epoch/self.warmup_steps) * self.warmup_start_lr
|
| 20 |
+
else:
|
| 21 |
+
# Cosine decay
|
| 22 |
+
cosine_decay = 0.5 * (1 + np.cos(torch.pi * (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)))
|
| 23 |
+
decayed = (1 - self.min_lr / self.max_lr) * cosine_decay + self.min_lr / self.max_lr
|
| 24 |
+
lr = self.max_lr * decayed
|
| 25 |
+
return [lr for _ in self.base_lrs]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def guidance_grad(pred_shape, net, scale_noise, grad_scale=1, batch_size=32, device="cpu", save_guidance_path=None):
|
| 29 |
+
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
| 30 |
+
sigma = 0.01 + torch.rand([batch_size, 1, 1, 1], device=device)*scale_noise
|
| 31 |
+
# predict the noise residual with unet, NO grad!
|
| 32 |
+
with torch.no_grad():
|
| 33 |
+
# sample noise
|
| 34 |
+
noise = torch.randn_like(pred_shape) * sigma
|
| 35 |
+
# pred noise
|
| 36 |
+
x = pred_shape+noise
|
| 37 |
+
denoised = net(x, sigma)
|
| 38 |
+
# w(t), sigma_t^2
|
| 39 |
+
grad = torch.mean(grad_scale * (pred_shape - denoised), dim=0) # / sigma**2
|
| 40 |
+
#print(sigma.item()**2, weight.item(), torch.norm(pred_shape-denoised).item())
|
| 41 |
+
#print(grad)
|
| 42 |
+
grad = torch.nan_to_num(grad)
|
| 43 |
+
|
| 44 |
+
# if save_guidance_path:
|
| 45 |
+
# with torch.no_grad():
|
| 46 |
+
# if as_latent:
|
| 47 |
+
# pred_rgb_512 = self.decode_latents(latents)
|
| 48 |
+
|
| 49 |
+
# # visualize predicted denoised image
|
| 50 |
+
# # The following block of code is equivalent to `predict_start_from_noise`...
|
| 51 |
+
# # see zero123_utils.py's version for a simpler implementation.
|
| 52 |
+
# alphas = self.scheduler.alphas.to(latents)
|
| 53 |
+
# total_timesteps = self.max_step - self.min_step + 1
|
| 54 |
+
# index = total_timesteps - t.to(latents.device) - 1
|
| 55 |
+
# b = len(noise_pred)
|
| 56 |
+
# a_t = alphas[index].reshape(b,1,1,1).to(self.device)
|
| 57 |
+
# sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
|
| 58 |
+
# sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)
|
| 59 |
+
# pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
|
| 60 |
+
# result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))
|
| 61 |
+
|
| 62 |
+
# # visualize noisier image
|
| 63 |
+
# result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))
|
| 64 |
+
|
| 65 |
+
# # TODO: also denoise all-the-way
|
| 66 |
+
|
| 67 |
+
# # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
|
| 68 |
+
# viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
|
| 69 |
+
# save_image(viz_images, save_guidance_path)
|
| 70 |
+
|
| 71 |
+
return grad, denoised
|
| 72 |
+
|
| 73 |
+
def guidance_loss(pred_shape, loss_sde, net, grad_scale=1, device="cpu", save_guidance_path=None):
|
| 74 |
+
grad = guidance_grad(pred_shape, loss_sde, net, grad_scale, device, save_guidance_path)
|
| 75 |
+
targets = (pred_shape - grad).detach()
|
| 76 |
+
loss = 0.5 * F.mse_loss(pred_shape.float(), targets, reduction='sum') / pred_shape.shape[0]
|
| 77 |
+
return loss
|
shape_data/__init__.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
|
| 7 |
+
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
|
| 8 |
+
if ROOT_DIR not in sys.path:
|
| 9 |
+
sys.path.append(ROOT_DIR)
|
| 10 |
+
|
| 11 |
+
DATA_DIRS = {
|
| 12 |
+
'faust': 'FAUST_r',
|
| 13 |
+
'faust_ori': 'FAUST_r_ori',
|
| 14 |
+
'scape': 'SCAPE_r',
|
| 15 |
+
'scape_ori': 'SCAPE_r_ori',
|
| 16 |
+
'smalr': 'SMAL_r',
|
| 17 |
+
'smalr_ori': 'SMAL_r_ori',
|
| 18 |
+
'shrec19': 'SHREC_r',
|
| 19 |
+
'shrec19_ori': 'SHREC_r_ori',
|
| 20 |
+
'dt4d': 'DT4D_r',
|
| 21 |
+
'dt4dintra': 'DT4D_r',
|
| 22 |
+
'dt4dintra_ori': 'DT4D_r_ori',
|
| 23 |
+
'dt4dinter': 'DT4D_r',
|
| 24 |
+
'dt4dinter_ori': 'DT4D_r_ori',
|
| 25 |
+
'tosca': 'TOSCA_r',
|
| 26 |
+
'tosca_ori': 'TOSCA_r',
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_data_dirs(root, name, mode):
|
| 31 |
+
prefix = osp.join(root, DATA_DIRS[name])
|
| 32 |
+
shape_dir = osp.join(prefix, 'shapes')
|
| 33 |
+
corr_dir = osp.join(prefix, 'correspondences')
|
| 34 |
+
return shape_dir, DATA_DIRS[name], corr_dir
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# def collate_default(data_list):
|
| 38 |
+
# data_dict = defaultdict(list)
|
| 39 |
+
# for pair_dict in data_list:
|
| 40 |
+
# for k, v in pair_dict.items():
|
| 41 |
+
# data_dict[k].append(v)
|
| 42 |
+
# for k in data_dict.keys():
|
| 43 |
+
# if k.startswith('fmap') or k.startswith('evals') or k.endswith('_sub'):
|
| 44 |
+
# data_dict[k] = np.stack(data_dict[k], axis=0)
|
| 45 |
+
# batch_size = len(data_list)
|
| 46 |
+
# for k, v in data_dict.items():
|
| 47 |
+
# assert len(v) == batch_size
|
| 48 |
+
|
| 49 |
+
# return data_dict
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def prepare_batch(data_dict, device):
|
| 53 |
+
for k in data_dict.keys():
|
| 54 |
+
if isinstance(data_dict[k], np.ndarray):
|
| 55 |
+
data_dict[k] = torch.from_numpy(data_dict[k]).to(device)
|
| 56 |
+
else:
|
| 57 |
+
if k.startswith('gradX') or \
|
| 58 |
+
k.startswith('gradY') or \
|
| 59 |
+
k.startswith('L'):
|
| 60 |
+
from diffusion_net.utils import sparse_np_to_torch
|
| 61 |
+
tmp_list = [sparse_np_to_torch(st).to(device) for st in data_dict[k]]
|
| 62 |
+
if len(data_dict[k]) == 1:
|
| 63 |
+
data_dict[k] = torch.stack(tmp_list, dim=0)
|
| 64 |
+
else:
|
| 65 |
+
data_dict[k] = tmp_list
|
| 66 |
+
else:
|
| 67 |
+
if isinstance(data_dict[k][0], np.ndarray):
|
| 68 |
+
tmp_list = [torch.from_numpy(st).to(device) for st in data_dict[k]]
|
| 69 |
+
if len(data_dict[k]) == 1:
|
| 70 |
+
data_dict[k] = torch.stack(tmp_list, dim=0).to(device)
|
| 71 |
+
else:
|
| 72 |
+
data_dict[k] = tmp_list
|
| 73 |
+
|
| 74 |
+
return data_dict
|
shape_data/data_utils.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scipy
|
| 2 |
+
import scipy.sparse
|
| 3 |
+
import scipy.sparse.linalg
|
| 4 |
+
from scipy.io import loadmat
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
import os.path as osp
|
| 8 |
+
import math
|
| 9 |
+
import numpy as np
|
| 10 |
+
import open3d as o3d
|
| 11 |
+
import potpourri3d as pp3d
|
| 12 |
+
import torch
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
class CorrLoader(object):
|
| 16 |
+
|
| 17 |
+
def __init__(self, root_dir, data_type='mat'):
|
| 18 |
+
self.root_dir = root_dir
|
| 19 |
+
self.data_type = data_type
|
| 20 |
+
|
| 21 |
+
def get_by_names(self, sname0, sname1):
|
| 22 |
+
if self.data_type.endswith('mat'):
|
| 23 |
+
pmap10 = self._load_mat(osp.join(self.root_dir, f'{sname0}-{sname1}.mat'))
|
| 24 |
+
return np.stack((pmap10, np.arange(len(pmap10))), axis=1)
|
| 25 |
+
else:
|
| 26 |
+
raise RuntimeError(f'Data type {self.data_type} is not supported.')
|
| 27 |
+
|
| 28 |
+
def _load_mat(self, filepath):
|
| 29 |
+
data = loadmat(filepath)
|
| 30 |
+
pmap10 = np.squeeze(np.asarray(data['pmap10'], dtype=np.int32))
|
| 31 |
+
return pmap10
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/HKS_functions.py
|
| 35 |
+
def HKS(evals, evects, time_list, scaled=False):
|
| 36 |
+
evals_s = np.asarray(evals).flatten()
|
| 37 |
+
t_list = np.asarray(time_list).flatten()
|
| 38 |
+
|
| 39 |
+
coefs = np.exp(-np.outer(t_list, evals_s))
|
| 40 |
+
weighted_evects = evects[None, :, :] * coefs[:, None, :]
|
| 41 |
+
natural_HKS = np.einsum('tnk,nk->nt', weighted_evects, evects)
|
| 42 |
+
|
| 43 |
+
if scaled:
|
| 44 |
+
inv_scaling = coefs.sum(1)
|
| 45 |
+
return (1 / inv_scaling)[None, :] * natural_HKS
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
return natural_HKS
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def lm_HKS(evals, evects, landmarks, time_list, scaled=False):
|
| 52 |
+
evals_s = np.asarray(evals).flatten()
|
| 53 |
+
t_list = np.asarray(time_list).flatten()
|
| 54 |
+
|
| 55 |
+
coefs = np.exp(-np.outer(t_list, evals_s))
|
| 56 |
+
weighted_evects = evects[None, landmarks, :] * coefs[:, None, :]
|
| 57 |
+
|
| 58 |
+
landmarks_HKS = np.einsum('tpk,nk->ptn', weighted_evects, evects)
|
| 59 |
+
|
| 60 |
+
if scaled:
|
| 61 |
+
inv_scaling = coefs.sum(1)
|
| 62 |
+
landmarks_HKS = (1 / inv_scaling)[None, :, None] * landmarks_HKS
|
| 63 |
+
|
| 64 |
+
return landmarks_HKS.reshape(-1, evects.shape[0]).T
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def auto_HKS(evals, evects, num_T, landmarks=None, scaled=True):
|
| 68 |
+
abs_ev = sorted(np.abs(evals))
|
| 69 |
+
t_list = np.geomspace(4 * np.log(10) / abs_ev[-1], 4 * np.log(10) / abs_ev[1], num_T)
|
| 70 |
+
|
| 71 |
+
if landmarks is None:
|
| 72 |
+
return HKS(abs_ev, evects, t_list, scaled=scaled)
|
| 73 |
+
else:
|
| 74 |
+
return lm_HKS(abs_ev, evects, landmarks, t_list, scaled=scaled)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# https://github.com/RobinMagnet/pyFM/blob/master/pyFM/signatures/WKS_functions.py
|
| 78 |
+
def WKS(evals, evects, energy_list, sigma, scaled=False):
|
| 79 |
+
assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
|
| 80 |
+
|
| 81 |
+
evals = np.asarray(evals).flatten()
|
| 82 |
+
indices = np.where(evals > 1e-5)[0].flatten()
|
| 83 |
+
evals = evals[indices]
|
| 84 |
+
evects = evects[:, indices]
|
| 85 |
+
|
| 86 |
+
e_list = np.asarray(energy_list)
|
| 87 |
+
coefs = np.exp(-np.square(e_list[:, None] - np.log(np.abs(evals))[None, :]) / (2 * sigma**2))
|
| 88 |
+
|
| 89 |
+
weighted_evects = evects[None, :, :] * coefs[:, None, :]
|
| 90 |
+
|
| 91 |
+
natural_WKS = np.einsum('tnk,nk->nt', weighted_evects, evects)
|
| 92 |
+
|
| 93 |
+
if scaled:
|
| 94 |
+
inv_scaling = coefs.sum(1)
|
| 95 |
+
return (1 / inv_scaling)[None, :] * natural_WKS
|
| 96 |
+
|
| 97 |
+
else:
|
| 98 |
+
return natural_WKS
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def lm_WKS(evals, evects, landmarks, energy_list, sigma, scaled=False):
|
| 102 |
+
assert sigma > 0, f"Sigma should be positive ! Given value : {sigma}"
|
| 103 |
+
|
| 104 |
+
evals = np.asarray(evals).flatten()
|
| 105 |
+
indices = np.where(evals > 1e-2)[0].flatten()
|
| 106 |
+
evals = evals[indices]
|
| 107 |
+
evects = evects[:, indices]
|
| 108 |
+
|
| 109 |
+
e_list = np.asarray(energy_list)
|
| 110 |
+
coefs = np.exp(-np.square(e_list[:, None] - np.log(np.abs(evals))[None, :]) / (2 * sigma**2))
|
| 111 |
+
weighted_evects = evects[None, landmarks, :] * coefs[:, None, :]
|
| 112 |
+
|
| 113 |
+
landmarks_WKS = np.einsum('tpk,nk->ptn', weighted_evects, evects)
|
| 114 |
+
|
| 115 |
+
if scaled:
|
| 116 |
+
inv_scaling = coefs.sum(1)
|
| 117 |
+
landmarks_WKS = ((1 / inv_scaling)[None, :, None] * landmarks_WKS)
|
| 118 |
+
|
| 119 |
+
return landmarks_WKS.reshape(-1, evects.shape[0]).T
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def auto_WKS(evals, evects, num_E, landmarks=None, scaled=True):
|
| 123 |
+
abs_ev = sorted(np.abs(evals))
|
| 124 |
+
|
| 125 |
+
e_min, e_max = np.log(abs_ev[1]), np.log(abs_ev[-1])
|
| 126 |
+
sigma = 7 * (e_max - e_min) / num_E
|
| 127 |
+
|
| 128 |
+
e_min += 2 * sigma
|
| 129 |
+
e_max -= 2 * sigma
|
| 130 |
+
|
| 131 |
+
energy_list = np.linspace(e_min, e_max, num_E)
|
| 132 |
+
|
| 133 |
+
if landmarks is None:
|
| 134 |
+
return WKS(abs_ev, evects, energy_list, sigma, scaled=scaled)
|
| 135 |
+
else:
|
| 136 |
+
return lm_WKS(abs_ev, evects, landmarks, energy_list, sigma, scaled=scaled)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def compute_hks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35):
|
| 140 |
+
feats = auto_HKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
|
| 141 |
+
feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
|
| 142 |
+
feats_norm2 = np.einsum('np,np->p', feats, np.expand_dims(mass, 1) * feats).flatten()
|
| 143 |
+
feats /= np.expand_dims(np.sqrt(feats_norm2), 0)
|
| 144 |
+
return feats.astype(np.float32)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def compute_wks(evecs, evals, mass, n_descr=100, subsample_step=5, n_eig=35):
|
| 148 |
+
feats = auto_WKS(evals[:n_eig], evecs[:, :n_eig], n_descr, scaled=True)
|
| 149 |
+
feats = feats[:, np.arange(0, feats.shape[1], subsample_step)]
|
| 150 |
+
feats_norm2 = np.einsum('np,np->p', feats, np.expand_dims(mass, 1) * feats).flatten()
|
| 151 |
+
feats /= np.expand_dims(np.sqrt(feats_norm2), 0)
|
| 152 |
+
return feats.astype(np.float32)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def compute_geodesic_distance(V, F, vindices):
|
| 156 |
+
solver = pp3d.MeshHeatMethodDistanceSolver(np.asarray(V, dtype=np.float32), np.asarray(F, dtype=np.int32))
|
| 157 |
+
dists = [solver.compute_distance(vid)[vindices] for vid in vindices]
|
| 158 |
+
dists = np.stack(dists, axis=0)
|
| 159 |
+
assert dists.ndim == 2
|
| 160 |
+
return dists.astype(np.float32)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def compute_vertex_normals(vertices, faces):
|
| 164 |
+
mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
|
| 165 |
+
mesh.compute_vertex_normals()
|
| 166 |
+
return np.asarray(mesh.vertex_normals, dtype=np.float32)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def compute_surface_area(vertices, faces):
|
| 170 |
+
mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
|
| 171 |
+
return mesh.get_surface_area()
|
| 172 |
+
|
| 173 |
+
def numpy_to_open3d_mesh(V, F):
|
| 174 |
+
# Create an empty TriangleMesh object
|
| 175 |
+
mesh = o3d.geometry.TriangleMesh()
|
| 176 |
+
# Set vertices
|
| 177 |
+
mesh.vertices = o3d.utility.Vector3dVector(V)
|
| 178 |
+
# Set triangles
|
| 179 |
+
mesh.triangles = o3d.utility.Vector3iVector(F)
|
| 180 |
+
return mesh
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def load_mesh(filepath, scale=True, return_vnormals=False):
|
| 184 |
+
if os.path.splitext(filepath)[1] == ".obj": #Avoid pre process from open3d
|
| 185 |
+
V, F = pp3d.read_mesh(filepath)
|
| 186 |
+
mesh = numpy_to_open3d_mesh(V, F)
|
| 187 |
+
else:
|
| 188 |
+
mesh = o3d.io.read_triangle_mesh(filepath)
|
| 189 |
+
|
| 190 |
+
tmat = np.identity(4, dtype=np.float32)
|
| 191 |
+
center = mesh.get_center()
|
| 192 |
+
tmat[:3, 3] = -center
|
| 193 |
+
if scale:
|
| 194 |
+
smat = np.identity(4, dtype=np.float32)
|
| 195 |
+
area = mesh.get_surface_area()
|
| 196 |
+
smat[:3, :3] = np.identity(3, dtype=np.float32) / math.sqrt(area)
|
| 197 |
+
tmat = smat @ tmat
|
| 198 |
+
mesh.transform(tmat)
|
| 199 |
+
|
| 200 |
+
vertices = np.asarray(mesh.vertices, dtype=np.float32)
|
| 201 |
+
faces = np.asarray(mesh.triangles, dtype=np.int32)
|
| 202 |
+
if return_vnormals:
|
| 203 |
+
mesh.compute_vertex_normals()
|
| 204 |
+
vnormals = np.asarray(mesh.vertex_normals, dtype=np.float32)
|
| 205 |
+
return vertices, faces, vnormals
|
| 206 |
+
else:
|
| 207 |
+
return vertices, faces
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def save_mesh(filepath, vertices, faces):
|
| 211 |
+
mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
|
| 212 |
+
o3d.io.write_triangle_mesh(filepath, mesh)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def load_geodist(filepath):
|
| 216 |
+
data = loadmat(filepath)
|
| 217 |
+
if 'geodist' in data and 'sqrt_area' in data:
|
| 218 |
+
geodist = np.asarray(data['geodist'], dtype=np.float32)
|
| 219 |
+
sqrt_area = data['sqrt_area'].toarray().flatten()[0]
|
| 220 |
+
elif 'G' in data and 'SQRarea' in data:
|
| 221 |
+
geodist = np.asarray(data['G'], dtype=np.float32)
|
| 222 |
+
sqrt_area = data['SQRarea'].flatten()[0]
|
| 223 |
+
else:
|
| 224 |
+
raise RuntimeError(f'File {filepath} does not have geodesics data.')
|
| 225 |
+
return geodist, sqrt_area
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def farthest_point_sampling(points, max_points, random_start=True):
|
| 229 |
+
import torch_cluster
|
| 230 |
+
|
| 231 |
+
if torch.is_tensor(points):
|
| 232 |
+
device = points.device
|
| 233 |
+
is_batch = points.dim() == 3
|
| 234 |
+
if not is_batch:
|
| 235 |
+
points = torch.unsqueeze(points, dim=0)
|
| 236 |
+
assert points.dim() == 3
|
| 237 |
+
|
| 238 |
+
B, N, D = points.size()
|
| 239 |
+
assert N >= max_points
|
| 240 |
+
bindices = torch.flatten(torch.unsqueeze(torch.arange(B), 1).repeat(1, N)).long().to(device)
|
| 241 |
+
points = torch.reshape(points, (B * N, D)).float()
|
| 242 |
+
sindices = torch_cluster.fps(points, bindices, ratio=float(max_points) / N, random_start=random_start)
|
| 243 |
+
if is_batch:
|
| 244 |
+
sindices = torch.reshape(sindices, (B, max_points)) - torch.unsqueeze(torch.arange(B), 1).long().to(device) * N
|
| 245 |
+
elif isinstance(points, np.ndarray):
|
| 246 |
+
device = torch.device('cpu')
|
| 247 |
+
is_batch = points.ndim == 3
|
| 248 |
+
if not is_batch:
|
| 249 |
+
points = np.expand_dims(points, axis=0)
|
| 250 |
+
assert points.ndim == 3
|
| 251 |
+
|
| 252 |
+
B, N, D = points.shape
|
| 253 |
+
assert N >= max_points
|
| 254 |
+
bindices = np.tile(np.expand_dims(np.arange(B), 1), (1, N)).flatten()
|
| 255 |
+
bindices = torch.as_tensor(bindices, device=device).long()
|
| 256 |
+
points = torch.as_tensor(np.reshape(points, (B * N, D)), device=device).float()
|
| 257 |
+
sindices = torch_cluster.fps(points, bindices, ratio=float(max_points) / N, random_start=random_start)
|
| 258 |
+
sindices = sindices.cpu().numpy()
|
| 259 |
+
if is_batch:
|
| 260 |
+
sindices = np.reshape(sindices, (B, max_points)) - np.expand_dims(np.arange(B), 1) * N
|
| 261 |
+
else:
|
| 262 |
+
raise NotImplementedError
|
| 263 |
+
return sindices
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def lstsq(A, B):
|
| 267 |
+
assert A.ndim == B.ndim == 2
|
| 268 |
+
sols = scipy.linalg.lstsq(A, B)[0]
|
| 269 |
+
return sols
|
| 270 |
+
|
shape_data/dt4dinter.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
|
| 9 |
+
if ROOT_DIR not in sys.path:
|
| 10 |
+
sys.path.append(ROOT_DIR)
|
| 11 |
+
|
| 12 |
+
from .dt4dintra import IGNORED_CATEGORIES
|
| 13 |
+
from .dt4dintra import ShapeDataset
|
| 14 |
+
from .faust import ShapePairDataset as FaustShapePairDataset
|
| 15 |
+
from utils.mesh import list_files
|
| 16 |
+
|
| 17 |
+
#IGNORED_CATEGORIES = ["drake", "mannequin", "ninja", "prisoner", "zlorp", "pumpkinhulk"]
|
| 18 |
+
IGNORED_CATEGORIES = ["pumpkinhulk"]
|
| 19 |
+
class ShapePairDataset(FaustShapePairDataset):
|
| 20 |
+
|
| 21 |
+
def _init(self):
|
| 22 |
+
self.name_id_map = self.shape_data.get_name_id_map()
|
| 23 |
+
categories = defaultdict(list)
|
| 24 |
+
for sname in self.name_id_map.keys():
|
| 25 |
+
categories[sname.split('/')[0]].append(sname)
|
| 26 |
+
self.pair_indices = list()
|
| 27 |
+
for filename in list_files(osp.join(self.corr_dir, 'cross_category_corres'), '*.vts', alphanum_sort=False):
|
| 28 |
+
cname0, cname1 = filename[:-4].split('_')
|
| 29 |
+
if cname0 in IGNORED_CATEGORIES or cname1 in IGNORED_CATEGORIES:
|
| 30 |
+
continue
|
| 31 |
+
for sname0 in categories[cname0]:
|
| 32 |
+
for sname1 in categories[cname1]:
|
| 33 |
+
self.pair_indices.append((self.name_id_map[sname0], self.name_id_map[sname1]))
|
| 34 |
+
|
| 35 |
+
def _load_corr_gt(self, sdict0, sdict1):
|
| 36 |
+
sname0 = sdict0['name']
|
| 37 |
+
sname1 = sdict1['name']
|
| 38 |
+
cname0 = sname0.split('/')[0]
|
| 39 |
+
cname1 = sname1.split('/')[0]
|
| 40 |
+
assert cname0 != cname1
|
| 41 |
+
lmk01 = self._load_corr_file(f'cross_category_corres/{cname0}_{cname1}')
|
| 42 |
+
corr0 = self._load_corr_file(sname0)
|
| 43 |
+
corr1 = self._load_corr_file(sname1)
|
| 44 |
+
corr_gt = np.stack((corr0, corr1[lmk01]), axis=1)
|
| 45 |
+
return corr_gt
|
| 46 |
+
|
| 47 |
+
def _load_corr_file(self, sname):
|
| 48 |
+
corr_path = osp.join(self.corr_dir, f'{sname}.vts')
|
| 49 |
+
corr = np.loadtxt(corr_path, dtype=np.int32)
|
| 50 |
+
return corr - 1
|
shape_data/dt4dintra.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import itertools
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
|
| 9 |
+
if ROOT_DIR not in sys.path:
|
| 10 |
+
sys.path.append(ROOT_DIR)
|
| 11 |
+
|
| 12 |
+
from .faust import ShapeDataset as FaustShapeDataset
|
| 13 |
+
from .faust import ShapePairDataset as FaustShapePairDataset
|
| 14 |
+
from utils.utils_legacy import read_lines
|
| 15 |
+
IGNORED_CATEGORIES = ['pumpkinhulk']
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ShapeDataset(FaustShapeDataset):
|
| 19 |
+
TRAIN_IDX = None
|
| 20 |
+
TEST_IDX = None
|
| 21 |
+
NAME = "DT4D"
|
| 22 |
+
|
| 23 |
+
def _get_file_list(self):
|
| 24 |
+
if self.mode.startswith('train'):
|
| 25 |
+
file_list = read_lines(osp.join(self.shape_dir, '..', 'train.txt'))
|
| 26 |
+
elif self.mode.startswith('test'):
|
| 27 |
+
file_list = read_lines(osp.join(self.shape_dir, '..', 'test.txt'))
|
| 28 |
+
else:
|
| 29 |
+
raise RuntimeError(f'Mode {self.mode} is not supported.')
|
| 30 |
+
shape_list = [fn + '.ply' for fn in file_list]
|
| 31 |
+
return shape_list
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ShapePairDataset(FaustShapePairDataset):
|
| 35 |
+
|
| 36 |
+
def _init(self):
|
| 37 |
+
self.name_id_map = self.shape_data.get_name_id_map()
|
| 38 |
+
categories = defaultdict(list)
|
| 39 |
+
for sname in self.name_id_map.keys():
|
| 40 |
+
categories[sname.split('/')[0]].append(sname)
|
| 41 |
+
self.pair_indices = list()
|
| 42 |
+
for cname, clist in categories.items():
|
| 43 |
+
if cname in IGNORED_CATEGORIES:
|
| 44 |
+
continue
|
| 45 |
+
for pname in itertools.combinations(clist, 2):
|
| 46 |
+
self.pair_indices.append((self.name_id_map[pname[0]], self.name_id_map[pname[1]]))
|
| 47 |
+
|
| 48 |
+
def _load_corr_gt(self, sdict0, sdict1):
|
| 49 |
+
corr0 = self._load_corr_file(sdict0['name'])
|
| 50 |
+
corr1 = self._load_corr_file(sdict1['name'])
|
| 51 |
+
corr_gt = np.stack((corr0, corr1), axis=1)
|
| 52 |
+
return corr_gt
|
| 53 |
+
|
| 54 |
+
def _load_corr_file(self, sname):
|
| 55 |
+
corr_path = osp.join(self.corr_dir, f'{sname}.vts')
|
| 56 |
+
corr = np.loadtxt(corr_path, dtype=np.int32)
|
| 57 |
+
return corr - 1
|
shape_data/faust.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import sys
|
| 3 |
+
import itertools
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import potpourri3d as pp3d
|
| 10 |
+
import open3d as o3d
|
| 11 |
+
from utils.geometry import get_operators, load_operators
|
| 12 |
+
from utils.surfaces import Surface
|
| 13 |
+
from utils.utils_func import may_create_folder
|
| 14 |
+
from utils.mesh import find_mesh_files
|
| 15 |
+
|
| 16 |
+
# def opt_rot_points(pts_1, pts_2, device="cuda:0"):
|
| 17 |
+
# center_1 = pts_1.mean(dim=0)
|
| 18 |
+
# pts_c1 = pts_1 - center_1
|
| 19 |
+
# center_2 = pts_2.mean(dim=0)
|
| 20 |
+
# pts_c2 = pts_2 - center_2
|
| 21 |
+
# to_sum = pts_c1[:, :, None] * pts_c2[:, None, :]
|
| 22 |
+
# A = pts_c1.T @ pts_c2
|
| 23 |
+
# #A = to_sum.sum(axis=0)
|
| 24 |
+
# u, _, v = torch.linalg.svd(A)
|
| 25 |
+
# a = torch.Tensor([[1, 0, 0], [0, 1, 0], [0, 0, torch.sign(torch.linalg.det(A))]]).float().to(device)
|
| 26 |
+
# O = u @ a @ v
|
| 27 |
+
# return O.T
|
| 28 |
+
|
| 29 |
+
def opt_rot_points(pts_1, pts_2):
|
| 30 |
+
center_1 = pts_1.mean(axis=0)
|
| 31 |
+
pts_c1 = pts_1 - center_1
|
| 32 |
+
center_2 = pts_2.mean(axis=0)
|
| 33 |
+
pts_c2 = pts_2 - center_2
|
| 34 |
+
|
| 35 |
+
A = np.dot(pts_c1.T, pts_c2)
|
| 36 |
+
u, _, v = np.linalg.svd(A)
|
| 37 |
+
a = np.array([[1, 0, 0], [0, 1, 0], [0, 0, np.sign(np.linalg.det(A))]])
|
| 38 |
+
O = u @ a @ v
|
| 39 |
+
return O.T
|
| 40 |
+
|
| 41 |
+
def compute_vertex_normals(vertices, faces):
|
| 42 |
+
mesh = o3d.geometry.TriangleMesh(o3d.utility.Vector3dVector(vertices), o3d.utility.Vector3iVector(faces))
|
| 43 |
+
mesh.compute_vertex_normals()
|
| 44 |
+
return np.asarray(mesh.vertex_normals, dtype=np.float32)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def numpy_to_open3d_mesh(V, F):
|
| 48 |
+
# Create an empty TriangleMesh object
|
| 49 |
+
mesh = o3d.geometry.TriangleMesh()
|
| 50 |
+
# Set vertices
|
| 51 |
+
mesh.vertices = o3d.utility.Vector3dVector(V)
|
| 52 |
+
# Set triangles
|
| 53 |
+
mesh.triangles = o3d.utility.Vector3iVector(F)
|
| 54 |
+
return mesh
|
| 55 |
+
|
| 56 |
+
def open_mesh(path):
|
| 57 |
+
"""
|
| 58 |
+
Tries to open a mesh.
|
| 59 |
+
If it fails, try .ply, .obj, and .off alternatives.
|
| 60 |
+
|
| 61 |
+
Parameters
|
| 62 |
+
----------
|
| 63 |
+
path : str
|
| 64 |
+
Path of the mesh
|
| 65 |
+
Returns
|
| 66 |
+
-------
|
| 67 |
+
mesh or None
|
| 68 |
+
Loaded mesh (V, F format) if successful, else None
|
| 69 |
+
"""
|
| 70 |
+
p = Path(path)
|
| 71 |
+
base, ext = p.with_suffix(""), p.suffix
|
| 72 |
+
tried_exts = [ext, ".ply", ".obj", ".off"]
|
| 73 |
+
for e in tried_exts:
|
| 74 |
+
path = base.with_suffix(e)
|
| 75 |
+
if Path.exists(path):
|
| 76 |
+
try:
|
| 77 |
+
temp = pp3d.read_mesh(str(path))
|
| 78 |
+
return temp
|
| 79 |
+
except Exception as err:
|
| 80 |
+
print(f"Failed loading {path}: {err}")
|
| 81 |
+
return None
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
KEYS = ['vertices', 'faces', 'frames', 'mass', 'L', 'evals', 'evecs', 'gradX', 'gradY', 'hks', 'wks', 'idx', 'name']
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class ShapeDataset(Dataset):
|
| 88 |
+
TRAIN_IDX = np.arange(0, 80)
|
| 89 |
+
TEST_IDX = np.arange(80, 100)
|
| 90 |
+
NAME = "FAUST"
|
| 91 |
+
def __init__(self,
|
| 92 |
+
shape_dir,
|
| 93 |
+
cache_dir,
|
| 94 |
+
mode,
|
| 95 |
+
oriented=False,
|
| 96 |
+
rot_auto=False,
|
| 97 |
+
num_eigenbasis=256,
|
| 98 |
+
laplacian_type='mesh',
|
| 99 |
+
feature_type=None,
|
| 100 |
+
**kwargs):
|
| 101 |
+
super().__init__()
|
| 102 |
+
|
| 103 |
+
self.shape_dir = shape_dir
|
| 104 |
+
self.cache_dir = cache_dir
|
| 105 |
+
self.mode = mode
|
| 106 |
+
self.oriented = oriented
|
| 107 |
+
if self.oriented:
|
| 108 |
+
self.NAME = self.NAME + "_ori"
|
| 109 |
+
self.num_eigenbasis = num_eigenbasis
|
| 110 |
+
self.laplacian_type = laplacian_type
|
| 111 |
+
self.feature_type = feature_type
|
| 112 |
+
for k, w in kwargs.items():
|
| 113 |
+
setattr(self, k, w)
|
| 114 |
+
|
| 115 |
+
print(f'Loading {mode} data from {shape_dir}')
|
| 116 |
+
self.shape_list = self._get_file_list()
|
| 117 |
+
self._prepare()
|
| 118 |
+
|
| 119 |
+
self.randg = np.random.RandomState(0)
|
| 120 |
+
|
| 121 |
+
def _get_file_list(self):
|
| 122 |
+
path_list = find_mesh_files(Path(self.shape_dir), alphanum_sort=True)
|
| 123 |
+
file_list = [f.name for f in path_list]
|
| 124 |
+
if self.mode.startswith('train'):
|
| 125 |
+
assert self.TRAIN_IDX is not None
|
| 126 |
+
shape_list = [file_list[idx] for idx in self.TRAIN_IDX]
|
| 127 |
+
elif self.mode.startswith('test'):
|
| 128 |
+
assert self.TEST_IDX is not None
|
| 129 |
+
shape_list = [file_list[idx] for idx in self.TEST_IDX]
|
| 130 |
+
else:
|
| 131 |
+
raise RuntimeError(f'Mode {self.mode} is not supported.')
|
| 132 |
+
return shape_list
|
| 133 |
+
|
| 134 |
+
def _load_mesh(self, filepath, scale=True, return_vnormals=False):
|
| 135 |
+
V, F = open_mesh(filepath)
|
| 136 |
+
mesh = numpy_to_open3d_mesh(V, F)
|
| 137 |
+
|
| 138 |
+
tmat = np.identity(4, dtype=np.float32)
|
| 139 |
+
center = mesh.get_center()
|
| 140 |
+
tmat[:3, 3] = -center
|
| 141 |
+
if scale:
|
| 142 |
+
smat = np.identity(4, dtype=np.float32)
|
| 143 |
+
area = mesh.get_surface_area()
|
| 144 |
+
smat[:3, :3] = np.identity(3, dtype=np.float32) / math.sqrt(area)
|
| 145 |
+
tmat = smat @ tmat
|
| 146 |
+
mesh.transform(tmat)
|
| 147 |
+
|
| 148 |
+
vertices = np.asarray(mesh.vertices, dtype=np.float32)
|
| 149 |
+
faces = np.asarray(mesh.triangles, dtype=np.int32)
|
| 150 |
+
if return_vnormals:
|
| 151 |
+
mesh.compute_vertex_normals()
|
| 152 |
+
vnormals = np.asarray(mesh.vertex_normals, dtype=np.float32)
|
| 153 |
+
return vertices, faces, vnormals
|
| 154 |
+
else:
|
| 155 |
+
return vertices, faces
|
| 156 |
+
|
| 157 |
+
def _prepare(self):
|
| 158 |
+
may_create_folder(self.cache_dir)
|
| 159 |
+
for sid, sname in enumerate(self.shape_list):
|
| 160 |
+
cache_prefix = osp.join(self.cache_dir, self.NAME, f'{sname[:-4]}_{self.laplacian_type}_{self.num_eigenbasis}k')
|
| 161 |
+
cache_path = cache_prefix + '_0n.npz'
|
| 162 |
+
if not Path(cache_path).is_file():
|
| 163 |
+
vertices_np, faces_np, vertex_normals_np = self._load_mesh(osp.join(self.shape_dir, sname),
|
| 164 |
+
scale=True,
|
| 165 |
+
return_vnormals=True)
|
| 166 |
+
|
| 167 |
+
if self.laplacian_type == 'mesh':
|
| 168 |
+
_ = get_operators(torch.from_numpy(vertices_np).float(), torch.from_numpy(faces_np).long(), self.num_eigenbasis, cache_path=cache_path)
|
| 169 |
+
# elif self.laplacian_type == 'pcd':
|
| 170 |
+
# compute_operators(vertices_np, np.asarray([], dtype=np.int32), vertex_normals_np, self.num_eigenbasis,
|
| 171 |
+
# cache_path)
|
| 172 |
+
else:
|
| 173 |
+
raise RuntimeError(f'Basis type {self.laplacian_type} is not supported')
|
| 174 |
+
|
| 175 |
+
# if self.aug_noise_type is not None and self.aug_noise_type != 'naive':
|
| 176 |
+
# max_magnitude, max_levels = self.aug_noise_args[:2]
|
| 177 |
+
# randg = np.random.RandomState(sid)
|
| 178 |
+
# for nlevel in range(1, max_levels + 1):
|
| 179 |
+
# cache_path = cache_prefix + f'_{nlevel}n.npz'
|
| 180 |
+
# if Path(cache_path).is_file():
|
| 181 |
+
# continue
|
| 182 |
+
# noise_mag = max_magnitude * nlevel / max_levels
|
| 183 |
+
# noise_mat = np.clip(noise_mag * randg.randn(vertices_np.shape[0], vertices_np.shape[1]), -noise_mag,
|
| 184 |
+
# noise_mag)
|
| 185 |
+
# vertices_noise_np = vertices_np + noise_mat.astype(vertices_np.dtype)
|
| 186 |
+
# vertex_normals_noise_np = compute_vertex_normals(vertices_noise_np, faces_np)
|
| 187 |
+
|
| 188 |
+
# if self.laplacian_type == 'mesh':
|
| 189 |
+
# compute_operators(vertices_noise_np, faces_np, vertex_normals_noise_np, self.num_eigenbasis, cache_path)
|
| 190 |
+
# elif self.laplacian_type == 'pcd':
|
| 191 |
+
# compute_operators(vertices_noise_np, np.asarray([], dtype=np.int32), vertex_normals_noise_np,
|
| 192 |
+
# self.num_eigenbasis, cache_path)
|
| 193 |
+
# else:
|
| 194 |
+
# raise RuntimeError(f'Basis type {self.laplacian_type} is not supported')
|
| 195 |
+
|
| 196 |
+
def __getitem__(self, idx):
|
| 197 |
+
sname = self.shape_list[idx]
|
| 198 |
+
|
| 199 |
+
cache_prefix = osp.join(self.cache_dir, self.NAME, f'{sname[:-4]}_{self.laplacian_type}_{self.num_eigenbasis}k')
|
| 200 |
+
cache_path = cache_prefix + '_0n.npz'
|
| 201 |
+
|
| 202 |
+
assert Path(cache_path).is_file()
|
| 203 |
+
|
| 204 |
+
sdict = load_operators(cache_path)
|
| 205 |
+
sdict['idx'] = idx
|
| 206 |
+
sdict['name'] = sname[:-4]
|
| 207 |
+
|
| 208 |
+
if self.feature_type is not None:
|
| 209 |
+
sdict['feats'] = np.concatenate([sdict[ft] for ft in self.feature_type.split('_')], axis=-1)
|
| 210 |
+
vertices_np, _, _ = self._load_mesh(osp.join(self.shape_dir, sname), scale=True, return_vnormals=True)
|
| 211 |
+
sdict['vertices'] = vertices_np
|
| 212 |
+
sdict = self._centering(sdict)
|
| 213 |
+
return sdict
|
| 214 |
+
|
| 215 |
+
def __len__(self):
|
| 216 |
+
return len(self.shape_list)
|
| 217 |
+
|
| 218 |
+
def _centering(self, sdict):
|
| 219 |
+
vertices, areas = sdict['vertices'], sdict["mass"]
|
| 220 |
+
center = (vertices*areas[:, None]).sum()/areas.sum()
|
| 221 |
+
sdict['vertices'] = vertices - center
|
| 222 |
+
return sdict
|
| 223 |
+
|
| 224 |
+
def _random_noise_naive(self, sdict, randg, args):
|
| 225 |
+
vertices = sdict['vertices']
|
| 226 |
+
dtype = vertices.dtype
|
| 227 |
+
shape = vertices.shape
|
| 228 |
+
std, clip = args
|
| 229 |
+
|
| 230 |
+
noise = np.clip(std * randg.randn(*shape), -clip, clip)
|
| 231 |
+
sdict['vertices'] = vertices + noise.astype(dtype)
|
| 232 |
+
return sdict
|
| 233 |
+
|
| 234 |
+
def _random_rotation(self, sdict, randg, axes, args):
|
| 235 |
+
vertices = sdict['vertices']
|
| 236 |
+
dtype = vertices.dtype
|
| 237 |
+
|
| 238 |
+
max_x, max_y, max_z = args
|
| 239 |
+
if 'x' in axes:
|
| 240 |
+
anglex = randg.rand() * max_x * np.pi / 180.0
|
| 241 |
+
cosx = np.cos(anglex)
|
| 242 |
+
sinx = np.sin(anglex)
|
| 243 |
+
Rx = np.asarray([[1, 0, 0], [0, cosx, -sinx], [0, sinx, cosx]], dtype=dtype)
|
| 244 |
+
else:
|
| 245 |
+
Rx = np.eye(3, dtype=dtype)
|
| 246 |
+
|
| 247 |
+
if 'y' in axes:
|
| 248 |
+
angley = randg.rand() * max_y * np.pi / 180.0
|
| 249 |
+
cosy = np.cos(angley)
|
| 250 |
+
siny = np.sin(angley)
|
| 251 |
+
Ry = np.asarray([[cosy, 0, siny], [0, 1, 0], [-siny, 0, cosy]], dtype=dtype)
|
| 252 |
+
else:
|
| 253 |
+
Ry = np.eye(3, dtype=dtype)
|
| 254 |
+
|
| 255 |
+
if 'z' in axes:
|
| 256 |
+
anglez = randg.rand() * max_z * np.pi / 180.0
|
| 257 |
+
cosz = np.cos(anglez)
|
| 258 |
+
sinz = np.sin(anglez)
|
| 259 |
+
Rz = np.asarray([[cosz, -sinz, 0], [sinz, cosz, 0], [0, 0, 1]], dtype=dtype)
|
| 260 |
+
else:
|
| 261 |
+
Rz = np.eye(3, dtype=dtype)
|
| 262 |
+
|
| 263 |
+
Rxyz = randg.permutation(np.stack((Rx, Ry, Rz), axis=0))
|
| 264 |
+
R = Rxyz[2] @ Rxyz[1] @ Rxyz[0]
|
| 265 |
+
sdict['vertices'] = vertices @ R.T
|
| 266 |
+
|
| 267 |
+
return sdict
|
| 268 |
+
|
| 269 |
+
def _random_scaling(self, sdict, randg, args):
|
| 270 |
+
scale_min, scale_max = args
|
| 271 |
+
vertices = sdict['vertices']
|
| 272 |
+
scale = scale_min + randg.rand(1, 3) * (scale_max - scale_min)
|
| 273 |
+
sdict['vertices'] = vertices * scale
|
| 274 |
+
return sdict
|
| 275 |
+
|
| 276 |
+
def get_name_id_map(self):
|
| 277 |
+
return {sname[:-4]: sid for sid, sname in enumerate(self.shape_list)}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class ShapePairDataset(Dataset):
|
| 281 |
+
|
| 282 |
+
def __init__(self, corr_dir, mode, shape_data, rotate=False, **kwargs):
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.corr_dir = corr_dir
|
| 285 |
+
self.mode = mode
|
| 286 |
+
self.shape_data = shape_data
|
| 287 |
+
self.rotate = rotate
|
| 288 |
+
if self.shape_data.oriented and self.rotate:
|
| 289 |
+
self.rotate = False
|
| 290 |
+
for k, w in kwargs.items():
|
| 291 |
+
setattr(self, k, w)
|
| 292 |
+
|
| 293 |
+
self._init()
|
| 294 |
+
|
| 295 |
+
self.randg = np.random.RandomState(0)
|
| 296 |
+
|
| 297 |
+
def _init(self):
|
| 298 |
+
self.name_id_map = self.shape_data.get_name_id_map()
|
| 299 |
+
self.pair_indices = list(itertools.combinations(range(len(self.shape_data)), 2))
|
| 300 |
+
|
| 301 |
+
def __getitem__(self, idx):
|
| 302 |
+
pidx = self.pair_indices[idx]
|
| 303 |
+
sdict0 = self.shape_data[pidx[0]]
|
| 304 |
+
sdict1 = self.shape_data[pidx[1]]
|
| 305 |
+
return self._prepare_pair(sdict0, sdict1)
|
| 306 |
+
|
| 307 |
+
def get_by_names(self, sname0, sname1):
|
| 308 |
+
sdict0 = self.shape_data[self.name_id_map[sname0]]
|
| 309 |
+
sdict1 = self.shape_data[self.name_id_map[sname1]]
|
| 310 |
+
return self._prepare_pair(sdict0, sdict1)
|
| 311 |
+
|
| 312 |
+
def _prepare_pair(self, sdict0, sdict1):
|
| 313 |
+
corr_gt = self._load_corr_gt(sdict0, sdict1)
|
| 314 |
+
# for fmap_size in self.fmap_sizes:
|
| 315 |
+
# fmap01_gt = pmap_to_fmap(sdict0['evecs'][:, :fmap_size], sdict1['evecs'][:, :fmap_size], corr_gt)
|
| 316 |
+
# pdict[f'fmap01_{fmap_size}_gt'] = fmap01_gt
|
| 317 |
+
|
| 318 |
+
# for idx in range(2):
|
| 319 |
+
# indices_sel = farthest_point_sampling(pdict[f'vertices{idx}'], self.num_corrs, random_start=is_train)
|
| 320 |
+
# for k in ['vertices', 'evecs', 'feats']:
|
| 321 |
+
# kid = f'{k}{idx}'
|
| 322 |
+
# if kid in pdict:
|
| 323 |
+
# pdict[kid + '_sub'] = pdict[kid][indices_sel, :]
|
| 324 |
+
# if self.use_geodists:
|
| 325 |
+
# geodists = compute_geodesic_distance(pdict[f'vertices{idx}'], pdict[f'faces{idx}'], indices_sel)
|
| 326 |
+
# pdict[f'geodists{idx}_sub'] = geodists
|
| 327 |
+
# pdict[f'vindices{idx}_sub'] = indices_sel
|
| 328 |
+
|
| 329 |
+
# fmap_size = self.fmap_sizes[-1]
|
| 330 |
+
# corr_gt_sub = fmap_to_pmap(pdict['evecs0_sub'][:, :fmap_size], pdict['evecs1_sub'][:, :fmap_size],
|
| 331 |
+
# pdict[f'fmap01_{fmap_size}_gt'])
|
| 332 |
+
# pdict['corr_gt_sub'] = corr_gt_sub
|
| 333 |
+
|
| 334 |
+
# if is_train:
|
| 335 |
+
# fmap_size = self.fmap_sizes[0]
|
| 336 |
+
# axis = self.randg.choice([0, 1]).item()
|
| 337 |
+
# max_bases = fmap_size // 2
|
| 338 |
+
# noise_ratio = 0.5
|
| 339 |
+
# if self.randg.rand() > 0.5:
|
| 340 |
+
# pdict[f'fmap01_{fmap_size}'] = self._random_scale(pdict[f'fmap01_{fmap_size}_gt'], self.randg, axis, max_bases)
|
| 341 |
+
# else:
|
| 342 |
+
# pdict[f'fmap01_{fmap_size}'] = self._random_noise(pdict[f'fmap01_{fmap_size}_gt'], self.randg, axis, max_bases,
|
| 343 |
+
# noise_ratio)
|
| 344 |
+
# else:
|
| 345 |
+
# if self.corr_loader is not None:
|
| 346 |
+
# corr_init = self.corr_loader.get_by_names(sdict0['name'], sdict1['name'])
|
| 347 |
+
# assert corr_init.ndim == 2 and len(corr_init) == len(sdict1['vertices'])
|
| 348 |
+
# fmap_size = self.fmap_sizes[0]
|
| 349 |
+
# fmap01_init = pmap_to_fmap(sdict0['evecs'][:, :fmap_size], sdict1['evecs'][:, :fmap_size], corr_init)
|
| 350 |
+
# pdict[f'fmap01_{fmap_size}'] = fmap01_init
|
| 351 |
+
# pdict['pmap10'] = corr_init[:, 0]
|
| 352 |
+
|
| 353 |
+
vts_1, vts_2 = corr_gt[:, 0], corr_gt[:, 1]
|
| 354 |
+
shape_dict, target_dict = sdict0, sdict1
|
| 355 |
+
|
| 356 |
+
if self.rotate:
|
| 357 |
+
pts_1, pts_2 = shape_dict['vertices'][vts_1], target_dict['vertices'][vts_2]
|
| 358 |
+
rot = opt_rot_points(pts_1, pts_2).astype(np.float32)#, device="cuda")
|
| 359 |
+
target_dict['vertices'] = target_dict['vertices'] @ rot
|
| 360 |
+
target_surf = Surface(FV=[target_dict['faces'], target_dict['vertices']])
|
| 361 |
+
target_normals = torch.from_numpy(target_surf.surfel/np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().cuda()
|
| 362 |
+
|
| 363 |
+
shape_surf = Surface(FV=[shape_dict['faces'], shape_dict['vertices']])
|
| 364 |
+
map_info = (shape_dict['name'], vts_1, vts_2)
|
| 365 |
+
return shape_dict, shape_surf, target_dict, target_surf, target_normals, map_info
|
| 366 |
+
|
| 367 |
+
def _random_scale(self, fmap, randg, axis, max_bases):
|
| 368 |
+
assert max_bases > 1
|
| 369 |
+
assert axis in [0, 1]
|
| 370 |
+
num_bases = randg.randint(1, max_bases)
|
| 371 |
+
ids = randg.choice(fmap.shape[axis], num_bases, replace=False)
|
| 372 |
+
fmap_out = np.copy(fmap)
|
| 373 |
+
if axis == 0:
|
| 374 |
+
fmap_out[ids, :] *= (randg.rand(num_bases, 1) * 2 - 1)
|
| 375 |
+
else:
|
| 376 |
+
fmap_out[:, ids] *= (randg.rand(1, num_bases) * 2 - 1)
|
| 377 |
+
return fmap_out
|
| 378 |
+
|
| 379 |
+
def _random_noise(self, fmap, randg, axis, max_bases, max_ratio):
|
| 380 |
+
assert max_bases > 1
|
| 381 |
+
assert axis in [0, 1]
|
| 382 |
+
num_bases = randg.randint(1, max_bases)
|
| 383 |
+
ids = randg.choice(fmap.shape[axis], num_bases, replace=False)
|
| 384 |
+
fmap_out = np.copy(fmap)
|
| 385 |
+
ratio = randg.rand() * max_ratio
|
| 386 |
+
if axis == 0:
|
| 387 |
+
maxvals = np.amax(np.abs(fmap_out[ids, :]), axis=1 - axis, keepdims=True)
|
| 388 |
+
noise = ratio * maxvals * randg.randn(num_bases, fmap.shape[1 - axis])
|
| 389 |
+
fmap_out[ids, :] += noise
|
| 390 |
+
else:
|
| 391 |
+
maxvals = np.amax(np.abs(fmap_out[:, ids]), axis=1 - axis, keepdims=True)
|
| 392 |
+
noise = ratio * maxvals * randg.randn(fmap.shape[1 - axis], num_bases)
|
| 393 |
+
fmap_out[:, ids] += noise
|
| 394 |
+
return fmap_out
|
| 395 |
+
|
| 396 |
+
def _load_corr_gt(self, sdict0, sdict1):
|
| 397 |
+
corr0 = self._load_corr_file(sdict0['name'])
|
| 398 |
+
corr1 = self._load_corr_file(sdict1['name'])
|
| 399 |
+
corr_gt = np.stack((corr0, corr1), axis=1)
|
| 400 |
+
return corr_gt
|
| 401 |
+
|
| 402 |
+
def _load_corr_file(self, sname):
|
| 403 |
+
corr_path = osp.join(self.corr_dir, f'{sname}.vts')
|
| 404 |
+
corr = np.loadtxt(corr_path, dtype=np.int32)
|
| 405 |
+
return corr - 1
|
| 406 |
+
|
| 407 |
+
def __len__(self):
|
| 408 |
+
return len(self.pair_indices)
|
shape_data/scape.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
|
| 7 |
+
if ROOT_DIR not in sys.path:
|
| 8 |
+
sys.path.append(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
from shape_data.faust import ShapeDataset as FaustShapeDataset
|
| 11 |
+
from shape_data.faust import ShapePairDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ShapeDataset(FaustShapeDataset):
|
| 15 |
+
TRAIN_IDX = np.arange(0, 51)
|
| 16 |
+
TEST_IDX = np.arange(51, 71)
|
| 17 |
+
NAME = "SCAPE"
|
shape_data/shrec19.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
|
| 7 |
+
if ROOT_DIR not in sys.path:
|
| 8 |
+
sys.path.append(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
from shape_data.faust import ShapeDataset as FaustShapeDataset
|
| 11 |
+
from shape_data.faust import ShapePairDataset as FaustShapePairDataset
|
| 12 |
+
from utils.io import list_files
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ShapeDataset(FaustShapeDataset):
|
| 16 |
+
TRAIN_IDX = None
|
| 17 |
+
TEST_IDX = np.arange(44)
|
| 18 |
+
|
| 19 |
+
NAME = "SHREC"
|
| 20 |
+
|
| 21 |
+
class ShapePairDataset(FaustShapePairDataset):
|
| 22 |
+
|
| 23 |
+
def _init(self):
|
| 24 |
+
assert self.mode.startswith('test')
|
| 25 |
+
|
| 26 |
+
self.name_id_map = self.shape_data.get_name_id_map()
|
| 27 |
+
self.pair_indices = list()
|
| 28 |
+
for corr_filename in list_files(self.corr_dir, '*.map', alphanum_sort=True):
|
| 29 |
+
sname0, sname1 = corr_filename[:-4].split('_')
|
| 30 |
+
if sname0 == '40' or sname1 == '40':
|
| 31 |
+
continue
|
| 32 |
+
self.pair_indices.append((self.name_id_map[sname1], self.name_id_map[sname0]))
|
| 33 |
+
|
| 34 |
+
def _load_corr_gt(self, sdict0, sdict1):
|
| 35 |
+
pmap10 = self._load_corr_file(sdict1['name'], sdict0['name'])
|
| 36 |
+
corr_gt = np.stack((pmap10, np.arange(len(pmap10))), axis=1)
|
| 37 |
+
return corr_gt
|
| 38 |
+
|
| 39 |
+
def _load_corr_file(self, sname0, sname1):
|
| 40 |
+
corr_path = osp.join(self.corr_dir, f'{sname0}_{sname1}.map')
|
| 41 |
+
corr = np.loadtxt(corr_path, dtype=np.int32)
|
| 42 |
+
return corr - 1
|
shape_data/smalr.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
|
| 7 |
+
if ROOT_DIR not in sys.path:
|
| 8 |
+
sys.path.append(ROOT_DIR)
|
| 9 |
+
|
| 10 |
+
from shape_data.faust import ShapeDataset as FaustShapeDataset
|
| 11 |
+
from shape_data.faust import ShapePairDataset
|
| 12 |
+
from utils.mesh import find_mesh_files
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ShapeDataset(FaustShapeDataset):
|
| 16 |
+
TRAIN_IDX = None
|
| 17 |
+
TEST_IDX = None
|
| 18 |
+
NAME = "SMAL"
|
| 19 |
+
|
| 20 |
+
def _get_file_list(self):
|
| 21 |
+
if self.mode.startswith('train'):
|
| 22 |
+
categories = ['cow', 'dog', 'fox', 'lion', 'wolf']
|
| 23 |
+
elif self.mode.startswith('test'):
|
| 24 |
+
categories = ['cougar', 'hippo', 'horse']
|
| 25 |
+
else:
|
| 26 |
+
raise RuntimeError(f'Mode {self.mode} is not supported.')
|
| 27 |
+
|
| 28 |
+
path_list = find_mesh_files(Path(self.shape_dir), alphanum_sort=True)
|
| 29 |
+
file_list = [f.name for f in path_list]
|
| 30 |
+
shape_list = [fn for fn in file_list if fn.split('_')[0] in categories]
|
| 31 |
+
return shape_list
|
shape_data/tosca.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import re
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from itertools import permutations as pmt
|
| 7 |
+
|
| 8 |
+
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
|
| 9 |
+
if ROOT_DIR not in sys.path:
|
| 10 |
+
sys.path.append(ROOT_DIR)
|
| 11 |
+
|
| 12 |
+
from shape_data.faust import ShapeDataset as FaustShapeDataset
|
| 13 |
+
from shape_data.faust import ShapePairDataset as FaustShapePairDataset
|
| 14 |
+
from utils.io import list_files
|
| 15 |
+
|
| 16 |
+
def contains_any_regex(substrings, ext, texts):
|
| 17 |
+
pattern = re.compile('|'.join(map(re.escape, substrings))) # Compile regex once
|
| 18 |
+
return [text for text in texts if bool(pattern.search(text)) and (ext in text)] # Apply to all texts efficiently
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ShapeDataset(FaustShapeDataset):
|
| 22 |
+
TRAIN_IDX = None
|
| 23 |
+
TEST_IDX = None
|
| 24 |
+
|
| 25 |
+
def _get_file_list(self):
|
| 26 |
+
if self.mode.startswith('train'):
|
| 27 |
+
categories = None
|
| 28 |
+
elif self.mode.startswith('test'):
|
| 29 |
+
categories = ['cat', 'dog', 'horse', 'wolf']
|
| 30 |
+
else:
|
| 31 |
+
raise RuntimeError(f'Mode {self.mode} is not supported.')
|
| 32 |
+
file_list = list_files(self.shape_dir, '*.off', alphanum_sort=True)
|
| 33 |
+
shape_list = contains_any_regex(categories, ".off", file_list)
|
| 34 |
+
return shape_list
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ShapePairDataset(FaustShapePairDataset):
|
| 38 |
+
categories = ['cat', 'dog', 'horse', 'wolf']
|
| 39 |
+
|
| 40 |
+
def _init(self):
|
| 41 |
+
assert self.mode.startswith('test')
|
| 42 |
+
self.name_id_map = self.shape_data.get_name_id_map()
|
| 43 |
+
self.pair_indices = list()
|
| 44 |
+
for cat in self.categories:
|
| 45 |
+
shape_list_temp = [self.name_id_map[fn] for fn in self.name_id_map if cat in fn]
|
| 46 |
+
self.pair_indices += list(pmt(shape_list_temp, 2))
|
snk/__init__.py
ADDED
|
File without changes
|
snk/loss.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from trimesh.graph import face_adjacency
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class PrismRegularizationLoss(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Calculate the loss based on the PriMo energy, as described in the paper:
|
| 8 |
+
PriMo: Coupled Prisms for Intuitive Surface Modeling
|
| 9 |
+
"""
|
| 10 |
+
def __init__(self, primo_h):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.h = primo_h
|
| 13 |
+
|
| 14 |
+
# compute coefficient for the energy
|
| 15 |
+
indices = torch.tensor([(i, j) for i in range(2) for j in range(2)])
|
| 16 |
+
indices_A = indices.repeat_interleave(4, dim=0)
|
| 17 |
+
indices_B = indices.repeat(4, 1)
|
| 18 |
+
self.coeff = (torch.ones(1) * 2).pow(((indices_A - indices_B).abs() * -1).sum(dim=1))[None, :]
|
| 19 |
+
|
| 20 |
+
def forward(self, transformed_prism, rotations, verts, faces, normals):
|
| 21 |
+
# transformed_prism is (n_faces, 3, 3)
|
| 22 |
+
# verts and faces are from the template (shape 2)
|
| 23 |
+
# * for now assumes there is only one batch
|
| 24 |
+
# todo add batch support
|
| 25 |
+
bs = 1
|
| 26 |
+
verts = verts.reshape(-1, 3)
|
| 27 |
+
normals = normals.reshape(-1, 3)
|
| 28 |
+
faces = faces
|
| 29 |
+
|
| 30 |
+
# get the area of each face
|
| 31 |
+
face_areas = self.get_face_areas(verts, faces) # (n_faces,)
|
| 32 |
+
|
| 33 |
+
# get list of edges and the faces that share each edge
|
| 34 |
+
face_ids, edges = face_adjacency(faces.cpu().numpy(), return_edges=True) # (n_edges, 2), (n_edges, 2)
|
| 35 |
+
face_ids, edges = torch.from_numpy(face_ids).to(verts.device), torch.from_numpy(edges).to(verts.device)
|
| 36 |
+
|
| 37 |
+
# normals and rotations of the faces that share each edge
|
| 38 |
+
normals1, normals2 = normals[edges[:, 0]], normals[edges[:, 1]] # (n_edges, 3), normals are per vertex
|
| 39 |
+
rotations1, rotations2 = rotations[face_ids[:, 0]], rotations[face_ids[:, 1]] # (n_edges, 3, 3), rotations are per face
|
| 40 |
+
|
| 41 |
+
# computed normals from the transformed prism
|
| 42 |
+
# normals = self.compute_normals(transformed_prism)
|
| 43 |
+
|
| 44 |
+
# compute the loss
|
| 45 |
+
face_id1, face_id2 = face_ids[:, 0], face_ids[:, 1] # (n_edges,)
|
| 46 |
+
faces_to_verts = self.get_verts_id_face(faces, edges, face_ids) # (n_edges, 4)
|
| 47 |
+
verts1_p1, verts2_p1 = transformed_prism[face_id1, faces_to_verts[:, 0]], transformed_prism[face_id1, faces_to_verts[:, 1]] # (n_edges, 3)
|
| 48 |
+
verts1_p2, verts2_p2 = transformed_prism[face_id2, faces_to_verts[:, 2]], transformed_prism[face_id2, faces_to_verts[:, 3]] # (n_edges, 3)
|
| 49 |
+
|
| 50 |
+
# get the normals per vertex
|
| 51 |
+
# normals1, normals2 = normals[face_id1], normals[face_id2] # (n_edges, 3) # normals per face (NOT USED)
|
| 52 |
+
prism1_n1, prism1_n2 = (normals1[:, None] @ rotations1).squeeze(1), (normals2[:, None] @ rotations1).squeeze(1) # todo check if this is correct
|
| 53 |
+
prism2_n1, prism2_n2 = (normals1[:, None] @ rotations2).squeeze(1), (normals2[:, None] @ rotations2).squeeze(1)
|
| 54 |
+
|
| 55 |
+
# get the coordinates of the face of the prism
|
| 56 |
+
# prism1 (1 -> 2)
|
| 57 |
+
f_p1_00, f_p1_01 = verts1_p1 + prism1_n1 * self.h, verts2_p1 + prism1_n2 * self.h # (n_edges, 3)
|
| 58 |
+
f_p1_10, f_p1_11 = verts1_p1 - prism1_n1 * self.h, verts2_p1 - prism1_n2 * self.h # (n_edges, 3)
|
| 59 |
+
# prism2 (2 -> 1)
|
| 60 |
+
f_p2_00, f_p2_01 = verts1_p2 + prism2_n1 * self.h, verts2_p2 + prism2_n2 * self.h # (n_edges, 3)
|
| 61 |
+
f_p2_10, f_p2_11 = verts1_p2 - prism2_n1 * self.h, verts2_p2 - prism2_n2 * self.h # (n_edges, 3)
|
| 62 |
+
|
| 63 |
+
# compute the energy
|
| 64 |
+
A, B = torch.stack((f_p1_00, f_p1_01, f_p1_10, f_p1_11), dim=1), torch.stack((f_p2_00, f_p2_01, f_p2_10, f_p2_11), dim=1) # (n_edges, 4, 3)
|
| 65 |
+
energy = self.compute_energy(A - B, A - B) # (n_edges,)
|
| 66 |
+
|
| 67 |
+
# compute weight
|
| 68 |
+
area1, area2 = face_areas[face_id1], face_areas[face_id2] # (n_edges,)
|
| 69 |
+
weight = torch.norm(verts[edges[:, 0]] - verts[edges[:, 1]], dim=1).square() / (area1 + area2) # (n_edges,)
|
| 70 |
+
# weight = torch.ones_like(weight).to(weight.device) # todo remove
|
| 71 |
+
energy = energy * weight # (n_edges,)
|
| 72 |
+
|
| 73 |
+
loss = energy.sum() / bs # todo when batch enabled, need to divide by batch size
|
| 74 |
+
return loss
|
| 75 |
+
|
| 76 |
+
def compute_energy(self, A, B):
|
| 77 |
+
"""
|
| 78 |
+
Computes the formula sum_{i,j,k,l=0}^{1} a_{ij}b_{kl} 2^{-|i - k| - |j - l|}.
|
| 79 |
+
Assumes that A and B are tensors of size bs x 4 x 3, where bs is the batch size.
|
| 80 |
+
"""
|
| 81 |
+
self.coeff = self.coeff.to(A.device)
|
| 82 |
+
|
| 83 |
+
A_repeated = A.repeat_interleave(4, dim=1)
|
| 84 |
+
B_repeated = B.repeat(1, 4, 1)
|
| 85 |
+
|
| 86 |
+
energy = (A_repeated * B_repeated).sum(dim=-1)
|
| 87 |
+
energy = (energy * self.coeff).sum(dim=1)
|
| 88 |
+
energy = energy / 9
|
| 89 |
+
|
| 90 |
+
return energy
|
| 91 |
+
|
| 92 |
+
def get_face_areas(self, verts, faces):
|
| 93 |
+
# get the area of each face
|
| 94 |
+
v1, v2, v3 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]]
|
| 95 |
+
area = 0.5 * torch.cross(v2 - v1, v3 - v1, dim=-1).norm(dim=1)
|
| 96 |
+
|
| 97 |
+
return area
|
| 98 |
+
|
| 99 |
+
def get_verts_id_face(self, F, E, Q):
|
| 100 |
+
e = E.shape[0]
|
| 101 |
+
Z = torch.zeros((e, 4), dtype=torch.long)
|
| 102 |
+
|
| 103 |
+
v1 = F[:, 0][Q[:, 0]]
|
| 104 |
+
v2 = F[:, 1][Q[:, 0]]
|
| 105 |
+
v3 = F[:, 2][Q[:, 0]]
|
| 106 |
+
v4 = F[:, 0][Q[:, 1]]
|
| 107 |
+
v5 = F[:, 1][Q[:, 1]]
|
| 108 |
+
v6 = F[:, 2][Q[:, 1]]
|
| 109 |
+
|
| 110 |
+
idx1 = torch.where(v1 == E[:, 0], 0, torch.where(v2 == E[:, 0], 1, torch.where(v3 == E[:, 0], 2, -1)))
|
| 111 |
+
idx2 = torch.where(v1 == E[:, 1], 0, torch.where(v2 == E[:, 1], 1, torch.where(v3 == E[:, 1], 2, -1)))
|
| 112 |
+
idx3 = torch.where(v4 == E[:, 0], 0, torch.where(v5 == E[:, 0], 1, torch.where(v6 == E[:, 0], 2, -1)))
|
| 113 |
+
idx4 = torch.where(v4 == E[:, 1], 0, torch.where(v5 == E[:, 1], 1, torch.where(v6 == E[:, 1], 2, -1)))
|
| 114 |
+
|
| 115 |
+
Z[:, 0:2] = torch.stack((idx1, idx2), dim=1)
|
| 116 |
+
Z[:, 2:4] = torch.stack((idx3, idx4), dim=1)
|
| 117 |
+
Z = Z.to(F.device)
|
| 118 |
+
|
| 119 |
+
return Z
|
snk/prism_decoder.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import roma
|
| 4 |
+
from shape_models.layers import DiffusionNet
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PrismDecoder(torch.nn.Module):
|
| 8 |
+
def __init__(self, dim_in=1024, dim_out=512, n_width=256, n_block=4, pairwise_dot=True, dropout=False, dot_linear_complex=True, neig=128):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
self.diffusion_net = DiffusionNet(
|
| 13 |
+
C_in=dim_in,
|
| 14 |
+
C_out=dim_out,
|
| 15 |
+
C_width=n_width,
|
| 16 |
+
N_block=n_block,
|
| 17 |
+
dropout=dropout,
|
| 18 |
+
with_gradient_features=pairwise_dot,
|
| 19 |
+
with_gradient_rotations=dot_linear_complex,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self.mlp_refine = nn.Sequential(
|
| 23 |
+
nn.Linear(dim_out, dim_out),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
nn.Linear(dim_out, 512),
|
| 26 |
+
nn.ReLU(),
|
| 27 |
+
nn.Linear(512, 256),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Linear(256, 12),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def forward(self, batch_dict, latent):
|
| 33 |
+
# original prism
|
| 34 |
+
try:
|
| 35 |
+
verts = batch_dict["vertices"]
|
| 36 |
+
except:
|
| 37 |
+
verts = batch_dict["verts"]
|
| 38 |
+
faces = batch_dict["faces"]
|
| 39 |
+
prism_base = verts[faces] # (n_faces, 3, 3)
|
| 40 |
+
bs = 1
|
| 41 |
+
|
| 42 |
+
# forward through diffusion net
|
| 43 |
+
features = self.diffusion_net(latent, batch_dict["mass"], batch_dict["L"], evals=batch_dict["evals"],
|
| 44 |
+
evecs=batch_dict["evecs"], gradX=batch_dict["gradX"], gradY=batch_dict["gradY"], faces=batch_dict["faces"]) # (bs, n_verts, dim)
|
| 45 |
+
|
| 46 |
+
# features per face
|
| 47 |
+
x_gather = features.unsqueeze(-1).expand(-1, -1, 3)
|
| 48 |
+
faces_gather = faces.unsqueeze(1).expand(-1, features.shape[-1], -1)
|
| 49 |
+
xf = torch.gather(x_gather, 0, faces_gather)
|
| 50 |
+
features = torch.mean(xf, dim=-1) # (bs, n_faces, dim)
|
| 51 |
+
|
| 52 |
+
# refine features with mlp
|
| 53 |
+
features = self.mlp_refine(features) # (bs, n_faces, 12)
|
| 54 |
+
|
| 55 |
+
# get the translation and rotation
|
| 56 |
+
rotations = features[:, :9].reshape(-1, 3, 3)
|
| 57 |
+
rotations = roma.special_procrustes(rotations) # (n_faces, 3, 3)
|
| 58 |
+
translations = features[:, 9:].reshape(-1, 3) # (n_faces, 3)
|
| 59 |
+
|
| 60 |
+
# transform the prism
|
| 61 |
+
transformed_prism = (prism_base @ rotations) + translations[:, None]
|
| 62 |
+
|
| 63 |
+
# prism to vertices
|
| 64 |
+
features = self.prism_to_vertices(transformed_prism, faces, verts)
|
| 65 |
+
|
| 66 |
+
out_features = features.reshape(bs, -1, 3)
|
| 67 |
+
transformed_prism = transformed_prism
|
| 68 |
+
rotations = rotations
|
| 69 |
+
return out_features, transformed_prism, rotations
|
| 70 |
+
|
| 71 |
+
def prism_to_vertices(self, prism, faces, verts):
|
| 72 |
+
# initialize the transformed features tensor
|
| 73 |
+
N = verts.shape[0]
|
| 74 |
+
d = prism.shape[-1]
|
| 75 |
+
device = prism.device
|
| 76 |
+
features = torch.zeros((N, d), device=device)
|
| 77 |
+
|
| 78 |
+
# scatter the features in K onto L using the indices in F
|
| 79 |
+
features.scatter_add_(0, faces[:, :, None].repeat(1, 1, d).reshape(-1, d), prism.reshape(-1, d))
|
| 80 |
+
|
| 81 |
+
# divide each row in the transformed features tensor by the number of faces that the corresponding vertex appears in
|
| 82 |
+
num_faces_per_vertex = torch.zeros(N, dtype=torch.float32, device=device)
|
| 83 |
+
num_faces_per_vertex.index_add_(0, faces.reshape(-1), torch.ones(faces.shape[0] * 3, device=device))
|
| 84 |
+
features /= num_faces_per_vertex.unsqueeze(1).clamp(min=1)
|
| 85 |
+
|
| 86 |
+
return features
|