File size: 14,112 Bytes
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40d0f76
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531dfb5
2fd6166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Iterable
import os.path as osp

from hydra.core.config_store import ConfigStore
from hydra.conf import RunDir


@dataclass
class CustomHydraRunDir(RunDir):
    dir: str = './outputs/${run.name}/single'


@dataclass
class RunConfig:
    name: str = 'debug'
    job: str = 'train'
    mixed_precision: str = 'fp16'  # 'no'
    cpu: bool = False
    seed: int = 42
    val_before_training: bool = True
    vis_before_training: bool = True
    limit_train_batches: Optional[int] = None
    limit_val_batches: Optional[int] = None
    max_steps: int = 100_000
    checkpoint_freq: int = 1_000
    val_freq: int = 5_000
    vis_freq: int = 5_000
    # vis_freq: int = 10_000
    log_step_freq: int = 20
    print_step_freq: int = 100

    # config to run demo
    stage1_name: str = 'stage1'     # experiment name to the stage 1 model
    stage2_name: str = 'stage2'     # experiment name to the stage 2 model
    image_path: str = ''            # the path to the images for running demo, can be a single file or a glob pattern
    share: bool = False             # whether to run gradio with a temporal public url or not

    # abs path to working dir
    code_dir_abs: str = osp.dirname(osp.dirname(osp.abspath(__file__)))

    # Inference configs
    num_inference_steps: int = 1000
    diffusion_scheduler: Optional[str] = 'ddpm'
    num_samples: int = 1
    # num_sample_batches: Optional[int] = None
    num_sample_batches: Optional[int] = 2000  # XH: change to 2
    sample_from_ema: bool = False
    sample_save_evolutions: bool = False  # temporarily set by default
    save_name: str = 'sample'  # XH: additional save name
    redo: bool = False

    # for parallel sampling in slurm
    batch_start: int = 0
    batch_end: Optional[int] = None

    # Training configs
    freeze_feature_model: bool = True

    # Coloring training configs
    coloring_training_noise_std: float = 0.0
    coloring_sample_dir: Optional[str] = None

    sample_mode: str = 'sample'  # whether from noise or from some intermediate steps
    sample_noise_step: int = 500  # add noise to GT up to some steps, and then denoise
    sample_save_gt: bool = True


@dataclass
class LoggingConfig:
    wandb: bool = True
    wandb_project: str = 'pc2'



@dataclass
class PointCloudProjectionModelConfig:
    # Feature extraction arguments
    image_size: int = '${dataset.image_size}'
    image_feature_model: str = 'vit_base_patch16_224_mae' # or 'vit_small_patch16_224_msn' or 'identity'
    use_local_colors: bool = True
    use_local_features: bool = True
    use_global_features: bool = False
    use_mask: bool = True
    use_distance_transform: bool = True

    # Point cloud data arguments. Note these are here because the processing happens
    # inside the model, rather than inside the dataset.
    scale_factor: float = "${dataset.scale_factor}"
    colors_mean: float = 0.5
    colors_std: float = 0.5
    color_channels: int = 3
    predict_shape: bool = True
    predict_color: bool = False

    # added by XH
    load_sample_init: bool = False  # load init samples from file
    sample_init_scale: float = 1.0  # scale the initial pc samples
    test_init_with_gtpc: bool = False  # test time init samples with GT samples
    consistent_center: bool = True  # use consistent center prediction by CCD-3DR
    voxel_resolution_multiplier: float = 1  # increase network voxel resolution

    # predict binary segmentation
    predict_binary: bool = False # True for stage 1 model, False for others
    lw_binary: float = 3.0  # to have roughly the same magnitude of the binary segmentation loss
    # for separate model
    binary_training_noise_std: float = 0.1  # from github doc for predicting color
    self_conditioning: bool = False

@dataclass
class PVCNNAEModelConfig(PointCloudProjectionModelConfig):
    "my own model config, must inherit parent class"
    model_name: str = 'pvcnn-ae'
    latent_dim: int = 1024
    num_dec_blocks: int = 6
    block_dims: List[int] = field(default_factory=lambda: [512, 256])
    num_points: int = 1500
    bottleneck_dim: int = -1 # the input dim to the last MLP layer

@dataclass
class PointCloudDiffusionModelConfig(PointCloudProjectionModelConfig):
    model_name: str = 'pc2-diff-ho'  # default as behave

    # Diffusion arguments
    beta_start: float = 1e-5  # 0.00085
    beta_end: float = 8e-3  # 0.012
    beta_schedule: str = 'linear'  # 'custom'
    dm_pred_type: str = 'epsilon'  # diffusion model prediction type, sample (x0) or noise
    ddim_eta: float = 1.0  # DDIM eta parameter: 0 is the default one which does deterministic generation

    # Point cloud model arguments
    point_cloud_model: str = 'pvcnn'
    point_cloud_model_embed_dim: int = 64

    dataset_type: str = '${dataset.type}'

@dataclass
class CrossAttnHOModelConfig(PointCloudDiffusionModelConfig):
    model_name: str = 'diff-ho-attn'

    attn_type: str = 'coord3d+posenc-learnable'
    attn_weight: float = 1.0
    point_visible_test: str = 'combine'  # To compute point visibility: use all points or only human/object points


