InteriorFusion / docs /TRAINING.md
stevee00's picture
Upload docs/TRAINING.md
df12227 verified

InteriorFusion Training Guide

Hardware Requirements

Stage GPUs VRAM Each Duration Cost (Cloud)
VAE Pre-training 8× A100 (80GB) 80GB 7 days ~$15K
Structure DiT 32× A100 (80GB) 80GB 14 days ~$30K
Material DiT 16× A100 (80GB) 80GB 7 days ~$15K
Fine-tuning 8× A100 (80GB) 80GB 3 days ~$5K
Total Variable ~4 weeks ~$65K

Minimum viable: 8× A100 (all stages, longer duration) Budget option: 8× RTX 4090 (48GB) — requires gradient accumulation, ~2× longer

Stage 1: SLAT-Interior VAE Pre-training

Architecture

  • Encoder: Sparse 3D convolutional U-Net

    • Input: Dense occupancy grid O ∈ {0,1}^N³ where N=256/512/1024
    • Sparse convolution layers with channel-to-space shortcuts
    • 16× spatial compression (1024³ → 64³ latent)
  • Decoder:

    • Sparse conv upsampler with skip connections
    • Early-pruning: predict binary mask for active children before upsampling
    • Outputs: per-voxel shape features + material features

Training Configuration

# configs/vae_pretrain.yaml
model:
  latent_dim: 64
  base_resolution: 256
  target_resolution: 1024
  
optimizer:
  type: AdamW
  lr: 1.0e-4
  weight_decay: 0.01
  betas: [0.9, 0.999]

scheduler:
  type: cosine_with_restarts
  warmup_steps: 10000
  
training:
  batch_size: 8  # per GPU
  num_gpus: 8
  effective_batch_size: 64
  max_steps: 200000
  gradient_accumulation: 1
  mixed_precision: bf16
  
  curriculum:
    - resolution: 256
      steps: 50000
      lr: 1.0e-4
    - resolution: 512
      steps: 100000
      lr: 1.0e-4
    - resolution: 1024
      steps: 50000
      lr: 5.0e-5

data:
  dataset: InteriorFusion-Train
  num_workers: 8
  pin_memory: true
  
loss:
  reconstruction:
    weight: 1.0
    type: l1
  kl_divergence:
    weight: 1.0e-3
  depth_consistency:
    weight: 0.5
    type: l1
  normal_consistency:
    weight: 0.3
    type: cosine
  edge_preservation:
    weight: 0.2
    type: l1

Loss Functions

def vae_loss(pred_shape, pred_material, target_shape, target_material, 
             pred_depth, target_depth, pred_normal, target_normal, mu, logvar):
    
    # Reconstruction
    loss_recon = F.l1_loss(pred_shape, target_shape) + \
                 F.l1_loss(pred_material, target_material)
    
    # KL divergence
    loss_kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss_kl = loss_kl * 1e-3
    
    # Depth consistency
    loss_depth = F.l1_loss(pred_depth, target_depth)
    
    # Normal consistency
    loss_normal = 1 - F.cosine_similarity(pred_normal, target_normal, dim=-1).mean()
    
    return loss_recon + loss_kl + 0.5 * loss_depth + 0.3 * loss_normal

Stage 2: Structure DiT (Rectified Flow)

Architecture

  • DiT model: Flow-matching transformer

    • Width: 1536
    • Depth: 30 blocks
    • Heads: 12
    • MLP ratio: 8192
    • Parameters: ~1.3B
  • Conditioning encoders:

    • Image: DINOv3-L (frozen, 1024-dim features)
    • Depth: Custom CNN encoder (256-dim)
    • Layout: Transformer encoder on SpatialLM tokens (512-dim)
    • Semantic: Mask2Former feature pyramid (256-dim)
  • Conditioning fusion: Cross-attention + AdaLN-single modulation

Training Configuration

# configs/dit_structure.yaml
model:
  width: 1536
  depth: 30
  num_heads: 12
  mlp_ratio: 8192
  
optimizer:
  type: AdamW
  lr: 1.0e-4
  weight_decay: 0.01

scheduler:
  type: linear_warmup_cosine
  warmup_steps: 10000
  
training:
  batch_size: 8  # per GPU
  num_gpus: 32
  effective_batch_size: 256
  max_steps: 400000
  mixed_precision: bf16
  
  curriculum:
    - resolution: 256
      steps: 100000
      lr: 1.0e-4
    - resolution: 512
      steps: 200000
      lr: 1.0e-4
    - resolution: 1024
      steps: 100000
      lr: 2.0e-5

