stevee00 commited on
Commit
df12227
·
verified ·
1 Parent(s): c88ec9c

Upload docs/TRAINING.md

Browse files
Files changed (1) hide show
  1. docs/TRAINING.md +352 -0
docs/TRAINING.md ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # InteriorFusion Training Guide
2
+
3
+ ## Hardware Requirements
4
+
5
+ | Stage | GPUs | VRAM Each | Duration | Cost (Cloud) |
6
+ |-------|------|-----------|----------|-------------|
7
+ | VAE Pre-training | 8× A100 (80GB) | 80GB | 7 days | ~$15K |
8
+ | Structure DiT | 32× A100 (80GB) | 80GB | 14 days | ~$30K |
9
+ | Material DiT | 16× A100 (80GB) | 80GB | 7 days | ~$15K |
10
+ | Fine-tuning | 8× A100 (80GB) | 80GB | 3 days | ~$5K |
11
+ | **Total** | **Variable** | — | **~4 weeks** | **~$65K** |
12
+
13
+ Minimum viable: 8× A100 (all stages, longer duration)
14
+ Budget option: 8× RTX 4090 (48GB) — requires gradient accumulation, ~2× longer
15
+
16
+ ## Stage 1: SLAT-Interior VAE Pre-training
17
+
18
+ ### Architecture
19
+ - **Encoder**: Sparse 3D convolutional U-Net
20
+ - Input: Dense occupancy grid O ∈ {0,1}^N³ where N=256/512/1024
21
+ - Sparse convolution layers with channel-to-space shortcuts
22
+ - 16× spatial compression (1024³ → 64³ latent)
23
+
24
+ - **Decoder**:
25
+ - Sparse conv upsampler with skip connections
26
+ - Early-pruning: predict binary mask for active children before upsampling
27
+ - Outputs: per-voxel shape features + material features
28
+
29
+ ### Training Configuration
30
+ ```yaml
31
+ # configs/vae_pretrain.yaml
32
+ model:
33
+ latent_dim: 64
34
+ base_resolution: 256
35
+ target_resolution: 1024
36
+
37
+ optimizer:
38
+ type: AdamW
39
+ lr: 1.0e-4
40
+ weight_decay: 0.01
41
+ betas: [0.9, 0.999]
42
+
43
+ scheduler:
44
+ type: cosine_with_restarts
45
+ warmup_steps: 10000
46
+
47
+ training:
48
+ batch_size: 8 # per GPU
49
+ num_gpus: 8
50
+ effective_batch_size: 64
51
+ max_steps: 200000
52
+ gradient_accumulation: 1
53
+ mixed_precision: bf16
54
+
55
+ curriculum:
56
+ - resolution: 256
57
+ steps: 50000
58
+ lr: 1.0e-4
59
+ - resolution: 512
60
+ steps: 100000
61
+ lr: 1.0e-4
62
+ - resolution: 1024
63
+ steps: 50000
64
+ lr: 5.0e-5
65
+
66
+ data:
67
+ dataset: InteriorFusion-Train
68
+ num_workers: 8
69
+ pin_memory: true
70
+
71
+ loss:
72
+ reconstruction:
73
+ weight: 1.0
74
+ type: l1
75
+ kl_divergence:
76
+ weight: 1.0e-3
77
+ depth_consistency:
78
+ weight: 0.5
79
+ type: l1
80
+ normal_consistency:
81
+ weight: 0.3
82
+ type: cosine
83
+ edge_preservation:
84
+ weight: 0.2
85
+ type: l1
86
+ ```
87
+
88
+ ### Loss Functions
89
+
90
+ ```python
91
+ def vae_loss(pred_shape, pred_material, target_shape, target_material,
92
+ pred_depth, target_depth, pred_normal, target_normal, mu, logvar):
93
+
94
+ # Reconstruction
95
+ loss_recon = F.l1_loss(pred_shape, target_shape) + \
96
+ F.l1_loss(pred_material, target_material)
97
+
98
+ # KL divergence
99
+ loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
100
+ loss_kl = loss_kl * 1e-3
101
+
102
+ # Depth consistency
103
+ loss_depth = F.l1_loss(pred_depth, target_depth)
104
+
105
+ # Normal consistency
106
+ loss_normal = 1 - F.cosine_similarity(pred_normal, target_normal, dim=-1).mean()
107
+
108
+ return loss_recon + loss_kl + 0.5 * loss_depth + 0.3 * loss_normal
109
+ ```
110
+
111
+ ## Stage 2: Structure DiT (Rectified Flow)
112
+
113
+ ### Architecture
114
+ - **DiT model**: Flow-matching transformer
115
+ - Width: 1536
116
+ - Depth: 30 blocks
117
+ - Heads: 12
118
+ - MLP ratio: 8192
119
+ - Parameters: ~1.3B
120
+
121
+ - **Conditioning encoders**:
122
+ - Image: DINOv3-L (frozen, 1024-dim features)
123
+ - Depth: Custom CNN encoder (256-dim)
124
+ - Layout: Transformer encoder on SpatialLM tokens (512-dim)
125
+ - Semantic: Mask2Former feature pyramid (256-dim)
126
+
127
+ - **Conditioning fusion**: Cross-attention + AdaLN-single modulation
128
+
129
+ ### Training Configuration
130
+ ```yaml
131
+ # configs/dit_structure.yaml
132
+ model:
133
+ width: 1536
134
+ depth: 30
135
+ num_heads: 12
136
+ mlp_ratio: 8192
137
+
138
+ optimizer:
139
+ type: AdamW
140
+ lr: 1.0e-4
141
+ weight_decay: 0.01
142
+
143
+ scheduler:
144
+ type: linear_warmup_cosine
145
+ warmup_steps: 10000
146
+
147
+ training:
148
+ batch_size: 8 # per GPU
149
+ num_gpus: 32
150
+ effective_batch_size: 256
151
+ max_steps: 400000
152
+ mixed_precision: bf16
153
+
154
+ curriculum:
155
+ - resolution: 256
156
+ steps: 100000
157
+ lr: 1.0e-4
158
+ - resolution: 512
159
+ steps: 200000
160
+ lr: 1.0e-4
161
+ - resolution: 1024
162
+ steps: 100000
163
+ lr: 2.0e-5
164
+
165
+ data:
166
+ dataset: InteriorFusion-Train
167
+ num_workers: 8
168
+
169
+ flow_matching:
170
+ sigma_min: 0.001
171
+ sigma_max: 80.0
172
+ p_mean: -1.2
173
+ p_std: 1.2
174
+
175
+ loss:
176
+ flow_matching:
177
+ weight: 1.0
178
+ depth_guidance:
179
+ weight: 0.3
180
+ ```
181
+
182
+ ### Flow Matching Loss
183
+
184
+ ```python
185
+ def flow_matching_loss(model, x_1, cond_img, cond_depth, cond_layout, cond_semantic):
186
+ """
187
+ Rectified flow matching for 3D generation.
188
+ x_1: target structured latent (from VAE encoder)
189
+ """
190
+ # Sample noise
191
+ x_0 = torch.randn_like(x_1)
192
+
193
+ # Sample timestep
194
+ t = torch.rand(x_1.shape[0], device=x_1.device)
195
+
196
+ # Interpolate
197
+ x_t = (1 - t[:, None, None, None]) * x_0 + t[:, None, None, None] * x_1
198
+
199
+ # Model predicts velocity
200
+ v_pred = model(x_t, t, cond_img, cond_depth, cond_layout, cond_semantic)
201
+
202
+ # Target velocity
203
+ v_target = x_1 - x_0
204
+
205
+ # MSE loss
206
+ loss = F.mse_loss(v_pred, v_target)
207
+
208
+ return loss
209
+ ```
210
+
211
+ ## Stage 3: Material DiT
212
+
213
+ ### Architecture
214
+ - Same DiT backbone as Stage 2
215
+ - Additional conditioning: generated geometry latent
216
+ - Output: per-voxel material features (albedo RGB, metallic, roughness, normal XYZ)
217
+
218
+ ### Training
219
+ ```yaml
220
+ # configs/dit_material.yaml
221
+ training:
222
+ batch_size: 16 # per GPU
223
+ num_gpus: 16
224
+ effective_batch_size: 256
225
+ max_steps: 200000
226
+
227
+ loss:
228
+ albedo:
229
+ weight: 1.0
230
+ type: l1
231
+ metallic_roughness:
232
+ weight: 0.5
233
+ type: l1
234
+ normal:
235
+ weight: 0.5
236
+ type: cosine
237
+ perceptual:
238
+ weight: 0.3
239
+ type: lpips
240
+ network: vgg
241
+ rendering:
242
+ weight: 0.5
243
+ type: mse # rendered vs ground truth
244
+ ```
245
+
246
+ ## Stage 4: Real-World Fine-tuning
247
+
248
+ ### LoRA Configuration
249
+ ```yaml
250
+ # configs/finetune_lora.yaml
251
+ lora:
252
+ rank: 32
253
+ alpha: 32
254
+ target_modules:
255
+ - "attention.qkv"
256
+ - "attention.proj"
257
+ - "mlp.fc1"
258
+ - "mlp.fc2"
259
+ dropout: 0.0
260
+
261
+ training:
262
+ batch_size: 4
263
+ num_gpus: 8
264
+ max_steps: 50000
265
+ lr: 1.0e-5
266
+
267
+ data:
268
+ dataset: InteriorFusion-Real # ScanNet + HM3D
269
+ weight: 1.0
270
+ ```
271
+
272
+ ### RL Fine-tuning (Optional)
273
+ ```yaml
274
+ # configs/rl_finetune.yaml
275
+ rl:
276
+ algorithm: GRPO
277
+ group_size: 8
278
+ reward_weights:
279
+ depth_consistency: 0.25
280
+ point_cloud_consistency: 0.25
281
+ pose_stability: 0.25
282
+ edit_quality: 0.25
283
+
284
+ vggt_model: "microsoft/VGGT-1B" # For geometric rewards
285
+
286
+ training:
287
+ num_iterations: 10000
288
+ lr: 1.0e-6
289
+ kl_penalty: 0.01
290
+ ```
291
+
292
+ ## Distributed Training
293
+
294
+ ### Using Accelerate / DeepSpeed
295
+ ```bash
296
+ # Launch with DeepSpeed ZeRO-3
297
+ accelerate launch --config_file configs/accelerate_deepspeed.yaml \
298
+ scripts/train_vae.py --config configs/vae_pretrain.yaml
299
+ ```
300
+
301
+ ```yaml
302
+ # configs/accelerate_deepspeed.yaml
303
+ deep_speed_config:
304
+ zero_stage: 3
305
+ offload_optimizer_device: none
306
+ offload_param_device: none
307
+ gradient_accumulation_steps: 1
308
+ gradient_clipping: 1.0
309
+ train_batch_size: auto
310
+ train_micro_batch_size_per_gpu: auto
311
+ ```
312
+
313
+ ### LR Scaling for Distributed Training
314
+ Following Grendel-GS:
315
+ ```python
316
+ def scale_lr_for_distributed(base_lr, batch_size):
317
+ """Square root scaling for distributed training."""
318
+ return base_lr * math.sqrt(batch_size)
319
+
320
+ def scale_adam_betas_for_distributed(beta1, beta2, batch_size):
321
+ """Exponential momentum scaling."""
322
+ return beta1 ** batch_size, beta2 ** batch_size
323
+ ```
324
+
325
+ ## Checkpointing & Resumption
326
+
327
+ ```python
328
+ checkpoint = {
329
+ 'model': model.state_dict(),
330
+ 'optimizer': optimizer.state_dict(),
331
+ 'scheduler': scheduler.state_dict(),
332
+ 'step': step,
333
+ 'epoch': epoch,
334
+ 'best_val_loss': best_val_loss,
335
+ 'config': OmegaConf.to_container(config),
336
+ }
337
+
338
+ torch.save(checkpoint, f'checkpoints/stage1_step{step}.pt')
339
+ ```
340
+
341
+ ## Validation Metrics
342
+
343
+ | Metric | Target | How to Compute |
344
+ |--------|--------|---------------|
345
+ | Chamfer Distance | < 0.01 | Point cloud comparison |
346
+ | F-Score @ 0.1 | > 0.80 | Precision/recall on surface |
347
+ | LPIPS | < 0.06 | Perceptual similarity |
348
+ | PSNR | > 28 | Rendering quality |
349
+ | SSIM | > 0.90 | Structural similarity |
350
+ | Layout IoU | > 0.85 | Room layout accuracy |
351
+ | Object Detection mAP | > 0.70 | Furniture detection |
352
+ | Scale Error | < 5% | Metric depth consistency |