@dataclass
class DirectTransModelConfig(PointCloudProjectionModelConfig):
    model_name: str = 'direct-transl-ho'

    pooling: str = "avg"
    act: str = 'gelu'
    out_act: str = 'relu'
    # feat_dims_transl: Iterable[Any] = (384, 256, 128, 6) # cannot use List[int] https://github.com/facebookresearch/hydra/issues/1752#issuecomment-893174197
    # feat_dims_scale: Iterable[Any] = (384, 128, 64, 2)
    feat_dims_transl: List[int] = field(default_factory=lambda: [384, 256, 128, 6])
    feat_dims_scale: List[int] = field(default_factory=lambda: [384, 128, 64, 2])
    lw_transl: float = 10000.0
    lw_scale: float = 10000.0


@dataclass
class PointCloudColoringModelConfig(PointCloudProjectionModelConfig):
    # Projection arguments
    predict_shape: bool = False
    predict_color: bool = True

    # Point cloud model arguments
    point_cloud_model: str = 'pvcnn'
    point_cloud_model_layers: int = 1
    point_cloud_model_embed_dim: int = 64


@dataclass
class DatasetConfig:
    type: str


@dataclass
class PointCloudDatasetConfig(DatasetConfig):
    eval_split: str = 'val'
    max_points: int = 16_384
    image_size: int = 224
    scale_factor: float = 1.0
    restrict_model_ids: Optional[List] = None  # for only running on a subset of data points


@dataclass
class CO3DConfig(PointCloudDatasetConfig):
    type: str = 'co3dv2'
    # root: str = os.getenv('CO3DV2_DATASET_ROOT')
    root: str = "/BS/xxie-2/work/co3d/hydrant"
    category: str = 'hydrant'
    subset_name: str = 'fewview_dev'
    mask_images: bool = '${model.use_mask}'


@dataclass
class ShapeNetR2N2Config(PointCloudDatasetConfig):
    # added by XH
    fix_sample: bool = True
    category: str = 'chair'

    type: str = 'shapenet_r2n2'
    root: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1"
    r2n2_dir: str = "/BS/databases20/3d-r2n2"
    shapenet_dir: str = "/BS/chiban2/work/data_shapenet/ShapeNetCore.v1"
    preprocessed_r2n2_dir: str = "${dataset.root}/r2n2_preprocessed_renders"
    splits_file: str = "${dataset.root}/r2n2_standard_splits_from_ShapeNet_taxonomy.json"
    # splits_file: str = "${dataset.root}/pix2mesh_splits_val05.json"  # <-- incorrect
    scale_factor: float = 7.0
    point_cloud_filename: str = 'pointcloud_r2n2.npz'  # should use 'pointcloud_mesh.npz'



@dataclass
class BehaveDatasetConfig(PointCloudDatasetConfig):
    # added by XH
    type: str = 'behave'

    fix_sample: bool = True
    behave_dir: str = "/BS/xxie-5/static00/behave_release/sequences/"
    split_file: str = "" # specify you dataset split file here
    scale_factor: float = 7.0  # use the same as shapenet
    sample_ratio_hum: float = 0.5
    image_size: int = 224

    normalize_type: str = 'comb'
    smpl_type: str = 'gt'  # use which SMPL mesh to obtain normalization parameters
    test_transl_type: str = 'norm'

    load_corr_points: bool = False  # load autoencoder points for object and SMPL
    uniform_obj_sample: bool = False

    # configs for direct translation prediction
    bkg_type: str = 'none'
    bbox_params: str = 'none'
    ho_segm_pred_path: Optional[str] = None
    use_gt_transl: bool = False

    cam_noise_std: float = 0. # add noise to the camera pose
    sep_same_crop: bool = False # use same input image crop to separate models
    aug_blur: float = 0. # blur augmentation

    std_coverage: float=3.5 # a heuristic value to estimate translation

    v2v_path: str = '' # object v2v corr path

@dataclass
class ShapeDatasetConfig(BehaveDatasetConfig):
    "the dataset to train AE for aligned shapes"
    type: str = 'shape'
    fix_sample: bool = False
    split_file: str = "/BS/xxie-2/work/pc2-diff/experiments/splits/shapes-chair.pkl"


# TODO
@dataclass
class ShapeNetNMRConfig(PointCloudDatasetConfig):
    type: str = 'shapenet_nmr'
    shapenet_nmr_dir: str = "/work/lukemk/machine-learning-datasets/3d-reconstruction/ShapeNet_NMR/NMR_Dataset"
    synset_names: str = 'chair'  # comma-separated or 'all'
    augmentation: str = 'all'
    scale_factor: float = 7.0


@dataclass
class AugmentationConfig:
    # need to specify the variable type in order to define it properly
    max_radius: int = 0  # generate a random square to mask object, this is the radius for the square in pixel size, zero means no occlusion


@dataclass
class DataloaderConfig:
    # batch_size: int = 8  # 2 for debug
    batch_size: int = 16
    num_workers: int = 14  # 0 for debug # suggested by accelerator for gpu20


