Spaces:
Sleeping
Sleeping
DveloperY0115
commited on
Commit
•
801501a
1
Parent(s):
7d3169e
init repo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +5 -4
- app.py +116 -0
- checkpoints/lang_phase1/hparams.yaml +48 -0
- checkpoints/lang_phase1/state_only.ckpt +3 -0
- checkpoints/lang_phase2/hparams.yaml +47 -0
- checkpoints/lang_phase2/state_only.ckpt +3 -0
- checkpoints/phase1/hparams.yaml +39 -0
- checkpoints/phase1/state_only.ckpt +3 -0
- checkpoints/phase2/hparams.yaml +41 -0
- checkpoints/phase2/state_only.ckpt +3 -0
- custom_wheels/salad-0.1-py3-none-any.whl +0 -0
- data/autosdf_spaghetti_intersec_game_data.csv +0 -0
- data/spaghetti_airplane_latents.hdf5 +3 -0
- data/spaghetti_airplane_latents_mean_std.hdf5 +3 -0
- data/spaghetti_chair_latents.hdf5 +3 -0
- data/spaghetti_chair_latents_mean_std.hdf5 +3 -0
- data/spaghetti_table_latents.hdf5 +3 -0
- data/spaghetti_table_latents_mean_std.hdf5 +3 -0
- requirements.txt +1 -0
- salad.egg-info/PKG-INFO +5 -0
- salad.egg-info/SOURCES.txt +7 -0
- salad.egg-info/dependency_links.txt +1 -0
- salad.egg-info/not-zip-safe +1 -0
- salad.egg-info/top_level.txt +1 -0
- salad/data/__pycache__/dataset.cpython-39.pyc +0 -0
- salad/data/dataset.py +149 -0
- salad/model_components/__pycache__/lstm.cpython-39.pyc +0 -0
- salad/model_components/__pycache__/network.cpython-39.pyc +0 -0
- salad/model_components/__pycache__/simple_module.cpython-39.pyc +0 -0
- salad/model_components/__pycache__/transformer.cpython-39.pyc +0 -0
- salad/model_components/__pycache__/variance_schedule.cpython-39.pyc +0 -0
- salad/model_components/lstm.py +56 -0
- salad/model_components/network.py +229 -0
- salad/model_components/simple_module.py +125 -0
- salad/model_components/transformer.py +308 -0
- salad/model_components/variance_schedule.py +57 -0
- salad/models/__init__.py +0 -0
- salad/models/__pycache__/__init__.cpython-39.pyc +0 -0
- salad/models/__pycache__/base_model.cpython-39.pyc +0 -0
- salad/models/__pycache__/language_phase1.cpython-39.pyc +0 -0
- salad/models/__pycache__/language_phase2.cpython-39.pyc +0 -0
- salad/models/__pycache__/phase1.cpython-39.pyc +0 -0
- salad/models/__pycache__/phase2.cpython-39.pyc +0 -0
- salad/models/base_model.py +147 -0
- salad/models/language_phase1.py +340 -0
- salad/models/language_phase2.py +201 -0
- salad/models/phase1.py +65 -0
- salad/models/phase2.py +183 -0
- salad/spaghetti/.gitignore +9 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.hdf5 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: blue
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Test
|
3 |
+
emoji: 🦀
|
4 |
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.36.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
app.py
|
3 |
+
|
4 |
+
An interactive demo of text-guided shape generation.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Literal
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
|
13 |
+
from salad.utils.spaghetti_util import (
|
14 |
+
get_mesh_from_spaghetti,
|
15 |
+
generate_zc_from_sj_gaus,
|
16 |
+
load_mesher,
|
17 |
+
load_spaghetti,
|
18 |
+
)
|
19 |
+
import hydra
|
20 |
+
from omegaconf import OmegaConf
|
21 |
+
import torch
|
22 |
+
from pytorch_lightning import seed_everything
|
23 |
+
|
24 |
+
|
25 |
+
def load_model(
|
26 |
+
model_class: Literal["phase1", "phase2", "lang_phase1", "lang_phase2"],
|
27 |
+
device,
|
28 |
+
):
|
29 |
+
checkpoint_dir = Path(__file__).parent / "checkpoints"
|
30 |
+
c = OmegaConf.load(checkpoint_dir / f"{model_class}/hparams.yaml")
|
31 |
+
model = hydra.utils.instantiate(c)
|
32 |
+
ckpt = torch.load(checkpoint_dir / f"{model_class}/state_only.ckpt")
|
33 |
+
model.load_state_dict(ckpt)
|
34 |
+
model.eval()
|
35 |
+
for p in model.parameters(): p.requires_grad_(False)
|
36 |
+
model = model.to(device)
|
37 |
+
return model
|
38 |
+
|
39 |
+
|
40 |
+
def run_inference(prompt: str):
|
41 |
+
"""The entry point of the demo."""
|
42 |
+
|
43 |
+
device: torch.device = torch.device("cuda")
|
44 |
+
"""Device to run the demo on."""
|
45 |
+
seed: int = 63
|
46 |
+
"""Random seed for reproducibility."""
|
47 |
+
|
48 |
+
# set random seed
|
49 |
+
seed_everything(seed)
|
50 |
+
|
51 |
+
# load SPAGHETTI and mesher
|
52 |
+
spaghetti = load_spaghetti(device)
|
53 |
+
mesher = load_mesher(device)
|
54 |
+
|
55 |
+
# load SALAD
|
56 |
+
lang_phase1_model = load_model("lang_phase1", device)
|
57 |
+
lang_phase2_model = load_model("phase2", device)
|
58 |
+
lang_phase1_model._build_dataset("val")
|
59 |
+
|
60 |
+
# run phase 1
|
61 |
+
extrinsics = lang_phase1_model.sampling_gaussians([prompt])
|
62 |
+
|
63 |
+
# run phase 2
|
64 |
+
intrinsics = lang_phase2_model.sample(extrinsics)
|
65 |
+
|
66 |
+
# generate mesh
|
67 |
+
zcs = generate_zc_from_sj_gaus(spaghetti, intrinsics, extrinsics)
|
68 |
+
vertices, faces = get_mesh_from_spaghetti(
|
69 |
+
spaghetti,
|
70 |
+
mesher,
|
71 |
+
zcs[0],
|
72 |
+
res=256,
|
73 |
+
)
|
74 |
+
|
75 |
+
# plot
|
76 |
+
figure = go.Figure(
|
77 |
+
data=[
|
78 |
+
go.Mesh3d(
|
79 |
+
x=vertices[:, 0], # flip front-back
|
80 |
+
y=-vertices[:, 2],
|
81 |
+
z=vertices[:, 1],
|
82 |
+
i=faces[:, 0],
|
83 |
+
j=faces[:, 1],
|
84 |
+
k=faces[:, 2],
|
85 |
+
color="gray",
|
86 |
+
)
|
87 |
+
],
|
88 |
+
layout=dict(
|
89 |
+
scene=dict(
|
90 |
+
xaxis=dict(visible=False),
|
91 |
+
yaxis=dict(visible=False),
|
92 |
+
zaxis=dict(visible=False),
|
93 |
+
)
|
94 |
+
),
|
95 |
+
)
|
96 |
+
|
97 |
+
return figure
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
|
101 |
+
# create UI
|
102 |
+
demo = gr.Interface(
|
103 |
+
fn=run_inference,
|
104 |
+
inputs="text",
|
105 |
+
outputs=gr.Plot(),
|
106 |
+
title="SALAD: Text-Guided Shape Generation",
|
107 |
+
description="Describe a chair",
|
108 |
+
examples=[
|
109 |
+
"an office chair",
|
110 |
+
"a chair with armrests",
|
111 |
+
"a chair without armrests",
|
112 |
+
]
|
113 |
+
)
|
114 |
+
# initiate
|
115 |
+
demo.queue(max_size=30)
|
116 |
+
demo.launch()
|
checkpoints/lang_phase1/hparams.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: salad.models.language_phase1.LangPhase1Model
|
2 |
+
|
3 |
+
network:
|
4 |
+
_target_: salad.model_components.network.CondDiffNetwork
|
5 |
+
input_dim: 16
|
6 |
+
residual: true
|
7 |
+
context_dim: 768
|
8 |
+
context_embedding_dim: 1024
|
9 |
+
embedding_dim: 512
|
10 |
+
encoder_use_time: false
|
11 |
+
encoder_type: pointwise
|
12 |
+
decoder_type: transformer_encoder
|
13 |
+
enc_num_layers: 2
|
14 |
+
dec_num_layers: 6
|
15 |
+
use_timestep_embedder: true
|
16 |
+
timestep_embedder_dim: 128
|
17 |
+
|
18 |
+
variance_schedule:
|
19 |
+
_target_: salad.model_components.variance_schedule.VarianceSchedule
|
20 |
+
num_steps: &time_steps 1000
|
21 |
+
beta_1: 1e-4
|
22 |
+
beta_T: 0.05
|
23 |
+
mode: linear
|
24 |
+
|
25 |
+
# optimizer
|
26 |
+
lr: 1e-4
|
27 |
+
batch_size: 64
|
28 |
+
|
29 |
+
# dataset
|
30 |
+
dataset_kwargs:
|
31 |
+
data_path: spaghetti_chair_latents.hdf5
|
32 |
+
repeat: 1
|
33 |
+
data_keys: ["g_js_affine"]
|
34 |
+
only_easy_context: false
|
35 |
+
global_normalization: &normalization partial
|
36 |
+
|
37 |
+
global_normalization: *normalization
|
38 |
+
num_timesteps: *time_steps
|
39 |
+
faster: true
|
40 |
+
validation_step: 10
|
41 |
+
no_run_validation: false
|
42 |
+
spaghetti_tag: "chairs_large" # or airplanes, tables
|
43 |
+
|
44 |
+
text_encoder_freeze: false
|
45 |
+
use_lstm: true
|
46 |
+
classifier_free_guidance: true
|
47 |
+
conditioning_dropout_prob: 0.2
|
48 |
+
|
checkpoints/lang_phase1/state_only.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf46454eaaabbb7f3008c51beaae5b16794b189f3cae48f79db70fcdf5413380
|
3 |
+
size 318782397
|
checkpoints/lang_phase2/hparams.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: salad.models.language_phase2.LangPhase2Model
|
2 |
+
network:
|
3 |
+
_target_: salad.model_components.network.CondDiffNetwork
|
4 |
+
input_dim: 512
|
5 |
+
residual: true
|
6 |
+
context_dim: 784 # concat of 768 lang feat and gaussian.
|
7 |
+
context_embedding_dim: 1024
|
8 |
+
embedding_dim: 512
|
9 |
+
encoder_use_time: false
|
10 |
+
encoder_type: transformer
|
11 |
+
decoder_type: transformer_encoder
|
12 |
+
enc_num_layers: 6
|
13 |
+
dec_num_layers: 6
|
14 |
+
use_timestep_embedder: true
|
15 |
+
timestep_embedder_dim: 128
|
16 |
+
|
17 |
+
variance_schedule:
|
18 |
+
_target_: salad.model_components.variance_schedule.VarianceSchedule
|
19 |
+
num_steps: &time_steps 1000
|
20 |
+
beta_1: 1e-4
|
21 |
+
beta_T: 0.05
|
22 |
+
mode: linear
|
23 |
+
|
24 |
+
# optimizer
|
25 |
+
lr: 1e-4
|
26 |
+
batch_size: 64
|
27 |
+
|
28 |
+
# dataset
|
29 |
+
dataset_kwargs:
|
30 |
+
data_path: spaghetti_chair_latents.hdf5
|
31 |
+
repeat: 1
|
32 |
+
data_keys: ["s_j_affine", "g_js_affine"]
|
33 |
+
only_easy_context: false
|
34 |
+
global_normalization: &normalization false
|
35 |
+
|
36 |
+
global_normalization: *normalization
|
37 |
+
num_timesteps: *time_steps
|
38 |
+
faster: true
|
39 |
+
validation_step: 10
|
40 |
+
no_run_validation: false
|
41 |
+
spaghetti_tag: "chairs_large" # or airplanes, tables
|
42 |
+
|
43 |
+
text_encoder_freeze: false
|
44 |
+
use_lstm: true
|
45 |
+
classifier_free_guidance: true
|
46 |
+
conditioning_dropout_prob: 0.2
|
47 |
+
|
checkpoints/lang_phase2/state_only.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4105dd24201fa8aad3fc4db2f74376f98f4df53b38ae749d944bfdb6552ea40f
|
3 |
+
size 455307461
|
checkpoints/phase1/hparams.yaml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: salad.models.phase1.Phase1Model
|
2 |
+
|
3 |
+
network:
|
4 |
+
_target_: salad.model_components.network.UnCondDiffNetwork
|
5 |
+
input_dim: 16
|
6 |
+
embedding_dim: 512
|
7 |
+
num_heads: 4
|
8 |
+
use_timestep_embedder: true
|
9 |
+
timestep_embedder_dim: 128
|
10 |
+
enc_num_layers: 6
|
11 |
+
residual: true
|
12 |
+
encoder_type: transformer
|
13 |
+
attn_dropout: 0.0
|
14 |
+
|
15 |
+
variance_schedule:
|
16 |
+
_target_: salad.model_components.variance_schedule.VarianceSchedule
|
17 |
+
num_steps: &time_steps 1000
|
18 |
+
beta_1: 1e-4
|
19 |
+
beta_T: 0.05
|
20 |
+
mode: linear
|
21 |
+
|
22 |
+
# optimizer
|
23 |
+
lr: 1e-4
|
24 |
+
batch_size: 64
|
25 |
+
|
26 |
+
# dataset
|
27 |
+
dataset_kwargs:
|
28 |
+
data_path: spaghetti_chair_latents.hdf5
|
29 |
+
repeat: 3
|
30 |
+
data_keys: ["g_js_affine"]
|
31 |
+
global_normalization: &normalization partial
|
32 |
+
|
33 |
+
global_normalization: *normalization # normalize pi, eigenvalues.
|
34 |
+
num_timesteps: *time_steps
|
35 |
+
faster: true
|
36 |
+
validation_step: 10
|
37 |
+
no_run_validation: false
|
38 |
+
spaghetti_tag: "chairs_large" # or airplanes, tables
|
39 |
+
|
checkpoints/phase1/state_only.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f616fa657723de4855e8571f3ef828ff25221b86cb516a755aaa93538b0c7de
|
3 |
+
size 60275831
|
checkpoints/phase2/hparams.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: salad.models.phase2.Phase2Model
|
2 |
+
|
3 |
+
network:
|
4 |
+
_target_: salad.model_components.network.CondDiffNetwork
|
5 |
+
input_dim: 512
|
6 |
+
residual: true
|
7 |
+
context_dim: 16 # gaussian condition dim.
|
8 |
+
context_embedding_dim: 512
|
9 |
+
embedding_dim: 512
|
10 |
+
encoder_use_time: false
|
11 |
+
encoder_type: transformer
|
12 |
+
decoder_type: transformer_encoder # we don't use cross attention.
|
13 |
+
enc_num_layers: 6
|
14 |
+
dec_num_layers: 6
|
15 |
+
use_timestep_embedder: true
|
16 |
+
timestep_embedder_dim: 128
|
17 |
+
|
18 |
+
variance_schedule:
|
19 |
+
_target_: salad.model_components.variance_schedule.VarianceSchedule
|
20 |
+
num_steps: &time_steps 1000
|
21 |
+
beta_1: 1e-4
|
22 |
+
beta_T: 0.05
|
23 |
+
mode: linear
|
24 |
+
|
25 |
+
# optimizer
|
26 |
+
lr: 1e-4
|
27 |
+
batch_size: 64
|
28 |
+
|
29 |
+
# dataset
|
30 |
+
dataset_kwargs:
|
31 |
+
data_path: spaghetti_chair_latents.hdf5
|
32 |
+
repeat: 3
|
33 |
+
data_keys: ["s_j_affine", "g_js_affine"]
|
34 |
+
global_normalization: &normalization null
|
35 |
+
|
36 |
+
global_normalization: *normalization # normalize pi, eigenvalues.
|
37 |
+
num_timesteps: *time_steps
|
38 |
+
faster: true
|
39 |
+
validation_step: 10
|
40 |
+
no_run_validation: false
|
41 |
+
spaghetti_tag: "chairs_large" # or airplanes, tables
|
checkpoints/phase2/state_only.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aed08103f6eebbd84fac523affaab2cd493f8f2a1d5e81e9d298cc0a7a807ed2
|
3 |
+
size 150592331
|
custom_wheels/salad-0.1-py3-none-any.whl
ADDED
Binary file (994 Bytes). View file
|
|
data/autosdf_spaghetti_intersec_game_data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/spaghetti_airplane_latents.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c242271687d13159b0df44a3179a0d460c9e87c577851d8d0282f0369a529f46
|
3 |
+
size 222017536
|
data/spaghetti_airplane_latents_mean_std.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c32e24c7786593ffdd918e05fdd1148634c3c707a4948e0a3bf6a6c002b540e1
|
3 |
+
size 12544
|
data/spaghetti_chair_latents.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7bfa1533a0366e9271f6bf96d4f7a135f8763ba66ce26b5cc952e6af14e5bfe4
|
3 |
+
size 1255457792
|
data/spaghetti_chair_latents_mean_std.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f6be55ae235fe77aa821146122f0e911e9593e78a001b8fa63dea041c49095fa
|
3 |
+
size 8320
|
data/spaghetti_table_latents.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:20ef1da19e47e2c23782defa2c5d2172d7322c5476c92f7ed3fee271e3893f91
|
3 |
+
size 1127843840
|
data/spaghetti_table_latents_mean_std.hdf5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:718825073ede0c52ccd5c277b1114425558a2a45258dd952a4707fecb3dc5d57
|
3 |
+
size 8320
|
requirements.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
./custom_wheels/salad-0.1-py3-none-any.whl
|
salad.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: salad
|
3 |
+
Version: 0.1
|
4 |
+
Summary: SALAD: Part-Level Latent Diffusion for 3D Shape Generation and Manipulation
|
5 |
+
Home-page: https://github.com/63days/SALAD
|
salad.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
README.md
|
2 |
+
setup.py
|
3 |
+
salad.egg-info/PKG-INFO
|
4 |
+
salad.egg-info/SOURCES.txt
|
5 |
+
salad.egg-info/dependency_links.txt
|
6 |
+
salad.egg-info/not-zip-safe
|
7 |
+
salad.egg-info/top_level.txt
|
salad.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
salad.egg-info/not-zip-safe
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
salad.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
salad
|
salad/data/__pycache__/dataset.cpython-39.pyc
ADDED
Binary file (4.61 kB). View file
|
|
salad/data/dataset.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import torch
|
5 |
+
from dotmap import DotMap
|
6 |
+
|
7 |
+
from salad.utils.paths import DATA_DIR
|
8 |
+
from salad.utils import thutil
|
9 |
+
|
10 |
+
|
11 |
+
class SALADDataset(torch.utils.data.Dataset):
|
12 |
+
def __init__(self, data_path, repeat=None, **kwargs):
|
13 |
+
super().__init__()
|
14 |
+
self.data_path = str(DATA_DIR / data_path)
|
15 |
+
self.repeat = repeat
|
16 |
+
self.__dict__.update(kwargs)
|
17 |
+
self.hparams = DotMap(self.__dict__)
|
18 |
+
|
19 |
+
"""
|
20 |
+
Global Data statistics.
|
21 |
+
"""
|
22 |
+
if self.hparams.get("global_normalization"):
|
23 |
+
with h5py.File(self.data_path.replace(".hdf5", "_mean_std.hdf5")) as f:
|
24 |
+
self.global_mean = f["mean"][:].astype(np.float32)
|
25 |
+
self.global_std = f["std"][:].astype(np.float32)
|
26 |
+
|
27 |
+
self.data = dict()
|
28 |
+
with h5py.File(self.data_path) as f:
|
29 |
+
for k in self.hparams.data_keys:
|
30 |
+
self.data[k] = f[k][:].astype(np.float32)
|
31 |
+
|
32 |
+
"""
|
33 |
+
global_normalization arg is for gaussians only.
|
34 |
+
"""
|
35 |
+
if k == "g_js_affine":
|
36 |
+
if self.hparams.get("global_normalization") == "partial":
|
37 |
+
assert k == "g_js_affine"
|
38 |
+
if self.hparams.get("verbose"):
|
39 |
+
print("[*] Normalize data only for pi and eigenvalues.")
|
40 |
+
# 3: mu, 9: eigvec, 1: pi, 3: eigval
|
41 |
+
self.data[k] = self.normalize_global_static(
|
42 |
+
self.data[k], slice(12, None)
|
43 |
+
)
|
44 |
+
elif self.hparams.get("global_normalization") == "all":
|
45 |
+
assert k == "g_js_affine"
|
46 |
+
if self.hparams.get("verbose"):
|
47 |
+
print("[*] Normalize data for all elements.")
|
48 |
+
self.data[k] = self.normalize_global_static(
|
49 |
+
self.data[k], slice(None)
|
50 |
+
)
|
51 |
+
|
52 |
+
def __getitem__(self, idx):
|
53 |
+
if self.repeat is not None and self.repeat > 1:
|
54 |
+
idx = int(idx / self.repeat)
|
55 |
+
|
56 |
+
items = []
|
57 |
+
for k in self.hparams.data_keys:
|
58 |
+
data = torch.from_numpy(self.data[k][idx])
|
59 |
+
items.append(data)
|
60 |
+
|
61 |
+
if self.hparams.get("concat_data"):
|
62 |
+
return torch.cat(items, -1) # [16,528]
|
63 |
+
if len(items) == 1:
|
64 |
+
return items[0]
|
65 |
+
return items
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
k = self.hparams.data_keys[0]
|
69 |
+
if self.repeat is not None and self.repeat > 1:
|
70 |
+
return len(self.data[k]) * self.repeat
|
71 |
+
return len(self.data[k])
|
72 |
+
|
73 |
+
def get_other_latents(self, key):
|
74 |
+
with h5py.File(self.data_path) as f:
|
75 |
+
return f[key][:].astype(np.float32)
|
76 |
+
|
77 |
+
def normalize_global_static(self, data: np.ndarray, normalize_indices=slice(None)):
|
78 |
+
"""
|
79 |
+
Input:
|
80 |
+
np.ndarray or torch.Tensor. [16,16] or [B,16,16]
|
81 |
+
slice(None) -> full
|
82 |
+
slice(12, None) -> partial
|
83 |
+
Output:
|
84 |
+
[16,16] or [B,16,16]
|
85 |
+
"""
|
86 |
+
assert normalize_indices == slice(None) or normalize_indices == slice(
|
87 |
+
12, None
|
88 |
+
), print(f"{normalize_indices} is wrong.")
|
89 |
+
data = thutil.th2np(data).copy()
|
90 |
+
data[..., normalize_indices] = (
|
91 |
+
data[..., normalize_indices] - self.global_mean[normalize_indices]
|
92 |
+
) / self.global_std[normalize_indices]
|
93 |
+
return data
|
94 |
+
|
95 |
+
def unnormalize_global_static(
|
96 |
+
self, data: np.ndarray, unnormalize_indices=slice(None)
|
97 |
+
):
|
98 |
+
"""
|
99 |
+
Input:
|
100 |
+
np.ndarray or torch.Tensor. [16,16] or [B,16,16]
|
101 |
+
slice(None) -> full
|
102 |
+
slice(12, None) -> partial
|
103 |
+
Output:
|
104 |
+
[16,16] or [B,16,16]
|
105 |
+
"""
|
106 |
+
assert unnormalize_indices == slice(None) or unnormalize_indices == slice(
|
107 |
+
12, None
|
108 |
+
), print(f"{unnormalize_indices} is wrong.")
|
109 |
+
data = thutil.th2np(data).copy()
|
110 |
+
data[..., unnormalize_indices] = (
|
111 |
+
data[..., unnormalize_indices]
|
112 |
+
) * self.global_std[unnormalize_indices] + self.global_mean[unnormalize_indices]
|
113 |
+
return data
|
114 |
+
|
115 |
+
|
116 |
+
class LangSALADDataset(SALADDataset):
|
117 |
+
def __init__(self, data_path, repeat=None, **kwargs):
|
118 |
+
super().__init__(data_path, repeat, **kwargs)
|
119 |
+
|
120 |
+
# self.game_data = pd.read_csv(self.hparams.lang_data_path)
|
121 |
+
self.game_data = pd.read_csv(DATA_DIR / "autosdf_spaghetti_intersec_game_data.csv")
|
122 |
+
self.shapenet_ids = np.array(self.game_data["sn"])
|
123 |
+
self.spaghetti_indices = np.array(self.game_data["spaghetti_idx"]) # for 5401
|
124 |
+
self.texts = np.array(self.game_data["text"])
|
125 |
+
|
126 |
+
assert len(self.shapenet_ids) == len(self.spaghetti_indices) == len(self.texts)
|
127 |
+
|
128 |
+
def __getitem__(self, idx):
|
129 |
+
if self.repeat is not None and self.repeat > 1:
|
130 |
+
idx = int(idx / self.repeat)
|
131 |
+
|
132 |
+
spa_idx = self.spaghetti_indices[idx]
|
133 |
+
text = self.texts[idx]
|
134 |
+
latents = []
|
135 |
+
for k in self.hparams.data_keys:
|
136 |
+
data = torch.from_numpy(self.data[k][spa_idx])
|
137 |
+
latents.append(data)
|
138 |
+
|
139 |
+
item = latents + [text]
|
140 |
+
if self.hparams.get("concat_data"):
|
141 |
+
latents = torch.cat(latents, -1)
|
142 |
+
return latents, text
|
143 |
+
|
144 |
+
return item
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
if self.repeat is not None and self.repeat > 1:
|
148 |
+
return len(self.texts) * self.repeat
|
149 |
+
return len(self.texts)
|
salad/model_components/__pycache__/lstm.cpython-39.pyc
ADDED
Binary file (2.38 kB). View file
|
|
salad/model_components/__pycache__/network.cpython-39.pyc
ADDED
Binary file (4.73 kB). View file
|
|
salad/model_components/__pycache__/simple_module.cpython-39.pyc
ADDED
Binary file (3.9 kB). View file
|
|
salad/model_components/__pycache__/transformer.cpython-39.pyc
ADDED
Binary file (8.63 kB). View file
|
|
salad/model_components/__pycache__/variance_schedule.cpython-39.pyc
ADDED
Binary file (1.99 kB). View file
|
|
salad/model_components/lstm.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
5 |
+
|
6 |
+
|
7 |
+
class LSTM(nn.Module):
|
8 |
+
def __init__(self, text_dim, embedding_dim, vocab_size, padding_idx=0):
|
9 |
+
super().__init__()
|
10 |
+
self.padding_idx = padding_idx
|
11 |
+
self.word_embedding = nn.Embedding(
|
12 |
+
vocab_size, embedding_dim, padding_idx=padding_idx
|
13 |
+
)
|
14 |
+
self.rnn = nn.LSTM(embedding_dim, text_dim, batch_first=True)
|
15 |
+
self.w_attn = nn.Parameter(torch.Tensor(1, text_dim))
|
16 |
+
nn.init.xavier_uniform_(self.w_attn)
|
17 |
+
|
18 |
+
def forward(self, padded_tokens, dropout=0.5):
|
19 |
+
w_emb = self.word_embedding(padded_tokens)
|
20 |
+
w_emb = F.dropout(w_emb, dropout, self.training)
|
21 |
+
len_seq = (padded_tokens != self.padding_idx).sum(dim=1).cpu()
|
22 |
+
x_packed = pack_padded_sequence(
|
23 |
+
w_emb, len_seq, enforce_sorted=False, batch_first=True
|
24 |
+
)
|
25 |
+
B = padded_tokens.shape[0]
|
26 |
+
rnn_out, _ = self.rnn(x_packed)
|
27 |
+
rnn_out, dummy = pad_packed_sequence(rnn_out, batch_first=True)
|
28 |
+
h = rnn_out[torch.arange(B), len_seq - 1]
|
29 |
+
final_feat, attn = self.word_attention(rnn_out, h, len_seq)
|
30 |
+
return final_feat, attn
|
31 |
+
|
32 |
+
def word_attention(self, R, h, len_seq):
|
33 |
+
"""
|
34 |
+
Input:
|
35 |
+
R: hidden states of the entire words
|
36 |
+
h: the final hidden state after processing the entire words
|
37 |
+
len_seq: the length of the sequence
|
38 |
+
Output:
|
39 |
+
final_feat: the final feature after the bilinear attention
|
40 |
+
attn: word attention weights
|
41 |
+
"""
|
42 |
+
B, N, D = R.shape
|
43 |
+
device = R.device
|
44 |
+
len_seq = len_seq.to(device)
|
45 |
+
|
46 |
+
W_attn = (self.w_attn * torch.eye(D).to(device))[None].repeat(B, 1, 1)
|
47 |
+
score = torch.bmm(torch.bmm(R, W_attn), h.unsqueeze(-1))
|
48 |
+
|
49 |
+
mask = torch.arange(N).reshape(1, N, 1).repeat(B, 1, 1).to(device)
|
50 |
+
mask = mask < len_seq.reshape(B, 1, 1)
|
51 |
+
|
52 |
+
score = score.masked_fill(mask == 0, -1e9)
|
53 |
+
attn = F.softmax(score, 1)
|
54 |
+
final_feat = torch.bmm(R.transpose(1, 2), attn).squeeze(-1)
|
55 |
+
|
56 |
+
return final_feat, attn.squeeze(-1)
|
salad/model_components/network.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from dotmap import DotMap
|
5 |
+
from salad.model_components.simple_module import TimePointWiseEncoder, TimestepEmbedder
|
6 |
+
|
7 |
+
|
8 |
+
from salad.model_components.transformer import (
|
9 |
+
PositionalEncoding,
|
10 |
+
TimeTransformerDecoder,
|
11 |
+
TimeTransformerEncoder,
|
12 |
+
)
|
13 |
+
|
14 |
+
class UnCondDiffNetwork(nn.Module):
|
15 |
+
def __init__(self, input_dim, residual, **kwargs):
|
16 |
+
"""
|
17 |
+
Transformer Encoder.
|
18 |
+
"""
|
19 |
+
super().__init__()
|
20 |
+
self.input_dim = input_dim
|
21 |
+
self.residual = residual
|
22 |
+
self.__dict__.update(kwargs)
|
23 |
+
self.hparams = DotMap(self.__dict__)
|
24 |
+
|
25 |
+
self._build_model()
|
26 |
+
|
27 |
+
def _build_model(self):
|
28 |
+
self.act = F.leaky_relu
|
29 |
+
if self.hparams.get("use_timestep_embedder"):
|
30 |
+
self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim)
|
31 |
+
dim_ctx = self.hparams.timestep_embedder_dim
|
32 |
+
else:
|
33 |
+
dim_ctx = 3
|
34 |
+
|
35 |
+
"""
|
36 |
+
Encoder part
|
37 |
+
"""
|
38 |
+
enc_dim = self.hparams.embedding_dim
|
39 |
+
self.embedding = nn.Linear(self.hparams.input_dim, enc_dim)
|
40 |
+
if not self.hparams.get("encoder_type"):
|
41 |
+
self.encoder = TimeTransformerEncoder(
|
42 |
+
enc_dim,
|
43 |
+
dim_ctx=dim_ctx,
|
44 |
+
num_heads=self.hparams.num_heads
|
45 |
+
if self.hparams.get("num_heads")
|
46 |
+
else 4,
|
47 |
+
use_time=True,
|
48 |
+
num_layers=self.hparams.enc_num_layers,
|
49 |
+
last_fc=True,
|
50 |
+
last_fc_dim_out=self.hparams.input_dim,
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
if self.hparams.encoder_type == "transformer":
|
54 |
+
self.encoder = TimeTransformerEncoder(
|
55 |
+
enc_dim,
|
56 |
+
dim_ctx=dim_ctx,
|
57 |
+
num_heads=self.hparams.num_heads
|
58 |
+
if self.hparams.get("num_heads")
|
59 |
+
else 4,
|
60 |
+
use_time=True,
|
61 |
+
num_layers=self.hparams.enc_num_layers,
|
62 |
+
last_fc=True,
|
63 |
+
last_fc_dim_out=self.hparams.input_dim,
|
64 |
+
dropout=self.hparams.get("attn_dropout", 0.0)
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
raise ValueError
|
68 |
+
|
69 |
+
def forward(self, x, beta):
|
70 |
+
"""
|
71 |
+
Input:
|
72 |
+
x: [B,G,D] latent
|
73 |
+
beta: B
|
74 |
+
Output:
|
75 |
+
eta: [B,G,D]
|
76 |
+
"""
|
77 |
+
B, G = x.shape[:2]
|
78 |
+
if self.hparams.get("use_timestep_embedder"):
|
79 |
+
time_emb = self.time_embedder(beta).unsqueeze(1)
|
80 |
+
else:
|
81 |
+
beta = beta.view(B, 1, 1)
|
82 |
+
time_emb = torch.cat(
|
83 |
+
[beta, torch.sin(beta), torch.cos(beta)], dim=-1
|
84 |
+
) # [B,1,3]
|
85 |
+
|
86 |
+
ctx = time_emb
|
87 |
+
x_emb = self.embedding(x)
|
88 |
+
|
89 |
+
out = self.encoder(x_emb, ctx=ctx)
|
90 |
+
|
91 |
+
if self.hparams.residual:
|
92 |
+
out = out + x
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
class CondDiffNetwork(nn.Module):
|
97 |
+
def __init__(self, input_dim, residual, **kwargs):
|
98 |
+
"""
|
99 |
+
Transformer Encoder + Decoder.
|
100 |
+
"""
|
101 |
+
super().__init__()
|
102 |
+
self.input_dim = input_dim
|
103 |
+
self.residual = residual
|
104 |
+
self.__dict__.update(kwargs)
|
105 |
+
self.hparams = DotMap(self.__dict__)
|
106 |
+
|
107 |
+
self._build_model()
|
108 |
+
|
109 |
+
def _build_model(self):
|
110 |
+
self.act = F.leaky_relu
|
111 |
+
if self.hparams.get("use_timestep_embedder"):
|
112 |
+
self.time_embedder = TimestepEmbedder(self.hparams.timestep_embedder_dim)
|
113 |
+
dim_ctx = self.hparams.timestep_embedder_dim
|
114 |
+
else:
|
115 |
+
dim_ctx = 3
|
116 |
+
"""
|
117 |
+
Encoder part
|
118 |
+
"""
|
119 |
+
enc_dim = self.hparams.context_embedding_dim
|
120 |
+
self.context_embedding = nn.Linear(self.hparams.context_dim, enc_dim)
|
121 |
+
if self.hparams.encoder_type == "transformer":
|
122 |
+
self.encoder = TimeTransformerEncoder(
|
123 |
+
enc_dim,
|
124 |
+
3,
|
125 |
+
num_heads=4,
|
126 |
+
use_time=self.hparams.encoder_use_time,
|
127 |
+
num_layers=self.hparams.enc_num_layers
|
128 |
+
if self.hparams.get("enc_num_layers")
|
129 |
+
else 3,
|
130 |
+
last_fc=False,
|
131 |
+
)
|
132 |
+
|
133 |
+
elif self.hparams.encoder_type == "pointwise":
|
134 |
+
self.encoder = TimePointWiseEncoder(
|
135 |
+
enc_dim,
|
136 |
+
dim_ctx=None,
|
137 |
+
use_time=self.hparams.encoder_use_time,
|
138 |
+
num_layers=self.hparams.enc_num_layers,
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
raise ValueError
|
142 |
+
|
143 |
+
"""
|
144 |
+
Decoder part
|
145 |
+
"""
|
146 |
+
dec_dim = self.hparams.embedding_dim
|
147 |
+
input_dim = self.hparams.input_dim
|
148 |
+
self.query_embedding = nn.Linear(self.hparams.input_dim, dec_dim)
|
149 |
+
if self.hparams.decoder_type == "transformer_decoder":
|
150 |
+
self.decoder = TimeTransformerDecoder(
|
151 |
+
dec_dim,
|
152 |
+
enc_dim,
|
153 |
+
dim_ctx=dim_ctx,
|
154 |
+
num_heads=4,
|
155 |
+
last_fc=True,
|
156 |
+
last_fc_dim_out=input_dim,
|
157 |
+
num_layers=self.hparams.dec_num_layers
|
158 |
+
if self.hparams.get("dec_num_layers")
|
159 |
+
else 3,
|
160 |
+
)
|
161 |
+
elif self.hparams.decoder_type == "transformer_encoder":
|
162 |
+
self.decoder = TimeTransformerEncoder(
|
163 |
+
dec_dim,
|
164 |
+
dim_ctx=enc_dim + dim_ctx,
|
165 |
+
num_heads=4,
|
166 |
+
last_fc=True,
|
167 |
+
last_fc_dim_out=input_dim,
|
168 |
+
num_layers=self.hparams.dec_num_layers
|
169 |
+
if self.hparams.get("dec_num_layers")
|
170 |
+
else 3,
|
171 |
+
)
|
172 |
+
else:
|
173 |
+
raise ValueError
|
174 |
+
|
175 |
+
def forward(self, x, beta, context):
|
176 |
+
"""
|
177 |
+
Input:
|
178 |
+
x: [B,G,D] intrinsic
|
179 |
+
beta: B
|
180 |
+
context: [B,G,D2] or [B, D2] condition
|
181 |
+
Output:
|
182 |
+
eta: [B,G,D]
|
183 |
+
"""
|
184 |
+
# print(f"x: {x.shape} context: {context.shape} beta: {beta.shape}")
|
185 |
+
B, G = x.shape[:2]
|
186 |
+
|
187 |
+
if self.hparams.get("use_timestep_embedder"):
|
188 |
+
time_emb = self.time_embedder(beta).unsqueeze(1)
|
189 |
+
else:
|
190 |
+
beta = beta.view(B, 1, 1)
|
191 |
+
time_emb = torch.cat(
|
192 |
+
[beta, torch.sin(beta), torch.cos(beta)], dim=-1
|
193 |
+
) # [B,1,3]
|
194 |
+
ctx = time_emb
|
195 |
+
"""
|
196 |
+
Encoding
|
197 |
+
"""
|
198 |
+
cout = self.context_embedding(context)
|
199 |
+
cout = self.encoder(cout, ctx=ctx if self.hparams.encoder_use_time else None)
|
200 |
+
|
201 |
+
if cout.ndim == 2:
|
202 |
+
cout = cout.unsqueeze(1).expand(-1, G, -1)
|
203 |
+
|
204 |
+
"""
|
205 |
+
Decoding
|
206 |
+
"""
|
207 |
+
out = self.query_embedding(x)
|
208 |
+
if self.hparams.get("use_pos_encoding"):
|
209 |
+
out = self.pos_encoding(out)
|
210 |
+
|
211 |
+
if self.hparams.decoder_type == "transformer_encoder":
|
212 |
+
try:
|
213 |
+
ctx = ctx.expand(-1, G, -1)
|
214 |
+
if cout.ndim == 2:
|
215 |
+
cout = cout.unsqueeze(1)
|
216 |
+
cout = cout.expand(-1, G, -1)
|
217 |
+
ctx = torch.cat([ctx, cout], -1)
|
218 |
+
except Exception as e:
|
219 |
+
print(e, G, ctx.shape, cout.shape)
|
220 |
+
out = self.decoder(out, ctx=ctx)
|
221 |
+
else:
|
222 |
+
out = self.decoder(out, cout, ctx=ctx)
|
223 |
+
|
224 |
+
# if hasattr(self, "last_fc"):
|
225 |
+
# out = self.last_fc(out)
|
226 |
+
|
227 |
+
if self.hparams.residual:
|
228 |
+
out = out + x
|
229 |
+
return out
|
salad/model_components/simple_module.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
from salad.model_components.transformer import TimeMLP
|
7 |
+
|
8 |
+
|
9 |
+
class TimePointwiseLayer(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
dim_in,
|
13 |
+
dim_ctx,
|
14 |
+
mlp_ratio=2,
|
15 |
+
act=F.leaky_relu,
|
16 |
+
dropout=0.0,
|
17 |
+
use_time=False,
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
self.use_time = use_time
|
21 |
+
self.act = act
|
22 |
+
self.mlp1 = TimeMLP(
|
23 |
+
dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
|
24 |
+
)
|
25 |
+
self.norm1 = nn.LayerNorm(dim_in)
|
26 |
+
|
27 |
+
self.mlp2 = TimeMLP(
|
28 |
+
dim_in, dim_in * mlp_ratio, dim_in, dim_ctx, use_time=use_time
|
29 |
+
)
|
30 |
+
self.norm2 = nn.LayerNorm(dim_in)
|
31 |
+
self.dropout = nn.Dropout(dropout)
|
32 |
+
|
33 |
+
def forward(self, x, ctx=None):
|
34 |
+
res = x
|
35 |
+
x = self.mlp1(x, ctx=ctx)
|
36 |
+
x = self.norm1(x + res)
|
37 |
+
|
38 |
+
res = x
|
39 |
+
x = self.mlp2(x, ctx=ctx)
|
40 |
+
x = self.norm2(x + res)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class TimePointWiseEncoder(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
dim_in,
|
48 |
+
dim_ctx=None,
|
49 |
+
mlp_ratio=2,
|
50 |
+
act=F.leaky_relu,
|
51 |
+
dropout=0.0,
|
52 |
+
use_time=True,
|
53 |
+
num_layers=6,
|
54 |
+
last_fc=False,
|
55 |
+
last_fc_dim_out=None,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.last_fc = last_fc
|
59 |
+
if last_fc:
|
60 |
+
self.fc = nn.Linear(dim_in, last_fc_dim_out)
|
61 |
+
self.layers = nn.ModuleList(
|
62 |
+
[
|
63 |
+
TimePointwiseLayer(
|
64 |
+
dim_in,
|
65 |
+
dim_ctx=dim_ctx,
|
66 |
+
mlp_ratio=mlp_ratio,
|
67 |
+
act=act,
|
68 |
+
dropout=dropout,
|
69 |
+
use_time=use_time,
|
70 |
+
)
|
71 |
+
for _ in range(num_layers)
|
72 |
+
]
|
73 |
+
)
|
74 |
+
|
75 |
+
def forward(self, x, ctx=None):
|
76 |
+
for i, layer in enumerate(self.layers):
|
77 |
+
x = layer(x, ctx=ctx)
|
78 |
+
if self.last_fc:
|
79 |
+
x = self.fc(x)
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class TimestepEmbedder(nn.Module):
|
84 |
+
"""
|
85 |
+
Embeds scalar timesteps into vector representations.
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
89 |
+
super().__init__()
|
90 |
+
self.mlp = nn.Sequential(
|
91 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
92 |
+
nn.SiLU(),
|
93 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
94 |
+
)
|
95 |
+
self.frequency_embedding_size = frequency_embedding_size
|
96 |
+
|
97 |
+
@staticmethod
|
98 |
+
def timestep_embedding(t, dim, max_period=10000):
|
99 |
+
"""
|
100 |
+
Create sinusoidal timestep embeddings.
|
101 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
102 |
+
These may be fractional.
|
103 |
+
:param dim: the dimension of the output.
|
104 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
105 |
+
:return: an (N, D) Tensor of positional embeddings.
|
106 |
+
"""
|
107 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
108 |
+
half = dim // 2
|
109 |
+
freqs = torch.exp(
|
110 |
+
-math.log(max_period)
|
111 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
112 |
+
/ half
|
113 |
+
).to(device=t.device)
|
114 |
+
args = t[:, None].float() * freqs[None]
|
115 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
116 |
+
if dim % 2:
|
117 |
+
embedding = torch.cat(
|
118 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
119 |
+
)
|
120 |
+
return embedding
|
121 |
+
|
122 |
+
def forward(self, t):
|
123 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
124 |
+
t_emb = self.mlp(t_freq)
|
125 |
+
return t_emb
|
salad/model_components/transformer.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of time conditioned Transformer.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
class PositionalEncoding(nn.Module):
|
11 |
+
def __init__(self, d_hid, n_position=200):
|
12 |
+
super(PositionalEncoding, self).__init__()
|
13 |
+
|
14 |
+
# Not a parameter
|
15 |
+
self.register_buffer(
|
16 |
+
"pos_table", self._get_sinusoid_encoding_table(n_position, d_hid)
|
17 |
+
)
|
18 |
+
|
19 |
+
def _get_sinusoid_encoding_table(self, n_position, d_hid):
|
20 |
+
"""Sinusoid position encoding table"""
|
21 |
+
# TODO: make it with torch instead of numpy
|
22 |
+
|
23 |
+
def get_position_angle_vec(position):
|
24 |
+
return [
|
25 |
+
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
26 |
+
for hid_j in range(d_hid)
|
27 |
+
]
|
28 |
+
|
29 |
+
sinusoid_table = np.array(
|
30 |
+
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
31 |
+
)
|
32 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
33 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
34 |
+
|
35 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
"""
|
39 |
+
Input:
|
40 |
+
x: [B,N,D]
|
41 |
+
"""
|
42 |
+
return x + self.pos_table[:, : x.size(1)].clone().detach()
|
43 |
+
|
44 |
+
|
45 |
+
class ConcatSquashLinear(nn.Module):
|
46 |
+
def __init__(self, dim_in, dim_out, dim_ctx):
|
47 |
+
super(ConcatSquashLinear, self).__init__()
|
48 |
+
self._layer = nn.Linear(dim_in, dim_out)
|
49 |
+
self._hyper_bias = nn.Linear(dim_ctx, dim_out, bias=False)
|
50 |
+
self._hyper_gate = nn.Linear(dim_ctx, dim_out)
|
51 |
+
|
52 |
+
def forward(self, ctx, x):
|
53 |
+
assert ctx.dim() == x.dim()
|
54 |
+
gate = torch.sigmoid(self._hyper_gate(ctx))
|
55 |
+
bias = self._hyper_bias(ctx)
|
56 |
+
ret = self._layer(x) * gate + bias
|
57 |
+
return ret
|
58 |
+
|
59 |
+
|
60 |
+
class TimeMLP(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
dim_in,
|
64 |
+
dim_h,
|
65 |
+
dim_out,
|
66 |
+
dim_ctx=None,
|
67 |
+
act=F.relu,
|
68 |
+
dropout=0.0,
|
69 |
+
use_time=False,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.act = act
|
73 |
+
self.use_time = use_time
|
74 |
+
|
75 |
+
dim_h = int(dim_h)
|
76 |
+
if use_time:
|
77 |
+
self.fc1 = ConcatSquashLinear(dim_in, dim_h, dim_ctx)
|
78 |
+
self.fc2 = ConcatSquashLinear(dim_h, dim_out, dim_ctx)
|
79 |
+
else:
|
80 |
+
self.fc1 = nn.Linear(dim_in, dim_h)
|
81 |
+
self.fc2 = nn.Linear(dim_h, dim_out)
|
82 |
+
self.dropout = nn.Dropout(dropout)
|
83 |
+
|
84 |
+
def forward(self, x, ctx=None):
|
85 |
+
if self.use_time:
|
86 |
+
x = self.fc1(x=x, ctx=ctx)
|
87 |
+
else:
|
88 |
+
x = self.fc1(x)
|
89 |
+
|
90 |
+
x = self.act(x)
|
91 |
+
x = self.dropout(x)
|
92 |
+
if self.use_time:
|
93 |
+
x = self.fc2(x=x, ctx=ctx)
|
94 |
+
else:
|
95 |
+
x = self.fc2(x)
|
96 |
+
|
97 |
+
x = self.dropout(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
class MultiHeadAttention(nn.Module):
|
102 |
+
def __init__(self, dim_self, dim_ref, num_heads, dropout=0.0):
|
103 |
+
super().__init__()
|
104 |
+
self.num_heads = num_heads
|
105 |
+
head_dim = dim_self // num_heads
|
106 |
+
self.scale = head_dim**-0.5
|
107 |
+
self.to_queries = nn.Linear(dim_self, dim_self)
|
108 |
+
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2)
|
109 |
+
self.project = nn.Linear(dim_self, dim_self)
|
110 |
+
self.dropout = nn.Dropout(dropout)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
x,
|
115 |
+
y=None,
|
116 |
+
mask=None,
|
117 |
+
alpha=None,
|
118 |
+
):
|
119 |
+
y = y if y is not None else x
|
120 |
+
b_a, n, c = x.shape
|
121 |
+
b, m, d = y.shape
|
122 |
+
# b n h dh
|
123 |
+
queries = self.to_queries(x).reshape(
|
124 |
+
b_a, n, self.num_heads, c // self.num_heads
|
125 |
+
)
|
126 |
+
# b m 2 h dh
|
127 |
+
keys_values = self.to_keys_values(y).reshape(
|
128 |
+
b, m, 2, self.num_heads, c // self.num_heads
|
129 |
+
)
|
130 |
+
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
|
131 |
+
if alpha is not None:
|
132 |
+
out, attention = self.forward_interpolation(
|
133 |
+
queries, keys, values, alpha, mask
|
134 |
+
)
|
135 |
+
else:
|
136 |
+
attention = torch.einsum("bnhd,bmhd->bnmh", queries, keys) * self.scale
|
137 |
+
if mask is not None:
|
138 |
+
if mask.dim() == 2:
|
139 |
+
mask = mask.unsqueeze(1)
|
140 |
+
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
|
141 |
+
attention = attention.softmax(dim=2)
|
142 |
+
attention = self.dropout(attention)
|
143 |
+
out = torch.einsum("bnmh,bmhd->bnhd", attention, values).reshape(b, n, c)
|
144 |
+
out = self.project(out)
|
145 |
+
return out, attention
|
146 |
+
|
147 |
+
|
148 |
+
class TimeTransformerEncoderLayer(nn.Module):
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
dim_self,
|
152 |
+
dim_ctx=None,
|
153 |
+
num_heads=1,
|
154 |
+
mlp_ratio=2.0,
|
155 |
+
act=F.leaky_relu,
|
156 |
+
dropout=0.0,
|
157 |
+
use_time=True,
|
158 |
+
):
|
159 |
+
super().__init__()
|
160 |
+
self.use_time = use_time
|
161 |
+
self.act = act
|
162 |
+
self.attn = MultiHeadAttention(dim_self, dim_self, num_heads, dropout)
|
163 |
+
self.attn_norm = nn.LayerNorm(dim_self)
|
164 |
+
|
165 |
+
mlp_ratio = int(mlp_ratio)
|
166 |
+
self.mlp = TimeMLP(
|
167 |
+
dim_self, dim_self * mlp_ratio, dim_self, dim_ctx, use_time=use_time
|
168 |
+
)
|
169 |
+
self.norm = nn.LayerNorm(dim_self)
|
170 |
+
self.dropout = nn.Dropout(dropout)
|
171 |
+
|
172 |
+
def forward(self, x, ctx=None):
|
173 |
+
res = x
|
174 |
+
x, attn = self.attn(x)
|
175 |
+
x = self.attn_norm(x + res)
|
176 |
+
|
177 |
+
res = x
|
178 |
+
x = self.mlp(x, ctx=ctx)
|
179 |
+
x = self.norm(x + res)
|
180 |
+
|
181 |
+
return x, attn
|
182 |
+
|
183 |
+
|
184 |
+
class TimeTransformerDecoderLayer(TimeTransformerEncoderLayer):
|
185 |
+
def __init__(
|
186 |
+
self,
|
187 |
+
dim_self,
|
188 |
+
dim_ref,
|
189 |
+
dim_ctx=None,
|
190 |
+
num_heads=1,
|
191 |
+
mlp_ratio=2,
|
192 |
+
act=F.leaky_relu,
|
193 |
+
dropout=0.0,
|
194 |
+
use_time=True,
|
195 |
+
):
|
196 |
+
super().__init__(
|
197 |
+
dim_self=dim_self,
|
198 |
+
dim_ctx=dim_ctx,
|
199 |
+
num_heads=num_heads,
|
200 |
+
mlp_ratio=mlp_ratio,
|
201 |
+
act=act,
|
202 |
+
dropout=dropout,
|
203 |
+
use_time=use_time,
|
204 |
+
)
|
205 |
+
self.cross_attn = MultiHeadAttention(dim_self, dim_ref, num_heads, dropout)
|
206 |
+
self.cross_attn_norm = nn.LayerNorm(dim_self)
|
207 |
+
|
208 |
+
def forward(self, x, y, ctx=None):
|
209 |
+
res = x
|
210 |
+
x, attn = self.attn(x)
|
211 |
+
x = self.attn_norm(x + res)
|
212 |
+
|
213 |
+
res = x
|
214 |
+
x, attn = self.cross_attn(x, y)
|
215 |
+
x = self.cross_attn_norm(x + res)
|
216 |
+
|
217 |
+
res = x
|
218 |
+
x = self.mlp(x, ctx=ctx)
|
219 |
+
x = self.norm(x + res)
|
220 |
+
|
221 |
+
return x, attn
|
222 |
+
|
223 |
+
|
224 |
+
class TimeTransformerEncoder(nn.Module):
|
225 |
+
def __init__(
|
226 |
+
self,
|
227 |
+
dim_self,
|
228 |
+
dim_ctx=None,
|
229 |
+
num_heads=1,
|
230 |
+
mlp_ratio=2.0,
|
231 |
+
act=F.leaky_relu,
|
232 |
+
dropout=0.0,
|
233 |
+
use_time=True,
|
234 |
+
num_layers=3,
|
235 |
+
last_fc=False,
|
236 |
+
last_fc_dim_out=None,
|
237 |
+
):
|
238 |
+
super().__init__()
|
239 |
+
self.last_fc = last_fc
|
240 |
+
if last_fc:
|
241 |
+
self.fc = nn.Linear(dim_self, last_fc_dim_out)
|
242 |
+
self.layers = nn.ModuleList(
|
243 |
+
[
|
244 |
+
TimeTransformerEncoderLayer(
|
245 |
+
dim_self,
|
246 |
+
dim_ctx=dim_ctx,
|
247 |
+
num_heads=num_heads,
|
248 |
+
mlp_ratio=mlp_ratio,
|
249 |
+
act=act,
|
250 |
+
dropout=dropout,
|
251 |
+
use_time=use_time,
|
252 |
+
)
|
253 |
+
for _ in range(num_layers)
|
254 |
+
]
|
255 |
+
)
|
256 |
+
|
257 |
+
def forward(self, x, ctx=None):
|
258 |
+
for i, layer in enumerate(self.layers):
|
259 |
+
x, attn = layer(x, ctx=ctx)
|
260 |
+
|
261 |
+
if self.last_fc:
|
262 |
+
x = self.fc(x)
|
263 |
+
return x
|
264 |
+
|
265 |
+
|
266 |
+
class TimeTransformerDecoder(nn.Module):
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
dim_self,
|
270 |
+
dim_ref,
|
271 |
+
dim_ctx=None,
|
272 |
+
num_heads=1,
|
273 |
+
mlp_ratio=2.0,
|
274 |
+
act=F.leaky_relu,
|
275 |
+
dropout=0.0,
|
276 |
+
use_time=True,
|
277 |
+
num_layers=3,
|
278 |
+
last_fc=True,
|
279 |
+
last_fc_dim_out=None,
|
280 |
+
):
|
281 |
+
super().__init__()
|
282 |
+
self.last_fc = last_fc
|
283 |
+
if last_fc:
|
284 |
+
self.fc = nn.Linear(dim_self, last_fc_dim_out)
|
285 |
+
|
286 |
+
self.layers = nn.ModuleList(
|
287 |
+
[
|
288 |
+
TimeTransformerDecoderLayer(
|
289 |
+
dim_self,
|
290 |
+
dim_ref,
|
291 |
+
dim_ctx,
|
292 |
+
num_heads,
|
293 |
+
mlp_ratio,
|
294 |
+
act,
|
295 |
+
dropout,
|
296 |
+
use_time,
|
297 |
+
)
|
298 |
+
for _ in range(num_layers)
|
299 |
+
]
|
300 |
+
)
|
301 |
+
|
302 |
+
def forward(self, x, y, ctx=None):
|
303 |
+
for i, layer in enumerate(self.layers):
|
304 |
+
x, attn = layer(x, y=y, ctx=ctx)
|
305 |
+
if self.last_fc:
|
306 |
+
x = self.fc(x)
|
307 |
+
|
308 |
+
return x
|
salad/model_components/variance_schedule.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch.nn import Linear, Module
|
4 |
+
|
5 |
+
class VarianceSchedule(Module):
|
6 |
+
def __init__(self, num_steps, beta_1, beta_T, mode="linear"):
|
7 |
+
super().__init__()
|
8 |
+
# assert mode in ("linear",)
|
9 |
+
self.num_steps = num_steps
|
10 |
+
self.beta_1 = beta_1
|
11 |
+
self.beta_T = beta_T
|
12 |
+
self.mode = mode
|
13 |
+
|
14 |
+
if mode == "linear":
|
15 |
+
betas = torch.linspace(beta_1, beta_T, steps=num_steps)
|
16 |
+
elif mode == "quad":
|
17 |
+
betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2
|
18 |
+
elif mode == "cosine":
|
19 |
+
cosine_s = 8e-3
|
20 |
+
timesteps = torch.arange(num_steps + 1) / num_steps + cosine_s
|
21 |
+
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
22 |
+
alphas = torch.cos(alphas).pow(2)
|
23 |
+
betas = 1 - alphas[1:] / alphas[:-1]
|
24 |
+
betas = betas.clamp(max=0.999)
|
25 |
+
|
26 |
+
betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding
|
27 |
+
|
28 |
+
alphas = 1 - betas
|
29 |
+
log_alphas = torch.log(alphas)
|
30 |
+
for i in range(1, log_alphas.size(0)): # 1 to T
|
31 |
+
log_alphas[i] += log_alphas[i - 1]
|
32 |
+
alpha_bars = log_alphas.exp()
|
33 |
+
|
34 |
+
sigmas_flex = torch.sqrt(betas)
|
35 |
+
sigmas_inflex = torch.zeros_like(sigmas_flex)
|
36 |
+
for i in range(1, sigmas_flex.size(0)):
|
37 |
+
sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[
|
38 |
+
i
|
39 |
+
]
|
40 |
+
sigmas_inflex = torch.sqrt(sigmas_inflex)
|
41 |
+
|
42 |
+
self.register_buffer("betas", betas)
|
43 |
+
self.register_buffer("alphas", alphas)
|
44 |
+
self.register_buffer("alpha_bars", alpha_bars)
|
45 |
+
self.register_buffer("sigmas_flex", sigmas_flex)
|
46 |
+
self.register_buffer("sigmas_inflex", sigmas_inflex)
|
47 |
+
|
48 |
+
def uniform_sample_t(self, batch_size):
|
49 |
+
ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size)
|
50 |
+
return ts.tolist()
|
51 |
+
|
52 |
+
def get_sigmas(self, t, flexibility):
|
53 |
+
assert 0 <= flexibility and flexibility <= 1
|
54 |
+
sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (
|
55 |
+
1 - flexibility
|
56 |
+
)
|
57 |
+
return sigmas
|
salad/models/__init__.py
ADDED
File without changes
|
salad/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (163 Bytes). View file
|
|
salad/models/__pycache__/base_model.cpython-39.pyc
ADDED
Binary file (4.6 kB). View file
|
|
salad/models/__pycache__/language_phase1.cpython-39.pyc
ADDED
Binary file (8.83 kB). View file
|
|
salad/models/__pycache__/language_phase2.cpython-39.pyc
ADDED
Binary file (6.12 kB). View file
|
|
salad/models/__pycache__/phase1.cpython-39.pyc
ADDED
Binary file (2.12 kB). View file
|
|
salad/models/__pycache__/phase2.cpython-39.pyc
ADDED
Binary file (5.37 kB). View file
|
|
salad/models/base_model.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from salad.data.dataset import SALADDataset
|
5 |
+
from salad.utils.train_util import PolyDecayScheduler
|
6 |
+
|
7 |
+
|
8 |
+
class BaseModel(pl.LightningModule):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
network,
|
12 |
+
variance_schedule,
|
13 |
+
**kwargs,
|
14 |
+
):
|
15 |
+
super().__init__()
|
16 |
+
self.save_hyperparameters(logger=False)
|
17 |
+
self.net = network
|
18 |
+
self.var_sched = variance_schedule
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return self.get_loss(x)
|
22 |
+
|
23 |
+
def step(self, x, stage: str):
|
24 |
+
loss = self(x)
|
25 |
+
self.log(
|
26 |
+
f"{stage}/loss",
|
27 |
+
loss,
|
28 |
+
on_step=stage == "train",
|
29 |
+
prog_bar=True,
|
30 |
+
)
|
31 |
+
return loss
|
32 |
+
|
33 |
+
def training_step(self, batch, batch_idx):
|
34 |
+
x = batch
|
35 |
+
return self.step(x, "train")
|
36 |
+
|
37 |
+
def add_noise(self, x, t):
|
38 |
+
"""
|
39 |
+
Input:
|
40 |
+
x: [B,D] or [B,G,D]
|
41 |
+
t: list of size B
|
42 |
+
Output:
|
43 |
+
x_noisy: [B,D]
|
44 |
+
beta: [B]
|
45 |
+
e_rand: [B,D]
|
46 |
+
"""
|
47 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
48 |
+
beta = self.var_sched.betas[t]
|
49 |
+
|
50 |
+
c0 = torch.sqrt(alpha_bar).view(-1, 1) # [B,1]
|
51 |
+
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1)
|
52 |
+
|
53 |
+
e_rand = torch.randn_like(x)
|
54 |
+
if e_rand.dim() == 3:
|
55 |
+
c0 = c0.unsqueeze(1)
|
56 |
+
c1 = c1.unsqueeze(1)
|
57 |
+
|
58 |
+
x_noisy = c0 * x + c1 * e_rand
|
59 |
+
|
60 |
+
return x_noisy, beta, e_rand
|
61 |
+
|
62 |
+
def get_loss(
|
63 |
+
self,
|
64 |
+
x0,
|
65 |
+
t=None,
|
66 |
+
noisy_in=False,
|
67 |
+
beta_in=None,
|
68 |
+
e_rand_in=None,
|
69 |
+
):
|
70 |
+
if x0.dim() == 2:
|
71 |
+
B, D = x0.shape
|
72 |
+
else:
|
73 |
+
B, G, D = x0.shape
|
74 |
+
if not noisy_in:
|
75 |
+
if t is None:
|
76 |
+
t = self.var_sched.uniform_sample_t(B)
|
77 |
+
x_noisy, beta, e_rand = self.add_noise(x0, t)
|
78 |
+
else:
|
79 |
+
x_noisy = x0
|
80 |
+
beta = beta_in
|
81 |
+
e_rand = e_rand_in
|
82 |
+
|
83 |
+
e_theta = self.net(x_noisy, beta=beta)
|
84 |
+
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
|
85 |
+
return loss
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def sample(
|
89 |
+
self,
|
90 |
+
batch_size=0,
|
91 |
+
return_traj=False,
|
92 |
+
):
|
93 |
+
raise NotImplementedError
|
94 |
+
|
95 |
+
def validation_epoch_end(self, outputs):
|
96 |
+
if self.hparams.no_run_validation:
|
97 |
+
return
|
98 |
+
if not self.trainer.sanity_checking:
|
99 |
+
if (self.current_epoch) % self.hparams.validation_step == 0:
|
100 |
+
self.validation()
|
101 |
+
|
102 |
+
def _build_dataset(self, stage):
|
103 |
+
if hasattr(self, f"data_{stage}"):
|
104 |
+
return getattr(self, f"data_{stage}")
|
105 |
+
if stage == "train":
|
106 |
+
ds = SALADDataset(**self.hparams.dataset_kwargs)
|
107 |
+
else:
|
108 |
+
dataset_kwargs = self.hparams.dataset_kwargs.copy()
|
109 |
+
dataset_kwargs["repeat"] = 1
|
110 |
+
ds = SALADDataset(**dataset_kwargs)
|
111 |
+
setattr(self, f"data_{stage}", ds)
|
112 |
+
return ds
|
113 |
+
|
114 |
+
def _build_dataloader(self, stage):
|
115 |
+
try:
|
116 |
+
ds = getattr(self, f"data_{stage}")
|
117 |
+
except:
|
118 |
+
ds = self._build_dataset(stage)
|
119 |
+
|
120 |
+
return torch.utils.data.DataLoader(
|
121 |
+
ds,
|
122 |
+
batch_size=self.hparams.batch_size,
|
123 |
+
shuffle=stage == "train",
|
124 |
+
drop_last=stage == "train",
|
125 |
+
num_workers=4,
|
126 |
+
)
|
127 |
+
|
128 |
+
def train_dataloader(self):
|
129 |
+
return self._build_dataloader("train")
|
130 |
+
|
131 |
+
def val_dataloader(self):
|
132 |
+
return self._build_dataloader("val")
|
133 |
+
|
134 |
+
def test_dataloader(self):
|
135 |
+
return self._build_dataloader("test")
|
136 |
+
|
137 |
+
def configure_optimizers(self):
|
138 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
139 |
+
scheduler = PolyDecayScheduler(optimizer, self.hparams.lr, power=0.999)
|
140 |
+
return [optimizer], [scheduler]
|
141 |
+
|
142 |
+
#TODO move get_wandb_logger to logutil.py
|
143 |
+
def get_wandb_logger(self):
|
144 |
+
for logger in self.logger:
|
145 |
+
if isinstance(logger, pl.loggers.wandb.WandbLogger):
|
146 |
+
return logger
|
147 |
+
return None
|
salad/models/language_phase1.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import BertModel, BertTokenizer
|
5 |
+
|
6 |
+
from salad.model_components.lstm import LSTM
|
7 |
+
from salad.models.phase1 import Phase1Model
|
8 |
+
from salad.utils import imageutil, nputil, visutil
|
9 |
+
from salad.utils.spaghetti_util import (clip_eigenvalues,
|
10 |
+
generate_zc_from_sj_gaus,
|
11 |
+
get_mesh_from_spaghetti, load_mesher,
|
12 |
+
load_spaghetti, project_eigenvectors)
|
13 |
+
from salad.utils.train_util import get_dropout_mask
|
14 |
+
from salad.data.dataset import LangSALADDataset
|
15 |
+
|
16 |
+
|
17 |
+
class LangPhase1Model(Phase1Model):
|
18 |
+
def __init__(self, network, variance_schedule, **kwargs):
|
19 |
+
super().__init__(network, variance_schedule, **kwargs)
|
20 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
21 |
+
if self.hparams.get("use_lstm"):
|
22 |
+
self.bertmodel = LSTM(
|
23 |
+
text_dim=768, embedding_dim=768, vocab_size=30522, padding_idx=0
|
24 |
+
)
|
25 |
+
else:
|
26 |
+
self.bertmodel = BertModel.from_pretrained("bert-base-uncased")
|
27 |
+
if self.hparams.get("text_encoder_freeze"):
|
28 |
+
for p in self.bertmodel.parameters():
|
29 |
+
p.requires_grad_(False)
|
30 |
+
|
31 |
+
def forward(self, x, text):
|
32 |
+
"""
|
33 |
+
Input:
|
34 |
+
x: [B,G,16]
|
35 |
+
text: list of length [B]
|
36 |
+
"""
|
37 |
+
B, G = x.shape[:2]
|
38 |
+
text = self.random_mask_text(text)
|
39 |
+
lang_emb = self.text_to_embedding(text)
|
40 |
+
return self.get_loss(x, lang_emb)
|
41 |
+
|
42 |
+
def tokenizing(self, text):
|
43 |
+
tokenized = self.tokenizer(
|
44 |
+
text, return_tensors="pt", padding=True, truncation=True
|
45 |
+
).to(self.device)
|
46 |
+
return tokenized
|
47 |
+
|
48 |
+
def text_to_embedding(self, text):
|
49 |
+
"""
|
50 |
+
text: list of length [B]
|
51 |
+
return [B,768]
|
52 |
+
"""
|
53 |
+
tokenized = self.tokenizing(text)
|
54 |
+
if self.hparams.get("use_lstm"):
|
55 |
+
lang_emb, _ = self.bertmodel(tokenized.input_ids)
|
56 |
+
else:
|
57 |
+
if self.hparams.get("text_encoder_return_seq"):
|
58 |
+
lang_emb = self.bertmodel(**tokenized).last_hidden_state
|
59 |
+
else:
|
60 |
+
lang_emb = self.bertmodel(**tokenized).pooler_output
|
61 |
+
if lang_emb.ndim == 2:
|
62 |
+
lang_emb = lang_emb.unsqueeze(1)
|
63 |
+
return lang_emb
|
64 |
+
|
65 |
+
def random_mask_text(self, text):
|
66 |
+
text = list(text)
|
67 |
+
B = len(text)
|
68 |
+
if self.hparams.get("classifier_free_guidance"):
|
69 |
+
random_dp_mask = get_dropout_mask(
|
70 |
+
B, self.hparams.conditioning_dropout_prob, self.device
|
71 |
+
)
|
72 |
+
for i in range(B):
|
73 |
+
if random_dp_mask[i] == 0:
|
74 |
+
text[i] = ""
|
75 |
+
return text
|
76 |
+
|
77 |
+
def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None):
|
78 |
+
B, G, D = x0.shape
|
79 |
+
|
80 |
+
if not noisy_in:
|
81 |
+
if t is None:
|
82 |
+
t = self.var_sched.uniform_sample_t(B)
|
83 |
+
x_noisy, beta, e_rand = self.add_noise(x0, t)
|
84 |
+
else:
|
85 |
+
x_noisy = x0
|
86 |
+
beta = beta_in
|
87 |
+
e_rand = e_rand_in
|
88 |
+
|
89 |
+
e_theta = self.net(x_noisy, beta, cond)
|
90 |
+
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
|
91 |
+
return loss
|
92 |
+
|
93 |
+
def step(self, batch, stage: str):
|
94 |
+
x, text = batch
|
95 |
+
loss = self(x, text)
|
96 |
+
self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
|
97 |
+
return loss
|
98 |
+
|
99 |
+
@torch.no_grad()
|
100 |
+
def sample(
|
101 |
+
self,
|
102 |
+
num_samples_or_text,
|
103 |
+
return_traj=False,
|
104 |
+
return_cond=False,
|
105 |
+
classifier_free_guidance=True,
|
106 |
+
free_guidance_weight=2.0,
|
107 |
+
):
|
108 |
+
if isinstance(num_samples_or_text, str):
|
109 |
+
num_samples_or_text = [num_samples_or_text]
|
110 |
+
if isinstance(num_samples_or_text, int):
|
111 |
+
batch_size = num_samples_or_text
|
112 |
+
ds = self._build_dataset("val")
|
113 |
+
texts = [ds[i][1] for i in range(batch_size)]
|
114 |
+
elif isinstance(num_samples_or_text, list):
|
115 |
+
texts = num_samples_or_text
|
116 |
+
batch_size = len(num_samples_or_text)
|
117 |
+
if self.hparams.get("use_zc"):
|
118 |
+
x_T = torch.randn([batch_size, 16, 512]).to(self.device)
|
119 |
+
else:
|
120 |
+
x_T = torch.randn([batch_size, 16, 16]).to(self.device)
|
121 |
+
G = x_T.shape[1]
|
122 |
+
lang_emb = self.text_to_embedding(texts)
|
123 |
+
|
124 |
+
if classifier_free_guidance:
|
125 |
+
null_texts = ["" for _ in range(batch_size)]
|
126 |
+
null_lang_emb = self.text_to_embedding(null_texts)
|
127 |
+
|
128 |
+
traj = {self.var_sched.num_steps: x_T}
|
129 |
+
for t in range(self.var_sched.num_steps, 0, -1):
|
130 |
+
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
|
131 |
+
alpha = self.var_sched.alphas[t]
|
132 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
133 |
+
sigma = self.var_sched.get_sigmas(t, flexibility=0)
|
134 |
+
|
135 |
+
c0 = 1.0 / torch.sqrt(alpha)
|
136 |
+
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
|
137 |
+
|
138 |
+
x_t = traj[t]
|
139 |
+
|
140 |
+
beta = self.var_sched.betas[[t] * batch_size]
|
141 |
+
e_theta = self.net(x_t, beta=beta, context=lang_emb)
|
142 |
+
|
143 |
+
if classifier_free_guidance:
|
144 |
+
null_e_theta = self.net(x_t, beta=beta, context=null_lang_emb)
|
145 |
+
w = free_guidance_weight
|
146 |
+
e_theta = (1 + w) * e_theta - w * null_e_theta
|
147 |
+
|
148 |
+
x_next = c0 * (x_t - c1 * e_theta) + sigma * z
|
149 |
+
traj[t - 1] = x_next.detach()
|
150 |
+
|
151 |
+
traj[t] = traj[t].cpu()
|
152 |
+
|
153 |
+
if not return_traj:
|
154 |
+
del traj[t]
|
155 |
+
|
156 |
+
if return_traj:
|
157 |
+
if return_cond:
|
158 |
+
return traj, lang_emb
|
159 |
+
return traj
|
160 |
+
else:
|
161 |
+
if return_cond:
|
162 |
+
return traj[0], lang_emb
|
163 |
+
return traj[0]
|
164 |
+
|
165 |
+
def sampling_gaussians(
|
166 |
+
self,
|
167 |
+
num_samples_or_text,
|
168 |
+
classifier_free_guidance=True,
|
169 |
+
free_guidance_weight=2.0,
|
170 |
+
return_cond=False,
|
171 |
+
):
|
172 |
+
gaus = self.sample(
|
173 |
+
num_samples_or_text,
|
174 |
+
classifier_free_guidance=classifier_free_guidance,
|
175 |
+
free_guidance_weight=free_guidance_weight,
|
176 |
+
return_cond=return_cond,
|
177 |
+
)
|
178 |
+
if isinstance(gaus, tuple):
|
179 |
+
text = gaus[1]
|
180 |
+
gaus = gaus[0]
|
181 |
+
# gaus = reflect_and_concat_gmms(raw_gaus)
|
182 |
+
if self.hparams.get("global_normalization"):
|
183 |
+
if not hasattr(self, "data_val"):
|
184 |
+
self._build_dataset("val")
|
185 |
+
if self.hparams.get("global_normalization") == "partial":
|
186 |
+
gaus = self.data_val.unnormalize_global_static(gaus, slice(12, None))
|
187 |
+
elif self.hparams.get("global_normalization") == "all":
|
188 |
+
gaus = self.data_val.unnormalize_global_static(gaus, slice(None))
|
189 |
+
|
190 |
+
gaus = project_eigenvectors(clip_eigenvalues(gaus))
|
191 |
+
if return_cond:
|
192 |
+
return gaus, text
|
193 |
+
return gaus
|
194 |
+
|
195 |
+
def _build_dataset(self, stage):
|
196 |
+
if hasattr(self, f"data_{stage}"):
|
197 |
+
return getattr(self, f"data_{stage}")
|
198 |
+
|
199 |
+
ds_class = (
|
200 |
+
LangSALADDataset
|
201 |
+
)
|
202 |
+
if stage == "train":
|
203 |
+
ds = ds_class(**self.hparams.dataset_kwargs)
|
204 |
+
else:
|
205 |
+
dataset_kwargs = self.hparams.dataset_kwargs.copy()
|
206 |
+
dataset_kwargs["repeat"] = 1
|
207 |
+
ds = ds_class(**dataset_kwargs)
|
208 |
+
setattr(self, f"data_{stage}", ds)
|
209 |
+
return ds
|
210 |
+
|
211 |
+
def validation_zc(self):
|
212 |
+
vis_num_shapes = 4
|
213 |
+
vis_zcs = []
|
214 |
+
vis_texts = []
|
215 |
+
ds = self._build_dataset("val")
|
216 |
+
for i in [0, 1, 2, 3]:
|
217 |
+
zcs, text = ds[i]
|
218 |
+
vis_zcs.append(zcs)
|
219 |
+
vis_texts.append(text)
|
220 |
+
vis_zcs = torch.stack(vis_zcs, 0)
|
221 |
+
ldm_zcs = self.sample(vis_texts)
|
222 |
+
|
223 |
+
if not hasattr(self, "spaghetti"):
|
224 |
+
self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
|
225 |
+
spaghetti = self.spaghetti
|
226 |
+
|
227 |
+
if not hasattr(self, "mesher"):
|
228 |
+
self.mesher = load_mesher(self.device)
|
229 |
+
mesher = self.mesher
|
230 |
+
|
231 |
+
wandb_logger = self.get_wandb_logger()
|
232 |
+
images = []
|
233 |
+
for i in range(vis_num_shapes):
|
234 |
+
try:
|
235 |
+
v, f = get_mesh_from_spaghetti(spaghetti, mesher, vis_zcs[i], res=128)
|
236 |
+
gt_img = visutil.render_mesh(v, f, resolution=(256, 256))
|
237 |
+
except:
|
238 |
+
pass
|
239 |
+
try:
|
240 |
+
v, f = get_mesh_from_spaghetti(spaghetti, mesher, ldm_zcs[i], res=128)
|
241 |
+
pred_img = visutil.render_mesh(v, f, resolution=(256, 256))
|
242 |
+
except:
|
243 |
+
pass
|
244 |
+
|
245 |
+
img = imageutil.merge_images([gt_img, pred_img])
|
246 |
+
img = imageutil.draw_text(
|
247 |
+
img,
|
248 |
+
f"Left: GT | Right: Pred \n{vis_texts[i]}",
|
249 |
+
font_size=14,
|
250 |
+
max_seq_length=50,
|
251 |
+
)
|
252 |
+
images.append([img])
|
253 |
+
|
254 |
+
images = imageutil.merge_images(images)
|
255 |
+
wandb_logger.log_image("vis", [images])
|
256 |
+
|
257 |
+
def validation(self):
|
258 |
+
if self.hparams.get("use_zc"):
|
259 |
+
self.validation_zc()
|
260 |
+
return
|
261 |
+
|
262 |
+
vis_num_shapes = 4
|
263 |
+
vis_gaus = []
|
264 |
+
vis_texts = []
|
265 |
+
ds = self._build_dataset("val")
|
266 |
+
vis_indices = [18453, 13036, 13204, 48244]
|
267 |
+
for i in vis_indices:
|
268 |
+
gaus, text = ds[i]
|
269 |
+
vis_gaus.append(gaus)
|
270 |
+
vis_texts.append(text)
|
271 |
+
|
272 |
+
vis_gaus = torch.stack(vis_gaus, 0)
|
273 |
+
if self.hparams.get("global_normalization"):
|
274 |
+
if self.hparams.get("global_normalization") == "partial":
|
275 |
+
vis_gaus = self.data_val.unnormalize_global_static(
|
276 |
+
vis_gaus, slice(12, None)
|
277 |
+
)
|
278 |
+
elif self.hparams.get("global_normalization") == "all":
|
279 |
+
vis_gaus = self.dataval.unnormalize_global_static(vis_gaus, slice(None))
|
280 |
+
|
281 |
+
# vis_gaus = reflect_and_concat_gmms(vis_gaus)
|
282 |
+
pred_gaus = self.sampling_gaussians(vis_texts)
|
283 |
+
|
284 |
+
if not hasattr(self, "spaghetti"):
|
285 |
+
self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
|
286 |
+
spaghetti = self.spaghetti
|
287 |
+
|
288 |
+
if not hasattr(self, "mesher"):
|
289 |
+
self.mesher = load_mesher(self.device)
|
290 |
+
mesher = self.mesher
|
291 |
+
|
292 |
+
""" get intrinsics """
|
293 |
+
# TODO change the ckpt path.
|
294 |
+
if not hasattr(self, "phase2_model"):
|
295 |
+
phase2_ckpt = "/home/juil/pvddir/results/phase2/augment_final_0214/0214_202607/checkpoints/epoch=4999-val_loss=0.0000.ckpt"
|
296 |
+
self.phase2_model = SpaghettiConditionSALDM.load_from_checkpoint(
|
297 |
+
phase2_ckpt, strict=False
|
298 |
+
).to(self.device)
|
299 |
+
self.phase2_model.eval()
|
300 |
+
for p in self.phase2_model.parameters():
|
301 |
+
p.requires_grad_(False)
|
302 |
+
|
303 |
+
phase2_model = self.phase2_model
|
304 |
+
|
305 |
+
gt_sj = phase2_model.sample(vis_gaus)
|
306 |
+
pred_sj = phase2_model.sample(pred_gaus)
|
307 |
+
|
308 |
+
gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_sj, vis_gaus)
|
309 |
+
pred_zcs = generate_zc_from_sj_gaus(spaghetti, pred_sj, pred_gaus)
|
310 |
+
|
311 |
+
wandb_logger = self.get_wandb_logger()
|
312 |
+
images = []
|
313 |
+
for i in range(vis_num_shapes):
|
314 |
+
gt_img = visutil.render_gaussians(vis_gaus[i], resolution=(256, 256))
|
315 |
+
try:
|
316 |
+
v, f = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
|
317 |
+
gt_mesh_img = visutil.render_mesh(v, f, resolution=(256, 256))
|
318 |
+
gt_img = imageutil.merge_images([gt_img, gt_mesh_img])
|
319 |
+
except:
|
320 |
+
pass
|
321 |
+
|
322 |
+
pred_img = visutil.render_gaussians(pred_gaus[i], resolution=(256, 256))
|
323 |
+
try:
|
324 |
+
v, f = get_mesh_from_spaghetti(spaghetti, mesher, pred_zcs[i], res=128)
|
325 |
+
pred_mesh_img = visutil.render_mesh(v, f, resolution=(256, 256))
|
326 |
+
pred_img = imageutil.merge_images([pred_img, pred_mesh_img])
|
327 |
+
except:
|
328 |
+
pass
|
329 |
+
|
330 |
+
img = imageutil.merge_images([gt_img, pred_img])
|
331 |
+
img = imageutil.draw_text(
|
332 |
+
img,
|
333 |
+
f"Left: GT | Right: Pred \n{vis_texts[i]}",
|
334 |
+
font_size=14,
|
335 |
+
max_seq_length=50,
|
336 |
+
)
|
337 |
+
images.append([img])
|
338 |
+
|
339 |
+
images = imageutil.merge_images(images)
|
340 |
+
wandb_logger.log_image("vis", [images])
|
salad/models/language_phase2.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers import BertModel, BertTokenizer
|
5 |
+
|
6 |
+
from salad.model_components.lstm import LSTM
|
7 |
+
from salad.models.language_phase1 import LangPhase1Model
|
8 |
+
from salad.utils import imageutil, nputil, visutil
|
9 |
+
from salad.utils.spaghetti_util import (generate_zc_from_sj_gaus,
|
10 |
+
get_mesh_from_spaghetti, load_mesher,
|
11 |
+
load_spaghetti)
|
12 |
+
from salad.utils.train_util import get_dropout_mask
|
13 |
+
|
14 |
+
|
15 |
+
class LangPhase2Model(LangPhase1Model):
|
16 |
+
def __init__(self, network, variance_schedule, **kwargs):
|
17 |
+
super().__init__(network, variance_schedule, **kwargs)
|
18 |
+
|
19 |
+
def random_mask_gaus_text(self, gaus, text):
|
20 |
+
if self.hparams.get("classifier_free_guidance"):
|
21 |
+
text = list(text)
|
22 |
+
B = gaus.shape[0]
|
23 |
+
random_dp_mask = get_dropout_mask(
|
24 |
+
B, self.hparams.conditioning_dropout_prob, self.device
|
25 |
+
)
|
26 |
+
gaus = gaus * random_dp_mask.unsqueeze(1).unsqueeze(2)
|
27 |
+
for i in range(B):
|
28 |
+
if random_dp_mask[i] == 0:
|
29 |
+
text[i] = ""
|
30 |
+
|
31 |
+
return gaus, text
|
32 |
+
|
33 |
+
def forward(self, x, gaus, text):
|
34 |
+
"""
|
35 |
+
Input:
|
36 |
+
x: [B,G,512]
|
37 |
+
gaus: [B,G,16]
|
38 |
+
text: list of [B]
|
39 |
+
"""
|
40 |
+
B, G = x.shape[:2]
|
41 |
+
gaus, text = self.random_mask_gaus_text(gaus, text)
|
42 |
+
lang_emb = self.text_to_embedding(text)
|
43 |
+
cond = self.cond_from_gaus_lang_f(gaus, lang_emb)
|
44 |
+
|
45 |
+
return self.get_loss(x, cond)
|
46 |
+
|
47 |
+
def step(self, batch, stage):
|
48 |
+
x, gaus, text = batch
|
49 |
+
loss = self(x, gaus, text)
|
50 |
+
self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
|
51 |
+
return loss
|
52 |
+
|
53 |
+
def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None):
|
54 |
+
B, G, D = x0.shape
|
55 |
+
if not noisy_in:
|
56 |
+
if t is None:
|
57 |
+
t = self.var_sched.uniform_sample_t(B)
|
58 |
+
x_noisy, beta, e_rand = self.add_noise(x0, t)
|
59 |
+
else:
|
60 |
+
x_noisy = x0
|
61 |
+
beta = beta_in
|
62 |
+
e_rand = e_rand_in
|
63 |
+
e_theta = self.net(x_noisy, beta, cond)
|
64 |
+
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
|
65 |
+
return loss
|
66 |
+
|
67 |
+
def cond_from_gaus_lang_f(self, gaus, lang_f):
|
68 |
+
gaus = nputil.np2th(gaus).to(self.device)
|
69 |
+
G = gaus.shape[1]
|
70 |
+
lang_f = nputil.np2th(lang_f).to(self.device)
|
71 |
+
assert gaus.ndim == 3
|
72 |
+
if lang_f.ndim == 2:
|
73 |
+
lang_f = lang_f.unsqueeze(1)
|
74 |
+
lang_f = lang_f.expand(-1, G, -1)
|
75 |
+
return torch.cat([gaus, lang_f], -1)
|
76 |
+
|
77 |
+
def generate_null_cond(self, B, G):
|
78 |
+
text = ["" for _ in range(B)]
|
79 |
+
lang_emb = self.text_to_embedding(text)
|
80 |
+
gaus = torch.zeros(B, G, 16, dtype=torch.float, device=self.device)
|
81 |
+
return self.cond_from_gaus_lang_f(gaus, lang_emb)
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def sample(
|
85 |
+
self,
|
86 |
+
num_samples_or_cond,
|
87 |
+
return_traj=False,
|
88 |
+
return_cond=False,
|
89 |
+
classifier_free_guidance=False,
|
90 |
+
free_guidance_weight=0.7,
|
91 |
+
):
|
92 |
+
|
93 |
+
if isinstance(num_samples_or_cond, int):
|
94 |
+
batch_size = num_samples_or_cond
|
95 |
+
ds = self._build_dataset("val")
|
96 |
+
batch_gaus = []
|
97 |
+
batch_text = []
|
98 |
+
for i in range(batch_size):
|
99 |
+
_, gaus, text = ds[i]
|
100 |
+
batch_gaus.append(gaus)
|
101 |
+
batch_text.append(text)
|
102 |
+
|
103 |
+
batch_gaus = torch.stack(batch_gaus, 0)
|
104 |
+
lang_emb = self.text_to_embedding(batch_text)
|
105 |
+
cond = self.cond_from_gaus_lang_f(batch_gaus, lang_emb).to(self.device)
|
106 |
+
|
107 |
+
elif isinstance(num_samples_or_cond, np.ndarray) or isinstance(
|
108 |
+
num_samples_or_cond, torch.Tensor
|
109 |
+
):
|
110 |
+
cond = nputil.np2th(num_samples_or_cond).to(self.device)
|
111 |
+
batch_size = len(cond)
|
112 |
+
|
113 |
+
G = cond.shape[1]
|
114 |
+
if classifier_free_guidance:
|
115 |
+
null_cond = self.generate_null_cond(batch_size, G)
|
116 |
+
|
117 |
+
x_T = torch.randn([batch_size, 16, 512]).to(self.device)
|
118 |
+
traj = {self.var_sched.num_steps: x_T}
|
119 |
+
for t in range(self.var_sched.num_steps, 0, -1):
|
120 |
+
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
|
121 |
+
alpha = self.var_sched.alphas[t]
|
122 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
123 |
+
sigma = self.var_sched.get_sigmas(t, flexibility=0)
|
124 |
+
|
125 |
+
c0 = 1.0 / torch.sqrt(alpha)
|
126 |
+
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
|
127 |
+
|
128 |
+
x_t = traj[t]
|
129 |
+
|
130 |
+
beta = self.var_sched.betas[[t] * batch_size]
|
131 |
+
e_theta = self.net(x_t, beta=beta, context=cond)
|
132 |
+
|
133 |
+
if classifier_free_guidance:
|
134 |
+
null_e_theta = self.net(x_t, beta=beta, context=null_cond)
|
135 |
+
w = free_guidance_weight
|
136 |
+
e_theta = (1 + w) * e_theta - w * null_e_theta
|
137 |
+
|
138 |
+
x_next = c0 * (x_t - c1 * e_theta) + sigma * z
|
139 |
+
traj[t - 1] = x_next.detach()
|
140 |
+
|
141 |
+
traj[t] = traj[t].cpu()
|
142 |
+
|
143 |
+
if not return_traj:
|
144 |
+
del traj[t]
|
145 |
+
|
146 |
+
if return_traj:
|
147 |
+
if return_cond:
|
148 |
+
return traj, cond
|
149 |
+
return traj
|
150 |
+
else:
|
151 |
+
if return_cond:
|
152 |
+
return traj[0], cond
|
153 |
+
return traj[0]
|
154 |
+
|
155 |
+
def validation(self):
|
156 |
+
vis_num_shapes = 4
|
157 |
+
vis_gt_sj = []
|
158 |
+
vis_gaus = []
|
159 |
+
vis_texts = []
|
160 |
+
ds = self._build_dataset("val")
|
161 |
+
vis_indices = [18453, 13036, 13204, 48244]
|
162 |
+
for i in vis_indices:
|
163 |
+
sj, gaus, text = ds[i]
|
164 |
+
vis_gt_sj.append(sj)
|
165 |
+
vis_gaus.append(gaus)
|
166 |
+
vis_texts.append(text)
|
167 |
+
|
168 |
+
vis_gt_sj = torch.stack(vis_gt_sj, 0)
|
169 |
+
vis_gaus = torch.stack(vis_gaus, 0).to(self.device)
|
170 |
+
vis_lang_f = self.text_to_embedding(vis_texts)
|
171 |
+
vis_cond = self.cond_from_gaus_lang_f(vis_gaus, vis_lang_f)
|
172 |
+
pred_sj = self.sample(vis_cond)
|
173 |
+
|
174 |
+
if not hasattr(self, "spaghetti"):
|
175 |
+
self.spaghetti = load_spaghetti(self.device, self.hparams.spaghetti_tag)
|
176 |
+
spaghetti = self.spaghetti
|
177 |
+
|
178 |
+
if not hasattr(self, "mesher"):
|
179 |
+
self.mesher = load_mesher(self.device)
|
180 |
+
mesher = self.mesher
|
181 |
+
|
182 |
+
gt_zcs = generate_zc_from_sj_gaus(spaghetti, vis_gt_sj, vis_gaus)
|
183 |
+
pred_zcs = generate_zc_from_sj_gaus(spaghetti, pred_sj, vis_gaus)
|
184 |
+
|
185 |
+
wandb_logger = self.get_wandb_logger()
|
186 |
+
for i in range(vis_num_shapes):
|
187 |
+
gaus_img = visutil.render_gaussians(vis_gaus[i], resolution=(256, 256))
|
188 |
+
vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
|
189 |
+
gt_mesh_img = visutil.render_mesh(vert, face, resolution=(256, 256))
|
190 |
+
img = [gaus_img, gt_mesh_img]
|
191 |
+
try:
|
192 |
+
vert, face = get_mesh_from_spaghetti(spaghetti, mesher, pred_zcs[i])
|
193 |
+
pred_mesh_img = visutil.render_mesh(vert, face, resolution=(256, 256))
|
194 |
+
img.append(pred_mesh_img)
|
195 |
+
except Exception as e:
|
196 |
+
print(e)
|
197 |
+
img = imageutil.merge_images(img)
|
198 |
+
img = imageutil.draw_text(
|
199 |
+
img, vis_texts[i], font_size=14, max_seq_length=50
|
200 |
+
)
|
201 |
+
wandb_logger.log_image("vis", [img])
|
salad/models/phase1.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from salad.models.base_model import BaseModel
|
4 |
+
from salad.utils import nputil, thutil
|
5 |
+
from salad.utils.spaghetti_util import clip_eigenvalues, project_eigenvectors
|
6 |
+
|
7 |
+
class Phase1Model(BaseModel):
|
8 |
+
def __init__(self, network, variance_schedule, **kwargs):
|
9 |
+
super().__init__(network, variance_schedule, **kwargs)
|
10 |
+
|
11 |
+
@torch.no_grad()
|
12 |
+
def sample(
|
13 |
+
self,
|
14 |
+
batch_size=0,
|
15 |
+
return_traj=False,
|
16 |
+
):
|
17 |
+
x_T = torch.randn([batch_size, 16, 16]).to(self.device)
|
18 |
+
|
19 |
+
traj = {self.var_sched.num_steps: x_T}
|
20 |
+
for t in range(self.var_sched.num_steps, 0, -1):
|
21 |
+
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
|
22 |
+
alpha = self.var_sched.alphas[t]
|
23 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
24 |
+
sigma = self.var_sched.get_sigmas(t, flexibility=0)
|
25 |
+
|
26 |
+
c0 = 1.0 / torch.sqrt(alpha)
|
27 |
+
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
|
28 |
+
|
29 |
+
x_t = traj[t]
|
30 |
+
|
31 |
+
beta = self.var_sched.betas[[t] * batch_size]
|
32 |
+
e_theta = self.net(x_t, beta=beta)
|
33 |
+
# print(e_theta.norm(-1).mean())
|
34 |
+
|
35 |
+
x_next = c0 * (x_t - c1 * e_theta) + sigma * z
|
36 |
+
traj[t - 1] = x_next.detach()
|
37 |
+
|
38 |
+
traj[t] = traj[t].cpu()
|
39 |
+
|
40 |
+
if not return_traj:
|
41 |
+
del traj[t]
|
42 |
+
if return_traj:
|
43 |
+
return traj
|
44 |
+
else:
|
45 |
+
return traj[0]
|
46 |
+
|
47 |
+
def sampling_gaussians(self, num_shapes):
|
48 |
+
"""
|
49 |
+
Return:
|
50 |
+
ldm_gaus: np.ndarray
|
51 |
+
gt_gaus: np.ndarray
|
52 |
+
"""
|
53 |
+
ldm_gaus = self.sample(num_shapes)
|
54 |
+
|
55 |
+
if self.hparams.get("global_normalization"):
|
56 |
+
if not hasattr(self, "data_val"):
|
57 |
+
self._build_dataset("val")
|
58 |
+
if self.hparams.get("global_normalization") == "partial":
|
59 |
+
ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(12,None))
|
60 |
+
elif self.hparams.get("global_normalization") == "all":
|
61 |
+
ldm_gaus = self.data_val.unnormalize_global_static(ldm_gaus, slice(None))
|
62 |
+
|
63 |
+
ldm_gaus = clip_eigenvalues(ldm_gaus)
|
64 |
+
ldm_gaus = project_eigenvectors(ldm_gaus)
|
65 |
+
return ldm_gaus
|
salad/models/phase2.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from salad.models.base_model import BaseModel
|
8 |
+
from salad.utils import imageutil, nputil, sysutil, thutil, visutil
|
9 |
+
from salad.utils.spaghetti_util import (clip_eigenvalues,
|
10 |
+
generate_zc_from_sj_gaus,
|
11 |
+
get_mesh_from_spaghetti, load_mesher,
|
12 |
+
load_spaghetti, project_eigenvectors)
|
13 |
+
|
14 |
+
|
15 |
+
class Phase2Model(BaseModel):
|
16 |
+
def __init__(self, network, variance_schedule, **kwargs):
|
17 |
+
super().__init__(network, variance_schedule, **kwargs)
|
18 |
+
|
19 |
+
def forward(self, x, cond):
|
20 |
+
return self.get_loss(x, cond)
|
21 |
+
|
22 |
+
def step(self, batch, stage: str):
|
23 |
+
x, cond = batch
|
24 |
+
loss = self(x, cond)
|
25 |
+
self.log(f"{stage}/loss", loss, on_step=stage == "train", prog_bar=True)
|
26 |
+
return loss
|
27 |
+
|
28 |
+
def get_loss(self, x0, cond, t=None, noisy_in=False, beta_in=None, e_rand_in=None):
|
29 |
+
B, G, D = x0.shape
|
30 |
+
|
31 |
+
if not noisy_in:
|
32 |
+
if t is None:
|
33 |
+
t = self.var_sched.uniform_sample_t(B)
|
34 |
+
x_noisy, beta, e_rand = self.add_noise(x0, t)
|
35 |
+
else:
|
36 |
+
x_noisy = x0
|
37 |
+
beta = beta_in
|
38 |
+
e_rand = e_rand_in
|
39 |
+
|
40 |
+
e_theta = self.net(x_noisy, beta, cond)
|
41 |
+
loss = F.mse_loss(e_theta.flatten(), e_rand.flatten(), reduction="mean")
|
42 |
+
return loss
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def sample(
|
46 |
+
self,
|
47 |
+
num_samples_or_gaus: Union[torch.Tensor, np.ndarray, int],
|
48 |
+
return_traj=False,
|
49 |
+
classifier_free_guidance=None,
|
50 |
+
free_guidance_weight=-0.7,
|
51 |
+
augment_condition_in_test=False,
|
52 |
+
return_cond=False,
|
53 |
+
):
|
54 |
+
if isinstance(num_samples_or_gaus, int):
|
55 |
+
batch_size = num_samples_or_gaus
|
56 |
+
ds = self._build_dataset("val")
|
57 |
+
cond = torch.stack([ds[i][1] for i in range(batch_size)], 0)
|
58 |
+
|
59 |
+
elif isinstance(num_samples_or_gaus, np.ndarray) or isinstance(
|
60 |
+
num_samples_or_gaus, torch.Tensor
|
61 |
+
):
|
62 |
+
cond = nputil.np2th(num_samples_or_gaus)
|
63 |
+
if cond.dim() == 2:
|
64 |
+
cond = cond[None]
|
65 |
+
batch_size = len(cond)
|
66 |
+
else:
|
67 |
+
raise ValueError(
|
68 |
+
"'num_samples_or_gaus' should be int, torch.Tensor or np.ndarray."
|
69 |
+
)
|
70 |
+
|
71 |
+
x_T = torch.randn([batch_size, 16, 512]).to(self.device)
|
72 |
+
cond = cond.to(self.device)
|
73 |
+
|
74 |
+
traj = {self.var_sched.num_steps: x_T}
|
75 |
+
for t in range(self.var_sched.num_steps, 0, -1):
|
76 |
+
z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)
|
77 |
+
alpha = self.var_sched.alphas[t]
|
78 |
+
alpha_bar = self.var_sched.alpha_bars[t]
|
79 |
+
sigma = self.var_sched.get_sigmas(t, flexibility=0)
|
80 |
+
|
81 |
+
c0 = 1.0 / torch.sqrt(alpha)
|
82 |
+
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
|
83 |
+
|
84 |
+
x_t = traj[t]
|
85 |
+
|
86 |
+
beta = self.var_sched.betas[[t] * batch_size]
|
87 |
+
e_theta = self.net(x_t, beta=beta, context=cond)
|
88 |
+
|
89 |
+
x_next = c0 * (x_t - c1 * e_theta) + sigma * z
|
90 |
+
traj[t - 1] = x_next.detach()
|
91 |
+
|
92 |
+
traj[t] = traj[t].cpu()
|
93 |
+
|
94 |
+
if not return_traj:
|
95 |
+
del traj[t]
|
96 |
+
|
97 |
+
if return_traj:
|
98 |
+
if return_cond:
|
99 |
+
return traj, cond
|
100 |
+
return traj
|
101 |
+
else:
|
102 |
+
if return_cond:
|
103 |
+
return traj[0], cond
|
104 |
+
return traj[0]
|
105 |
+
|
106 |
+
def validation(self):
|
107 |
+
latent_ds = self._build_dataset("val")
|
108 |
+
vis_num_shapes = 3
|
109 |
+
num_variations = 3
|
110 |
+
sysutil.clean_gpu()
|
111 |
+
|
112 |
+
if not hasattr(self, "spaghetti"):
|
113 |
+
spaghetti = load_spaghetti(
|
114 |
+
self.device,
|
115 |
+
self.hparams.spaghetti_tag
|
116 |
+
if self.hparams.get("spaghetti_tag")
|
117 |
+
else "chairs_large",
|
118 |
+
)
|
119 |
+
self.spaghetti = spaghetti
|
120 |
+
else:
|
121 |
+
spaghetti = self.spaghetti
|
122 |
+
|
123 |
+
if not hasattr(self, "mesher"):
|
124 |
+
mesher = load_mesher(self.device)
|
125 |
+
self.mesher = mesher
|
126 |
+
else:
|
127 |
+
mesher = self.mesher
|
128 |
+
|
129 |
+
"""======== Sampling ========"""
|
130 |
+
gt_zs = []
|
131 |
+
gt_gaus = []
|
132 |
+
|
133 |
+
gt_zs, gt_gaus = zip(*[latent_ds[i + 3] for i in range(vis_num_shapes)])
|
134 |
+
gt_zs, gt_gaus = list(map(lambda x: torch.stack(x), [gt_zs, gt_gaus]))
|
135 |
+
if self.hparams.get("sj_global_normalization"):
|
136 |
+
gt_zs = thutil.th2np(gt_zs)
|
137 |
+
gt_zs = latent_ds.unnormalize_sj_global_static(gt_zs)
|
138 |
+
gt_zs = nputil.np2th(gt_zs).to(self.device)
|
139 |
+
|
140 |
+
gt_gaus_repeated = gt_gaus.repeat_interleave(num_variations, 0)
|
141 |
+
clean_ldm_zs, clean_gaus = self.sample(gt_gaus_repeated, return_cond=True)
|
142 |
+
clean_gaus = project_eigenvectors(clip_eigenvalues(clean_gaus))
|
143 |
+
clean_zcs = generate_zc_from_sj_gaus(spaghetti, clean_ldm_zs, clean_gaus)
|
144 |
+
gt_zcs = generate_zc_from_sj_gaus(spaghetti, gt_zs, gt_gaus)
|
145 |
+
sysutil.clean_gpu()
|
146 |
+
|
147 |
+
"""=========================="""
|
148 |
+
|
149 |
+
""" Spaghetti Decoding """
|
150 |
+
wandb_logger = self.get_wandb_logger()
|
151 |
+
resolution = (256, 256)
|
152 |
+
for i in range(vis_num_shapes):
|
153 |
+
img_per_shape = []
|
154 |
+
gaus_img = visutil.render_gaussians(gt_gaus[i], resolution=resolution)
|
155 |
+
vert, face = get_mesh_from_spaghetti(spaghetti, mesher, gt_zcs[i], res=128)
|
156 |
+
gt_mesh_img = visutil.render_mesh(vert, face, resolution=resolution)
|
157 |
+
gt_img = imageutil.merge_images([gaus_img, gt_mesh_img])
|
158 |
+
gt_img = imageutil.draw_text(gt_img, "GT", font_size=24)
|
159 |
+
img_per_shape.append(gt_img)
|
160 |
+
for j in range(num_variations):
|
161 |
+
try:
|
162 |
+
gaus_img = visutil.render_gaussians(
|
163 |
+
clean_gaus[i * num_variations + j], resolution=resolution
|
164 |
+
)
|
165 |
+
vert, face = get_mesh_from_spaghetti(
|
166 |
+
spaghetti, mesher, clean_zcs[i * num_variations + j], res=128
|
167 |
+
)
|
168 |
+
mesh_img = visutil.render_mesh(vert, face, resolution=resolution)
|
169 |
+
pred_img = imageutil.merge_images([gaus_img, mesh_img])
|
170 |
+
pred_img = imageutil.draw_text(
|
171 |
+
pred_img, f"{j}-th clean gaus", font_size=24
|
172 |
+
)
|
173 |
+
img_per_shape.append(pred_img)
|
174 |
+
except Exception as e:
|
175 |
+
print(e)
|
176 |
+
|
177 |
+
try:
|
178 |
+
image = imageutil.merge_images(img_per_shape)
|
179 |
+
wandb_logger.log_image("visualization", [image])
|
180 |
+
except Exception as e:
|
181 |
+
print(e)
|
182 |
+
|
183 |
+
""" ================== """
|
salad/spaghetti/.gitignore
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/assets/*
|
2 |
+
!/assets/readme_resources/
|
3 |
+
!/assets/ui_resources/
|
4 |
+
!/assets/splits/
|
5 |
+
!/assets/mesh/
|
6 |
+
*.vtk
|
7 |
+
.idea/
|
8 |
+
__pycache__/
|
9 |
+
**_ig_**
|