data:
  dataset: InteriorFusion-Train
  num_workers: 8
  
flow_matching:
  sigma_min: 0.001
  sigma_max: 80.0
  p_mean: -1.2
  p_std: 1.2
  
loss:
  flow_matching:
    weight: 1.0
  depth_guidance:
    weight: 0.3

Flow Matching Loss

def flow_matching_loss(model, x_1, cond_img, cond_depth, cond_layout, cond_semantic):
    """
    Rectified flow matching for 3D generation.
    x_1: target structured latent (from VAE encoder)
    """
    # Sample noise
    x_0 = torch.randn_like(x_1)
    
    # Sample timestep
    t = torch.rand(x_1.shape[0], device=x_1.device)
    
    # Interpolate
    x_t = (1 - t[:, None, None, None]) * x_0 + t[:, None, None, None] * x_1
    
    # Model predicts velocity
    v_pred = model(x_t, t, cond_img, cond_depth, cond_layout, cond_semantic)
    
    # Target velocity
    v_target = x_1 - x_0
    
    # MSE loss
    loss = F.mse_loss(v_pred, v_target)
    
    return loss

Stage 3: Material DiT

Architecture

  • Same DiT backbone as Stage 2
  • Additional conditioning: generated geometry latent
  • Output: per-voxel material features (albedo RGB, metallic, roughness, normal XYZ)

Training

# configs/dit_material.yaml
training:
  batch_size: 16  # per GPU
  num_gpus: 16
  effective_batch_size: 256
  max_steps: 200000
  
loss:
  albedo:
    weight: 1.0
    type: l1
  metallic_roughness:
    weight: 0.5
    type: l1
  normal:
    weight: 0.5
    type: cosine
  perceptual:
    weight: 0.3
    type: lpips
    network: vgg
  rendering:
    weight: 0.5
    type: mse  # rendered vs ground truth

Stage 4: Real-World Fine-tuning

LoRA Configuration

# configs/finetune_lora.yaml
lora:
  rank: 32
  alpha: 32
  target_modules:
    - "attention.qkv"
    - "attention.proj"
    - "mlp.fc1"
    - "mlp.fc2"
  dropout: 0.0

training:
  batch_size: 4
  num_gpus: 8
  max_steps: 50000
  lr: 1.0e-5
  
data:
  dataset: InteriorFusion-Real  # ScanNet + HM3D
  weight: 1.0

RL Fine-tuning (Optional)

# configs/rl_finetune.yaml
rl:
  algorithm: GRPO
  group_size: 8
  reward_weights:
    depth_consistency: 0.25
    point_cloud_consistency: 0.25
    pose_stability: 0.25
    edit_quality: 0.25
  
  vggt_model: "microsoft/VGGT-1B"  # For geometric rewards
  
training:
  num_iterations: 10000
  lr: 1.0e-6
  kl_penalty: 0.01

Distributed Training

Using Accelerate / DeepSpeed

# Launch with DeepSpeed ZeRO-3
accelerate launch --config_file configs/accelerate_deepspeed.yaml \
    scripts/train_vae.py --config configs/vae_pretrain.yaml
# configs/accelerate_deepspeed.yaml
deep_speed_config:
  zero_stage: 3
  offload_optimizer_device: none
  offload_param_device: none
  gradient_accumulation_steps: 1
  gradient_clipping: 1.0
  train_batch_size: auto
  train_micro_batch_size_per_gpu: auto

LR Scaling for Distributed Training

Following Grendel-GS:

def scale_lr_for_distributed(base_lr, batch_size):
    """Square root scaling for distributed training."""
    return base_lr * math.sqrt(batch_size)

def scale_adam_betas_for_distributed(beta1, beta2, batch_size):
    """Exponential momentum scaling."""
    return beta1 ** batch_size, beta2 ** batch_size

Checkpointing & Resumption

checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'scheduler': scheduler.state_dict(),
    'step': step,
    'epoch': epoch,
    'best_val_loss': best_val_loss,
    'config': OmegaConf.to_container(config),
}

torch.save(checkpoint, f'checkpoints/stage1_step{step}.pt')

Validation Metrics

Metric Target How to Compute
Chamfer Distance < 0.01 Point cloud comparison
F-Score @ 0.1 > 0.80 Precision/recall on surface
LPIPS < 0.06 Perceptual similarity
PSNR > 28 Rendering quality
SSIM > 0.90 Structural similarity
Layout IoU > 0.85 Room layout accuracy
Object Detection mAP > 0.70 Furniture detection
Scale Error < 5% Metric depth consistency