@dataclass
class LossConfig:
    diffusion_weight: float = 1.0
    rgb_weight: float = 1.0
    consistency_weight: float = 1.0


@dataclass
class CheckpointConfig:
    resume: Optional[str] = "test"
    resume_training: bool = True
    resume_training_optimizer: bool = True
    resume_training_scheduler: bool = True
    resume_training_state: bool = True


@dataclass
class ExponentialMovingAverageConfig:
    use_ema: bool = False
    # # From Diffusers EMA (should probably switch)
    # ema_inv_gamma: float = 1.0
    # ema_power: float = 0.75
    # ema_max_decay: float = 0.9999
    decay: float = 0.999
    update_every: int = 20


@dataclass
class OptimizerConfig:
    type: str
    name: str
    lr: float = 3e-4
    weight_decay: float = 0.0
    scale_learning_rate_with_batch_size: bool = False
    gradient_accumulation_steps: int = 1
    clip_grad_norm: Optional[float] = 50.0  # 5.0
    kwargs: Dict = field(default_factory=lambda: dict())


@dataclass
class AdadeltaOptimizerConfig(OptimizerConfig):
    type: str = 'torch'
    name: str = 'Adadelta'
    kwargs: Dict = field(default_factory=lambda: dict(
        weight_decay=1e-6,
    ))


@dataclass
class AdamOptimizerConfig(OptimizerConfig):
    type: str = 'torch'
    name: str = 'AdamW'
    weight_decay: float = 1e-6
    kwargs: Dict = field(default_factory=lambda: dict(betas=(0.95, 0.999)))


@dataclass
class SchedulerConfig:
    type: str
    kwargs: Dict = field(default_factory=lambda: dict())


@dataclass
class LinearSchedulerConfig(SchedulerConfig):
    type: str = 'transformers'
    kwargs: Dict = field(default_factory=lambda: dict(
        name='linear',
        num_warmup_steps=0,
        num_training_steps="${run.max_steps}",
    ))


@dataclass
class CosineSchedulerConfig(SchedulerConfig):
    type: str = 'transformers'
    kwargs: Dict = field(default_factory=lambda: dict(
        name='cosine',
        num_warmup_steps=2000,  # 0
        num_training_steps="${run.max_steps}",
    ))


@dataclass
class ProjectConfig:
    run: RunConfig
    logging: LoggingConfig
    dataset: PointCloudDatasetConfig
    augmentations: AugmentationConfig
    dataloader: DataloaderConfig
    loss: LossConfig
    model: PointCloudProjectionModelConfig
    ema: ExponentialMovingAverageConfig
    checkpoint: CheckpointConfig
    optimizer: OptimizerConfig
    scheduler: SchedulerConfig

    defaults: List[Any] = field(default_factory=lambda: [
        'custom_hydra_run_dir',
        {'run': 'default'},
        {'logging': 'default'},
        {'model': 'ho-attn'},
        # {'dataset': 'co3d'},
        {'dataset': 'behave'},
        {'augmentations': 'default'},
        {'dataloader': 'default'},
        {'ema': 'default'},
        {'loss': 'default'},
        {'checkpoint': 'default'},
        {'optimizer': 'adam'}, # default adamw
        {'scheduler': 'linear'},
        # {'scheduler': 'cosine'},
    ])


cs = ConfigStore.instance()
cs.store(name='custom_hydra_run_dir', node=CustomHydraRunDir, package="hydra.run")
cs.store(group='run', name='default', node=RunConfig)
cs.store(group='logging', name='default', node=LoggingConfig)
cs.store(group='model', name='diffrec', node=PointCloudDiffusionModelConfig)
cs.store(group='model', name='coloring_model', node=PointCloudColoringModelConfig)
cs.store(group='model', name='direct-transl', node=DirectTransModelConfig)
cs.store(group='model', name='ho-attn', node=CrossAttnHOModelConfig)
cs.store(group='model', name='pvcnn-ae', node=PVCNNAEModelConfig)
cs.store(group='dataset', name='co3d', node=CO3DConfig)
# TODO
cs.store(group='dataset', name='shapenet_r2n2', node=ShapeNetR2N2Config)
cs.store(group='dataset', name='behave', node=BehaveDatasetConfig)
cs.store(group='dataset', name='shape', node=ShapeDatasetConfig)
# cs.store(group='dataset', name='shapenet_nmr', node=ShapeNetNMRConfig)
cs.store(group='augmentations', name='default', node=AugmentationConfig)
cs.store(group='dataloader', name='default', node=DataloaderConfig)
cs.store(group='loss', name='default', node=LossConfig)
cs.store(group='ema', name='default', node=ExponentialMovingAverageConfig)
cs.store(group='checkpoint', name='default', node=CheckpointConfig)
cs.store(group='optimizer', name='adadelta', node=AdadeltaOptimizerConfig)
cs.store(group='optimizer', name='adam', node=AdamOptimizerConfig)
cs.store(group='scheduler', name='linear', node=LinearSchedulerConfig)
cs.store(group='scheduler', name='cosine', node=CosineSchedulerConfig)
cs.store(name='configs', node=ProjectConfig)