Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		heheyas
		
	commited on
		
		
					Commit 
							
							·
						
						cfb7702
	
1
								Parent(s):
							
							f5c8d4d
								
init
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitignore +50 -0
 - configs/ae/video.yaml +35 -0
 - configs/embedder/clip_image.yaml +8 -0
 - configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml +104 -0
 - configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml +105 -0
 - configs/example_training/imagenet-f8_cond.yaml +185 -0
 - configs/example_training/toy/cifar10_cond.yaml +98 -0
 - configs/example_training/toy/mnist.yaml +79 -0
 - configs/example_training/toy/mnist_cond.yaml +98 -0
 - configs/example_training/toy/mnist_cond_discrete_eps.yaml +103 -0
 - configs/example_training/toy/mnist_cond_l1_loss.yaml +99 -0
 - configs/example_training/toy/mnist_cond_with_ema.yaml +100 -0
 - configs/example_training/txt2img-clipl-legacy-ucg-training.yaml +182 -0
 - configs/example_training/txt2img-clipl.yaml +184 -0
 - configs/inference/sd_2_1.yaml +60 -0
 - configs/inference/sd_2_1_768.yaml +60 -0
 - configs/inference/sd_xl_base.yaml +93 -0
 - configs/inference/sd_xl_refiner.yaml +86 -0
 - configs/inference/svd.yaml +131 -0
 - configs/inference/svd_image_decoder.yaml +114 -0
 - configs/inference/svd_mv.yaml +202 -0
 - mesh_recon/configs/neuralangelo-ortho-wmask.yaml +145 -0
 - mesh_recon/configs/v3d.yaml +144 -0
 - mesh_recon/configs/videonvs.yaml +144 -0
 - mesh_recon/datasets/__init__.py +17 -0
 - mesh_recon/datasets/blender.py +143 -0
 - mesh_recon/datasets/colmap.py +332 -0
 - mesh_recon/datasets/colmap_utils.py +295 -0
 - mesh_recon/datasets/dtu.py +201 -0
 - mesh_recon/datasets/fixed_poses/000_back_RT.txt +3 -0
 - mesh_recon/datasets/fixed_poses/000_back_left_RT.txt +3 -0
 - mesh_recon/datasets/fixed_poses/000_back_right_RT.txt +3 -0
 - mesh_recon/datasets/fixed_poses/000_front_RT.txt +3 -0
 - mesh_recon/datasets/fixed_poses/000_front_left_RT.txt +3 -0
 - mesh_recon/datasets/fixed_poses/000_front_right_RT.txt +3 -0
 - mesh_recon/datasets/fixed_poses/000_left_RT.txt +3 -0
 - mesh_recon/datasets/fixed_poses/000_right_RT.txt +3 -0
 - mesh_recon/datasets/fixed_poses/000_top_RT.txt +3 -0
 - mesh_recon/datasets/ortho.py +287 -0
 - mesh_recon/datasets/utils.py +0 -0
 - mesh_recon/datasets/v3d.py +284 -0
 - mesh_recon/datasets/videonvs.py +256 -0
 - mesh_recon/datasets/videonvs_co3d.py +252 -0
 - mesh_recon/launch.py +144 -0
 - mesh_recon/mesh.py +845 -0
 - mesh_recon/models/__init__.py +16 -0
 - mesh_recon/models/base.py +32 -0
 - mesh_recon/models/geometry.py +238 -0
 - mesh_recon/models/nerf.py +161 -0
 - mesh_recon/models/network_utils.py +215 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,50 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # extensions
         
     | 
| 2 | 
         
            +
            *.egg-info
         
     | 
| 3 | 
         
            +
            *.py[cod]
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            # envs
         
     | 
| 6 | 
         
            +
            .pt13
         
     | 
| 7 | 
         
            +
            .pt2
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # directories
         
     | 
| 10 | 
         
            +
            /checkpoints
         
     | 
| 11 | 
         
            +
            /dist
         
     | 
| 12 | 
         
            +
            /outputs
         
     | 
| 13 | 
         
            +
            /build
         
     | 
| 14 | 
         
            +
            /src
         
     | 
| 15 | 
         
            +
            logs/
         
     | 
| 16 | 
         
            +
            ckpts/
         
     | 
| 17 | 
         
            +
            tmp/
         
     | 
| 18 | 
         
            +
            lightning_logs/
         
     | 
| 19 | 
         
            +
            images/
         
     | 
| 20 | 
         
            +
            images*/
         
     | 
| 21 | 
         
            +
            kb_configs/
         
     | 
| 22 | 
         
            +
            debug_lvis.log
         
     | 
| 23 | 
         
            +
            *.log
         
     | 
| 24 | 
         
            +
            .cache/
         
     | 
| 25 | 
         
            +
            redirects/
         
     | 
| 26 | 
         
            +
            submits/
         
     | 
| 27 | 
         
            +
            extern/
         
     | 
| 28 | 
         
            +
            assets/images
         
     | 
| 29 | 
         
            +
            output/
         
     | 
| 30 | 
         
            +
            assets/scene
         
     | 
| 31 | 
         
            +
            assets/GSO
         
     | 
| 32 | 
         
            +
            assets/SD
         
     | 
| 33 | 
         
            +
            spirals
         
     | 
| 34 | 
         
            +
            *.zip
         
     | 
| 35 | 
         
            +
            paper/
         
     | 
| 36 | 
         
            +
            spirals_co3d/
         
     | 
| 37 | 
         
            +
            scene_spirals/
         
     | 
| 38 | 
         
            +
            blenders/
         
     | 
| 39 | 
         
            +
            colmap_results/
         
     | 
| 40 | 
         
            +
            depth_spirals/
         
     | 
| 41 | 
         
            +
            recon/SIBR_viewers/
         
     | 
| 42 | 
         
            +
            recon/assets/
         
     | 
| 43 | 
         
            +
            mesh_recon/exp
         
     | 
| 44 | 
         
            +
            mesh_recon/runs
         
     | 
| 45 | 
         
            +
            mesh_recon/renders
         
     | 
| 46 | 
         
            +
            mesh_recon/refined
         
     | 
| 47 | 
         
            +
            *.png
         
     | 
| 48 | 
         
            +
            *.pdf
         
     | 
| 49 | 
         
            +
            *.npz
         
     | 
| 50 | 
         
            +
            *.npy
         
     | 
    	
        configs/ae/video.yaml
    ADDED
    
    | 
         @@ -0,0 +1,35 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            target: sgm.models.autoencoder.AutoencodingEngine
         
     | 
| 2 | 
         
            +
            params:
         
     | 
| 3 | 
         
            +
              loss_config:
         
     | 
| 4 | 
         
            +
                target: torch.nn.Identity
         
     | 
| 5 | 
         
            +
              regularizer_config:
         
     | 
| 6 | 
         
            +
                target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
         
     | 
| 7 | 
         
            +
              encoder_config: 
         
     | 
| 8 | 
         
            +
                target: sgm.modules.diffusionmodules.model.Encoder
         
     | 
| 9 | 
         
            +
                params:
         
     | 
| 10 | 
         
            +
                  attn_type: vanilla
         
     | 
| 11 | 
         
            +
                  double_z: True
         
     | 
| 12 | 
         
            +
                  z_channels: 4
         
     | 
| 13 | 
         
            +
                  resolution: 256
         
     | 
| 14 | 
         
            +
                  in_channels: 3
         
     | 
| 15 | 
         
            +
                  out_ch: 3
         
     | 
| 16 | 
         
            +
                  ch: 128
         
     | 
| 17 | 
         
            +
                  ch_mult: [1, 2, 4, 4]
         
     | 
| 18 | 
         
            +
                  num_res_blocks: 2
         
     | 
| 19 | 
         
            +
                  attn_resolutions: []
         
     | 
| 20 | 
         
            +
                  dropout: 0.0
         
     | 
| 21 | 
         
            +
              decoder_config:
         
     | 
| 22 | 
         
            +
                target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
         
     | 
| 23 | 
         
            +
                params:
         
     | 
| 24 | 
         
            +
                  attn_type: vanilla
         
     | 
| 25 | 
         
            +
                  double_z: True
         
     | 
| 26 | 
         
            +
                  z_channels: 4
         
     | 
| 27 | 
         
            +
                  resolution: 256
         
     | 
| 28 | 
         
            +
                  in_channels: 3
         
     | 
| 29 | 
         
            +
                  out_ch: 3
         
     | 
| 30 | 
         
            +
                  ch: 128
         
     | 
| 31 | 
         
            +
                  ch_mult: [1, 2, 4, 4]
         
     | 
| 32 | 
         
            +
                  num_res_blocks: 2
         
     | 
| 33 | 
         
            +
                  attn_resolutions: []
         
     | 
| 34 | 
         
            +
                  dropout: 0.0
         
     | 
| 35 | 
         
            +
                  video_kernel_size: [3, 1, 1]
         
     | 
    	
        configs/embedder/clip_image.yaml
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
         
     | 
| 2 | 
         
            +
            params:
         
     | 
| 3 | 
         
            +
              n_cond_frames: 1
         
     | 
| 4 | 
         
            +
              n_copies: 1
         
     | 
| 5 | 
         
            +
              open_clip_embedding_config:
         
     | 
| 6 | 
         
            +
                target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
         
     | 
| 7 | 
         
            +
                params:
         
     | 
| 8 | 
         
            +
                  freeze: True
         
     | 
    	
        configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml
    ADDED
    
    | 
         @@ -0,0 +1,104 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 4.5e-6
         
     | 
| 3 | 
         
            +
              target: sgm.models.autoencoder.AutoencodingEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                input_key: jpg
         
     | 
| 6 | 
         
            +
                monitor: val/rec_loss
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                loss_config:
         
     | 
| 9 | 
         
            +
                  target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
         
     | 
| 10 | 
         
            +
                  params:
         
     | 
| 11 | 
         
            +
                    perceptual_weight: 0.25
         
     | 
| 12 | 
         
            +
                    disc_start: 20001
         
     | 
| 13 | 
         
            +
                    disc_weight: 0.5
         
     | 
| 14 | 
         
            +
                    learn_logvar: True
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                    regularization_weights:
         
     | 
| 17 | 
         
            +
                      kl_loss: 1.0
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                regularizer_config:
         
     | 
| 20 | 
         
            +
                  target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                encoder_config:
         
     | 
| 23 | 
         
            +
                  target: sgm.modules.diffusionmodules.model.Encoder
         
     | 
| 24 | 
         
            +
                  params:
         
     | 
| 25 | 
         
            +
                    attn_type: none
         
     | 
| 26 | 
         
            +
                    double_z: True
         
     | 
| 27 | 
         
            +
                    z_channels: 4
         
     | 
| 28 | 
         
            +
                    resolution: 256
         
     | 
| 29 | 
         
            +
                    in_channels: 3
         
     | 
| 30 | 
         
            +
                    out_ch: 3
         
     | 
| 31 | 
         
            +
                    ch: 128
         
     | 
| 32 | 
         
            +
                    ch_mult: [1, 2, 4]
         
     | 
| 33 | 
         
            +
                    num_res_blocks: 4
         
     | 
| 34 | 
         
            +
                    attn_resolutions: []
         
     | 
| 35 | 
         
            +
                    dropout: 0.0
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                decoder_config:
         
     | 
| 38 | 
         
            +
                  target: sgm.modules.diffusionmodules.model.Decoder
         
     | 
| 39 | 
         
            +
                  params: ${model.params.encoder_config.params}
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            data:
         
     | 
| 42 | 
         
            +
              target: sgm.data.dataset.StableDataModuleFromConfig
         
     | 
| 43 | 
         
            +
              params:
         
     | 
| 44 | 
         
            +
                train:
         
     | 
| 45 | 
         
            +
                  datapipeline:
         
     | 
| 46 | 
         
            +
                    urls:
         
     | 
| 47 | 
         
            +
                      - DATA-PATH
         
     | 
| 48 | 
         
            +
                    pipeline_config:
         
     | 
| 49 | 
         
            +
                      shardshuffle: 10000
         
     | 
| 50 | 
         
            +
                      sample_shuffle: 10000
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    decoders:
         
     | 
| 53 | 
         
            +
                      - pil
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    postprocessors:
         
     | 
| 56 | 
         
            +
                      - target: sdata.mappers.TorchVisionImageTransforms
         
     | 
| 57 | 
         
            +
                        params:
         
     | 
| 58 | 
         
            +
                          key: jpg
         
     | 
| 59 | 
         
            +
                          transforms:
         
     | 
| 60 | 
         
            +
                            - target: torchvision.transforms.Resize
         
     | 
| 61 | 
         
            +
                              params:
         
     | 
| 62 | 
         
            +
                                size: 256
         
     | 
| 63 | 
         
            +
                                interpolation: 3
         
     | 
| 64 | 
         
            +
                            - target: torchvision.transforms.ToTensor
         
     | 
| 65 | 
         
            +
                      - target: sdata.mappers.Rescaler
         
     | 
| 66 | 
         
            +
                      - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
         
     | 
| 67 | 
         
            +
                        params:
         
     | 
| 68 | 
         
            +
                          h_key: height
         
     | 
| 69 | 
         
            +
                          w_key: width
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                  loader:
         
     | 
| 72 | 
         
            +
                    batch_size: 8
         
     | 
| 73 | 
         
            +
                    num_workers: 4
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            lightning:
         
     | 
| 77 | 
         
            +
              strategy:
         
     | 
| 78 | 
         
            +
                target: pytorch_lightning.strategies.DDPStrategy
         
     | 
| 79 | 
         
            +
                params:
         
     | 
| 80 | 
         
            +
                  find_unused_parameters: True
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
              modelcheckpoint:
         
     | 
| 83 | 
         
            +
                params:
         
     | 
| 84 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
              callbacks:
         
     | 
| 87 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 88 | 
         
            +
                  params:
         
     | 
| 89 | 
         
            +
                    every_n_train_steps: 50000
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                image_logger:
         
     | 
| 92 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 93 | 
         
            +
                  params:
         
     | 
| 94 | 
         
            +
                    enable_autocast: False
         
     | 
| 95 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 96 | 
         
            +
                    max_images: 8
         
     | 
| 97 | 
         
            +
                    increase_log_steps: True
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
              trainer:
         
     | 
| 100 | 
         
            +
                devices: 0,
         
     | 
| 101 | 
         
            +
                limit_val_batches: 50
         
     | 
| 102 | 
         
            +
                benchmark: True
         
     | 
| 103 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 104 | 
         
            +
                val_check_interval: 10000
         
     | 
    	
        configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml
    ADDED
    
    | 
         @@ -0,0 +1,105 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 4.5e-6
         
     | 
| 3 | 
         
            +
              target: sgm.models.autoencoder.AutoencodingEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                input_key: jpg
         
     | 
| 6 | 
         
            +
                monitor: val/loss/rec
         
     | 
| 7 | 
         
            +
                disc_start_iter: 0
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                encoder_config:
         
     | 
| 10 | 
         
            +
                  target: sgm.modules.diffusionmodules.model.Encoder
         
     | 
| 11 | 
         
            +
                  params:
         
     | 
| 12 | 
         
            +
                    attn_type: vanilla-xformers
         
     | 
| 13 | 
         
            +
                    double_z: true
         
     | 
| 14 | 
         
            +
                    z_channels: 8
         
     | 
| 15 | 
         
            +
                    resolution: 256
         
     | 
| 16 | 
         
            +
                    in_channels: 3
         
     | 
| 17 | 
         
            +
                    out_ch: 3
         
     | 
| 18 | 
         
            +
                    ch: 128
         
     | 
| 19 | 
         
            +
                    ch_mult: [1, 2, 4, 4]
         
     | 
| 20 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 21 | 
         
            +
                    attn_resolutions: []
         
     | 
| 22 | 
         
            +
                    dropout: 0.0
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                decoder_config:
         
     | 
| 25 | 
         
            +
                  target: sgm.modules.diffusionmodules.model.Decoder
         
     | 
| 26 | 
         
            +
                  params: ${model.params.encoder_config.params}
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                regularizer_config:
         
     | 
| 29 | 
         
            +
                  target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                loss_config:
         
     | 
| 32 | 
         
            +
                  target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
         
     | 
| 33 | 
         
            +
                  params:
         
     | 
| 34 | 
         
            +
                    perceptual_weight: 0.25
         
     | 
| 35 | 
         
            +
                    disc_start: 20001
         
     | 
| 36 | 
         
            +
                    disc_weight: 0.5
         
     | 
| 37 | 
         
            +
                    learn_logvar: True
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    regularization_weights:
         
     | 
| 40 | 
         
            +
                      kl_loss: 1.0
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            data:
         
     | 
| 43 | 
         
            +
              target: sgm.data.dataset.StableDataModuleFromConfig
         
     | 
| 44 | 
         
            +
              params:
         
     | 
| 45 | 
         
            +
                train:
         
     | 
| 46 | 
         
            +
                  datapipeline:
         
     | 
| 47 | 
         
            +
                    urls:
         
     | 
| 48 | 
         
            +
                      - DATA-PATH
         
     | 
| 49 | 
         
            +
                    pipeline_config:
         
     | 
| 50 | 
         
            +
                      shardshuffle: 10000
         
     | 
| 51 | 
         
            +
                      sample_shuffle: 10000
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    decoders:
         
     | 
| 54 | 
         
            +
                      - pil
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    postprocessors:
         
     | 
| 57 | 
         
            +
                      - target: sdata.mappers.TorchVisionImageTransforms
         
     | 
| 58 | 
         
            +
                        params:
         
     | 
| 59 | 
         
            +
                          key: jpg
         
     | 
| 60 | 
         
            +
                          transforms:
         
     | 
| 61 | 
         
            +
                            - target: torchvision.transforms.Resize
         
     | 
| 62 | 
         
            +
                              params:
         
     | 
| 63 | 
         
            +
                                size: 256
         
     | 
| 64 | 
         
            +
                                interpolation: 3
         
     | 
| 65 | 
         
            +
                            - target: torchvision.transforms.ToTensor
         
     | 
| 66 | 
         
            +
                      - target: sdata.mappers.Rescaler
         
     | 
| 67 | 
         
            +
                      - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
         
     | 
| 68 | 
         
            +
                        params:
         
     | 
| 69 | 
         
            +
                          h_key: height
         
     | 
| 70 | 
         
            +
                          w_key: width
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                  loader:
         
     | 
| 73 | 
         
            +
                    batch_size: 8
         
     | 
| 74 | 
         
            +
                    num_workers: 4
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
            lightning:
         
     | 
| 78 | 
         
            +
              strategy:
         
     | 
| 79 | 
         
            +
                target: pytorch_lightning.strategies.DDPStrategy
         
     | 
| 80 | 
         
            +
                params:
         
     | 
| 81 | 
         
            +
                  find_unused_parameters: True
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
              modelcheckpoint:
         
     | 
| 84 | 
         
            +
                params:
         
     | 
| 85 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
              callbacks:
         
     | 
| 88 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 89 | 
         
            +
                  params:
         
     | 
| 90 | 
         
            +
                    every_n_train_steps: 50000
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                image_logger:
         
     | 
| 93 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 94 | 
         
            +
                  params:
         
     | 
| 95 | 
         
            +
                    enable_autocast: False
         
     | 
| 96 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 97 | 
         
            +
                    max_images: 8
         
     | 
| 98 | 
         
            +
                    increase_log_steps: True
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
              trainer:
         
     | 
| 101 | 
         
            +
                devices: 0,
         
     | 
| 102 | 
         
            +
                limit_val_batches: 50
         
     | 
| 103 | 
         
            +
                benchmark: True
         
     | 
| 104 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 105 | 
         
            +
                val_check_interval: 10000
         
     | 
    	
        configs/example_training/imagenet-f8_cond.yaml
    ADDED
    
    | 
         @@ -0,0 +1,185 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                scale_factor: 0.13025
         
     | 
| 6 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 7 | 
         
            +
                log_keys:
         
     | 
| 8 | 
         
            +
                  - cls
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                scheduler_config:
         
     | 
| 11 | 
         
            +
                  target: sgm.lr_scheduler.LambdaLinearScheduler
         
     | 
| 12 | 
         
            +
                  params:
         
     | 
| 13 | 
         
            +
                    warm_up_steps: [10000]
         
     | 
| 14 | 
         
            +
                    cycle_lengths: [10000000000000]
         
     | 
| 15 | 
         
            +
                    f_start: [1.e-6]
         
     | 
| 16 | 
         
            +
                    f_max: [1.]
         
     | 
| 17 | 
         
            +
                    f_min: [1.]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                denoiser_config:
         
     | 
| 20 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
         
     | 
| 21 | 
         
            +
                  params:
         
     | 
| 22 | 
         
            +
                    num_idx: 1000
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    scaling_config:
         
     | 
| 25 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
         
     | 
| 26 | 
         
            +
                    discretization_config:
         
     | 
| 27 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                network_config:
         
     | 
| 30 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 31 | 
         
            +
                  params:
         
     | 
| 32 | 
         
            +
                    use_checkpoint: True
         
     | 
| 33 | 
         
            +
                    in_channels: 4
         
     | 
| 34 | 
         
            +
                    out_channels: 4
         
     | 
| 35 | 
         
            +
                    model_channels: 256
         
     | 
| 36 | 
         
            +
                    attention_resolutions: [1, 2, 4]
         
     | 
| 37 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 38 | 
         
            +
                    channel_mult: [1, 2, 4]
         
     | 
| 39 | 
         
            +
                    num_head_channels: 64
         
     | 
| 40 | 
         
            +
                    num_classes: sequential
         
     | 
| 41 | 
         
            +
                    adm_in_channels: 1024
         
     | 
| 42 | 
         
            +
                    transformer_depth: 1
         
     | 
| 43 | 
         
            +
                    context_dim: 1024
         
     | 
| 44 | 
         
            +
                    spatial_transformer_attn_type: softmax-xformers
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                conditioner_config:
         
     | 
| 47 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 48 | 
         
            +
                  params:
         
     | 
| 49 | 
         
            +
                    emb_models:
         
     | 
| 50 | 
         
            +
                      - is_trainable: True
         
     | 
| 51 | 
         
            +
                        input_key: cls
         
     | 
| 52 | 
         
            +
                        ucg_rate: 0.2
         
     | 
| 53 | 
         
            +
                        target: sgm.modules.encoders.modules.ClassEmbedder
         
     | 
| 54 | 
         
            +
                        params:
         
     | 
| 55 | 
         
            +
                          add_sequence_dim: True
         
     | 
| 56 | 
         
            +
                          embed_dim: 1024
         
     | 
| 57 | 
         
            +
                          n_classes: 1000
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                      - is_trainable: False
         
     | 
| 60 | 
         
            +
                        ucg_rate: 0.2
         
     | 
| 61 | 
         
            +
                        input_key: original_size_as_tuple
         
     | 
| 62 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 63 | 
         
            +
                        params:
         
     | 
| 64 | 
         
            +
                          outdim: 256
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                      - is_trainable: False
         
     | 
| 67 | 
         
            +
                        input_key: crop_coords_top_left
         
     | 
| 68 | 
         
            +
                        ucg_rate: 0.2
         
     | 
| 69 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 70 | 
         
            +
                        params:
         
     | 
| 71 | 
         
            +
                          outdim: 256
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                first_stage_config:
         
     | 
| 74 | 
         
            +
                  target: sgm.models.autoencoder.AutoencoderKL
         
     | 
| 75 | 
         
            +
                  params:
         
     | 
| 76 | 
         
            +
                    ckpt_path: CKPT_PATH
         
     | 
| 77 | 
         
            +
                    embed_dim: 4
         
     | 
| 78 | 
         
            +
                    monitor: val/rec_loss
         
     | 
| 79 | 
         
            +
                    ddconfig:
         
     | 
| 80 | 
         
            +
                      attn_type: vanilla-xformers
         
     | 
| 81 | 
         
            +
                      double_z: true
         
     | 
| 82 | 
         
            +
                      z_channels: 4
         
     | 
| 83 | 
         
            +
                      resolution: 256
         
     | 
| 84 | 
         
            +
                      in_channels: 3
         
     | 
| 85 | 
         
            +
                      out_ch: 3
         
     | 
| 86 | 
         
            +
                      ch: 128
         
     | 
| 87 | 
         
            +
                      ch_mult: [1, 2, 4, 4]
         
     | 
| 88 | 
         
            +
                      num_res_blocks: 2
         
     | 
| 89 | 
         
            +
                      attn_resolutions: []
         
     | 
| 90 | 
         
            +
                      dropout: 0.0
         
     | 
| 91 | 
         
            +
                    lossconfig:
         
     | 
| 92 | 
         
            +
                      target: torch.nn.Identity
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                loss_fn_config:
         
     | 
| 95 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 96 | 
         
            +
                  params:        
         
     | 
| 97 | 
         
            +
                    loss_weighting_config:
         
     | 
| 98 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
         
     | 
| 99 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 100 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
         
     | 
| 101 | 
         
            +
                      params:
         
     | 
| 102 | 
         
            +
                        num_idx: 1000
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                        discretization_config:
         
     | 
| 105 | 
         
            +
                          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                sampler_config:
         
     | 
| 108 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 109 | 
         
            +
                  params:
         
     | 
| 110 | 
         
            +
                    num_steps: 50
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    discretization_config:
         
     | 
| 113 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    guider_config:
         
     | 
| 116 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.VanillaCFG
         
     | 
| 117 | 
         
            +
                      params:
         
     | 
| 118 | 
         
            +
                        scale: 5.0
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            data:
         
     | 
| 121 | 
         
            +
              target: sgm.data.dataset.StableDataModuleFromConfig
         
     | 
| 122 | 
         
            +
              params:
         
     | 
| 123 | 
         
            +
                train:
         
     | 
| 124 | 
         
            +
                  datapipeline:
         
     | 
| 125 | 
         
            +
                    urls:
         
     | 
| 126 | 
         
            +
                      # USER: adapt this path the root of your custom dataset
         
     | 
| 127 | 
         
            +
                      - DATA_PATH
         
     | 
| 128 | 
         
            +
                    pipeline_config:
         
     | 
| 129 | 
         
            +
                      shardshuffle: 10000
         
     | 
| 130 | 
         
            +
                      sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    decoders:
         
     | 
| 133 | 
         
            +
                      - pil
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    postprocessors:
         
     | 
| 136 | 
         
            +
                      - target: sdata.mappers.TorchVisionImageTransforms
         
     | 
| 137 | 
         
            +
                        params:
         
     | 
| 138 | 
         
            +
                          key: jpg # USER: you might wanna adapt this for your custom dataset
         
     | 
| 139 | 
         
            +
                          transforms:
         
     | 
| 140 | 
         
            +
                            - target: torchvision.transforms.Resize
         
     | 
| 141 | 
         
            +
                              params:
         
     | 
| 142 | 
         
            +
                                size: 256
         
     | 
| 143 | 
         
            +
                                interpolation: 3
         
     | 
| 144 | 
         
            +
                            - target: torchvision.transforms.ToTensor
         
     | 
| 145 | 
         
            +
                      - target: sdata.mappers.Rescaler
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                      - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
         
     | 
| 148 | 
         
            +
                        params:
         
     | 
| 149 | 
         
            +
                          h_key: height # USER: you might wanna adapt this for your custom dataset
         
     | 
| 150 | 
         
            +
                          w_key: width # USER: you might wanna adapt this for your custom dataset
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                  loader:
         
     | 
| 153 | 
         
            +
                    batch_size: 64
         
     | 
| 154 | 
         
            +
                    num_workers: 6
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
            lightning:
         
     | 
| 157 | 
         
            +
              modelcheckpoint:
         
     | 
| 158 | 
         
            +
                params:
         
     | 
| 159 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
              callbacks:
         
     | 
| 162 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 163 | 
         
            +
                  params:
         
     | 
| 164 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                image_logger:
         
     | 
| 167 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 168 | 
         
            +
                  params:
         
     | 
| 169 | 
         
            +
                    disabled: False
         
     | 
| 170 | 
         
            +
                    enable_autocast: False
         
     | 
| 171 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 172 | 
         
            +
                    max_images: 8
         
     | 
| 173 | 
         
            +
                    increase_log_steps: True
         
     | 
| 174 | 
         
            +
                    log_first_step: False
         
     | 
| 175 | 
         
            +
                    log_images_kwargs:
         
     | 
| 176 | 
         
            +
                      use_ema_scope: False
         
     | 
| 177 | 
         
            +
                      N: 8
         
     | 
| 178 | 
         
            +
                      n_rows: 2
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
              trainer:
         
     | 
| 181 | 
         
            +
                devices: 0,
         
     | 
| 182 | 
         
            +
                benchmark: True
         
     | 
| 183 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 184 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 185 | 
         
            +
                max_epochs: 1000
         
     | 
    	
        configs/example_training/toy/cifar10_cond.yaml
    ADDED
    
    | 
         @@ -0,0 +1,98 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                denoiser_config:
         
     | 
| 6 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.Denoiser
         
     | 
| 7 | 
         
            +
                  params:
         
     | 
| 8 | 
         
            +
                    scaling_config:
         
     | 
| 9 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
         
     | 
| 10 | 
         
            +
                      params:
         
     | 
| 11 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                network_config:
         
     | 
| 14 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 15 | 
         
            +
                  params:
         
     | 
| 16 | 
         
            +
                    in_channels: 3
         
     | 
| 17 | 
         
            +
                    out_channels: 3
         
     | 
| 18 | 
         
            +
                    model_channels: 32
         
     | 
| 19 | 
         
            +
                    attention_resolutions: []
         
     | 
| 20 | 
         
            +
                    num_res_blocks: 4
         
     | 
| 21 | 
         
            +
                    channel_mult: [1, 2, 2]
         
     | 
| 22 | 
         
            +
                    num_head_channels: 32
         
     | 
| 23 | 
         
            +
                    num_classes: sequential
         
     | 
| 24 | 
         
            +
                    adm_in_channels: 128
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                conditioner_config:
         
     | 
| 27 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 28 | 
         
            +
                  params:
         
     | 
| 29 | 
         
            +
                    emb_models:
         
     | 
| 30 | 
         
            +
                      - is_trainable: True
         
     | 
| 31 | 
         
            +
                        input_key: cls
         
     | 
| 32 | 
         
            +
                        ucg_rate: 0.2
         
     | 
| 33 | 
         
            +
                        target: sgm.modules.encoders.modules.ClassEmbedder
         
     | 
| 34 | 
         
            +
                        params:
         
     | 
| 35 | 
         
            +
                          embed_dim: 128
         
     | 
| 36 | 
         
            +
                          n_classes: 10
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                first_stage_config:
         
     | 
| 39 | 
         
            +
                  target: sgm.models.autoencoder.IdentityFirstStage
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                loss_fn_config:
         
     | 
| 42 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 43 | 
         
            +
                  params:
         
     | 
| 44 | 
         
            +
                    loss_weighting_config:
         
     | 
| 45 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
         
     | 
| 46 | 
         
            +
                      params:
         
     | 
| 47 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 48 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 49 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                sampler_config:
         
     | 
| 52 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 53 | 
         
            +
                  params:
         
     | 
| 54 | 
         
            +
                    num_steps: 50
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    discretization_config:
         
     | 
| 57 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    guider_config:
         
     | 
| 60 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.VanillaCFG
         
     | 
| 61 | 
         
            +
                      params:
         
     | 
| 62 | 
         
            +
                        scale: 3.0
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            data:
         
     | 
| 65 | 
         
            +
              target: sgm.data.cifar10.CIFAR10Loader
         
     | 
| 66 | 
         
            +
              params:
         
     | 
| 67 | 
         
            +
                batch_size: 512
         
     | 
| 68 | 
         
            +
                num_workers: 1
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            lightning:
         
     | 
| 71 | 
         
            +
              modelcheckpoint:
         
     | 
| 72 | 
         
            +
                params:
         
     | 
| 73 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
              callbacks:
         
     | 
| 76 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 77 | 
         
            +
                  params:
         
     | 
| 78 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                image_logger:
         
     | 
| 81 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 82 | 
         
            +
                  params:
         
     | 
| 83 | 
         
            +
                    disabled: False
         
     | 
| 84 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 85 | 
         
            +
                    max_images: 64
         
     | 
| 86 | 
         
            +
                    increase_log_steps: True
         
     | 
| 87 | 
         
            +
                    log_first_step: False
         
     | 
| 88 | 
         
            +
                    log_images_kwargs:
         
     | 
| 89 | 
         
            +
                      use_ema_scope: False
         
     | 
| 90 | 
         
            +
                      N: 64
         
     | 
| 91 | 
         
            +
                      n_rows: 8
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
              trainer:
         
     | 
| 94 | 
         
            +
                devices: 0,
         
     | 
| 95 | 
         
            +
                benchmark: True
         
     | 
| 96 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 97 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 98 | 
         
            +
                max_epochs: 20
         
     | 
    	
        configs/example_training/toy/mnist.yaml
    ADDED
    
    | 
         @@ -0,0 +1,79 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                denoiser_config:
         
     | 
| 6 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.Denoiser
         
     | 
| 7 | 
         
            +
                  params:
         
     | 
| 8 | 
         
            +
                    scaling_config:
         
     | 
| 9 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
         
     | 
| 10 | 
         
            +
                      params:
         
     | 
| 11 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                network_config:
         
     | 
| 14 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 15 | 
         
            +
                  params:
         
     | 
| 16 | 
         
            +
                    in_channels: 1
         
     | 
| 17 | 
         
            +
                    out_channels: 1
         
     | 
| 18 | 
         
            +
                    model_channels: 32
         
     | 
| 19 | 
         
            +
                    attention_resolutions: []
         
     | 
| 20 | 
         
            +
                    num_res_blocks: 4
         
     | 
| 21 | 
         
            +
                    channel_mult: [1, 2, 2]
         
     | 
| 22 | 
         
            +
                    num_head_channels: 32
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                first_stage_config:
         
     | 
| 25 | 
         
            +
                  target: sgm.models.autoencoder.IdentityFirstStage
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                loss_fn_config:
         
     | 
| 28 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 29 | 
         
            +
                  params:
         
     | 
| 30 | 
         
            +
                    loss_weighting_config:
         
     | 
| 31 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
         
     | 
| 32 | 
         
            +
                      params:
         
     | 
| 33 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 34 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 35 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                sampler_config:
         
     | 
| 38 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 39 | 
         
            +
                  params:
         
     | 
| 40 | 
         
            +
                    num_steps: 50
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    discretization_config:
         
     | 
| 43 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            data:
         
     | 
| 46 | 
         
            +
              target: sgm.data.mnist.MNISTLoader
         
     | 
| 47 | 
         
            +
              params:
         
     | 
| 48 | 
         
            +
                batch_size: 512
         
     | 
| 49 | 
         
            +
                num_workers: 1
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
            lightning:
         
     | 
| 52 | 
         
            +
              modelcheckpoint:
         
     | 
| 53 | 
         
            +
                params:
         
     | 
| 54 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
              callbacks:
         
     | 
| 57 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 58 | 
         
            +
                  params:
         
     | 
| 59 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                image_logger:
         
     | 
| 62 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 63 | 
         
            +
                  params:
         
     | 
| 64 | 
         
            +
                    disabled: False
         
     | 
| 65 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 66 | 
         
            +
                    max_images: 64
         
     | 
| 67 | 
         
            +
                    increase_log_steps: False
         
     | 
| 68 | 
         
            +
                    log_first_step: False
         
     | 
| 69 | 
         
            +
                    log_images_kwargs:
         
     | 
| 70 | 
         
            +
                      use_ema_scope: False
         
     | 
| 71 | 
         
            +
                      N: 64
         
     | 
| 72 | 
         
            +
                      n_rows: 8
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
              trainer:
         
     | 
| 75 | 
         
            +
                devices: 0,
         
     | 
| 76 | 
         
            +
                benchmark: True
         
     | 
| 77 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 78 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 79 | 
         
            +
                max_epochs: 10
         
     | 
    	
        configs/example_training/toy/mnist_cond.yaml
    ADDED
    
    | 
         @@ -0,0 +1,98 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                denoiser_config:
         
     | 
| 6 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.Denoiser
         
     | 
| 7 | 
         
            +
                  params:
         
     | 
| 8 | 
         
            +
                    scaling_config:
         
     | 
| 9 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
         
     | 
| 10 | 
         
            +
                      params:
         
     | 
| 11 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                network_config:
         
     | 
| 14 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 15 | 
         
            +
                  params:
         
     | 
| 16 | 
         
            +
                    in_channels: 1
         
     | 
| 17 | 
         
            +
                    out_channels: 1
         
     | 
| 18 | 
         
            +
                    model_channels: 32
         
     | 
| 19 | 
         
            +
                    attention_resolutions: []
         
     | 
| 20 | 
         
            +
                    num_res_blocks: 4
         
     | 
| 21 | 
         
            +
                    channel_mult: [1, 2, 2]
         
     | 
| 22 | 
         
            +
                    num_head_channels: 32
         
     | 
| 23 | 
         
            +
                    num_classes: sequential
         
     | 
| 24 | 
         
            +
                    adm_in_channels: 128
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                conditioner_config:
         
     | 
| 27 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 28 | 
         
            +
                  params:
         
     | 
| 29 | 
         
            +
                    emb_models:
         
     | 
| 30 | 
         
            +
                      - is_trainable: True
         
     | 
| 31 | 
         
            +
                        input_key: cls
         
     | 
| 32 | 
         
            +
                        ucg_rate: 0.2
         
     | 
| 33 | 
         
            +
                        target: sgm.modules.encoders.modules.ClassEmbedder
         
     | 
| 34 | 
         
            +
                        params:
         
     | 
| 35 | 
         
            +
                          embed_dim: 128
         
     | 
| 36 | 
         
            +
                          n_classes: 10
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                first_stage_config:
         
     | 
| 39 | 
         
            +
                  target: sgm.models.autoencoder.IdentityFirstStage
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                loss_fn_config:
         
     | 
| 42 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 43 | 
         
            +
                  params:
         
     | 
| 44 | 
         
            +
                    loss_weighting_config:
         
     | 
| 45 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
         
     | 
| 46 | 
         
            +
                      params:
         
     | 
| 47 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 48 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 49 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                sampler_config:
         
     | 
| 52 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 53 | 
         
            +
                  params:
         
     | 
| 54 | 
         
            +
                    num_steps: 50
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    discretization_config:
         
     | 
| 57 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    guider_config:
         
     | 
| 60 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.VanillaCFG
         
     | 
| 61 | 
         
            +
                      params:
         
     | 
| 62 | 
         
            +
                        scale: 3.0
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            data:
         
     | 
| 65 | 
         
            +
              target: sgm.data.mnist.MNISTLoader
         
     | 
| 66 | 
         
            +
              params:
         
     | 
| 67 | 
         
            +
                batch_size: 512
         
     | 
| 68 | 
         
            +
                num_workers: 1
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
            lightning:
         
     | 
| 71 | 
         
            +
              modelcheckpoint:
         
     | 
| 72 | 
         
            +
                params:
         
     | 
| 73 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
              callbacks:
         
     | 
| 76 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 77 | 
         
            +
                  params:
         
     | 
| 78 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                image_logger:
         
     | 
| 81 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 82 | 
         
            +
                  params:
         
     | 
| 83 | 
         
            +
                    disabled: False
         
     | 
| 84 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 85 | 
         
            +
                    max_images: 16
         
     | 
| 86 | 
         
            +
                    increase_log_steps: True
         
     | 
| 87 | 
         
            +
                    log_first_step: False
         
     | 
| 88 | 
         
            +
                    log_images_kwargs:
         
     | 
| 89 | 
         
            +
                      use_ema_scope: False
         
     | 
| 90 | 
         
            +
                      N: 16
         
     | 
| 91 | 
         
            +
                      n_rows: 4
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
              trainer:
         
     | 
| 94 | 
         
            +
                devices: 0,
         
     | 
| 95 | 
         
            +
                benchmark: True
         
     | 
| 96 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 97 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 98 | 
         
            +
                max_epochs: 20
         
     | 
    	
        configs/example_training/toy/mnist_cond_discrete_eps.yaml
    ADDED
    
    | 
         @@ -0,0 +1,103 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                denoiser_config:
         
     | 
| 6 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
         
     | 
| 7 | 
         
            +
                  params:
         
     | 
| 8 | 
         
            +
                    num_idx: 1000
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                    scaling_config:
         
     | 
| 11 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
         
     | 
| 12 | 
         
            +
                    discretization_config:
         
     | 
| 13 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                network_config:
         
     | 
| 16 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 17 | 
         
            +
                  params:
         
     | 
| 18 | 
         
            +
                    in_channels: 1
         
     | 
| 19 | 
         
            +
                    out_channels: 1
         
     | 
| 20 | 
         
            +
                    model_channels: 32
         
     | 
| 21 | 
         
            +
                    attention_resolutions: []
         
     | 
| 22 | 
         
            +
                    num_res_blocks: 4
         
     | 
| 23 | 
         
            +
                    channel_mult: [1, 2, 2]
         
     | 
| 24 | 
         
            +
                    num_head_channels: 32
         
     | 
| 25 | 
         
            +
                    num_classes: sequential
         
     | 
| 26 | 
         
            +
                    adm_in_channels: 128
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                conditioner_config:
         
     | 
| 29 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 30 | 
         
            +
                  params:
         
     | 
| 31 | 
         
            +
                    emb_models:
         
     | 
| 32 | 
         
            +
                      - is_trainable: True
         
     | 
| 33 | 
         
            +
                        input_key: cls
         
     | 
| 34 | 
         
            +
                        ucg_rate: 0.2
         
     | 
| 35 | 
         
            +
                        target: sgm.modules.encoders.modules.ClassEmbedder
         
     | 
| 36 | 
         
            +
                        params:
         
     | 
| 37 | 
         
            +
                          embed_dim: 128
         
     | 
| 38 | 
         
            +
                          n_classes: 10
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                first_stage_config:
         
     | 
| 41 | 
         
            +
                  target: sgm.models.autoencoder.IdentityFirstStage
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                loss_fn_config:
         
     | 
| 44 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 45 | 
         
            +
                  params:
         
     | 
| 46 | 
         
            +
                    loss_weighting_config:
         
     | 
| 47 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
         
     | 
| 48 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 49 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
         
     | 
| 50 | 
         
            +
                      params:
         
     | 
| 51 | 
         
            +
                        num_idx: 1000
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                        discretization_config:
         
     | 
| 54 | 
         
            +
                          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                sampler_config:
         
     | 
| 57 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 58 | 
         
            +
                  params:
         
     | 
| 59 | 
         
            +
                    num_steps: 50
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    discretization_config:
         
     | 
| 62 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    guider_config:
         
     | 
| 65 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.VanillaCFG
         
     | 
| 66 | 
         
            +
                      params:
         
     | 
| 67 | 
         
            +
                        scale: 5.0
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            data:
         
     | 
| 70 | 
         
            +
              target: sgm.data.mnist.MNISTLoader
         
     | 
| 71 | 
         
            +
              params:
         
     | 
| 72 | 
         
            +
                batch_size: 512
         
     | 
| 73 | 
         
            +
                num_workers: 1
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
            lightning:
         
     | 
| 76 | 
         
            +
              modelcheckpoint:
         
     | 
| 77 | 
         
            +
                params:
         
     | 
| 78 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
              callbacks:
         
     | 
| 81 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 82 | 
         
            +
                  params:
         
     | 
| 83 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                image_logger:
         
     | 
| 86 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 87 | 
         
            +
                  params:
         
     | 
| 88 | 
         
            +
                    disabled: False
         
     | 
| 89 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 90 | 
         
            +
                    max_images: 16
         
     | 
| 91 | 
         
            +
                    increase_log_steps: True
         
     | 
| 92 | 
         
            +
                    log_first_step: False
         
     | 
| 93 | 
         
            +
                    log_images_kwargs:
         
     | 
| 94 | 
         
            +
                      use_ema_scope: False
         
     | 
| 95 | 
         
            +
                      N: 16
         
     | 
| 96 | 
         
            +
                      n_rows: 4
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
              trainer:
         
     | 
| 99 | 
         
            +
                devices: 0,
         
     | 
| 100 | 
         
            +
                benchmark: True
         
     | 
| 101 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 102 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 103 | 
         
            +
                max_epochs: 20
         
     | 
    	
        configs/example_training/toy/mnist_cond_l1_loss.yaml
    ADDED
    
    | 
         @@ -0,0 +1,99 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                denoiser_config:
         
     | 
| 6 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.Denoiser
         
     | 
| 7 | 
         
            +
                  params:
         
     | 
| 8 | 
         
            +
                    scaling_config:
         
     | 
| 9 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
         
     | 
| 10 | 
         
            +
                      params:
         
     | 
| 11 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                network_config:
         
     | 
| 14 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 15 | 
         
            +
                  params:
         
     | 
| 16 | 
         
            +
                    in_channels: 1
         
     | 
| 17 | 
         
            +
                    out_channels: 1
         
     | 
| 18 | 
         
            +
                    model_channels: 32
         
     | 
| 19 | 
         
            +
                    attention_resolutions: []
         
     | 
| 20 | 
         
            +
                    num_res_blocks: 4
         
     | 
| 21 | 
         
            +
                    channel_mult: [1, 2, 2]
         
     | 
| 22 | 
         
            +
                    num_head_channels: 32
         
     | 
| 23 | 
         
            +
                    num_classes: sequential
         
     | 
| 24 | 
         
            +
                    adm_in_channels: 128
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                conditioner_config:
         
     | 
| 27 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 28 | 
         
            +
                  params:
         
     | 
| 29 | 
         
            +
                    emb_models:
         
     | 
| 30 | 
         
            +
                      - is_trainable: True
         
     | 
| 31 | 
         
            +
                        input_key: cls
         
     | 
| 32 | 
         
            +
                        ucg_rate: 0.2
         
     | 
| 33 | 
         
            +
                        target: sgm.modules.encoders.modules.ClassEmbedder
         
     | 
| 34 | 
         
            +
                        params:
         
     | 
| 35 | 
         
            +
                          embed_dim: 128
         
     | 
| 36 | 
         
            +
                          n_classes: 10
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                first_stage_config:
         
     | 
| 39 | 
         
            +
                  target: sgm.models.autoencoder.IdentityFirstStage
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                loss_fn_config:
         
     | 
| 42 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 43 | 
         
            +
                  params:
         
     | 
| 44 | 
         
            +
                    loss_type: l1
         
     | 
| 45 | 
         
            +
                    loss_weighting_config:
         
     | 
| 46 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
         
     | 
| 47 | 
         
            +
                      params:
         
     | 
| 48 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 49 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 50 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                sampler_config:
         
     | 
| 53 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 54 | 
         
            +
                  params:
         
     | 
| 55 | 
         
            +
                    num_steps: 50
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    discretization_config:
         
     | 
| 58 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                    guider_config:
         
     | 
| 61 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.VanillaCFG
         
     | 
| 62 | 
         
            +
                      params:
         
     | 
| 63 | 
         
            +
                        scale: 3.0
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
            data:
         
     | 
| 66 | 
         
            +
              target: sgm.data.mnist.MNISTLoader
         
     | 
| 67 | 
         
            +
              params:
         
     | 
| 68 | 
         
            +
                batch_size: 512
         
     | 
| 69 | 
         
            +
                num_workers: 1
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            lightning:
         
     | 
| 72 | 
         
            +
              modelcheckpoint:
         
     | 
| 73 | 
         
            +
                params:
         
     | 
| 74 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
              callbacks:
         
     | 
| 77 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 78 | 
         
            +
                  params:
         
     | 
| 79 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                image_logger:
         
     | 
| 82 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 83 | 
         
            +
                  params:
         
     | 
| 84 | 
         
            +
                    disabled: False
         
     | 
| 85 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 86 | 
         
            +
                    max_images: 64
         
     | 
| 87 | 
         
            +
                    increase_log_steps: True
         
     | 
| 88 | 
         
            +
                    log_first_step: False
         
     | 
| 89 | 
         
            +
                    log_images_kwargs:
         
     | 
| 90 | 
         
            +
                      use_ema_scope: False
         
     | 
| 91 | 
         
            +
                      N: 64
         
     | 
| 92 | 
         
            +
                      n_rows: 8
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
              trainer:
         
     | 
| 95 | 
         
            +
                devices: 0,
         
     | 
| 96 | 
         
            +
                benchmark: True
         
     | 
| 97 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 98 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 99 | 
         
            +
                max_epochs: 20
         
     | 
    	
        configs/example_training/toy/mnist_cond_with_ema.yaml
    ADDED
    
    | 
         @@ -0,0 +1,100 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                use_ema: True
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                denoiser_config:
         
     | 
| 8 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.Denoiser
         
     | 
| 9 | 
         
            +
                  params:
         
     | 
| 10 | 
         
            +
                    scaling_config:
         
     | 
| 11 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling
         
     | 
| 12 | 
         
            +
                      params:
         
     | 
| 13 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
                network_config:
         
     | 
| 16 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 17 | 
         
            +
                  params:
         
     | 
| 18 | 
         
            +
                    in_channels: 1
         
     | 
| 19 | 
         
            +
                    out_channels: 1
         
     | 
| 20 | 
         
            +
                    model_channels: 32
         
     | 
| 21 | 
         
            +
                    attention_resolutions: []
         
     | 
| 22 | 
         
            +
                    num_res_blocks: 4
         
     | 
| 23 | 
         
            +
                    channel_mult: [1, 2, 2]
         
     | 
| 24 | 
         
            +
                    num_head_channels: 32
         
     | 
| 25 | 
         
            +
                    num_classes: sequential
         
     | 
| 26 | 
         
            +
                    adm_in_channels: 128
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                conditioner_config:
         
     | 
| 29 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 30 | 
         
            +
                  params:
         
     | 
| 31 | 
         
            +
                    emb_models:
         
     | 
| 32 | 
         
            +
                      - is_trainable: True
         
     | 
| 33 | 
         
            +
                        input_key: cls
         
     | 
| 34 | 
         
            +
                        ucg_rate: 0.2
         
     | 
| 35 | 
         
            +
                        target: sgm.modules.encoders.modules.ClassEmbedder
         
     | 
| 36 | 
         
            +
                        params:
         
     | 
| 37 | 
         
            +
                          embed_dim: 128
         
     | 
| 38 | 
         
            +
                          n_classes: 10
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                first_stage_config:
         
     | 
| 41 | 
         
            +
                  target: sgm.models.autoencoder.IdentityFirstStage
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                loss_fn_config:
         
     | 
| 44 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 45 | 
         
            +
                  params:
         
     | 
| 46 | 
         
            +
                    loss_weighting_config:
         
     | 
| 47 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
         
     | 
| 48 | 
         
            +
                      params:
         
     | 
| 49 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 50 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 51 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                sampler_config:
         
     | 
| 54 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 55 | 
         
            +
                  params:
         
     | 
| 56 | 
         
            +
                    num_steps: 50
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    discretization_config:
         
     | 
| 59 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    guider_config:
         
     | 
| 62 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.VanillaCFG
         
     | 
| 63 | 
         
            +
                      params:
         
     | 
| 64 | 
         
            +
                        scale: 3.0
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            data:
         
     | 
| 67 | 
         
            +
              target: sgm.data.mnist.MNISTLoader
         
     | 
| 68 | 
         
            +
              params:
         
     | 
| 69 | 
         
            +
                batch_size: 512
         
     | 
| 70 | 
         
            +
                num_workers: 1
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            lightning:
         
     | 
| 73 | 
         
            +
              modelcheckpoint:
         
     | 
| 74 | 
         
            +
                params:
         
     | 
| 75 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
              callbacks:
         
     | 
| 78 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 79 | 
         
            +
                  params:
         
     | 
| 80 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                image_logger:
         
     | 
| 83 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 84 | 
         
            +
                  params:
         
     | 
| 85 | 
         
            +
                    disabled: False
         
     | 
| 86 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 87 | 
         
            +
                    max_images: 64
         
     | 
| 88 | 
         
            +
                    increase_log_steps: True
         
     | 
| 89 | 
         
            +
                    log_first_step: False
         
     | 
| 90 | 
         
            +
                    log_images_kwargs:
         
     | 
| 91 | 
         
            +
                      use_ema_scope: False
         
     | 
| 92 | 
         
            +
                      N: 64
         
     | 
| 93 | 
         
            +
                      n_rows: 8
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
              trainer:
         
     | 
| 96 | 
         
            +
                devices: 0,
         
     | 
| 97 | 
         
            +
                benchmark: True
         
     | 
| 98 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 99 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 100 | 
         
            +
                max_epochs: 20
         
     | 
    	
        configs/example_training/txt2img-clipl-legacy-ucg-training.yaml
    ADDED
    
    | 
         @@ -0,0 +1,182 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                scale_factor: 0.13025
         
     | 
| 6 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 7 | 
         
            +
                log_keys:
         
     | 
| 8 | 
         
            +
                  - txt
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                scheduler_config:
         
     | 
| 11 | 
         
            +
                  target: sgm.lr_scheduler.LambdaLinearScheduler
         
     | 
| 12 | 
         
            +
                  params:
         
     | 
| 13 | 
         
            +
                    warm_up_steps: [10000]
         
     | 
| 14 | 
         
            +
                    cycle_lengths: [10000000000000]
         
     | 
| 15 | 
         
            +
                    f_start: [1.e-6]
         
     | 
| 16 | 
         
            +
                    f_max: [1.]
         
     | 
| 17 | 
         
            +
                    f_min: [1.]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                denoiser_config:
         
     | 
| 20 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
         
     | 
| 21 | 
         
            +
                  params:
         
     | 
| 22 | 
         
            +
                    num_idx: 1000
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    scaling_config:
         
     | 
| 25 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
         
     | 
| 26 | 
         
            +
                    discretization_config:
         
     | 
| 27 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                network_config:
         
     | 
| 30 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 31 | 
         
            +
                  params:
         
     | 
| 32 | 
         
            +
                    use_checkpoint: True
         
     | 
| 33 | 
         
            +
                    in_channels: 4
         
     | 
| 34 | 
         
            +
                    out_channels: 4
         
     | 
| 35 | 
         
            +
                    model_channels: 320
         
     | 
| 36 | 
         
            +
                    attention_resolutions: [1, 2, 4]
         
     | 
| 37 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 38 | 
         
            +
                    channel_mult: [1, 2, 4, 4]
         
     | 
| 39 | 
         
            +
                    num_head_channels: 64
         
     | 
| 40 | 
         
            +
                    num_classes: sequential
         
     | 
| 41 | 
         
            +
                    adm_in_channels: 1792
         
     | 
| 42 | 
         
            +
                    num_heads: 1
         
     | 
| 43 | 
         
            +
                    transformer_depth: 1
         
     | 
| 44 | 
         
            +
                    context_dim: 768
         
     | 
| 45 | 
         
            +
                    spatial_transformer_attn_type: softmax-xformers
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                conditioner_config:
         
     | 
| 48 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 49 | 
         
            +
                  params:
         
     | 
| 50 | 
         
            +
                    emb_models:
         
     | 
| 51 | 
         
            +
                      - is_trainable: True
         
     | 
| 52 | 
         
            +
                        input_key: txt
         
     | 
| 53 | 
         
            +
                        ucg_rate: 0.1
         
     | 
| 54 | 
         
            +
                        legacy_ucg_value: ""
         
     | 
| 55 | 
         
            +
                        target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
         
     | 
| 56 | 
         
            +
                        params:
         
     | 
| 57 | 
         
            +
                          always_return_pooled: True
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                      - is_trainable: False
         
     | 
| 60 | 
         
            +
                        ucg_rate: 0.1
         
     | 
| 61 | 
         
            +
                        input_key: original_size_as_tuple
         
     | 
| 62 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 63 | 
         
            +
                        params:
         
     | 
| 64 | 
         
            +
                          outdim: 256
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                      - is_trainable: False
         
     | 
| 67 | 
         
            +
                        input_key: crop_coords_top_left
         
     | 
| 68 | 
         
            +
                        ucg_rate: 0.1
         
     | 
| 69 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 70 | 
         
            +
                        params:
         
     | 
| 71 | 
         
            +
                          outdim: 256
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                first_stage_config:
         
     | 
| 74 | 
         
            +
                  target: sgm.models.autoencoder.AutoencoderKL
         
     | 
| 75 | 
         
            +
                  params:
         
     | 
| 76 | 
         
            +
                    ckpt_path: CKPT_PATH
         
     | 
| 77 | 
         
            +
                    embed_dim: 4
         
     | 
| 78 | 
         
            +
                    monitor: val/rec_loss
         
     | 
| 79 | 
         
            +
                    ddconfig:
         
     | 
| 80 | 
         
            +
                      attn_type: vanilla-xformers
         
     | 
| 81 | 
         
            +
                      double_z: true
         
     | 
| 82 | 
         
            +
                      z_channels: 4
         
     | 
| 83 | 
         
            +
                      resolution: 256
         
     | 
| 84 | 
         
            +
                      in_channels: 3
         
     | 
| 85 | 
         
            +
                      out_ch: 3
         
     | 
| 86 | 
         
            +
                      ch: 128
         
     | 
| 87 | 
         
            +
                      ch_mult: [ 1, 2, 4, 4 ]
         
     | 
| 88 | 
         
            +
                      num_res_blocks: 2
         
     | 
| 89 | 
         
            +
                      attn_resolutions: [ ]
         
     | 
| 90 | 
         
            +
                      dropout: 0.0
         
     | 
| 91 | 
         
            +
                    lossconfig:
         
     | 
| 92 | 
         
            +
                      target: torch.nn.Identity
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                loss_fn_config:
         
     | 
| 95 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 96 | 
         
            +
                  params:
         
     | 
| 97 | 
         
            +
                    loss_weighting_config:
         
     | 
| 98 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
         
     | 
| 99 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 100 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
         
     | 
| 101 | 
         
            +
                      params:
         
     | 
| 102 | 
         
            +
                        num_idx: 1000
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                        discretization_config:
         
     | 
| 105 | 
         
            +
                          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                sampler_config:
         
     | 
| 108 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 109 | 
         
            +
                  params:
         
     | 
| 110 | 
         
            +
                    num_steps: 50
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    discretization_config:
         
     | 
| 113 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    guider_config:
         
     | 
| 116 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.VanillaCFG
         
     | 
| 117 | 
         
            +
                      params:
         
     | 
| 118 | 
         
            +
                        scale: 7.5
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            data:
         
     | 
| 121 | 
         
            +
              target: sgm.data.dataset.StableDataModuleFromConfig
         
     | 
| 122 | 
         
            +
              params:
         
     | 
| 123 | 
         
            +
                train:
         
     | 
| 124 | 
         
            +
                  datapipeline:
         
     | 
| 125 | 
         
            +
                    urls:
         
     | 
| 126 | 
         
            +
                      # USER: adapt this path the root of your custom dataset
         
     | 
| 127 | 
         
            +
                      - DATA_PATH
         
     | 
| 128 | 
         
            +
                    pipeline_config:
         
     | 
| 129 | 
         
            +
                      shardshuffle: 10000
         
     | 
| 130 | 
         
            +
                      sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    decoders:
         
     | 
| 133 | 
         
            +
                      - pil
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                    postprocessors:
         
     | 
| 136 | 
         
            +
                      - target: sdata.mappers.TorchVisionImageTransforms
         
     | 
| 137 | 
         
            +
                        params:
         
     | 
| 138 | 
         
            +
                          key: jpg # USER: you might wanna adapt this for your custom dataset
         
     | 
| 139 | 
         
            +
                          transforms:
         
     | 
| 140 | 
         
            +
                            - target: torchvision.transforms.Resize
         
     | 
| 141 | 
         
            +
                              params:
         
     | 
| 142 | 
         
            +
                                size: 256
         
     | 
| 143 | 
         
            +
                                interpolation: 3
         
     | 
| 144 | 
         
            +
                            - target: torchvision.transforms.ToTensor
         
     | 
| 145 | 
         
            +
                      - target: sdata.mappers.Rescaler
         
     | 
| 146 | 
         
            +
                      - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
         
     | 
| 147 | 
         
            +
                        # USER: you might wanna use non-default parameters due to your custom dataset
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                  loader:
         
     | 
| 150 | 
         
            +
                    batch_size: 64
         
     | 
| 151 | 
         
            +
                    num_workers: 6
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            lightning:
         
     | 
| 154 | 
         
            +
              modelcheckpoint:
         
     | 
| 155 | 
         
            +
                params:
         
     | 
| 156 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
              callbacks:
         
     | 
| 159 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 160 | 
         
            +
                  params:
         
     | 
| 161 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                image_logger:
         
     | 
| 164 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 165 | 
         
            +
                  params:
         
     | 
| 166 | 
         
            +
                    disabled: False
         
     | 
| 167 | 
         
            +
                    enable_autocast: False
         
     | 
| 168 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 169 | 
         
            +
                    max_images: 8
         
     | 
| 170 | 
         
            +
                    increase_log_steps: True
         
     | 
| 171 | 
         
            +
                    log_first_step: False
         
     | 
| 172 | 
         
            +
                    log_images_kwargs:
         
     | 
| 173 | 
         
            +
                      use_ema_scope: False
         
     | 
| 174 | 
         
            +
                      N: 8
         
     | 
| 175 | 
         
            +
                      n_rows: 2
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
              trainer:
         
     | 
| 178 | 
         
            +
                devices: 0,
         
     | 
| 179 | 
         
            +
                benchmark: True
         
     | 
| 180 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 181 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 182 | 
         
            +
                max_epochs: 1000
         
     | 
    	
        configs/example_training/txt2img-clipl.yaml
    ADDED
    
    | 
         @@ -0,0 +1,184 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-4
         
     | 
| 3 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                scale_factor: 0.13025
         
     | 
| 6 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 7 | 
         
            +
                log_keys:
         
     | 
| 8 | 
         
            +
                  - txt
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                scheduler_config:
         
     | 
| 11 | 
         
            +
                  target: sgm.lr_scheduler.LambdaLinearScheduler
         
     | 
| 12 | 
         
            +
                  params:
         
     | 
| 13 | 
         
            +
                    warm_up_steps: [10000]
         
     | 
| 14 | 
         
            +
                    cycle_lengths: [10000000000000]
         
     | 
| 15 | 
         
            +
                    f_start: [1.e-6]
         
     | 
| 16 | 
         
            +
                    f_max: [1.]
         
     | 
| 17 | 
         
            +
                    f_min: [1.]
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                denoiser_config:
         
     | 
| 20 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
         
     | 
| 21 | 
         
            +
                  params:
         
     | 
| 22 | 
         
            +
                    num_idx: 1000
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    scaling_config:
         
     | 
| 25 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
         
     | 
| 26 | 
         
            +
                    discretization_config:
         
     | 
| 27 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                network_config:
         
     | 
| 30 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 31 | 
         
            +
                  params:
         
     | 
| 32 | 
         
            +
                    use_checkpoint: True
         
     | 
| 33 | 
         
            +
                    in_channels: 4
         
     | 
| 34 | 
         
            +
                    out_channels: 4
         
     | 
| 35 | 
         
            +
                    model_channels: 320
         
     | 
| 36 | 
         
            +
                    attention_resolutions: [1, 2, 4]
         
     | 
| 37 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 38 | 
         
            +
                    channel_mult: [1, 2, 4, 4]
         
     | 
| 39 | 
         
            +
                    num_head_channels: 64
         
     | 
| 40 | 
         
            +
                    num_classes: sequential
         
     | 
| 41 | 
         
            +
                    adm_in_channels: 1792
         
     | 
| 42 | 
         
            +
                    num_heads: 1
         
     | 
| 43 | 
         
            +
                    transformer_depth: 1
         
     | 
| 44 | 
         
            +
                    context_dim: 768
         
     | 
| 45 | 
         
            +
                    spatial_transformer_attn_type: softmax-xformers
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                conditioner_config:
         
     | 
| 48 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 49 | 
         
            +
                  params:
         
     | 
| 50 | 
         
            +
                    emb_models:
         
     | 
| 51 | 
         
            +
                      - is_trainable: True
         
     | 
| 52 | 
         
            +
                        input_key: txt
         
     | 
| 53 | 
         
            +
                        ucg_rate: 0.1
         
     | 
| 54 | 
         
            +
                        legacy_ucg_value: ""
         
     | 
| 55 | 
         
            +
                        target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
         
     | 
| 56 | 
         
            +
                        params:
         
     | 
| 57 | 
         
            +
                          always_return_pooled: True
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                      - is_trainable: False
         
     | 
| 60 | 
         
            +
                        ucg_rate: 0.1
         
     | 
| 61 | 
         
            +
                        input_key: original_size_as_tuple
         
     | 
| 62 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 63 | 
         
            +
                        params:
         
     | 
| 64 | 
         
            +
                          outdim: 256
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                      - is_trainable: False
         
     | 
| 67 | 
         
            +
                        input_key: crop_coords_top_left
         
     | 
| 68 | 
         
            +
                        ucg_rate: 0.1
         
     | 
| 69 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 70 | 
         
            +
                        params:
         
     | 
| 71 | 
         
            +
                          outdim: 256
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                first_stage_config:
         
     | 
| 74 | 
         
            +
                  target: sgm.models.autoencoder.AutoencoderKL
         
     | 
| 75 | 
         
            +
                  params:
         
     | 
| 76 | 
         
            +
                    ckpt_path: CKPT_PATH
         
     | 
| 77 | 
         
            +
                    embed_dim: 4
         
     | 
| 78 | 
         
            +
                    monitor: val/rec_loss
         
     | 
| 79 | 
         
            +
                    ddconfig:
         
     | 
| 80 | 
         
            +
                      attn_type: vanilla-xformers
         
     | 
| 81 | 
         
            +
                      double_z: true
         
     | 
| 82 | 
         
            +
                      z_channels: 4
         
     | 
| 83 | 
         
            +
                      resolution: 256
         
     | 
| 84 | 
         
            +
                      in_channels: 3
         
     | 
| 85 | 
         
            +
                      out_ch: 3
         
     | 
| 86 | 
         
            +
                      ch: 128
         
     | 
| 87 | 
         
            +
                      ch_mult: [1, 2, 4, 4]
         
     | 
| 88 | 
         
            +
                      num_res_blocks: 2
         
     | 
| 89 | 
         
            +
                      attn_resolutions: []
         
     | 
| 90 | 
         
            +
                      dropout: 0.0
         
     | 
| 91 | 
         
            +
                    lossconfig:
         
     | 
| 92 | 
         
            +
                      target: torch.nn.Identity
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                loss_fn_config:
         
     | 
| 95 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 96 | 
         
            +
                  params:
         
     | 
| 97 | 
         
            +
                    loss_weighting_config:
         
     | 
| 98 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
         
     | 
| 99 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 100 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
         
     | 
| 101 | 
         
            +
                      params:
         
     | 
| 102 | 
         
            +
                        num_idx: 1000
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                        discretization_config:
         
     | 
| 105 | 
         
            +
                          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                sampler_config:
         
     | 
| 108 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 109 | 
         
            +
                  params:
         
     | 
| 110 | 
         
            +
                    num_steps: 50
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    discretization_config:
         
     | 
| 113 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    guider_config:
         
     | 
| 116 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.VanillaCFG
         
     | 
| 117 | 
         
            +
                      params:
         
     | 
| 118 | 
         
            +
                        scale: 7.5
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
            data:
         
     | 
| 121 | 
         
            +
              target: sgm.data.dataset.StableDataModuleFromConfig
         
     | 
| 122 | 
         
            +
              params:
         
     | 
| 123 | 
         
            +
                train:
         
     | 
| 124 | 
         
            +
                  datapipeline:
         
     | 
| 125 | 
         
            +
                    urls:
         
     | 
| 126 | 
         
            +
                      # USER: adapt this path the root of your custom dataset
         
     | 
| 127 | 
         
            +
                      - DATA_PATH
         
     | 
| 128 | 
         
            +
                    pipeline_config:
         
     | 
| 129 | 
         
            +
                      shardshuffle: 10000
         
     | 
| 130 | 
         
            +
                      sample_shuffle: 10000
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    decoders:
         
     | 
| 134 | 
         
            +
                      - pil
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    postprocessors:
         
     | 
| 137 | 
         
            +
                      - target: sdata.mappers.TorchVisionImageTransforms
         
     | 
| 138 | 
         
            +
                        params:
         
     | 
| 139 | 
         
            +
                          key: jpg # USER: you might wanna adapt this for your custom dataset
         
     | 
| 140 | 
         
            +
                          transforms:
         
     | 
| 141 | 
         
            +
                            - target: torchvision.transforms.Resize
         
     | 
| 142 | 
         
            +
                              params:
         
     | 
| 143 | 
         
            +
                                size: 256
         
     | 
| 144 | 
         
            +
                                interpolation: 3
         
     | 
| 145 | 
         
            +
                            - target: torchvision.transforms.ToTensor
         
     | 
| 146 | 
         
            +
                      - target: sdata.mappers.Rescaler
         
     | 
| 147 | 
         
            +
                        # USER: you might wanna use non-default parameters due to your custom dataset
         
     | 
| 148 | 
         
            +
                      - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
         
     | 
| 149 | 
         
            +
                        # USER: you might wanna use non-default parameters due to your custom dataset
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                  loader:
         
     | 
| 152 | 
         
            +
                    batch_size: 64
         
     | 
| 153 | 
         
            +
                    num_workers: 6
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            lightning:
         
     | 
| 156 | 
         
            +
              modelcheckpoint:
         
     | 
| 157 | 
         
            +
                params:
         
     | 
| 158 | 
         
            +
                  every_n_train_steps: 5000
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
              callbacks:
         
     | 
| 161 | 
         
            +
                metrics_over_trainsteps_checkpoint:
         
     | 
| 162 | 
         
            +
                  params:
         
     | 
| 163 | 
         
            +
                    every_n_train_steps: 25000
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                image_logger:
         
     | 
| 166 | 
         
            +
                  target: main.ImageLogger
         
     | 
| 167 | 
         
            +
                  params:
         
     | 
| 168 | 
         
            +
                    disabled: False
         
     | 
| 169 | 
         
            +
                    enable_autocast: False
         
     | 
| 170 | 
         
            +
                    batch_frequency: 1000
         
     | 
| 171 | 
         
            +
                    max_images: 8
         
     | 
| 172 | 
         
            +
                    increase_log_steps: True
         
     | 
| 173 | 
         
            +
                    log_first_step: False
         
     | 
| 174 | 
         
            +
                    log_images_kwargs:
         
     | 
| 175 | 
         
            +
                      use_ema_scope: False
         
     | 
| 176 | 
         
            +
                      N: 8
         
     | 
| 177 | 
         
            +
                      n_rows: 2
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
              trainer:
         
     | 
| 180 | 
         
            +
                devices: 0,
         
     | 
| 181 | 
         
            +
                benchmark: True
         
     | 
| 182 | 
         
            +
                num_sanity_val_steps: 0
         
     | 
| 183 | 
         
            +
                accumulate_grad_batches: 1
         
     | 
| 184 | 
         
            +
                max_epochs: 1000
         
     | 
    	
        configs/inference/sd_2_1.yaml
    ADDED
    
    | 
         @@ -0,0 +1,60 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                scale_factor: 0.18215
         
     | 
| 5 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                denoiser_config:
         
     | 
| 8 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
         
     | 
| 9 | 
         
            +
                  params:
         
     | 
| 10 | 
         
            +
                    num_idx: 1000
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    scaling_config:
         
     | 
| 13 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
         
     | 
| 14 | 
         
            +
                    discretization_config:
         
     | 
| 15 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                network_config:
         
     | 
| 18 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 19 | 
         
            +
                  params:
         
     | 
| 20 | 
         
            +
                    use_checkpoint: True
         
     | 
| 21 | 
         
            +
                    in_channels: 4
         
     | 
| 22 | 
         
            +
                    out_channels: 4
         
     | 
| 23 | 
         
            +
                    model_channels: 320
         
     | 
| 24 | 
         
            +
                    attention_resolutions: [4, 2, 1]
         
     | 
| 25 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 26 | 
         
            +
                    channel_mult: [1, 2, 4, 4]
         
     | 
| 27 | 
         
            +
                    num_head_channels: 64
         
     | 
| 28 | 
         
            +
                    use_linear_in_transformer: True
         
     | 
| 29 | 
         
            +
                    transformer_depth: 1
         
     | 
| 30 | 
         
            +
                    context_dim: 1024
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                conditioner_config:
         
     | 
| 33 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 34 | 
         
            +
                  params:
         
     | 
| 35 | 
         
            +
                    emb_models:
         
     | 
| 36 | 
         
            +
                      - is_trainable: False
         
     | 
| 37 | 
         
            +
                        input_key: txt
         
     | 
| 38 | 
         
            +
                        target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
         
     | 
| 39 | 
         
            +
                        params:
         
     | 
| 40 | 
         
            +
                          freeze: true
         
     | 
| 41 | 
         
            +
                          layer: penultimate
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                first_stage_config:
         
     | 
| 44 | 
         
            +
                  target: sgm.models.autoencoder.AutoencoderKL
         
     | 
| 45 | 
         
            +
                  params:
         
     | 
| 46 | 
         
            +
                    embed_dim: 4
         
     | 
| 47 | 
         
            +
                    monitor: val/rec_loss
         
     | 
| 48 | 
         
            +
                    ddconfig:
         
     | 
| 49 | 
         
            +
                      double_z: true
         
     | 
| 50 | 
         
            +
                      z_channels: 4
         
     | 
| 51 | 
         
            +
                      resolution: 256
         
     | 
| 52 | 
         
            +
                      in_channels: 3
         
     | 
| 53 | 
         
            +
                      out_ch: 3
         
     | 
| 54 | 
         
            +
                      ch: 128
         
     | 
| 55 | 
         
            +
                      ch_mult: [1, 2, 4, 4]
         
     | 
| 56 | 
         
            +
                      num_res_blocks: 2
         
     | 
| 57 | 
         
            +
                      attn_resolutions: []
         
     | 
| 58 | 
         
            +
                      dropout: 0.0
         
     | 
| 59 | 
         
            +
                    lossconfig:
         
     | 
| 60 | 
         
            +
                      target: torch.nn.Identity
         
     | 
    	
        configs/inference/sd_2_1_768.yaml
    ADDED
    
    | 
         @@ -0,0 +1,60 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                scale_factor: 0.18215
         
     | 
| 5 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                denoiser_config:
         
     | 
| 8 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
         
     | 
| 9 | 
         
            +
                  params:
         
     | 
| 10 | 
         
            +
                    num_idx: 1000
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    scaling_config:
         
     | 
| 13 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling
         
     | 
| 14 | 
         
            +
                    discretization_config:
         
     | 
| 15 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                network_config:
         
     | 
| 18 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 19 | 
         
            +
                  params:
         
     | 
| 20 | 
         
            +
                    use_checkpoint: True
         
     | 
| 21 | 
         
            +
                    in_channels: 4
         
     | 
| 22 | 
         
            +
                    out_channels: 4
         
     | 
| 23 | 
         
            +
                    model_channels: 320
         
     | 
| 24 | 
         
            +
                    attention_resolutions: [4, 2, 1]
         
     | 
| 25 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 26 | 
         
            +
                    channel_mult: [1, 2, 4, 4]
         
     | 
| 27 | 
         
            +
                    num_head_channels: 64
         
     | 
| 28 | 
         
            +
                    use_linear_in_transformer: True
         
     | 
| 29 | 
         
            +
                    transformer_depth: 1
         
     | 
| 30 | 
         
            +
                    context_dim: 1024
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                conditioner_config:
         
     | 
| 33 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 34 | 
         
            +
                  params:
         
     | 
| 35 | 
         
            +
                    emb_models:
         
     | 
| 36 | 
         
            +
                      - is_trainable: False
         
     | 
| 37 | 
         
            +
                        input_key: txt
         
     | 
| 38 | 
         
            +
                        target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
         
     | 
| 39 | 
         
            +
                        params:
         
     | 
| 40 | 
         
            +
                          freeze: true
         
     | 
| 41 | 
         
            +
                          layer: penultimate
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                first_stage_config:
         
     | 
| 44 | 
         
            +
                  target: sgm.models.autoencoder.AutoencoderKL
         
     | 
| 45 | 
         
            +
                  params:
         
     | 
| 46 | 
         
            +
                    embed_dim: 4
         
     | 
| 47 | 
         
            +
                    monitor: val/rec_loss
         
     | 
| 48 | 
         
            +
                    ddconfig:
         
     | 
| 49 | 
         
            +
                      double_z: true
         
     | 
| 50 | 
         
            +
                      z_channels: 4
         
     | 
| 51 | 
         
            +
                      resolution: 256
         
     | 
| 52 | 
         
            +
                      in_channels: 3
         
     | 
| 53 | 
         
            +
                      out_ch: 3
         
     | 
| 54 | 
         
            +
                      ch: 128
         
     | 
| 55 | 
         
            +
                      ch_mult: [1, 2, 4, 4]
         
     | 
| 56 | 
         
            +
                      num_res_blocks: 2
         
     | 
| 57 | 
         
            +
                      attn_resolutions: []
         
     | 
| 58 | 
         
            +
                      dropout: 0.0
         
     | 
| 59 | 
         
            +
                    lossconfig:
         
     | 
| 60 | 
         
            +
                      target: torch.nn.Identity
         
     | 
    	
        configs/inference/sd_xl_base.yaml
    ADDED
    
    | 
         @@ -0,0 +1,93 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                scale_factor: 0.13025
         
     | 
| 5 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                denoiser_config:
         
     | 
| 8 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
         
     | 
| 9 | 
         
            +
                  params:
         
     | 
| 10 | 
         
            +
                    num_idx: 1000
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    scaling_config:
         
     | 
| 13 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
         
     | 
| 14 | 
         
            +
                    discretization_config:
         
     | 
| 15 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                network_config:
         
     | 
| 18 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 19 | 
         
            +
                  params:
         
     | 
| 20 | 
         
            +
                    adm_in_channels: 2816
         
     | 
| 21 | 
         
            +
                    num_classes: sequential
         
     | 
| 22 | 
         
            +
                    use_checkpoint: True
         
     | 
| 23 | 
         
            +
                    in_channels: 4
         
     | 
| 24 | 
         
            +
                    out_channels: 4
         
     | 
| 25 | 
         
            +
                    model_channels: 320
         
     | 
| 26 | 
         
            +
                    attention_resolutions: [4, 2]
         
     | 
| 27 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 28 | 
         
            +
                    channel_mult: [1, 2, 4]
         
     | 
| 29 | 
         
            +
                    num_head_channels: 64
         
     | 
| 30 | 
         
            +
                    use_linear_in_transformer: True
         
     | 
| 31 | 
         
            +
                    transformer_depth: [1, 2, 10]
         
     | 
| 32 | 
         
            +
                    context_dim: 2048
         
     | 
| 33 | 
         
            +
                    spatial_transformer_attn_type: softmax-xformers
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                conditioner_config:
         
     | 
| 36 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 37 | 
         
            +
                  params:
         
     | 
| 38 | 
         
            +
                    emb_models:
         
     | 
| 39 | 
         
            +
                      - is_trainable: False
         
     | 
| 40 | 
         
            +
                        input_key: txt
         
     | 
| 41 | 
         
            +
                        target: sgm.modules.encoders.modules.FrozenCLIPEmbedder
         
     | 
| 42 | 
         
            +
                        params:
         
     | 
| 43 | 
         
            +
                          layer: hidden
         
     | 
| 44 | 
         
            +
                          layer_idx: 11
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                      - is_trainable: False
         
     | 
| 47 | 
         
            +
                        input_key: txt
         
     | 
| 48 | 
         
            +
                        target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
         
     | 
| 49 | 
         
            +
                        params:
         
     | 
| 50 | 
         
            +
                          arch: ViT-bigG-14
         
     | 
| 51 | 
         
            +
                          version: laion2b_s39b_b160k
         
     | 
| 52 | 
         
            +
                          freeze: True
         
     | 
| 53 | 
         
            +
                          layer: penultimate
         
     | 
| 54 | 
         
            +
                          always_return_pooled: True
         
     | 
| 55 | 
         
            +
                          legacy: False
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                      - is_trainable: False
         
     | 
| 58 | 
         
            +
                        input_key: original_size_as_tuple
         
     | 
| 59 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 60 | 
         
            +
                        params:
         
     | 
| 61 | 
         
            +
                          outdim: 256
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                      - is_trainable: False
         
     | 
| 64 | 
         
            +
                        input_key: crop_coords_top_left
         
     | 
| 65 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 66 | 
         
            +
                        params:
         
     | 
| 67 | 
         
            +
                          outdim: 256
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                      - is_trainable: False
         
     | 
| 70 | 
         
            +
                        input_key: target_size_as_tuple
         
     | 
| 71 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 72 | 
         
            +
                        params:
         
     | 
| 73 | 
         
            +
                          outdim: 256
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                first_stage_config:
         
     | 
| 76 | 
         
            +
                  target: sgm.models.autoencoder.AutoencoderKL
         
     | 
| 77 | 
         
            +
                  params:
         
     | 
| 78 | 
         
            +
                    embed_dim: 4
         
     | 
| 79 | 
         
            +
                    monitor: val/rec_loss
         
     | 
| 80 | 
         
            +
                    ddconfig:
         
     | 
| 81 | 
         
            +
                      attn_type: vanilla-xformers
         
     | 
| 82 | 
         
            +
                      double_z: true
         
     | 
| 83 | 
         
            +
                      z_channels: 4
         
     | 
| 84 | 
         
            +
                      resolution: 256
         
     | 
| 85 | 
         
            +
                      in_channels: 3
         
     | 
| 86 | 
         
            +
                      out_ch: 3
         
     | 
| 87 | 
         
            +
                      ch: 128
         
     | 
| 88 | 
         
            +
                      ch_mult: [1, 2, 4, 4]
         
     | 
| 89 | 
         
            +
                      num_res_blocks: 2
         
     | 
| 90 | 
         
            +
                      attn_resolutions: []
         
     | 
| 91 | 
         
            +
                      dropout: 0.0
         
     | 
| 92 | 
         
            +
                    lossconfig:
         
     | 
| 93 | 
         
            +
                      target: torch.nn.Identity
         
     | 
    	
        configs/inference/sd_xl_refiner.yaml
    ADDED
    
    | 
         @@ -0,0 +1,86 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                scale_factor: 0.13025
         
     | 
| 5 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                denoiser_config:
         
     | 
| 8 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
         
     | 
| 9 | 
         
            +
                  params:
         
     | 
| 10 | 
         
            +
                    num_idx: 1000
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    scaling_config:
         
     | 
| 13 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling
         
     | 
| 14 | 
         
            +
                    discretization_config:
         
     | 
| 15 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                network_config:
         
     | 
| 18 | 
         
            +
                  target: sgm.modules.diffusionmodules.openaimodel.UNetModel
         
     | 
| 19 | 
         
            +
                  params:
         
     | 
| 20 | 
         
            +
                    adm_in_channels: 2560
         
     | 
| 21 | 
         
            +
                    num_classes: sequential
         
     | 
| 22 | 
         
            +
                    use_checkpoint: True
         
     | 
| 23 | 
         
            +
                    in_channels: 4
         
     | 
| 24 | 
         
            +
                    out_channels: 4
         
     | 
| 25 | 
         
            +
                    model_channels: 384
         
     | 
| 26 | 
         
            +
                    attention_resolutions: [4, 2]
         
     | 
| 27 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 28 | 
         
            +
                    channel_mult: [1, 2, 4, 4]
         
     | 
| 29 | 
         
            +
                    num_head_channels: 64
         
     | 
| 30 | 
         
            +
                    use_linear_in_transformer: True
         
     | 
| 31 | 
         
            +
                    transformer_depth: 4
         
     | 
| 32 | 
         
            +
                    context_dim: [1280, 1280, 1280, 1280]
         
     | 
| 33 | 
         
            +
                    spatial_transformer_attn_type: softmax-xformers
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                conditioner_config:
         
     | 
| 36 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 37 | 
         
            +
                  params:
         
     | 
| 38 | 
         
            +
                    emb_models:
         
     | 
| 39 | 
         
            +
                      - is_trainable: False
         
     | 
| 40 | 
         
            +
                        input_key: txt
         
     | 
| 41 | 
         
            +
                        target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2
         
     | 
| 42 | 
         
            +
                        params:
         
     | 
| 43 | 
         
            +
                          arch: ViT-bigG-14
         
     | 
| 44 | 
         
            +
                          version: laion2b_s39b_b160k
         
     | 
| 45 | 
         
            +
                          legacy: False
         
     | 
| 46 | 
         
            +
                          freeze: True
         
     | 
| 47 | 
         
            +
                          layer: penultimate
         
     | 
| 48 | 
         
            +
                          always_return_pooled: True
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                      - is_trainable: False
         
     | 
| 51 | 
         
            +
                        input_key: original_size_as_tuple
         
     | 
| 52 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 53 | 
         
            +
                        params:
         
     | 
| 54 | 
         
            +
                          outdim: 256
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                      - is_trainable: False
         
     | 
| 57 | 
         
            +
                        input_key: crop_coords_top_left
         
     | 
| 58 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 59 | 
         
            +
                        params:
         
     | 
| 60 | 
         
            +
                          outdim: 256
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                      - is_trainable: False
         
     | 
| 63 | 
         
            +
                        input_key: aesthetic_score
         
     | 
| 64 | 
         
            +
                        target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 65 | 
         
            +
                        params:
         
     | 
| 66 | 
         
            +
                          outdim: 256
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                first_stage_config:
         
     | 
| 69 | 
         
            +
                  target: sgm.models.autoencoder.AutoencoderKL
         
     | 
| 70 | 
         
            +
                  params:
         
     | 
| 71 | 
         
            +
                    embed_dim: 4
         
     | 
| 72 | 
         
            +
                    monitor: val/rec_loss
         
     | 
| 73 | 
         
            +
                    ddconfig:
         
     | 
| 74 | 
         
            +
                      attn_type: vanilla-xformers
         
     | 
| 75 | 
         
            +
                      double_z: true
         
     | 
| 76 | 
         
            +
                      z_channels: 4
         
     | 
| 77 | 
         
            +
                      resolution: 256
         
     | 
| 78 | 
         
            +
                      in_channels: 3
         
     | 
| 79 | 
         
            +
                      out_ch: 3
         
     | 
| 80 | 
         
            +
                      ch: 128
         
     | 
| 81 | 
         
            +
                      ch_mult: [1, 2, 4, 4]
         
     | 
| 82 | 
         
            +
                      num_res_blocks: 2
         
     | 
| 83 | 
         
            +
                      attn_resolutions: []
         
     | 
| 84 | 
         
            +
                      dropout: 0.0
         
     | 
| 85 | 
         
            +
                    lossconfig:
         
     | 
| 86 | 
         
            +
                      target: torch.nn.Identity
         
     | 
    	
        configs/inference/svd.yaml
    ADDED
    
    | 
         @@ -0,0 +1,131 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                scale_factor: 0.18215
         
     | 
| 5 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                denoiser_config:
         
     | 
| 8 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.Denoiser
         
     | 
| 9 | 
         
            +
                  params:
         
     | 
| 10 | 
         
            +
                    scaling_config:
         
     | 
| 11 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                network_config:
         
     | 
| 14 | 
         
            +
                  target: sgm.modules.diffusionmodules.video_model.VideoUNet
         
     | 
| 15 | 
         
            +
                  params:
         
     | 
| 16 | 
         
            +
                    adm_in_channels: 768
         
     | 
| 17 | 
         
            +
                    num_classes: sequential
         
     | 
| 18 | 
         
            +
                    use_checkpoint: True
         
     | 
| 19 | 
         
            +
                    in_channels: 8
         
     | 
| 20 | 
         
            +
                    out_channels: 4
         
     | 
| 21 | 
         
            +
                    model_channels: 320
         
     | 
| 22 | 
         
            +
                    attention_resolutions: [4, 2, 1]
         
     | 
| 23 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 24 | 
         
            +
                    channel_mult: [1, 2, 4, 4]
         
     | 
| 25 | 
         
            +
                    num_head_channels: 64
         
     | 
| 26 | 
         
            +
                    use_linear_in_transformer: True
         
     | 
| 27 | 
         
            +
                    transformer_depth: 1
         
     | 
| 28 | 
         
            +
                    context_dim: 1024
         
     | 
| 29 | 
         
            +
                    spatial_transformer_attn_type: softmax-xformers
         
     | 
| 30 | 
         
            +
                    extra_ff_mix_layer: True
         
     | 
| 31 | 
         
            +
                    use_spatial_context: True
         
     | 
| 32 | 
         
            +
                    merge_strategy: learned_with_images
         
     | 
| 33 | 
         
            +
                    video_kernel_size: [3, 1, 1]
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                conditioner_config:
         
     | 
| 36 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 37 | 
         
            +
                  params:
         
     | 
| 38 | 
         
            +
                    emb_models:
         
     | 
| 39 | 
         
            +
                    - is_trainable: False
         
     | 
| 40 | 
         
            +
                      input_key: cond_frames_without_noise
         
     | 
| 41 | 
         
            +
                      target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
         
     | 
| 42 | 
         
            +
                      params:
         
     | 
| 43 | 
         
            +
                        n_cond_frames: 1
         
     | 
| 44 | 
         
            +
                        n_copies: 1
         
     | 
| 45 | 
         
            +
                        open_clip_embedding_config:
         
     | 
| 46 | 
         
            +
                          target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
         
     | 
| 47 | 
         
            +
                          params:
         
     | 
| 48 | 
         
            +
                            freeze: True
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    - input_key: fps_id
         
     | 
| 51 | 
         
            +
                      is_trainable: False
         
     | 
| 52 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 53 | 
         
            +
                      params:
         
     | 
| 54 | 
         
            +
                        outdim: 256
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    - input_key: motion_bucket_id
         
     | 
| 57 | 
         
            +
                      is_trainable: False
         
     | 
| 58 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 59 | 
         
            +
                      params:
         
     | 
| 60 | 
         
            +
                        outdim: 256
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    - input_key: cond_frames
         
     | 
| 63 | 
         
            +
                      is_trainable: False
         
     | 
| 64 | 
         
            +
                      target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
         
     | 
| 65 | 
         
            +
                      params:
         
     | 
| 66 | 
         
            +
                        disable_encoder_autocast: True
         
     | 
| 67 | 
         
            +
                        n_cond_frames: 1
         
     | 
| 68 | 
         
            +
                        n_copies: 1
         
     | 
| 69 | 
         
            +
                        is_ae: True
         
     | 
| 70 | 
         
            +
                        encoder_config:
         
     | 
| 71 | 
         
            +
                          target: sgm.models.autoencoder.AutoencoderKLModeOnly
         
     | 
| 72 | 
         
            +
                          params:
         
     | 
| 73 | 
         
            +
                            embed_dim: 4
         
     | 
| 74 | 
         
            +
                            monitor: val/rec_loss
         
     | 
| 75 | 
         
            +
                            ddconfig:
         
     | 
| 76 | 
         
            +
                              attn_type: vanilla-xformers
         
     | 
| 77 | 
         
            +
                              double_z: True
         
     | 
| 78 | 
         
            +
                              z_channels: 4
         
     | 
| 79 | 
         
            +
                              resolution: 256
         
     | 
| 80 | 
         
            +
                              in_channels: 3
         
     | 
| 81 | 
         
            +
                              out_ch: 3
         
     | 
| 82 | 
         
            +
                              ch: 128
         
     | 
| 83 | 
         
            +
                              ch_mult: [1, 2, 4, 4]
         
     | 
| 84 | 
         
            +
                              num_res_blocks: 2
         
     | 
| 85 | 
         
            +
                              attn_resolutions: []
         
     | 
| 86 | 
         
            +
                              dropout: 0.0
         
     | 
| 87 | 
         
            +
                            lossconfig:
         
     | 
| 88 | 
         
            +
                              target: torch.nn.Identity
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    - input_key: cond_aug
         
     | 
| 91 | 
         
            +
                      is_trainable: False
         
     | 
| 92 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 93 | 
         
            +
                      params:
         
     | 
| 94 | 
         
            +
                        outdim: 256
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                first_stage_config:
         
     | 
| 97 | 
         
            +
                  target: sgm.models.autoencoder.AutoencodingEngine
         
     | 
| 98 | 
         
            +
                  params:
         
     | 
| 99 | 
         
            +
                    loss_config:
         
     | 
| 100 | 
         
            +
                      target: torch.nn.Identity
         
     | 
| 101 | 
         
            +
                    regularizer_config:
         
     | 
| 102 | 
         
            +
                      target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
         
     | 
| 103 | 
         
            +
                    encoder_config: 
         
     | 
| 104 | 
         
            +
                      target: sgm.modules.diffusionmodules.model.Encoder
         
     | 
| 105 | 
         
            +
                      params:
         
     | 
| 106 | 
         
            +
                        attn_type: vanilla
         
     | 
| 107 | 
         
            +
                        double_z: True
         
     | 
| 108 | 
         
            +
                        z_channels: 4
         
     | 
| 109 | 
         
            +
                        resolution: 256
         
     | 
| 110 | 
         
            +
                        in_channels: 3
         
     | 
| 111 | 
         
            +
                        out_ch: 3
         
     | 
| 112 | 
         
            +
                        ch: 128
         
     | 
| 113 | 
         
            +
                        ch_mult: [1, 2, 4, 4]
         
     | 
| 114 | 
         
            +
                        num_res_blocks: 2
         
     | 
| 115 | 
         
            +
                        attn_resolutions: []
         
     | 
| 116 | 
         
            +
                        dropout: 0.0
         
     | 
| 117 | 
         
            +
                    decoder_config:
         
     | 
| 118 | 
         
            +
                      target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
         
     | 
| 119 | 
         
            +
                      params:
         
     | 
| 120 | 
         
            +
                        attn_type: vanilla
         
     | 
| 121 | 
         
            +
                        double_z: True
         
     | 
| 122 | 
         
            +
                        z_channels: 4
         
     | 
| 123 | 
         
            +
                        resolution: 256
         
     | 
| 124 | 
         
            +
                        in_channels: 3
         
     | 
| 125 | 
         
            +
                        out_ch: 3
         
     | 
| 126 | 
         
            +
                        ch: 128
         
     | 
| 127 | 
         
            +
                        ch_mult: [1, 2, 4, 4]
         
     | 
| 128 | 
         
            +
                        num_res_blocks: 2
         
     | 
| 129 | 
         
            +
                        attn_resolutions: []
         
     | 
| 130 | 
         
            +
                        dropout: 0.0
         
     | 
| 131 | 
         
            +
                        video_kernel_size: [3, 1, 1]
         
     | 
    	
        configs/inference/svd_image_decoder.yaml
    ADDED
    
    | 
         @@ -0,0 +1,114 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              target: sgm.models.diffusion.DiffusionEngine
         
     | 
| 3 | 
         
            +
              params:
         
     | 
| 4 | 
         
            +
                scale_factor: 0.18215
         
     | 
| 5 | 
         
            +
                disable_first_stage_autocast: True
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
                denoiser_config:
         
     | 
| 8 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.Denoiser
         
     | 
| 9 | 
         
            +
                  params:
         
     | 
| 10 | 
         
            +
                    scaling_config:
         
     | 
| 11 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
                network_config:
         
     | 
| 14 | 
         
            +
                  target: sgm.modules.diffusionmodules.video_model.VideoUNet
         
     | 
| 15 | 
         
            +
                  params:
         
     | 
| 16 | 
         
            +
                    adm_in_channels: 768
         
     | 
| 17 | 
         
            +
                    num_classes: sequential
         
     | 
| 18 | 
         
            +
                    use_checkpoint: True
         
     | 
| 19 | 
         
            +
                    in_channels: 8
         
     | 
| 20 | 
         
            +
                    out_channels: 4
         
     | 
| 21 | 
         
            +
                    model_channels: 320
         
     | 
| 22 | 
         
            +
                    attention_resolutions: [4, 2, 1]
         
     | 
| 23 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 24 | 
         
            +
                    channel_mult: [1, 2, 4, 4]
         
     | 
| 25 | 
         
            +
                    num_head_channels: 64
         
     | 
| 26 | 
         
            +
                    use_linear_in_transformer: True
         
     | 
| 27 | 
         
            +
                    transformer_depth: 1
         
     | 
| 28 | 
         
            +
                    context_dim: 1024
         
     | 
| 29 | 
         
            +
                    spatial_transformer_attn_type: softmax-xformers
         
     | 
| 30 | 
         
            +
                    extra_ff_mix_layer: True
         
     | 
| 31 | 
         
            +
                    use_spatial_context: True
         
     | 
| 32 | 
         
            +
                    merge_strategy: learned_with_images
         
     | 
| 33 | 
         
            +
                    video_kernel_size: [3, 1, 1]
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                conditioner_config:
         
     | 
| 36 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 37 | 
         
            +
                  params:
         
     | 
| 38 | 
         
            +
                    emb_models:
         
     | 
| 39 | 
         
            +
                    - is_trainable: False
         
     | 
| 40 | 
         
            +
                      input_key: cond_frames_without_noise
         
     | 
| 41 | 
         
            +
                      target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
         
     | 
| 42 | 
         
            +
                      params:
         
     | 
| 43 | 
         
            +
                        n_cond_frames: 1
         
     | 
| 44 | 
         
            +
                        n_copies: 1
         
     | 
| 45 | 
         
            +
                        open_clip_embedding_config:
         
     | 
| 46 | 
         
            +
                          target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
         
     | 
| 47 | 
         
            +
                          params:
         
     | 
| 48 | 
         
            +
                            freeze: True
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    - input_key: fps_id
         
     | 
| 51 | 
         
            +
                      is_trainable: False
         
     | 
| 52 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 53 | 
         
            +
                      params:
         
     | 
| 54 | 
         
            +
                        outdim: 256
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    - input_key: motion_bucket_id
         
     | 
| 57 | 
         
            +
                      is_trainable: False
         
     | 
| 58 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 59 | 
         
            +
                      params:
         
     | 
| 60 | 
         
            +
                        outdim: 256
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                    - input_key: cond_frames
         
     | 
| 63 | 
         
            +
                      is_trainable: False
         
     | 
| 64 | 
         
            +
                      target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
         
     | 
| 65 | 
         
            +
                      params:
         
     | 
| 66 | 
         
            +
                        disable_encoder_autocast: True
         
     | 
| 67 | 
         
            +
                        n_cond_frames: 1
         
     | 
| 68 | 
         
            +
                        n_copies: 1
         
     | 
| 69 | 
         
            +
                        is_ae: True
         
     | 
| 70 | 
         
            +
                        encoder_config:
         
     | 
| 71 | 
         
            +
                          target: sgm.models.autoencoder.AutoencoderKLModeOnly
         
     | 
| 72 | 
         
            +
                          params:
         
     | 
| 73 | 
         
            +
                            embed_dim: 4
         
     | 
| 74 | 
         
            +
                            monitor: val/rec_loss
         
     | 
| 75 | 
         
            +
                            ddconfig:
         
     | 
| 76 | 
         
            +
                              attn_type: vanilla-xformers
         
     | 
| 77 | 
         
            +
                              double_z: True
         
     | 
| 78 | 
         
            +
                              z_channels: 4
         
     | 
| 79 | 
         
            +
                              resolution: 256
         
     | 
| 80 | 
         
            +
                              in_channels: 3
         
     | 
| 81 | 
         
            +
                              out_ch: 3
         
     | 
| 82 | 
         
            +
                              ch: 128
         
     | 
| 83 | 
         
            +
                              ch_mult: [1, 2, 4, 4]
         
     | 
| 84 | 
         
            +
                              num_res_blocks: 2
         
     | 
| 85 | 
         
            +
                              attn_resolutions: []
         
     | 
| 86 | 
         
            +
                              dropout: 0.0
         
     | 
| 87 | 
         
            +
                            lossconfig:
         
     | 
| 88 | 
         
            +
                              target: torch.nn.Identity
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    - input_key: cond_aug
         
     | 
| 91 | 
         
            +
                      is_trainable: False
         
     | 
| 92 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 93 | 
         
            +
                      params:
         
     | 
| 94 | 
         
            +
                        outdim: 256
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                first_stage_config:
         
     | 
| 97 | 
         
            +
                  target: sgm.models.autoencoder.AutoencoderKL
         
     | 
| 98 | 
         
            +
                  params:
         
     | 
| 99 | 
         
            +
                    embed_dim: 4
         
     | 
| 100 | 
         
            +
                    monitor: val/rec_loss
         
     | 
| 101 | 
         
            +
                    ddconfig:
         
     | 
| 102 | 
         
            +
                      attn_type: vanilla-xformers
         
     | 
| 103 | 
         
            +
                      double_z: True
         
     | 
| 104 | 
         
            +
                      z_channels: 4
         
     | 
| 105 | 
         
            +
                      resolution: 256
         
     | 
| 106 | 
         
            +
                      in_channels: 3
         
     | 
| 107 | 
         
            +
                      out_ch: 3
         
     | 
| 108 | 
         
            +
                      ch: 128
         
     | 
| 109 | 
         
            +
                      ch_mult: [1, 2, 4, 4]
         
     | 
| 110 | 
         
            +
                      num_res_blocks: 2
         
     | 
| 111 | 
         
            +
                      attn_resolutions: []
         
     | 
| 112 | 
         
            +
                      dropout: 0.0
         
     | 
| 113 | 
         
            +
                    lossconfig:
         
     | 
| 114 | 
         
            +
                      target: torch.nn.Identity
         
     | 
    	
        configs/inference/svd_mv.yaml
    ADDED
    
    | 
         @@ -0,0 +1,202 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            model:
         
     | 
| 2 | 
         
            +
              base_learning_rate: 1.0e-05
         
     | 
| 3 | 
         
            +
              target: sgm.models.video_diffusion.DiffusionEngine
         
     | 
| 4 | 
         
            +
              params:
         
     | 
| 5 | 
         
            +
                ckpt_path: ckpts/svd_xt.safetensors
         
     | 
| 6 | 
         
            +
                scale_factor: 0.18215
         
     | 
| 7 | 
         
            +
                disable_first_stage_autocast: true
         
     | 
| 8 | 
         
            +
                scheduler_config:
         
     | 
| 9 | 
         
            +
                  target: sgm.lr_scheduler.LambdaLinearScheduler
         
     | 
| 10 | 
         
            +
                  params:
         
     | 
| 11 | 
         
            +
                    warm_up_steps:
         
     | 
| 12 | 
         
            +
                    - 1
         
     | 
| 13 | 
         
            +
                    cycle_lengths:
         
     | 
| 14 | 
         
            +
                    - 10000000000000
         
     | 
| 15 | 
         
            +
                    f_start:
         
     | 
| 16 | 
         
            +
                    - 1.0e-06
         
     | 
| 17 | 
         
            +
                    f_max:
         
     | 
| 18 | 
         
            +
                    - 1.0
         
     | 
| 19 | 
         
            +
                    f_min:
         
     | 
| 20 | 
         
            +
                    - 1.0
         
     | 
| 21 | 
         
            +
                denoiser_config:
         
     | 
| 22 | 
         
            +
                  target: sgm.modules.diffusionmodules.denoiser.Denoiser
         
     | 
| 23 | 
         
            +
                  params:
         
     | 
| 24 | 
         
            +
                    scaling_config:
         
     | 
| 25 | 
         
            +
                      target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
         
     | 
| 26 | 
         
            +
                network_config:
         
     | 
| 27 | 
         
            +
                  target: sgm.modules.diffusionmodules.video_model.VideoUNet
         
     | 
| 28 | 
         
            +
                  params:
         
     | 
| 29 | 
         
            +
                    adm_in_channels: 768
         
     | 
| 30 | 
         
            +
                    num_classes: sequential
         
     | 
| 31 | 
         
            +
                    use_checkpoint: true
         
     | 
| 32 | 
         
            +
                    in_channels: 8
         
     | 
| 33 | 
         
            +
                    out_channels: 4
         
     | 
| 34 | 
         
            +
                    model_channels: 320
         
     | 
| 35 | 
         
            +
                    attention_resolutions:
         
     | 
| 36 | 
         
            +
                    - 4
         
     | 
| 37 | 
         
            +
                    - 2
         
     | 
| 38 | 
         
            +
                    - 1
         
     | 
| 39 | 
         
            +
                    num_res_blocks: 2
         
     | 
| 40 | 
         
            +
                    channel_mult:
         
     | 
| 41 | 
         
            +
                    - 1
         
     | 
| 42 | 
         
            +
                    - 2
         
     | 
| 43 | 
         
            +
                    - 4
         
     | 
| 44 | 
         
            +
                    - 4
         
     | 
| 45 | 
         
            +
                    num_head_channels: 64
         
     | 
| 46 | 
         
            +
                    use_linear_in_transformer: true
         
     | 
| 47 | 
         
            +
                    transformer_depth: 1
         
     | 
| 48 | 
         
            +
                    context_dim: 1024
         
     | 
| 49 | 
         
            +
                    spatial_transformer_attn_type: softmax-xformers
         
     | 
| 50 | 
         
            +
                    extra_ff_mix_layer: true
         
     | 
| 51 | 
         
            +
                    use_spatial_context: true
         
     | 
| 52 | 
         
            +
                    merge_strategy: learned_with_images
         
     | 
| 53 | 
         
            +
                    video_kernel_size:
         
     | 
| 54 | 
         
            +
                    - 3
         
     | 
| 55 | 
         
            +
                    - 1
         
     | 
| 56 | 
         
            +
                    - 1
         
     | 
| 57 | 
         
            +
                conditioner_config:
         
     | 
| 58 | 
         
            +
                  target: sgm.modules.GeneralConditioner
         
     | 
| 59 | 
         
            +
                  params:
         
     | 
| 60 | 
         
            +
                    emb_models:
         
     | 
| 61 | 
         
            +
                    - is_trainable: false
         
     | 
| 62 | 
         
            +
                      ucg_rate: 0.2
         
     | 
| 63 | 
         
            +
                      input_key: cond_frames_without_noise
         
     | 
| 64 | 
         
            +
                      target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
         
     | 
| 65 | 
         
            +
                      params:
         
     | 
| 66 | 
         
            +
                        n_cond_frames: 1
         
     | 
| 67 | 
         
            +
                        n_copies: 1
         
     | 
| 68 | 
         
            +
                        open_clip_embedding_config:
         
     | 
| 69 | 
         
            +
                          target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
         
     | 
| 70 | 
         
            +
                          params:
         
     | 
| 71 | 
         
            +
                            freeze: true
         
     | 
| 72 | 
         
            +
                    - input_key: fps_id
         
     | 
| 73 | 
         
            +
                      is_trainable: true
         
     | 
| 74 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 75 | 
         
            +
                      params:
         
     | 
| 76 | 
         
            +
                        outdim: 256
         
     | 
| 77 | 
         
            +
                    - input_key: motion_bucket_id
         
     | 
| 78 | 
         
            +
                      is_trainable: true
         
     | 
| 79 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 80 | 
         
            +
                      params:
         
     | 
| 81 | 
         
            +
                        outdim: 256
         
     | 
| 82 | 
         
            +
                    - input_key: cond_frames
         
     | 
| 83 | 
         
            +
                      is_trainable: false
         
     | 
| 84 | 
         
            +
                      ucg_rate: 0.2
         
     | 
| 85 | 
         
            +
                      target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
         
     | 
| 86 | 
         
            +
                      params:
         
     | 
| 87 | 
         
            +
                        disable_encoder_autocast: true
         
     | 
| 88 | 
         
            +
                        n_cond_frames: 1
         
     | 
| 89 | 
         
            +
                        n_copies: 1
         
     | 
| 90 | 
         
            +
                        is_ae: true
         
     | 
| 91 | 
         
            +
                        encoder_config:
         
     | 
| 92 | 
         
            +
                          target: sgm.models.autoencoder.AutoencoderKLModeOnly
         
     | 
| 93 | 
         
            +
                          params:
         
     | 
| 94 | 
         
            +
                            embed_dim: 4
         
     | 
| 95 | 
         
            +
                            monitor: val/rec_loss
         
     | 
| 96 | 
         
            +
                            ddconfig:
         
     | 
| 97 | 
         
            +
                              attn_type: vanilla-xformers
         
     | 
| 98 | 
         
            +
                              double_z: true
         
     | 
| 99 | 
         
            +
                              z_channels: 4
         
     | 
| 100 | 
         
            +
                              resolution: 256
         
     | 
| 101 | 
         
            +
                              in_channels: 3
         
     | 
| 102 | 
         
            +
                              out_ch: 3
         
     | 
| 103 | 
         
            +
                              ch: 128
         
     | 
| 104 | 
         
            +
                              ch_mult:
         
     | 
| 105 | 
         
            +
                              - 1
         
     | 
| 106 | 
         
            +
                              - 2
         
     | 
| 107 | 
         
            +
                              - 4
         
     | 
| 108 | 
         
            +
                              - 4
         
     | 
| 109 | 
         
            +
                              num_res_blocks: 2
         
     | 
| 110 | 
         
            +
                              attn_resolutions: []
         
     | 
| 111 | 
         
            +
                              dropout: 0.0
         
     | 
| 112 | 
         
            +
                            lossconfig:
         
     | 
| 113 | 
         
            +
                              target: torch.nn.Identity
         
     | 
| 114 | 
         
            +
                    - input_key: cond_aug
         
     | 
| 115 | 
         
            +
                      is_trainable: true
         
     | 
| 116 | 
         
            +
                      target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
         
     | 
| 117 | 
         
            +
                      params:
         
     | 
| 118 | 
         
            +
                        outdim: 256
         
     | 
| 119 | 
         
            +
                first_stage_config:
         
     | 
| 120 | 
         
            +
                  target: sgm.models.autoencoder.AutoencodingEngine
         
     | 
| 121 | 
         
            +
                  params:
         
     | 
| 122 | 
         
            +
                    loss_config:
         
     | 
| 123 | 
         
            +
                      target: torch.nn.Identity
         
     | 
| 124 | 
         
            +
                    regularizer_config:
         
     | 
| 125 | 
         
            +
                      target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
         
     | 
| 126 | 
         
            +
                    encoder_config:
         
     | 
| 127 | 
         
            +
                      target: sgm.modules.diffusionmodules.model.Encoder
         
     | 
| 128 | 
         
            +
                      params:
         
     | 
| 129 | 
         
            +
                        attn_type: vanilla
         
     | 
| 130 | 
         
            +
                        double_z: true
         
     | 
| 131 | 
         
            +
                        z_channels: 4
         
     | 
| 132 | 
         
            +
                        resolution: 256
         
     | 
| 133 | 
         
            +
                        in_channels: 3
         
     | 
| 134 | 
         
            +
                        out_ch: 3
         
     | 
| 135 | 
         
            +
                        ch: 128
         
     | 
| 136 | 
         
            +
                        ch_mult:
         
     | 
| 137 | 
         
            +
                        - 1
         
     | 
| 138 | 
         
            +
                        - 2
         
     | 
| 139 | 
         
            +
                        - 4
         
     | 
| 140 | 
         
            +
                        - 4
         
     | 
| 141 | 
         
            +
                        num_res_blocks: 2
         
     | 
| 142 | 
         
            +
                        attn_resolutions: []
         
     | 
| 143 | 
         
            +
                        dropout: 0.0
         
     | 
| 144 | 
         
            +
                    decoder_config:
         
     | 
| 145 | 
         
            +
                      target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
         
     | 
| 146 | 
         
            +
                      params:
         
     | 
| 147 | 
         
            +
                        attn_type: vanilla
         
     | 
| 148 | 
         
            +
                        double_z: true
         
     | 
| 149 | 
         
            +
                        z_channels: 4
         
     | 
| 150 | 
         
            +
                        resolution: 256
         
     | 
| 151 | 
         
            +
                        in_channels: 3
         
     | 
| 152 | 
         
            +
                        out_ch: 3
         
     | 
| 153 | 
         
            +
                        ch: 128
         
     | 
| 154 | 
         
            +
                        ch_mult:
         
     | 
| 155 | 
         
            +
                        - 1
         
     | 
| 156 | 
         
            +
                        - 2
         
     | 
| 157 | 
         
            +
                        - 4
         
     | 
| 158 | 
         
            +
                        - 4
         
     | 
| 159 | 
         
            +
                        num_res_blocks: 2
         
     | 
| 160 | 
         
            +
                        attn_resolutions: []
         
     | 
| 161 | 
         
            +
                        dropout: 0.0
         
     | 
| 162 | 
         
            +
                        video_kernel_size:
         
     | 
| 163 | 
         
            +
                        - 3
         
     | 
| 164 | 
         
            +
                        - 1
         
     | 
| 165 | 
         
            +
                        - 1
         
     | 
| 166 | 
         
            +
                sampler_config:
         
     | 
| 167 | 
         
            +
                  target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
         
     | 
| 168 | 
         
            +
                  params:
         
     | 
| 169 | 
         
            +
                    num_steps: 30
         
     | 
| 170 | 
         
            +
                    discretization_config:
         
     | 
| 171 | 
         
            +
                      target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
         
     | 
| 172 | 
         
            +
                      params:
         
     | 
| 173 | 
         
            +
                        sigma_max: 700.0
         
     | 
| 174 | 
         
            +
                    guider_config:
         
     | 
| 175 | 
         
            +
                      target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider
         
     | 
| 176 | 
         
            +
                      params:
         
     | 
| 177 | 
         
            +
                        max_scale: 2.5
         
     | 
| 178 | 
         
            +
                        min_scale: 1.0
         
     | 
| 179 | 
         
            +
                        num_frames: 24
         
     | 
| 180 | 
         
            +
                loss_fn_config:
         
     | 
| 181 | 
         
            +
                  target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
         
     | 
| 182 | 
         
            +
                  params:
         
     | 
| 183 | 
         
            +
                    batch2model_keys:
         
     | 
| 184 | 
         
            +
                    - num_video_frames
         
     | 
| 185 | 
         
            +
                    - image_only_indicator
         
     | 
| 186 | 
         
            +
                    loss_weighting_config:
         
     | 
| 187 | 
         
            +
                      target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting
         
     | 
| 188 | 
         
            +
                      params:
         
     | 
| 189 | 
         
            +
                        sigma_data: 1.0
         
     | 
| 190 | 
         
            +
                    sigma_sampler_config:
         
     | 
| 191 | 
         
            +
                      target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling
         
     | 
| 192 | 
         
            +
                      params:
         
     | 
| 193 | 
         
            +
                        p_mean: 0.3
         
     | 
| 194 | 
         
            +
                        p_std: 1.2
         
     | 
| 195 | 
         
            +
            data:
         
     | 
| 196 | 
         
            +
              target: sgm.data.objaverse.ObjaverseSpiralDataset
         
     | 
| 197 | 
         
            +
              params:
         
     | 
| 198 | 
         
            +
                root_dir: /mnt/mfs/zilong.chen/Downloads/objaverse-ndd-samples
         
     | 
| 199 | 
         
            +
                random_front: true
         
     | 
| 200 | 
         
            +
                batch_size: 2
         
     | 
| 201 | 
         
            +
                num_workers: 16
         
     | 
| 202 | 
         
            +
                cond_aug_mean: -0.0
         
     | 
    	
        mesh_recon/configs/neuralangelo-ortho-wmask.yaml
    ADDED
    
    | 
         @@ -0,0 +1,145 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            name: ${basename:${dataset.scene}}
         
     | 
| 2 | 
         
            +
            tag: ""
         
     | 
| 3 | 
         
            +
            seed: 42
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            dataset:
         
     | 
| 6 | 
         
            +
              name: ortho
         
     | 
| 7 | 
         
            +
              root_dir: /home/xiaoxiao/Workplace/wonder3Dplus/outputs/joint-twice/aigc/cropsize-224-cfg1.0
         
     | 
| 8 | 
         
            +
              cam_pose_dir: null
         
     | 
| 9 | 
         
            +
              scene: scene_name
         
     | 
| 10 | 
         
            +
              imSize: [1024, 1024]  # should use larger res, otherwise the exported mesh has wrong colors
         
     | 
| 11 | 
         
            +
              camera_type: ortho
         
     | 
| 12 | 
         
            +
              apply_mask: true
         
     | 
| 13 | 
         
            +
              camera_params: null
         
     | 
| 14 | 
         
            +
              view_weights: [1.0, 0.8, 0.2, 1.0, 0.4, 0.7]  #['front', 'front_right', 'right', 'back', 'left', 'front_left']
         
     | 
| 15 | 
         
            +
              # view_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            model:
         
     | 
| 18 | 
         
            +
              name: neus
         
     | 
| 19 | 
         
            +
              radius: 1.0
         
     | 
| 20 | 
         
            +
              num_samples_per_ray: 1024
         
     | 
| 21 | 
         
            +
              train_num_rays: 256
         
     | 
| 22 | 
         
            +
              max_train_num_rays: 8192
         
     | 
| 23 | 
         
            +
              grid_prune: true
         
     | 
| 24 | 
         
            +
              grid_prune_occ_thre: 0.001
         
     | 
| 25 | 
         
            +
              dynamic_ray_sampling: true
         
     | 
| 26 | 
         
            +
              batch_image_sampling: true
         
     | 
| 27 | 
         
            +
              randomized: true
         
     | 
| 28 | 
         
            +
              ray_chunk: 2048
         
     | 
| 29 | 
         
            +
              cos_anneal_end: 20000
         
     | 
| 30 | 
         
            +
              learned_background: false
         
     | 
| 31 | 
         
            +
              background_color: black
         
     | 
| 32 | 
         
            +
              variance:
         
     | 
| 33 | 
         
            +
                init_val: 0.3
         
     | 
| 34 | 
         
            +
                modulate: false
         
     | 
| 35 | 
         
            +
              geometry:
         
     | 
| 36 | 
         
            +
                name: volume-sdf
         
     | 
| 37 | 
         
            +
                radius: ${model.radius}
         
     | 
| 38 | 
         
            +
                feature_dim: 13
         
     | 
| 39 | 
         
            +
                grad_type: finite_difference
         
     | 
| 40 | 
         
            +
                finite_difference_eps: progressive
         
     | 
| 41 | 
         
            +
                isosurface:
         
     | 
| 42 | 
         
            +
                  method: mc
         
     | 
| 43 | 
         
            +
                  resolution: 192
         
     | 
| 44 | 
         
            +
                  chunk: 2097152
         
     | 
| 45 | 
         
            +
                  threshold: 0.
         
     | 
| 46 | 
         
            +
                xyz_encoding_config:
         
     | 
| 47 | 
         
            +
                  otype: ProgressiveBandHashGrid
         
     | 
| 48 | 
         
            +
                  n_levels: 10 # 12 modify
         
     | 
| 49 | 
         
            +
                  n_features_per_level: 2
         
     | 
| 50 | 
         
            +
                  log2_hashmap_size: 19
         
     | 
| 51 | 
         
            +
                  base_resolution: 32
         
     | 
| 52 | 
         
            +
                  per_level_scale: 1.3195079107728942
         
     | 
| 53 | 
         
            +
                  include_xyz: true
         
     | 
| 54 | 
         
            +
                  start_level: 4
         
     | 
| 55 | 
         
            +
                  start_step: 0
         
     | 
| 56 | 
         
            +
                  update_steps: 1000
         
     | 
| 57 | 
         
            +
                mlp_network_config:
         
     | 
| 58 | 
         
            +
                  otype: VanillaMLP
         
     | 
| 59 | 
         
            +
                  activation: ReLU
         
     | 
| 60 | 
         
            +
                  output_activation: none
         
     | 
| 61 | 
         
            +
                  n_neurons: 64
         
     | 
| 62 | 
         
            +
                  n_hidden_layers: 1
         
     | 
| 63 | 
         
            +
                  sphere_init: true
         
     | 
| 64 | 
         
            +
                  sphere_init_radius: 0.5
         
     | 
| 65 | 
         
            +
                  weight_norm: true
         
     | 
| 66 | 
         
            +
              texture:
         
     | 
| 67 | 
         
            +
                name: volume-radiance
         
     | 
| 68 | 
         
            +
                input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
         
     | 
| 69 | 
         
            +
                dir_encoding_config:
         
     | 
| 70 | 
         
            +
                  otype: SphericalHarmonics
         
     | 
| 71 | 
         
            +
                  degree: 4
         
     | 
| 72 | 
         
            +
                mlp_network_config:
         
     | 
| 73 | 
         
            +
                  otype: VanillaMLP
         
     | 
| 74 | 
         
            +
                  activation: ReLU
         
     | 
| 75 | 
         
            +
                  output_activation: none
         
     | 
| 76 | 
         
            +
                  n_neurons: 64
         
     | 
| 77 | 
         
            +
                  n_hidden_layers: 2
         
     | 
| 78 | 
         
            +
                color_activation: sigmoid
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
            system:
         
     | 
| 81 | 
         
            +
              name: ortho-neus-system
         
     | 
| 82 | 
         
            +
              loss:
         
     | 
| 83 | 
         
            +
                lambda_rgb_mse: 0.5
         
     | 
| 84 | 
         
            +
                lambda_rgb_l1: 0.
         
     | 
| 85 | 
         
            +
                lambda_mask: 1.0
         
     | 
| 86 | 
         
            +
                lambda_eikonal: 0.2  # cannot be too large, will cause holes to thin objects
         
     | 
| 87 | 
         
            +
                lambda_normal: 1.0  # cannot be too large
         
     | 
| 88 | 
         
            +
                lambda_3d_normal_smooth: 1.0
         
     | 
| 89 | 
         
            +
                # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
         
     | 
| 90 | 
         
            +
                lambda_curvature: 0.
         
     | 
| 91 | 
         
            +
                lambda_sparsity: 0.5
         
     | 
| 92 | 
         
            +
                lambda_distortion: 0.0
         
     | 
| 93 | 
         
            +
                lambda_distortion_bg: 0.0
         
     | 
| 94 | 
         
            +
                lambda_opaque: 0.0
         
     | 
| 95 | 
         
            +
                sparsity_scale: 100.0
         
     | 
| 96 | 
         
            +
                geo_aware: true
         
     | 
| 97 | 
         
            +
                rgb_p_ratio: 0.8
         
     | 
| 98 | 
         
            +
                normal_p_ratio: 0.8
         
     | 
| 99 | 
         
            +
                mask_p_ratio: 0.9
         
     | 
| 100 | 
         
            +
              optimizer:
         
     | 
| 101 | 
         
            +
                name: AdamW
         
     | 
| 102 | 
         
            +
                args:
         
     | 
| 103 | 
         
            +
                  lr: 0.01
         
     | 
| 104 | 
         
            +
                  betas: [0.9, 0.99]
         
     | 
| 105 | 
         
            +
                  eps: 1.e-15
         
     | 
| 106 | 
         
            +
                params:
         
     | 
| 107 | 
         
            +
                  geometry:
         
     | 
| 108 | 
         
            +
                    lr: 0.001
         
     | 
| 109 | 
         
            +
                  texture:
         
     | 
| 110 | 
         
            +
                    lr: 0.01
         
     | 
| 111 | 
         
            +
                  variance:
         
     | 
| 112 | 
         
            +
                    lr: 0.001
         
     | 
| 113 | 
         
            +
              constant_steps: 500
         
     | 
| 114 | 
         
            +
              scheduler:
         
     | 
| 115 | 
         
            +
                name: SequentialLR
         
     | 
| 116 | 
         
            +
                interval: step
         
     | 
| 117 | 
         
            +
                milestones:
         
     | 
| 118 | 
         
            +
                  - ${system.constant_steps}
         
     | 
| 119 | 
         
            +
                schedulers:
         
     | 
| 120 | 
         
            +
                  - name: ConstantLR
         
     | 
| 121 | 
         
            +
                    args:
         
     | 
| 122 | 
         
            +
                      factor: 1.0
         
     | 
| 123 | 
         
            +
                      total_iters: ${system.constant_steps}
         
     | 
| 124 | 
         
            +
                  - name: ExponentialLR
         
     | 
| 125 | 
         
            +
                    args:
         
     | 
| 126 | 
         
            +
                      gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
            checkpoint:
         
     | 
| 129 | 
         
            +
              save_top_k: -1
         
     | 
| 130 | 
         
            +
              every_n_train_steps: ${trainer.max_steps}
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            export:
         
     | 
| 133 | 
         
            +
              chunk_size: 2097152
         
     | 
| 134 | 
         
            +
              export_vertex_color: True
         
     | 
| 135 | 
         
            +
              ortho_scale: 1.35   #modify
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
            trainer:
         
     | 
| 138 | 
         
            +
              max_steps: 3000
         
     | 
| 139 | 
         
            +
              log_every_n_steps: 100
         
     | 
| 140 | 
         
            +
              num_sanity_val_steps: 0
         
     | 
| 141 | 
         
            +
              val_check_interval: 4000
         
     | 
| 142 | 
         
            +
              limit_train_batches: 1.0
         
     | 
| 143 | 
         
            +
              limit_val_batches: 2
         
     | 
| 144 | 
         
            +
              enable_progress_bar: true
         
     | 
| 145 | 
         
            +
              precision: 16
         
     | 
    	
        mesh_recon/configs/v3d.yaml
    ADDED
    
    | 
         @@ -0,0 +1,144 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            name: ${basename:${dataset.scene}}
         
     | 
| 2 | 
         
            +
            tag: ""
         
     | 
| 3 | 
         
            +
            seed: 42
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            dataset:
         
     | 
| 6 | 
         
            +
              name: v3d
         
     | 
| 7 | 
         
            +
              root_dir: ./spirals
         
     | 
| 8 | 
         
            +
              cam_pose_dir: null
         
     | 
| 9 | 
         
            +
              scene: pizza_man
         
     | 
| 10 | 
         
            +
              apply_mask: true
         
     | 
| 11 | 
         
            +
              train_split: train
         
     | 
| 12 | 
         
            +
              test_split: train
         
     | 
| 13 | 
         
            +
              val_split: train
         
     | 
| 14 | 
         
            +
              img_wh: [1024, 1024]
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            model:
         
     | 
| 17 | 
         
            +
              name: neus
         
     | 
| 18 | 
         
            +
              radius: 1.0 ## check this
         
     | 
| 19 | 
         
            +
              num_samples_per_ray: 1024
         
     | 
| 20 | 
         
            +
              train_num_rays: 256
         
     | 
| 21 | 
         
            +
              max_train_num_rays: 8192
         
     | 
| 22 | 
         
            +
              grid_prune: true
         
     | 
| 23 | 
         
            +
              grid_prune_occ_thre: 0.001
         
     | 
| 24 | 
         
            +
              dynamic_ray_sampling: true
         
     | 
| 25 | 
         
            +
              batch_image_sampling: true
         
     | 
| 26 | 
         
            +
              randomized: true
         
     | 
| 27 | 
         
            +
              ray_chunk: 2048
         
     | 
| 28 | 
         
            +
              cos_anneal_end: 20000
         
     | 
| 29 | 
         
            +
              learned_background: false
         
     | 
| 30 | 
         
            +
              background_color: black
         
     | 
| 31 | 
         
            +
              variance:
         
     | 
| 32 | 
         
            +
                init_val: 0.3
         
     | 
| 33 | 
         
            +
                modulate: false
         
     | 
| 34 | 
         
            +
              geometry:
         
     | 
| 35 | 
         
            +
                name: volume-sdf
         
     | 
| 36 | 
         
            +
                radius: ${model.radius}
         
     | 
| 37 | 
         
            +
                feature_dim: 13
         
     | 
| 38 | 
         
            +
                grad_type: finite_difference
         
     | 
| 39 | 
         
            +
                finite_difference_eps: progressive
         
     | 
| 40 | 
         
            +
                isosurface:
         
     | 
| 41 | 
         
            +
                  method: mc
         
     | 
| 42 | 
         
            +
                  resolution: 384
         
     | 
| 43 | 
         
            +
                  chunk: 2097152
         
     | 
| 44 | 
         
            +
                  threshold: 0.
         
     | 
| 45 | 
         
            +
                xyz_encoding_config:
         
     | 
| 46 | 
         
            +
                  otype: ProgressiveBandHashGrid
         
     | 
| 47 | 
         
            +
                  n_levels: 10 # 12 modify
         
     | 
| 48 | 
         
            +
                  n_features_per_level: 2
         
     | 
| 49 | 
         
            +
                  log2_hashmap_size: 19
         
     | 
| 50 | 
         
            +
                  base_resolution: 32
         
     | 
| 51 | 
         
            +
                  per_level_scale: 1.3195079107728942
         
     | 
| 52 | 
         
            +
                  include_xyz: true
         
     | 
| 53 | 
         
            +
                  start_level: 4
         
     | 
| 54 | 
         
            +
                  start_step: 0
         
     | 
| 55 | 
         
            +
                  update_steps: 1000
         
     | 
| 56 | 
         
            +
                mlp_network_config:
         
     | 
| 57 | 
         
            +
                  otype: VanillaMLP
         
     | 
| 58 | 
         
            +
                  activation: ReLU
         
     | 
| 59 | 
         
            +
                  output_activation: none
         
     | 
| 60 | 
         
            +
                  n_neurons: 64
         
     | 
| 61 | 
         
            +
                  n_hidden_layers: 1
         
     | 
| 62 | 
         
            +
                  sphere_init: true
         
     | 
| 63 | 
         
            +
                  sphere_init_radius: 0.5
         
     | 
| 64 | 
         
            +
                  weight_norm: true
         
     | 
| 65 | 
         
            +
              texture:
         
     | 
| 66 | 
         
            +
                name: volume-radiance
         
     | 
| 67 | 
         
            +
                input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
         
     | 
| 68 | 
         
            +
                dir_encoding_config:
         
     | 
| 69 | 
         
            +
                  otype: SphericalHarmonics
         
     | 
| 70 | 
         
            +
                  degree: 4
         
     | 
| 71 | 
         
            +
                mlp_network_config:
         
     | 
| 72 | 
         
            +
                  otype: VanillaMLP
         
     | 
| 73 | 
         
            +
                  activation: ReLU
         
     | 
| 74 | 
         
            +
                  output_activation: none
         
     | 
| 75 | 
         
            +
                  n_neurons: 64
         
     | 
| 76 | 
         
            +
                  n_hidden_layers: 2
         
     | 
| 77 | 
         
            +
                color_activation: sigmoid
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            system:
         
     | 
| 80 | 
         
            +
              name: videonvs-neus-system
         
     | 
| 81 | 
         
            +
              loss:
         
     | 
| 82 | 
         
            +
                lambda_rgb_mse: 0.5
         
     | 
| 83 | 
         
            +
                lambda_rgb_l1: 0.
         
     | 
| 84 | 
         
            +
                lambda_mask: 1.0
         
     | 
| 85 | 
         
            +
                lambda_eikonal: 0.2  # cannot be too large, will cause holes to thin objects
         
     | 
| 86 | 
         
            +
                lambda_normal: 0.0  # cannot be too large
         
     | 
| 87 | 
         
            +
                lambda_3d_normal_smooth: 1.0
         
     | 
| 88 | 
         
            +
                # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
         
     | 
| 89 | 
         
            +
                lambda_curvature: 0.
         
     | 
| 90 | 
         
            +
                lambda_sparsity: 0.5
         
     | 
| 91 | 
         
            +
                lambda_distortion: 0.0
         
     | 
| 92 | 
         
            +
                lambda_distortion_bg: 0.0
         
     | 
| 93 | 
         
            +
                lambda_opaque: 0.0
         
     | 
| 94 | 
         
            +
                sparsity_scale: 100.0
         
     | 
| 95 | 
         
            +
                geo_aware: true
         
     | 
| 96 | 
         
            +
                rgb_p_ratio: 0.8
         
     | 
| 97 | 
         
            +
                normal_p_ratio: 0.8
         
     | 
| 98 | 
         
            +
                mask_p_ratio: 0.9
         
     | 
| 99 | 
         
            +
              optimizer:
         
     | 
| 100 | 
         
            +
                name: AdamW
         
     | 
| 101 | 
         
            +
                args:
         
     | 
| 102 | 
         
            +
                  lr: 0.01
         
     | 
| 103 | 
         
            +
                  betas: [0.9, 0.99]
         
     | 
| 104 | 
         
            +
                  eps: 1.e-15
         
     | 
| 105 | 
         
            +
                params:
         
     | 
| 106 | 
         
            +
                  geometry:
         
     | 
| 107 | 
         
            +
                    lr: 0.001
         
     | 
| 108 | 
         
            +
                  texture:
         
     | 
| 109 | 
         
            +
                    lr: 0.01
         
     | 
| 110 | 
         
            +
                  variance:
         
     | 
| 111 | 
         
            +
                    lr: 0.001
         
     | 
| 112 | 
         
            +
              constant_steps: 500
         
     | 
| 113 | 
         
            +
              scheduler:
         
     | 
| 114 | 
         
            +
                name: SequentialLR
         
     | 
| 115 | 
         
            +
                interval: step
         
     | 
| 116 | 
         
            +
                milestones:
         
     | 
| 117 | 
         
            +
                  - ${system.constant_steps}
         
     | 
| 118 | 
         
            +
                schedulers:
         
     | 
| 119 | 
         
            +
                  - name: ConstantLR
         
     | 
| 120 | 
         
            +
                    args:
         
     | 
| 121 | 
         
            +
                      factor: 1.0
         
     | 
| 122 | 
         
            +
                      total_iters: ${system.constant_steps}
         
     | 
| 123 | 
         
            +
                  - name: ExponentialLR
         
     | 
| 124 | 
         
            +
                    args:
         
     | 
| 125 | 
         
            +
                      gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            checkpoint:
         
     | 
| 128 | 
         
            +
              save_top_k: -1
         
     | 
| 129 | 
         
            +
              every_n_train_steps: ${trainer.max_steps}
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            export:
         
     | 
| 132 | 
         
            +
              chunk_size: 2097152
         
     | 
| 133 | 
         
            +
              export_vertex_color: True
         
     | 
| 134 | 
         
            +
              ortho_scale: null   #modify
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            trainer:
         
     | 
| 137 | 
         
            +
              max_steps: 3000
         
     | 
| 138 | 
         
            +
              log_every_n_steps: 100
         
     | 
| 139 | 
         
            +
              num_sanity_val_steps: 0
         
     | 
| 140 | 
         
            +
              val_check_interval: 3000
         
     | 
| 141 | 
         
            +
              limit_train_batches: 1.0
         
     | 
| 142 | 
         
            +
              limit_val_batches: 2
         
     | 
| 143 | 
         
            +
              enable_progress_bar: true
         
     | 
| 144 | 
         
            +
              precision: 16
         
     | 
    	
        mesh_recon/configs/videonvs.yaml
    ADDED
    
    | 
         @@ -0,0 +1,144 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            name: ${basename:${dataset.scene}}
         
     | 
| 2 | 
         
            +
            tag: ""
         
     | 
| 3 | 
         
            +
            seed: 42
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            dataset:
         
     | 
| 6 | 
         
            +
              name: videonvs
         
     | 
| 7 | 
         
            +
              root_dir: ./spirals
         
     | 
| 8 | 
         
            +
              cam_pose_dir: null
         
     | 
| 9 | 
         
            +
              scene: pizza_man
         
     | 
| 10 | 
         
            +
              apply_mask: true
         
     | 
| 11 | 
         
            +
              train_split: train
         
     | 
| 12 | 
         
            +
              test_split: train
         
     | 
| 13 | 
         
            +
              val_split: train
         
     | 
| 14 | 
         
            +
              img_wh: [1024, 1024]
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            model:
         
     | 
| 17 | 
         
            +
              name: neus
         
     | 
| 18 | 
         
            +
              radius: 1.0 ## check this
         
     | 
| 19 | 
         
            +
              num_samples_per_ray: 1024
         
     | 
| 20 | 
         
            +
              train_num_rays: 256
         
     | 
| 21 | 
         
            +
              max_train_num_rays: 8192
         
     | 
| 22 | 
         
            +
              grid_prune: true
         
     | 
| 23 | 
         
            +
              grid_prune_occ_thre: 0.001
         
     | 
| 24 | 
         
            +
              dynamic_ray_sampling: true
         
     | 
| 25 | 
         
            +
              batch_image_sampling: true
         
     | 
| 26 | 
         
            +
              randomized: true
         
     | 
| 27 | 
         
            +
              ray_chunk: 2048
         
     | 
| 28 | 
         
            +
              cos_anneal_end: 20000
         
     | 
| 29 | 
         
            +
              learned_background: false
         
     | 
| 30 | 
         
            +
              background_color: black
         
     | 
| 31 | 
         
            +
              variance:
         
     | 
| 32 | 
         
            +
                init_val: 0.3
         
     | 
| 33 | 
         
            +
                modulate: false
         
     | 
| 34 | 
         
            +
              geometry:
         
     | 
| 35 | 
         
            +
                name: volume-sdf
         
     | 
| 36 | 
         
            +
                radius: ${model.radius}
         
     | 
| 37 | 
         
            +
                feature_dim: 13
         
     | 
| 38 | 
         
            +
                grad_type: finite_difference
         
     | 
| 39 | 
         
            +
                finite_difference_eps: progressive
         
     | 
| 40 | 
         
            +
                isosurface:
         
     | 
| 41 | 
         
            +
                  method: mc
         
     | 
| 42 | 
         
            +
                  resolution: 384
         
     | 
| 43 | 
         
            +
                  chunk: 2097152
         
     | 
| 44 | 
         
            +
                  threshold: 0.
         
     | 
| 45 | 
         
            +
                xyz_encoding_config:
         
     | 
| 46 | 
         
            +
                  otype: ProgressiveBandHashGrid
         
     | 
| 47 | 
         
            +
                  n_levels: 10 # 12 modify
         
     | 
| 48 | 
         
            +
                  n_features_per_level: 2
         
     | 
| 49 | 
         
            +
                  log2_hashmap_size: 19
         
     | 
| 50 | 
         
            +
                  base_resolution: 32
         
     | 
| 51 | 
         
            +
                  per_level_scale: 1.3195079107728942
         
     | 
| 52 | 
         
            +
                  include_xyz: true
         
     | 
| 53 | 
         
            +
                  start_level: 4
         
     | 
| 54 | 
         
            +
                  start_step: 0
         
     | 
| 55 | 
         
            +
                  update_steps: 1000
         
     | 
| 56 | 
         
            +
                mlp_network_config:
         
     | 
| 57 | 
         
            +
                  otype: VanillaMLP
         
     | 
| 58 | 
         
            +
                  activation: ReLU
         
     | 
| 59 | 
         
            +
                  output_activation: none
         
     | 
| 60 | 
         
            +
                  n_neurons: 64
         
     | 
| 61 | 
         
            +
                  n_hidden_layers: 1
         
     | 
| 62 | 
         
            +
                  sphere_init: true
         
     | 
| 63 | 
         
            +
                  sphere_init_radius: 0.5
         
     | 
| 64 | 
         
            +
                  weight_norm: true
         
     | 
| 65 | 
         
            +
              texture:
         
     | 
| 66 | 
         
            +
                name: volume-radiance
         
     | 
| 67 | 
         
            +
                input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input
         
     | 
| 68 | 
         
            +
                dir_encoding_config:
         
     | 
| 69 | 
         
            +
                  otype: SphericalHarmonics
         
     | 
| 70 | 
         
            +
                  degree: 4
         
     | 
| 71 | 
         
            +
                mlp_network_config:
         
     | 
| 72 | 
         
            +
                  otype: VanillaMLP
         
     | 
| 73 | 
         
            +
                  activation: ReLU
         
     | 
| 74 | 
         
            +
                  output_activation: none
         
     | 
| 75 | 
         
            +
                  n_neurons: 64
         
     | 
| 76 | 
         
            +
                  n_hidden_layers: 2
         
     | 
| 77 | 
         
            +
                color_activation: sigmoid
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            system:
         
     | 
| 80 | 
         
            +
              name: videonvs-neus-system
         
     | 
| 81 | 
         
            +
              loss:
         
     | 
| 82 | 
         
            +
                lambda_rgb_mse: 0.5
         
     | 
| 83 | 
         
            +
                lambda_rgb_l1: 0.
         
     | 
| 84 | 
         
            +
                lambda_mask: 1.0
         
     | 
| 85 | 
         
            +
                lambda_eikonal: 0.2  # cannot be too large, will cause holes to thin objects
         
     | 
| 86 | 
         
            +
                lambda_normal: 1.0  # cannot be too large
         
     | 
| 87 | 
         
            +
                lambda_3d_normal_smooth: 1.0
         
     | 
| 88 | 
         
            +
                # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup
         
     | 
| 89 | 
         
            +
                lambda_curvature: 0.
         
     | 
| 90 | 
         
            +
                lambda_sparsity: 0.5
         
     | 
| 91 | 
         
            +
                lambda_distortion: 0.0
         
     | 
| 92 | 
         
            +
                lambda_distortion_bg: 0.0
         
     | 
| 93 | 
         
            +
                lambda_opaque: 0.0
         
     | 
| 94 | 
         
            +
                sparsity_scale: 100.0
         
     | 
| 95 | 
         
            +
                geo_aware: true
         
     | 
| 96 | 
         
            +
                rgb_p_ratio: 0.8
         
     | 
| 97 | 
         
            +
                normal_p_ratio: 0.8
         
     | 
| 98 | 
         
            +
                mask_p_ratio: 0.9
         
     | 
| 99 | 
         
            +
              optimizer:
         
     | 
| 100 | 
         
            +
                name: AdamW
         
     | 
| 101 | 
         
            +
                args:
         
     | 
| 102 | 
         
            +
                  lr: 0.01
         
     | 
| 103 | 
         
            +
                  betas: [0.9, 0.99]
         
     | 
| 104 | 
         
            +
                  eps: 1.e-15
         
     | 
| 105 | 
         
            +
                params:
         
     | 
| 106 | 
         
            +
                  geometry:
         
     | 
| 107 | 
         
            +
                    lr: 0.001
         
     | 
| 108 | 
         
            +
                  texture:
         
     | 
| 109 | 
         
            +
                    lr: 0.01
         
     | 
| 110 | 
         
            +
                  variance:
         
     | 
| 111 | 
         
            +
                    lr: 0.001
         
     | 
| 112 | 
         
            +
              constant_steps: 500
         
     | 
| 113 | 
         
            +
              scheduler:
         
     | 
| 114 | 
         
            +
                name: SequentialLR
         
     | 
| 115 | 
         
            +
                interval: step
         
     | 
| 116 | 
         
            +
                milestones:
         
     | 
| 117 | 
         
            +
                  - ${system.constant_steps}
         
     | 
| 118 | 
         
            +
                schedulers:
         
     | 
| 119 | 
         
            +
                  - name: ConstantLR
         
     | 
| 120 | 
         
            +
                    args:
         
     | 
| 121 | 
         
            +
                      factor: 1.0
         
     | 
| 122 | 
         
            +
                      total_iters: ${system.constant_steps}
         
     | 
| 123 | 
         
            +
                  - name: ExponentialLR
         
     | 
| 124 | 
         
            +
                    args:
         
     | 
| 125 | 
         
            +
                      gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}}
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
            checkpoint:
         
     | 
| 128 | 
         
            +
              save_top_k: -1
         
     | 
| 129 | 
         
            +
              every_n_train_steps: ${trainer.max_steps}
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            export:
         
     | 
| 132 | 
         
            +
              chunk_size: 2097152
         
     | 
| 133 | 
         
            +
              export_vertex_color: True
         
     | 
| 134 | 
         
            +
              ortho_scale: null   #modify
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            trainer:
         
     | 
| 137 | 
         
            +
              max_steps: 3000
         
     | 
| 138 | 
         
            +
              log_every_n_steps: 100
         
     | 
| 139 | 
         
            +
              num_sanity_val_steps: 0
         
     | 
| 140 | 
         
            +
              val_check_interval: 3000
         
     | 
| 141 | 
         
            +
              limit_train_batches: 1.0
         
     | 
| 142 | 
         
            +
              limit_val_batches: 2
         
     | 
| 143 | 
         
            +
              enable_progress_bar: true
         
     | 
| 144 | 
         
            +
              precision: 16
         
     | 
    	
        mesh_recon/datasets/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,17 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            datasets = {}
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            def register(name):
         
     | 
| 5 | 
         
            +
                def decorator(cls):
         
     | 
| 6 | 
         
            +
                    datasets[name] = cls
         
     | 
| 7 | 
         
            +
                    return cls
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                return decorator
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def make(name, config):
         
     | 
| 13 | 
         
            +
                dataset = datasets[name](config)
         
     | 
| 14 | 
         
            +
                return dataset
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            from . import blender, colmap, dtu, ortho, videonvs, videonvs_co3d, v3d
         
     | 
    	
        mesh_recon/datasets/blender.py
    ADDED
    
    | 
         @@ -0,0 +1,143 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, IterableDataset
         
     | 
| 9 | 
         
            +
            import torchvision.transforms.functional as TF
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import datasets
         
     | 
| 14 | 
         
            +
            from models.ray_utils import get_ray_directions
         
     | 
| 15 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            class BlenderDatasetBase:
         
     | 
| 19 | 
         
            +
                def setup(self, config, split):
         
     | 
| 20 | 
         
            +
                    self.config = config
         
     | 
| 21 | 
         
            +
                    self.split = split
         
     | 
| 22 | 
         
            +
                    self.rank = get_rank()
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    self.has_mask = True
         
     | 
| 25 | 
         
            +
                    self.apply_mask = True
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                    with open(
         
     | 
| 28 | 
         
            +
                        os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), "r"
         
     | 
| 29 | 
         
            +
                    ) as f:
         
     | 
| 30 | 
         
            +
                        meta = json.load(f)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    if "w" in meta and "h" in meta:
         
     | 
| 33 | 
         
            +
                        W, H = int(meta["w"]), int(meta["h"])
         
     | 
| 34 | 
         
            +
                    else:
         
     | 
| 35 | 
         
            +
                        W, H = 800, 800
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                    if "img_wh" in self.config:
         
     | 
| 38 | 
         
            +
                        w, h = self.config.img_wh
         
     | 
| 39 | 
         
            +
                        assert round(W / w * h) == H
         
     | 
| 40 | 
         
            +
                    elif "img_downscale" in self.config:
         
     | 
| 41 | 
         
            +
                        w, h = W // self.config.img_downscale, H // self.config.img_downscale
         
     | 
| 42 | 
         
            +
                    else:
         
     | 
| 43 | 
         
            +
                        raise KeyError("Either img_wh or img_downscale should be specified.")
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                    self.w, self.h = w, h
         
     | 
| 46 | 
         
            +
                    self.img_wh = (self.w, self.h)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                    self.near, self.far = self.config.near_plane, self.config.far_plane
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                    self.focal = (
         
     | 
| 51 | 
         
            +
                        0.5 * w / math.tan(0.5 * meta["camera_angle_x"])
         
     | 
| 52 | 
         
            +
                    )  # scaled focal length
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    # ray directions for all pixels, same for all images (same H, W, focal)
         
     | 
| 55 | 
         
            +
                    self.directions = get_ray_directions(
         
     | 
| 56 | 
         
            +
                        self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
         
     | 
| 57 | 
         
            +
                    ).to(
         
     | 
| 58 | 
         
            +
                        self.rank
         
     | 
| 59 | 
         
            +
                    )  # (h, w, 3)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                    for i, frame in enumerate(meta["frames"]):
         
     | 
| 64 | 
         
            +
                        c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
         
     | 
| 65 | 
         
            +
                        self.all_c2w.append(c2w)
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                        img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png")
         
     | 
| 68 | 
         
            +
                        img = Image.open(img_path)
         
     | 
| 69 | 
         
            +
                        img = img.resize(self.img_wh, Image.BICUBIC)
         
     | 
| 70 | 
         
            +
                        img = TF.to_tensor(img).permute(1, 2, 0)  # (4, h, w) => (h, w, 4)
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                        self.all_fg_masks.append(img[..., -1])  # (h, w)
         
     | 
| 73 | 
         
            +
                        self.all_images.append(img[..., :3])
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = (
         
     | 
| 76 | 
         
            +
                        torch.stack(self.all_c2w, dim=0).float().to(self.rank),
         
     | 
| 77 | 
         
            +
                        torch.stack(self.all_images, dim=0).float().to(self.rank),
         
     | 
| 78 | 
         
            +
                        torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
         
     | 
| 79 | 
         
            +
                    )
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            class BlenderDataset(Dataset, BlenderDatasetBase):
         
     | 
| 83 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 84 | 
         
            +
                    self.setup(config, split)
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                def __len__(self):
         
     | 
| 87 | 
         
            +
                    return len(self.all_images)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 90 | 
         
            +
                    return {"index": index}
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
         
     | 
| 94 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 95 | 
         
            +
                    self.setup(config, split)
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                def __iter__(self):
         
     | 
| 98 | 
         
            +
                    while True:
         
     | 
| 99 | 
         
            +
                        yield {}
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
            @datasets.register("blender")
         
     | 
| 103 | 
         
            +
            class VideoNVSDataModule(pl.LightningDataModule):
         
     | 
| 104 | 
         
            +
                def __init__(self, config):
         
     | 
| 105 | 
         
            +
                    super().__init__()
         
     | 
| 106 | 
         
            +
                    self.config = config
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                def setup(self, stage=None):
         
     | 
| 109 | 
         
            +
                    if stage in [None, "fit"]:
         
     | 
| 110 | 
         
            +
                        self.train_dataset = BlenderIterableDataset(
         
     | 
| 111 | 
         
            +
                            self.config, self.config.train_split
         
     | 
| 112 | 
         
            +
                        )
         
     | 
| 113 | 
         
            +
                    if stage in [None, "fit", "validate"]:
         
     | 
| 114 | 
         
            +
                        self.val_dataset = BlenderDataset(self.config, self.config.val_split)
         
     | 
| 115 | 
         
            +
                    if stage in [None, "test"]:
         
     | 
| 116 | 
         
            +
                        self.test_dataset = BlenderDataset(self.config, self.config.test_split)
         
     | 
| 117 | 
         
            +
                    if stage in [None, "predict"]:
         
     | 
| 118 | 
         
            +
                        self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                def prepare_data(self):
         
     | 
| 121 | 
         
            +
                    pass
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                def general_loader(self, dataset, batch_size):
         
     | 
| 124 | 
         
            +
                    sampler = None
         
     | 
| 125 | 
         
            +
                    return DataLoader(
         
     | 
| 126 | 
         
            +
                        dataset,
         
     | 
| 127 | 
         
            +
                        num_workers=os.cpu_count(),
         
     | 
| 128 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 129 | 
         
            +
                        pin_memory=True,
         
     | 
| 130 | 
         
            +
                        sampler=sampler,
         
     | 
| 131 | 
         
            +
                    )
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                def train_dataloader(self):
         
     | 
| 134 | 
         
            +
                    return self.general_loader(self.train_dataset, batch_size=1)
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                def val_dataloader(self):
         
     | 
| 137 | 
         
            +
                    return self.general_loader(self.val_dataset, batch_size=1)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                def test_dataloader(self):
         
     | 
| 140 | 
         
            +
                    return self.general_loader(self.test_dataset, batch_size=1)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                def predict_dataloader(self):
         
     | 
| 143 | 
         
            +
                    return self.general_loader(self.predict_dataset, batch_size=1)
         
     | 
    	
        mesh_recon/datasets/colmap.py
    ADDED
    
    | 
         @@ -0,0 +1,332 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import math
         
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            from PIL import Image
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 8 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, IterableDataset
         
     | 
| 9 | 
         
            +
            import torchvision.transforms.functional as TF
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import datasets
         
     | 
| 14 | 
         
            +
            from datasets.colmap_utils import \
         
     | 
| 15 | 
         
            +
                read_cameras_binary, read_images_binary, read_points3d_binary
         
     | 
| 16 | 
         
            +
            from models.ray_utils import get_ray_directions
         
     | 
| 17 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def get_center(pts):
         
     | 
| 21 | 
         
            +
                center = pts.mean(0)
         
     | 
| 22 | 
         
            +
                dis = (pts - center[None,:]).norm(p=2, dim=-1)
         
     | 
| 23 | 
         
            +
                mean, std = dis.mean(), dis.std()
         
     | 
| 24 | 
         
            +
                q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75)
         
     | 
| 25 | 
         
            +
                valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5)
         
     | 
| 26 | 
         
            +
                center = pts[valid].mean(0)
         
     | 
| 27 | 
         
            +
                return center
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def normalize_poses(poses, pts, up_est_method, center_est_method):
         
     | 
| 30 | 
         
            +
                if center_est_method == 'camera':
         
     | 
| 31 | 
         
            +
                    # estimation scene center as the average of all camera positions
         
     | 
| 32 | 
         
            +
                    center = poses[...,3].mean(0)
         
     | 
| 33 | 
         
            +
                elif center_est_method == 'lookat':
         
     | 
| 34 | 
         
            +
                    # estimation scene center as the average of the intersection of selected pairs of camera rays
         
     | 
| 35 | 
         
            +
                    cams_ori = poses[...,3]
         
     | 
| 36 | 
         
            +
                    cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.])
         
     | 
| 37 | 
         
            +
                    cams_dir = F.normalize(cams_dir, dim=-1)
         
     | 
| 38 | 
         
            +
                    A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1)
         
     | 
| 39 | 
         
            +
                    b = -cams_ori + cams_ori.roll(1,0)
         
     | 
| 40 | 
         
            +
                    t = torch.linalg.lstsq(A, b).solution
         
     | 
| 41 | 
         
            +
                    center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2))
         
     | 
| 42 | 
         
            +
                elif center_est_method == 'point':
         
     | 
| 43 | 
         
            +
                    # first estimation scene center as the average of all camera positions
         
     | 
| 44 | 
         
            +
                    # later we'll use the center of all points bounded by the cameras as the final scene center
         
     | 
| 45 | 
         
            +
                    center = poses[...,3].mean(0)
         
     | 
| 46 | 
         
            +
                else:
         
     | 
| 47 | 
         
            +
                    raise NotImplementedError(f'Unknown center estimation method: {center_est_method}')
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                if up_est_method == 'ground':
         
     | 
| 50 | 
         
            +
                    # estimate up direction as the normal of the estimated ground plane
         
     | 
| 51 | 
         
            +
                    # use RANSAC to estimate the ground plane in the point cloud
         
     | 
| 52 | 
         
            +
                    import pyransac3d as pyrsc
         
     | 
| 53 | 
         
            +
                    ground = pyrsc.Plane()
         
     | 
| 54 | 
         
            +
                    plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale
         
     | 
| 55 | 
         
            +
                    plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0
         
     | 
| 56 | 
         
            +
                    z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction
         
     | 
| 57 | 
         
            +
                    signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1)
         
     | 
| 58 | 
         
            +
                    if signed_distance.mean() < 0:
         
     | 
| 59 | 
         
            +
                        z = -z # flip the direction if points lie under the plane
         
     | 
| 60 | 
         
            +
                elif up_est_method == 'camera':
         
     | 
| 61 | 
         
            +
                    # estimate up direction as the average of all camera up directions
         
     | 
| 62 | 
         
            +
                    z = F.normalize((poses[...,3] - center).mean(0), dim=0)
         
     | 
| 63 | 
         
            +
                else:
         
     | 
| 64 | 
         
            +
                    raise NotImplementedError(f'Unknown up estimation method: {up_est_method}')
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                # new axis
         
     | 
| 67 | 
         
            +
                y_ = torch.as_tensor([z[1], -z[0], 0.])
         
     | 
| 68 | 
         
            +
                x = F.normalize(y_.cross(z), dim=0)
         
     | 
| 69 | 
         
            +
                y = z.cross(x)
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                if center_est_method == 'point':
         
     | 
| 72 | 
         
            +
                    # rotation
         
     | 
| 73 | 
         
            +
                    Rc = torch.stack([x, y, z], dim=1)
         
     | 
| 74 | 
         
            +
                    R = Rc.T
         
     | 
| 75 | 
         
            +
                    poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1)
         
     | 
| 76 | 
         
            +
                    inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
         
     | 
| 77 | 
         
            +
                    poses_norm = (inv_trans @ poses_homo)[:,:3]
         
     | 
| 78 | 
         
            +
                    pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    # translation and scaling
         
     | 
| 81 | 
         
            +
                    poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0]
         
     | 
| 82 | 
         
            +
                    pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])]
         
     | 
| 83 | 
         
            +
                    center = get_center(pts_fg)
         
     | 
| 84 | 
         
            +
                    tc = center.reshape(3, 1)
         
     | 
| 85 | 
         
            +
                    t = -tc
         
     | 
| 86 | 
         
            +
                    poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1)
         
     | 
| 87 | 
         
            +
                    inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
         
     | 
| 88 | 
         
            +
                    poses_norm = (inv_trans @ poses_homo)[:,:3]
         
     | 
| 89 | 
         
            +
                    scale = poses_norm[...,3].norm(p=2, dim=-1).min()
         
     | 
| 90 | 
         
            +
                    poses_norm[...,3] /= scale
         
     | 
| 91 | 
         
            +
                    pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
         
     | 
| 92 | 
         
            +
                    pts = pts / scale
         
     | 
| 93 | 
         
            +
                else:
         
     | 
| 94 | 
         
            +
                    # rotation and translation
         
     | 
| 95 | 
         
            +
                    Rc = torch.stack([x, y, z], dim=1)
         
     | 
| 96 | 
         
            +
                    tc = center.reshape(3, 1)
         
     | 
| 97 | 
         
            +
                    R, t = Rc.T, -Rc.T @ tc
         
     | 
| 98 | 
         
            +
                    poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1)
         
     | 
| 99 | 
         
            +
                    inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0)
         
     | 
| 100 | 
         
            +
                    poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4)
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    # scaling
         
     | 
| 103 | 
         
            +
                    scale = poses_norm[...,3].norm(p=2, dim=-1).min()
         
     | 
| 104 | 
         
            +
                    poses_norm[...,3] /= scale
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                    # apply the transformation to the point cloud
         
     | 
| 107 | 
         
            +
                    pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0]
         
     | 
| 108 | 
         
            +
                    pts = pts / scale
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                return poses_norm, pts
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            def create_spheric_poses(cameras, n_steps=120):
         
     | 
| 113 | 
         
            +
                center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
         
     | 
| 114 | 
         
            +
                mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean()
         
     | 
| 115 | 
         
            +
                mean_h = cameras[:,2].mean()
         
     | 
| 116 | 
         
            +
                r = (mean_d**2 - mean_h**2).sqrt()
         
     | 
| 117 | 
         
            +
                up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device)
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                all_c2w = []
         
     | 
| 120 | 
         
            +
                for theta in torch.linspace(0, 2 * math.pi, n_steps):
         
     | 
| 121 | 
         
            +
                    cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h])
         
     | 
| 122 | 
         
            +
                    l = F.normalize(center - cam_pos, p=2, dim=0)
         
     | 
| 123 | 
         
            +
                    s = F.normalize(l.cross(up), p=2, dim=0)
         
     | 
| 124 | 
         
            +
                    u = F.normalize(s.cross(l), p=2, dim=0)
         
     | 
| 125 | 
         
            +
                    c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
         
     | 
| 126 | 
         
            +
                    all_c2w.append(c2w)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                all_c2w = torch.stack(all_c2w, dim=0)
         
     | 
| 129 | 
         
            +
                
         
     | 
| 130 | 
         
            +
                return all_c2w
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
            class ColmapDatasetBase():
         
     | 
| 133 | 
         
            +
                # the data only has to be processed once
         
     | 
| 134 | 
         
            +
                initialized = False
         
     | 
| 135 | 
         
            +
                properties = {}
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                def setup(self, config, split):
         
     | 
| 138 | 
         
            +
                    self.config = config
         
     | 
| 139 | 
         
            +
                    self.split = split
         
     | 
| 140 | 
         
            +
                    self.rank = get_rank()
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                    if not ColmapDatasetBase.initialized:
         
     | 
| 143 | 
         
            +
                        camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin'))
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                        H = int(camdata[1].height)
         
     | 
| 146 | 
         
            +
                        W = int(camdata[1].width)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                        if 'img_wh' in self.config:
         
     | 
| 149 | 
         
            +
                            w, h = self.config.img_wh
         
     | 
| 150 | 
         
            +
                            assert round(W / w * h) == H
         
     | 
| 151 | 
         
            +
                        elif 'img_downscale' in self.config:
         
     | 
| 152 | 
         
            +
                            w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
         
     | 
| 153 | 
         
            +
                        else:
         
     | 
| 154 | 
         
            +
                            raise KeyError("Either img_wh or img_downscale should be specified.")
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                        img_wh = (w, h)
         
     | 
| 157 | 
         
            +
                        factor = w / W
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                        if camdata[1].model == 'SIMPLE_RADIAL':
         
     | 
| 160 | 
         
            +
                            fx = fy = camdata[1].params[0] * factor
         
     | 
| 161 | 
         
            +
                            cx = camdata[1].params[1] * factor
         
     | 
| 162 | 
         
            +
                            cy = camdata[1].params[2] * factor
         
     | 
| 163 | 
         
            +
                        elif camdata[1].model in ['PINHOLE', 'OPENCV']:
         
     | 
| 164 | 
         
            +
                            fx = camdata[1].params[0] * factor
         
     | 
| 165 | 
         
            +
                            fy = camdata[1].params[1] * factor
         
     | 
| 166 | 
         
            +
                            cx = camdata[1].params[2] * factor
         
     | 
| 167 | 
         
            +
                            cy = camdata[1].params[3] * factor
         
     | 
| 168 | 
         
            +
                        else:
         
     | 
| 169 | 
         
            +
                            raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!")
         
     | 
| 170 | 
         
            +
                        
         
     | 
| 171 | 
         
            +
                        directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                        imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin'))
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                        mask_dir = os.path.join(self.config.root_dir, 'masks')
         
     | 
| 176 | 
         
            +
                        has_mask = os.path.exists(mask_dir) # TODO: support partial masks
         
     | 
| 177 | 
         
            +
                        apply_mask = has_mask and self.config.apply_mask
         
     | 
| 178 | 
         
            +
                        
         
     | 
| 179 | 
         
            +
                        all_c2w, all_images, all_fg_masks = [], [], []
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                        for i, d in enumerate(imdata.values()):
         
     | 
| 182 | 
         
            +
                            R = d.qvec2rotmat()
         
     | 
| 183 | 
         
            +
                            t = d.tvec.reshape(3, 1)
         
     | 
| 184 | 
         
            +
                            c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float()
         
     | 
| 185 | 
         
            +
                            c2w[:,1:3] *= -1. # COLMAP => OpenGL
         
     | 
| 186 | 
         
            +
                            all_c2w.append(c2w)
         
     | 
| 187 | 
         
            +
                            if self.split in ['train', 'val']:
         
     | 
| 188 | 
         
            +
                                img_path = os.path.join(self.config.root_dir, 'images', d.name)
         
     | 
| 189 | 
         
            +
                                img = Image.open(img_path)
         
     | 
| 190 | 
         
            +
                                img = img.resize(img_wh, Image.BICUBIC)
         
     | 
| 191 | 
         
            +
                                img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
         
     | 
| 192 | 
         
            +
                                img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu()
         
     | 
| 193 | 
         
            +
                                if has_mask:
         
     | 
| 194 | 
         
            +
                                    mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])]
         
     | 
| 195 | 
         
            +
                                    mask_paths = list(filter(os.path.exists, mask_paths))
         
     | 
| 196 | 
         
            +
                                    assert len(mask_paths) == 1
         
     | 
| 197 | 
         
            +
                                    mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1)
         
     | 
| 198 | 
         
            +
                                    mask = mask.resize(img_wh, Image.BICUBIC)
         
     | 
| 199 | 
         
            +
                                    mask = TF.to_tensor(mask)[0]
         
     | 
| 200 | 
         
            +
                                else:
         
     | 
| 201 | 
         
            +
                                    mask = torch.ones_like(img[...,0], device=img.device)
         
     | 
| 202 | 
         
            +
                                all_fg_masks.append(mask) # (h, w)
         
     | 
| 203 | 
         
            +
                                all_images.append(img)
         
     | 
| 204 | 
         
            +
                        
         
     | 
| 205 | 
         
            +
                        all_c2w = torch.stack(all_c2w, dim=0)   
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                        pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin'))
         
     | 
| 208 | 
         
            +
                        pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float()
         
     | 
| 209 | 
         
            +
                        all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method)
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                        ColmapDatasetBase.properties = {
         
     | 
| 212 | 
         
            +
                            'w': w,
         
     | 
| 213 | 
         
            +
                            'h': h,
         
     | 
| 214 | 
         
            +
                            'img_wh': img_wh,
         
     | 
| 215 | 
         
            +
                            'factor': factor,
         
     | 
| 216 | 
         
            +
                            'has_mask': has_mask,
         
     | 
| 217 | 
         
            +
                            'apply_mask': apply_mask,
         
     | 
| 218 | 
         
            +
                            'directions': directions,
         
     | 
| 219 | 
         
            +
                            'pts3d': pts3d,
         
     | 
| 220 | 
         
            +
                            'all_c2w': all_c2w,
         
     | 
| 221 | 
         
            +
                            'all_images': all_images,
         
     | 
| 222 | 
         
            +
                            'all_fg_masks': all_fg_masks
         
     | 
| 223 | 
         
            +
                        }
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                        ColmapDatasetBase.initialized = True
         
     | 
| 226 | 
         
            +
                    
         
     | 
| 227 | 
         
            +
                    for k, v in ColmapDatasetBase.properties.items():
         
     | 
| 228 | 
         
            +
                        setattr(self, k, v)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    if self.split == 'test':
         
     | 
| 231 | 
         
            +
                        self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
         
     | 
| 232 | 
         
            +
                        self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
         
     | 
| 233 | 
         
            +
                        self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
         
     | 
| 234 | 
         
            +
                    else:
         
     | 
| 235 | 
         
            +
                        self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float()
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    """
         
     | 
| 238 | 
         
            +
                    # for debug use
         
     | 
| 239 | 
         
            +
                    from models.ray_utils import get_rays
         
     | 
| 240 | 
         
            +
                    rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True)
         
     | 
| 241 | 
         
            +
                    pts_out = []
         
     | 
| 242 | 
         
            +
                    pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()]))
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    t_vals = torch.linspace(0, 1, 8)
         
     | 
| 245 | 
         
            +
                    z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :])
         
     | 
| 248 | 
         
            +
                    pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()]))
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :])
         
     | 
| 251 | 
         
            +
                    pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :])
         
     | 
| 254 | 
         
            +
                    pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :])
         
     | 
| 257 | 
         
            +
                    pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()]))
         
     | 
| 258 | 
         
            +
                    
         
     | 
| 259 | 
         
            +
                    open('cameras.txt', 'w').write('\n'.join(pts_out))
         
     | 
| 260 | 
         
            +
                    open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()]))
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    exit(1)
         
     | 
| 263 | 
         
            +
                    """
         
     | 
| 264 | 
         
            +
             
     | 
| 265 | 
         
            +
                    self.all_c2w = self.all_c2w.float().to(self.rank)
         
     | 
| 266 | 
         
            +
                    if self.config.load_data_on_gpu:
         
     | 
| 267 | 
         
            +
                        self.all_images = self.all_images.to(self.rank) 
         
     | 
| 268 | 
         
            +
                        self.all_fg_masks = self.all_fg_masks.to(self.rank)
         
     | 
| 269 | 
         
            +
                    
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
            class ColmapDataset(Dataset, ColmapDatasetBase):
         
     | 
| 272 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 273 | 
         
            +
                    self.setup(config, split)
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                def __len__(self):
         
     | 
| 276 | 
         
            +
                    return len(self.all_images)
         
     | 
| 277 | 
         
            +
                
         
     | 
| 278 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 279 | 
         
            +
                    return {
         
     | 
| 280 | 
         
            +
                        'index': index
         
     | 
| 281 | 
         
            +
                    }
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
            class ColmapIterableDataset(IterableDataset, ColmapDatasetBase):
         
     | 
| 285 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 286 | 
         
            +
                    self.setup(config, split)
         
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
                def __iter__(self):
         
     | 
| 289 | 
         
            +
                    while True:
         
     | 
| 290 | 
         
            +
                        yield {}
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
            @datasets.register('colmap')
         
     | 
| 294 | 
         
            +
            class ColmapDataModule(pl.LightningDataModule):
         
     | 
| 295 | 
         
            +
                def __init__(self, config):
         
     | 
| 296 | 
         
            +
                    super().__init__()
         
     | 
| 297 | 
         
            +
                    self.config = config
         
     | 
| 298 | 
         
            +
                
         
     | 
| 299 | 
         
            +
                def setup(self, stage=None):
         
     | 
| 300 | 
         
            +
                    if stage in [None, 'fit']:
         
     | 
| 301 | 
         
            +
                        self.train_dataset = ColmapIterableDataset(self.config, 'train')
         
     | 
| 302 | 
         
            +
                    if stage in [None, 'fit', 'validate']:
         
     | 
| 303 | 
         
            +
                        self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train'))
         
     | 
| 304 | 
         
            +
                    if stage in [None, 'test']:
         
     | 
| 305 | 
         
            +
                        self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test'))
         
     | 
| 306 | 
         
            +
                    if stage in [None, 'predict']:
         
     | 
| 307 | 
         
            +
                        self.predict_dataset = ColmapDataset(self.config, 'train')         
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                def prepare_data(self):
         
     | 
| 310 | 
         
            +
                    pass
         
     | 
| 311 | 
         
            +
                
         
     | 
| 312 | 
         
            +
                def general_loader(self, dataset, batch_size):
         
     | 
| 313 | 
         
            +
                    sampler = None
         
     | 
| 314 | 
         
            +
                    return DataLoader(
         
     | 
| 315 | 
         
            +
                        dataset, 
         
     | 
| 316 | 
         
            +
                        num_workers=os.cpu_count(), 
         
     | 
| 317 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 318 | 
         
            +
                        pin_memory=True,
         
     | 
| 319 | 
         
            +
                        sampler=sampler
         
     | 
| 320 | 
         
            +
                    )
         
     | 
| 321 | 
         
            +
                
         
     | 
| 322 | 
         
            +
                def train_dataloader(self):
         
     | 
| 323 | 
         
            +
                    return self.general_loader(self.train_dataset, batch_size=1)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                def val_dataloader(self):
         
     | 
| 326 | 
         
            +
                    return self.general_loader(self.val_dataset, batch_size=1)
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                def test_dataloader(self):
         
     | 
| 329 | 
         
            +
                    return self.general_loader(self.test_dataset, batch_size=1) 
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                def predict_dataloader(self):
         
     | 
| 332 | 
         
            +
                    return self.general_loader(self.predict_dataset, batch_size=1)       
         
     | 
    	
        mesh_recon/datasets/colmap_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,295 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # Redistribution and use in source and binary forms, with or without
         
     | 
| 5 | 
         
            +
            # modification, are permitted provided that the following conditions are met:
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            #     * Redistributions of source code must retain the above copyright
         
     | 
| 8 | 
         
            +
            #       notice, this list of conditions and the following disclaimer.
         
     | 
| 9 | 
         
            +
            #
         
     | 
| 10 | 
         
            +
            #     * Redistributions in binary form must reproduce the above copyright
         
     | 
| 11 | 
         
            +
            #       notice, this list of conditions and the following disclaimer in the
         
     | 
| 12 | 
         
            +
            #       documentation and/or other materials provided with the distribution.
         
     | 
| 13 | 
         
            +
            #
         
     | 
| 14 | 
         
            +
            #     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
         
     | 
| 15 | 
         
            +
            #       its contributors may be used to endorse or promote products derived
         
     | 
| 16 | 
         
            +
            #       from this software without specific prior written permission.
         
     | 
| 17 | 
         
            +
            #
         
     | 
| 18 | 
         
            +
            # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
         
     | 
| 19 | 
         
            +
            # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
         
     | 
| 20 | 
         
            +
            # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
         
     | 
| 21 | 
         
            +
            # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
         
     | 
| 22 | 
         
            +
            # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
         
     | 
| 23 | 
         
            +
            # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
         
     | 
| 24 | 
         
            +
            # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
         
     | 
| 25 | 
         
            +
            # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
         
     | 
| 26 | 
         
            +
            # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
         
     | 
| 27 | 
         
            +
            # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
         
     | 
| 28 | 
         
            +
            # POSSIBILITY OF SUCH DAMAGE.
         
     | 
| 29 | 
         
            +
            #
         
     | 
| 30 | 
         
            +
            # Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            import os
         
     | 
| 33 | 
         
            +
            import collections
         
     | 
| 34 | 
         
            +
            import numpy as np
         
     | 
| 35 | 
         
            +
            import struct
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            CameraModel = collections.namedtuple(
         
     | 
| 39 | 
         
            +
                "CameraModel", ["model_id", "model_name", "num_params"])
         
     | 
| 40 | 
         
            +
            Camera = collections.namedtuple(
         
     | 
| 41 | 
         
            +
                "Camera", ["id", "model", "width", "height", "params"])
         
     | 
| 42 | 
         
            +
            BaseImage = collections.namedtuple(
         
     | 
| 43 | 
         
            +
                "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
         
     | 
| 44 | 
         
            +
            Point3D = collections.namedtuple(
         
     | 
| 45 | 
         
            +
                "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            class Image(BaseImage):
         
     | 
| 48 | 
         
            +
                def qvec2rotmat(self):
         
     | 
| 49 | 
         
            +
                    return qvec2rotmat(self.qvec)
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            CAMERA_MODELS = {
         
     | 
| 53 | 
         
            +
                CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
         
     | 
| 54 | 
         
            +
                CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
         
     | 
| 55 | 
         
            +
                CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
         
     | 
| 56 | 
         
            +
                CameraModel(model_id=3, model_name="RADIAL", num_params=5),
         
     | 
| 57 | 
         
            +
                CameraModel(model_id=4, model_name="OPENCV", num_params=8),
         
     | 
| 58 | 
         
            +
                CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
         
     | 
| 59 | 
         
            +
                CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
         
     | 
| 60 | 
         
            +
                CameraModel(model_id=7, model_name="FOV", num_params=5),
         
     | 
| 61 | 
         
            +
                CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
         
     | 
| 62 | 
         
            +
                CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
         
     | 
| 63 | 
         
            +
                CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
         
     | 
| 64 | 
         
            +
            }
         
     | 
| 65 | 
         
            +
            CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \
         
     | 
| 66 | 
         
            +
                                     for camera_model in CAMERA_MODELS])
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
         
     | 
| 70 | 
         
            +
                """Read and unpack the next bytes from a binary file.
         
     | 
| 71 | 
         
            +
                :param fid:
         
     | 
| 72 | 
         
            +
                :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
         
     | 
| 73 | 
         
            +
                :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
         
     | 
| 74 | 
         
            +
                :param endian_character: Any of {@, =, <, >, !}
         
     | 
| 75 | 
         
            +
                :return: Tuple of read and unpacked values.
         
     | 
| 76 | 
         
            +
                """
         
     | 
| 77 | 
         
            +
                data = fid.read(num_bytes)
         
     | 
| 78 | 
         
            +
                return struct.unpack(endian_character + format_char_sequence, data)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            def read_cameras_text(path):
         
     | 
| 82 | 
         
            +
                """
         
     | 
| 83 | 
         
            +
                see: src/base/reconstruction.cc
         
     | 
| 84 | 
         
            +
                    void Reconstruction::WriteCamerasText(const std::string& path)
         
     | 
| 85 | 
         
            +
                    void Reconstruction::ReadCamerasText(const std::string& path)
         
     | 
| 86 | 
         
            +
                """
         
     | 
| 87 | 
         
            +
                cameras = {}
         
     | 
| 88 | 
         
            +
                with open(path, "r") as fid:
         
     | 
| 89 | 
         
            +
                    while True:
         
     | 
| 90 | 
         
            +
                        line = fid.readline()
         
     | 
| 91 | 
         
            +
                        if not line:
         
     | 
| 92 | 
         
            +
                            break
         
     | 
| 93 | 
         
            +
                        line = line.strip()
         
     | 
| 94 | 
         
            +
                        if len(line) > 0 and line[0] != "#":
         
     | 
| 95 | 
         
            +
                            elems = line.split()
         
     | 
| 96 | 
         
            +
                            camera_id = int(elems[0])
         
     | 
| 97 | 
         
            +
                            model = elems[1]
         
     | 
| 98 | 
         
            +
                            width = int(elems[2])
         
     | 
| 99 | 
         
            +
                            height = int(elems[3])
         
     | 
| 100 | 
         
            +
                            params = np.array(tuple(map(float, elems[4:])))
         
     | 
| 101 | 
         
            +
                            cameras[camera_id] = Camera(id=camera_id, model=model,
         
     | 
| 102 | 
         
            +
                                                        width=width, height=height,
         
     | 
| 103 | 
         
            +
                                                        params=params)
         
     | 
| 104 | 
         
            +
                return cameras
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
            def read_cameras_binary(path_to_model_file):
         
     | 
| 108 | 
         
            +
                """
         
     | 
| 109 | 
         
            +
                see: src/base/reconstruction.cc
         
     | 
| 110 | 
         
            +
                    void Reconstruction::WriteCamerasBinary(const std::string& path)
         
     | 
| 111 | 
         
            +
                    void Reconstruction::ReadCamerasBinary(const std::string& path)
         
     | 
| 112 | 
         
            +
                """
         
     | 
| 113 | 
         
            +
                cameras = {}
         
     | 
| 114 | 
         
            +
                with open(path_to_model_file, "rb") as fid:
         
     | 
| 115 | 
         
            +
                    num_cameras = read_next_bytes(fid, 8, "Q")[0]
         
     | 
| 116 | 
         
            +
                    for camera_line_index in range(num_cameras):
         
     | 
| 117 | 
         
            +
                        camera_properties = read_next_bytes(
         
     | 
| 118 | 
         
            +
                            fid, num_bytes=24, format_char_sequence="iiQQ")
         
     | 
| 119 | 
         
            +
                        camera_id = camera_properties[0]
         
     | 
| 120 | 
         
            +
                        model_id = camera_properties[1]
         
     | 
| 121 | 
         
            +
                        model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
         
     | 
| 122 | 
         
            +
                        width = camera_properties[2]
         
     | 
| 123 | 
         
            +
                        height = camera_properties[3]
         
     | 
| 124 | 
         
            +
                        num_params = CAMERA_MODEL_IDS[model_id].num_params
         
     | 
| 125 | 
         
            +
                        params = read_next_bytes(fid, num_bytes=8*num_params,
         
     | 
| 126 | 
         
            +
                                                 format_char_sequence="d"*num_params)
         
     | 
| 127 | 
         
            +
                        cameras[camera_id] = Camera(id=camera_id,
         
     | 
| 128 | 
         
            +
                                                    model=model_name,
         
     | 
| 129 | 
         
            +
                                                    width=width,
         
     | 
| 130 | 
         
            +
                                                    height=height,
         
     | 
| 131 | 
         
            +
                                                    params=np.array(params))
         
     | 
| 132 | 
         
            +
                    assert len(cameras) == num_cameras
         
     | 
| 133 | 
         
            +
                return cameras
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
            def read_images_text(path):
         
     | 
| 137 | 
         
            +
                """
         
     | 
| 138 | 
         
            +
                see: src/base/reconstruction.cc
         
     | 
| 139 | 
         
            +
                    void Reconstruction::ReadImagesText(const std::string& path)
         
     | 
| 140 | 
         
            +
                    void Reconstruction::WriteImagesText(const std::string& path)
         
     | 
| 141 | 
         
            +
                """
         
     | 
| 142 | 
         
            +
                images = {}
         
     | 
| 143 | 
         
            +
                with open(path, "r") as fid:
         
     | 
| 144 | 
         
            +
                    while True:
         
     | 
| 145 | 
         
            +
                        line = fid.readline()
         
     | 
| 146 | 
         
            +
                        if not line:
         
     | 
| 147 | 
         
            +
                            break
         
     | 
| 148 | 
         
            +
                        line = line.strip()
         
     | 
| 149 | 
         
            +
                        if len(line) > 0 and line[0] != "#":
         
     | 
| 150 | 
         
            +
                            elems = line.split()
         
     | 
| 151 | 
         
            +
                            image_id = int(elems[0])
         
     | 
| 152 | 
         
            +
                            qvec = np.array(tuple(map(float, elems[1:5])))
         
     | 
| 153 | 
         
            +
                            tvec = np.array(tuple(map(float, elems[5:8])))
         
     | 
| 154 | 
         
            +
                            camera_id = int(elems[8])
         
     | 
| 155 | 
         
            +
                            image_name = elems[9]
         
     | 
| 156 | 
         
            +
                            elems = fid.readline().split()
         
     | 
| 157 | 
         
            +
                            xys = np.column_stack([tuple(map(float, elems[0::3])),
         
     | 
| 158 | 
         
            +
                                                   tuple(map(float, elems[1::3]))])
         
     | 
| 159 | 
         
            +
                            point3D_ids = np.array(tuple(map(int, elems[2::3])))
         
     | 
| 160 | 
         
            +
                            images[image_id] = Image(
         
     | 
| 161 | 
         
            +
                                id=image_id, qvec=qvec, tvec=tvec,
         
     | 
| 162 | 
         
            +
                                camera_id=camera_id, name=image_name,
         
     | 
| 163 | 
         
            +
                                xys=xys, point3D_ids=point3D_ids)
         
     | 
| 164 | 
         
            +
                return images
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
            def read_images_binary(path_to_model_file):
         
     | 
| 168 | 
         
            +
                """
         
     | 
| 169 | 
         
            +
                see: src/base/reconstruction.cc
         
     | 
| 170 | 
         
            +
                    void Reconstruction::ReadImagesBinary(const std::string& path)
         
     | 
| 171 | 
         
            +
                    void Reconstruction::WriteImagesBinary(const std::string& path)
         
     | 
| 172 | 
         
            +
                """
         
     | 
| 173 | 
         
            +
                images = {}
         
     | 
| 174 | 
         
            +
                with open(path_to_model_file, "rb") as fid:
         
     | 
| 175 | 
         
            +
                    num_reg_images = read_next_bytes(fid, 8, "Q")[0]
         
     | 
| 176 | 
         
            +
                    for image_index in range(num_reg_images):
         
     | 
| 177 | 
         
            +
                        binary_image_properties = read_next_bytes(
         
     | 
| 178 | 
         
            +
                            fid, num_bytes=64, format_char_sequence="idddddddi")
         
     | 
| 179 | 
         
            +
                        image_id = binary_image_properties[0]
         
     | 
| 180 | 
         
            +
                        qvec = np.array(binary_image_properties[1:5])
         
     | 
| 181 | 
         
            +
                        tvec = np.array(binary_image_properties[5:8])
         
     | 
| 182 | 
         
            +
                        camera_id = binary_image_properties[8]
         
     | 
| 183 | 
         
            +
                        image_name = ""
         
     | 
| 184 | 
         
            +
                        current_char = read_next_bytes(fid, 1, "c")[0]
         
     | 
| 185 | 
         
            +
                        while current_char != b"\x00":   # look for the ASCII 0 entry
         
     | 
| 186 | 
         
            +
                            image_name += current_char.decode("utf-8")
         
     | 
| 187 | 
         
            +
                            current_char = read_next_bytes(fid, 1, "c")[0]
         
     | 
| 188 | 
         
            +
                        num_points2D = read_next_bytes(fid, num_bytes=8,
         
     | 
| 189 | 
         
            +
                                                       format_char_sequence="Q")[0]
         
     | 
| 190 | 
         
            +
                        x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
         
     | 
| 191 | 
         
            +
                                                   format_char_sequence="ddq"*num_points2D)
         
     | 
| 192 | 
         
            +
                        xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
         
     | 
| 193 | 
         
            +
                                               tuple(map(float, x_y_id_s[1::3]))])
         
     | 
| 194 | 
         
            +
                        point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
         
     | 
| 195 | 
         
            +
                        images[image_id] = Image(
         
     | 
| 196 | 
         
            +
                            id=image_id, qvec=qvec, tvec=tvec,
         
     | 
| 197 | 
         
            +
                            camera_id=camera_id, name=image_name,
         
     | 
| 198 | 
         
            +
                            xys=xys, point3D_ids=point3D_ids)
         
     | 
| 199 | 
         
            +
                return images
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
            def read_points3D_text(path):
         
     | 
| 203 | 
         
            +
                """
         
     | 
| 204 | 
         
            +
                see: src/base/reconstruction.cc
         
     | 
| 205 | 
         
            +
                    void Reconstruction::ReadPoints3DText(const std::string& path)
         
     | 
| 206 | 
         
            +
                    void Reconstruction::WritePoints3DText(const std::string& path)
         
     | 
| 207 | 
         
            +
                """
         
     | 
| 208 | 
         
            +
                points3D = {}
         
     | 
| 209 | 
         
            +
                with open(path, "r") as fid:
         
     | 
| 210 | 
         
            +
                    while True:
         
     | 
| 211 | 
         
            +
                        line = fid.readline()
         
     | 
| 212 | 
         
            +
                        if not line:
         
     | 
| 213 | 
         
            +
                            break
         
     | 
| 214 | 
         
            +
                        line = line.strip()
         
     | 
| 215 | 
         
            +
                        if len(line) > 0 and line[0] != "#":
         
     | 
| 216 | 
         
            +
                            elems = line.split()
         
     | 
| 217 | 
         
            +
                            point3D_id = int(elems[0])
         
     | 
| 218 | 
         
            +
                            xyz = np.array(tuple(map(float, elems[1:4])))
         
     | 
| 219 | 
         
            +
                            rgb = np.array(tuple(map(int, elems[4:7])))
         
     | 
| 220 | 
         
            +
                            error = float(elems[7])
         
     | 
| 221 | 
         
            +
                            image_ids = np.array(tuple(map(int, elems[8::2])))
         
     | 
| 222 | 
         
            +
                            point2D_idxs = np.array(tuple(map(int, elems[9::2])))
         
     | 
| 223 | 
         
            +
                            points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
         
     | 
| 224 | 
         
            +
                                                           error=error, image_ids=image_ids,
         
     | 
| 225 | 
         
            +
                                                           point2D_idxs=point2D_idxs)
         
     | 
| 226 | 
         
            +
                return points3D
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
            def read_points3d_binary(path_to_model_file):
         
     | 
| 230 | 
         
            +
                """
         
     | 
| 231 | 
         
            +
                see: src/base/reconstruction.cc
         
     | 
| 232 | 
         
            +
                    void Reconstruction::ReadPoints3DBinary(const std::string& path)
         
     | 
| 233 | 
         
            +
                    void Reconstruction::WritePoints3DBinary(const std::string& path)
         
     | 
| 234 | 
         
            +
                """
         
     | 
| 235 | 
         
            +
                points3D = {}
         
     | 
| 236 | 
         
            +
                with open(path_to_model_file, "rb") as fid:
         
     | 
| 237 | 
         
            +
                    num_points = read_next_bytes(fid, 8, "Q")[0]
         
     | 
| 238 | 
         
            +
                    for point_line_index in range(num_points):
         
     | 
| 239 | 
         
            +
                        binary_point_line_properties = read_next_bytes(
         
     | 
| 240 | 
         
            +
                            fid, num_bytes=43, format_char_sequence="QdddBBBd")
         
     | 
| 241 | 
         
            +
                        point3D_id = binary_point_line_properties[0]
         
     | 
| 242 | 
         
            +
                        xyz = np.array(binary_point_line_properties[1:4])
         
     | 
| 243 | 
         
            +
                        rgb = np.array(binary_point_line_properties[4:7])
         
     | 
| 244 | 
         
            +
                        error = np.array(binary_point_line_properties[7])
         
     | 
| 245 | 
         
            +
                        track_length = read_next_bytes(
         
     | 
| 246 | 
         
            +
                            fid, num_bytes=8, format_char_sequence="Q")[0]
         
     | 
| 247 | 
         
            +
                        track_elems = read_next_bytes(
         
     | 
| 248 | 
         
            +
                            fid, num_bytes=8*track_length,
         
     | 
| 249 | 
         
            +
                            format_char_sequence="ii"*track_length)
         
     | 
| 250 | 
         
            +
                        image_ids = np.array(tuple(map(int, track_elems[0::2])))
         
     | 
| 251 | 
         
            +
                        point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
         
     | 
| 252 | 
         
            +
                        points3D[point3D_id] = Point3D(
         
     | 
| 253 | 
         
            +
                            id=point3D_id, xyz=xyz, rgb=rgb,
         
     | 
| 254 | 
         
            +
                            error=error, image_ids=image_ids,
         
     | 
| 255 | 
         
            +
                            point2D_idxs=point2D_idxs)
         
     | 
| 256 | 
         
            +
                return points3D
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
            def read_model(path, ext):
         
     | 
| 260 | 
         
            +
                if ext == ".txt":
         
     | 
| 261 | 
         
            +
                    cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
         
     | 
| 262 | 
         
            +
                    images = read_images_text(os.path.join(path, "images" + ext))
         
     | 
| 263 | 
         
            +
                    points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
         
     | 
| 264 | 
         
            +
                else:
         
     | 
| 265 | 
         
            +
                    cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
         
     | 
| 266 | 
         
            +
                    images = read_images_binary(os.path.join(path, "images" + ext))
         
     | 
| 267 | 
         
            +
                    points3D = read_points3d_binary(os.path.join(path, "points3D") + ext)
         
     | 
| 268 | 
         
            +
                return cameras, images, points3D
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
            def qvec2rotmat(qvec):
         
     | 
| 272 | 
         
            +
                return np.array([
         
     | 
| 273 | 
         
            +
                    [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
         
     | 
| 274 | 
         
            +
                     2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
         
     | 
| 275 | 
         
            +
                     2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
         
     | 
| 276 | 
         
            +
                    [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
         
     | 
| 277 | 
         
            +
                     1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
         
     | 
| 278 | 
         
            +
                     2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
         
     | 
| 279 | 
         
            +
                    [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
         
     | 
| 280 | 
         
            +
                     2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
         
     | 
| 281 | 
         
            +
                     1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
            def rotmat2qvec(R):
         
     | 
| 285 | 
         
            +
                Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
         
     | 
| 286 | 
         
            +
                K = np.array([
         
     | 
| 287 | 
         
            +
                    [Rxx - Ryy - Rzz, 0, 0, 0],
         
     | 
| 288 | 
         
            +
                    [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
         
     | 
| 289 | 
         
            +
                    [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
         
     | 
| 290 | 
         
            +
                    [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
         
     | 
| 291 | 
         
            +
                eigvals, eigvecs = np.linalg.eigh(K)
         
     | 
| 292 | 
         
            +
                qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
         
     | 
| 293 | 
         
            +
                if qvec[0] < 0:
         
     | 
| 294 | 
         
            +
                    qvec *= -1
         
     | 
| 295 | 
         
            +
                return qvec
         
     | 
    	
        mesh_recon/datasets/dtu.py
    ADDED
    
    | 
         @@ -0,0 +1,201 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
            import cv2
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, IterableDataset
         
     | 
| 11 | 
         
            +
            import torchvision.transforms.functional as TF
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import datasets
         
     | 
| 16 | 
         
            +
            from models.ray_utils import get_ray_directions
         
     | 
| 17 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            def load_K_Rt_from_P(P=None):
         
     | 
| 21 | 
         
            +
                out = cv2.decomposeProjectionMatrix(P)
         
     | 
| 22 | 
         
            +
                K = out[0]
         
     | 
| 23 | 
         
            +
                R = out[1]
         
     | 
| 24 | 
         
            +
                t = out[2]
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
                K = K / K[2, 2]
         
     | 
| 27 | 
         
            +
                intrinsics = np.eye(4)
         
     | 
| 28 | 
         
            +
                intrinsics[:3, :3] = K
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                pose = np.eye(4, dtype=np.float32)
         
     | 
| 31 | 
         
            +
                pose[:3, :3] = R.transpose()
         
     | 
| 32 | 
         
            +
                pose[:3, 3] = (t[:3] / t[3])[:, 0]
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                return intrinsics, pose
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def create_spheric_poses(cameras, n_steps=120):
         
     | 
| 37 | 
         
            +
                center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device)
         
     | 
| 38 | 
         
            +
                cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2)
         
     | 
| 39 | 
         
            +
                eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors
         
     | 
| 40 | 
         
            +
                rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1)
         
     | 
| 41 | 
         
            +
                up = rot_axis
         
     | 
| 42 | 
         
            +
                rot_dir = torch.cross(rot_axis, cam_center)
         
     | 
| 43 | 
         
            +
                max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max()
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
                all_c2w = []
         
     | 
| 46 | 
         
            +
                for theta in torch.linspace(-max_angle, max_angle, n_steps):
         
     | 
| 47 | 
         
            +
                    cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta)
         
     | 
| 48 | 
         
            +
                    l = F.normalize(center - cam_pos, p=2, dim=0)
         
     | 
| 49 | 
         
            +
                    s = F.normalize(l.cross(up), p=2, dim=0)
         
     | 
| 50 | 
         
            +
                    u = F.normalize(s.cross(l), p=2, dim=0)
         
     | 
| 51 | 
         
            +
                    c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1)
         
     | 
| 52 | 
         
            +
                    all_c2w.append(c2w)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                all_c2w = torch.stack(all_c2w, dim=0)
         
     | 
| 55 | 
         
            +
                
         
     | 
| 56 | 
         
            +
                return all_c2w
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            class DTUDatasetBase():
         
     | 
| 59 | 
         
            +
                def setup(self, config, split):
         
     | 
| 60 | 
         
            +
                    self.config = config
         
     | 
| 61 | 
         
            +
                    self.split = split
         
     | 
| 62 | 
         
            +
                    self.rank = get_rank()
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file))
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png'))
         
     | 
| 67 | 
         
            +
                    H, W = img_sample.shape[0], img_sample.shape[1]
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    if 'img_wh' in self.config:
         
     | 
| 70 | 
         
            +
                        w, h = self.config.img_wh
         
     | 
| 71 | 
         
            +
                        assert round(W / w * h) == H
         
     | 
| 72 | 
         
            +
                    elif 'img_downscale' in self.config:
         
     | 
| 73 | 
         
            +
                        w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5)
         
     | 
| 74 | 
         
            +
                    else:
         
     | 
| 75 | 
         
            +
                        raise KeyError("Either img_wh or img_downscale should be specified.")
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                    self.w, self.h = w, h
         
     | 
| 78 | 
         
            +
                    self.img_wh = (w, h)
         
     | 
| 79 | 
         
            +
                    self.factor = w / W
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    mask_dir = os.path.join(self.config.root_dir, 'mask')
         
     | 
| 82 | 
         
            +
                    self.has_mask = True
         
     | 
| 83 | 
         
            +
                    self.apply_mask = self.config.apply_mask
         
     | 
| 84 | 
         
            +
                    
         
     | 
| 85 | 
         
            +
                    self.directions = []
         
     | 
| 86 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    for i in range(n_images):
         
     | 
| 91 | 
         
            +
                        world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}']
         
     | 
| 92 | 
         
            +
                        P = (world_mat @ scale_mat)[:3,:4]
         
     | 
| 93 | 
         
            +
                        K, c2w = load_K_Rt_from_P(P)
         
     | 
| 94 | 
         
            +
                        fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor
         
     | 
| 95 | 
         
            +
                        directions = get_ray_directions(w, h, fx, fy, cx, cy)
         
     | 
| 96 | 
         
            +
                        self.directions.append(directions)
         
     | 
| 97 | 
         
            +
                        
         
     | 
| 98 | 
         
            +
                        c2w = torch.from_numpy(c2w).float()
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                        # blender follows opengl camera coordinates (right up back)
         
     | 
| 101 | 
         
            +
                        # NeuS DTU data coordinate system (right down front) is different from blender
         
     | 
| 102 | 
         
            +
                        # https://github.com/Totoro97/NeuS/issues/9
         
     | 
| 103 | 
         
            +
                        # for c2w, flip the sign of input camera coordinate yz
         
     | 
| 104 | 
         
            +
                        c2w_ = c2w.clone()
         
     | 
| 105 | 
         
            +
                        c2w_[:3,1:3] *= -1. # flip input sign
         
     | 
| 106 | 
         
            +
                        self.all_c2w.append(c2w_[:3,:4])         
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                        if self.split in ['train', 'val']:
         
     | 
| 109 | 
         
            +
                            img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png')
         
     | 
| 110 | 
         
            +
                            img = Image.open(img_path)
         
     | 
| 111 | 
         
            +
                            img = img.resize(self.img_wh, Image.BICUBIC)
         
     | 
| 112 | 
         
            +
                            img = TF.to_tensor(img).permute(1, 2, 0)[...,:3]
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                            mask_path = os.path.join(mask_dir, f'{i:03d}.png')
         
     | 
| 115 | 
         
            +
                            mask = Image.open(mask_path).convert('L') # (H, W, 1)
         
     | 
| 116 | 
         
            +
                            mask = mask.resize(self.img_wh, Image.BICUBIC)
         
     | 
| 117 | 
         
            +
                            mask = TF.to_tensor(mask)[0]
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                            self.all_fg_masks.append(mask) # (h, w)
         
     | 
| 120 | 
         
            +
                            self.all_images.append(img)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    self.all_c2w = torch.stack(self.all_c2w, dim=0)
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
                    if self.split == 'test':
         
     | 
| 125 | 
         
            +
                        self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps)
         
     | 
| 126 | 
         
            +
                        self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32)
         
     | 
| 127 | 
         
            +
                        self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32)
         
     | 
| 128 | 
         
            +
                        self.directions = self.directions[0]
         
     | 
| 129 | 
         
            +
                    else:
         
     | 
| 130 | 
         
            +
                        self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0)  
         
     | 
| 131 | 
         
            +
                        self.directions = torch.stack(self.directions, dim=0)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    self.directions = self.directions.float().to(self.rank)
         
     | 
| 134 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = \
         
     | 
| 135 | 
         
            +
                        self.all_c2w.float().to(self.rank), \
         
     | 
| 136 | 
         
            +
                        self.all_images.float().to(self.rank), \
         
     | 
| 137 | 
         
            +
                        self.all_fg_masks.float().to(self.rank)
         
     | 
| 138 | 
         
            +
                    
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
            class DTUDataset(Dataset, DTUDatasetBase):
         
     | 
| 141 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 142 | 
         
            +
                    self.setup(config, split)
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def __len__(self):
         
     | 
| 145 | 
         
            +
                    return len(self.all_images)
         
     | 
| 146 | 
         
            +
                
         
     | 
| 147 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 148 | 
         
            +
                    return {
         
     | 
| 149 | 
         
            +
                        'index': index
         
     | 
| 150 | 
         
            +
                    }
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
            class DTUIterableDataset(IterableDataset, DTUDatasetBase):
         
     | 
| 154 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 155 | 
         
            +
                    self.setup(config, split)
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                def __iter__(self):
         
     | 
| 158 | 
         
            +
                    while True:
         
     | 
| 159 | 
         
            +
                        yield {}
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
            @datasets.register('dtu')
         
     | 
| 163 | 
         
            +
            class DTUDataModule(pl.LightningDataModule):
         
     | 
| 164 | 
         
            +
                def __init__(self, config):
         
     | 
| 165 | 
         
            +
                    super().__init__()
         
     | 
| 166 | 
         
            +
                    self.config = config
         
     | 
| 167 | 
         
            +
                
         
     | 
| 168 | 
         
            +
                def setup(self, stage=None):
         
     | 
| 169 | 
         
            +
                    if stage in [None, 'fit']:
         
     | 
| 170 | 
         
            +
                        self.train_dataset = DTUIterableDataset(self.config, 'train')
         
     | 
| 171 | 
         
            +
                    if stage in [None, 'fit', 'validate']:
         
     | 
| 172 | 
         
            +
                        self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train'))
         
     | 
| 173 | 
         
            +
                    if stage in [None, 'test']:
         
     | 
| 174 | 
         
            +
                        self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test'))
         
     | 
| 175 | 
         
            +
                    if stage in [None, 'predict']:
         
     | 
| 176 | 
         
            +
                        self.predict_dataset = DTUDataset(self.config, 'train')    
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                def prepare_data(self):
         
     | 
| 179 | 
         
            +
                    pass
         
     | 
| 180 | 
         
            +
                
         
     | 
| 181 | 
         
            +
                def general_loader(self, dataset, batch_size):
         
     | 
| 182 | 
         
            +
                    sampler = None
         
     | 
| 183 | 
         
            +
                    return DataLoader(
         
     | 
| 184 | 
         
            +
                        dataset, 
         
     | 
| 185 | 
         
            +
                        num_workers=os.cpu_count(), 
         
     | 
| 186 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 187 | 
         
            +
                        pin_memory=True,
         
     | 
| 188 | 
         
            +
                        sampler=sampler
         
     | 
| 189 | 
         
            +
                    )
         
     | 
| 190 | 
         
            +
                
         
     | 
| 191 | 
         
            +
                def train_dataloader(self):
         
     | 
| 192 | 
         
            +
                    return self.general_loader(self.train_dataset, batch_size=1)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                def val_dataloader(self):
         
     | 
| 195 | 
         
            +
                    return self.general_loader(self.val_dataset, batch_size=1)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                def test_dataloader(self):
         
     | 
| 198 | 
         
            +
                    return self.general_loader(self.test_dataset, batch_size=1) 
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                def predict_dataloader(self):
         
     | 
| 201 | 
         
            +
                    return self.general_loader(self.predict_dataset, batch_size=1)       
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_back_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            -1.000000238418579102e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 1.746665105883948854e-07
         
     | 
| 3 | 
         
            +
            0.000000000000000000e+00 1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_back_left_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            -7.071069478988647461e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08
         
     | 
| 3 | 
         
            +
            -7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_back_right_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            -7.071069478988647461e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08
         
     | 
| 3 | 
         
            +
            7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_front_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 -1.746665105883948854e-07
         
     | 
| 3 | 
         
            +
            0.000000000000000000e+00 -1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_front_left_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            7.071067690849304199e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08
         
     | 
| 3 | 
         
            +
            -7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_front_right_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            7.071067690849304199e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08
         
     | 
| 3 | 
         
            +
            7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_left_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            -2.220446049250313081e-16 -1.000000000000000000e+00 0.000000000000000000e+00 -2.886579758146288598e-16
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
         
     | 
| 3 | 
         
            +
            -1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_right_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 2.886579758146288598e-16
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00
         
     | 
| 3 | 
         
            +
            1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00
         
     | 
    	
        mesh_recon/datasets/fixed_poses/000_top_RT.txt
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
         
     | 
| 2 | 
         
            +
            0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00
         
     | 
| 3 | 
         
            +
            0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 -1.299999952316284180e+00
         
     | 
    	
        mesh_recon/datasets/ortho.py
    ADDED
    
    | 
         @@ -0,0 +1,287 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
            import cv2
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, IterableDataset
         
     | 
| 11 | 
         
            +
            import torchvision.transforms.functional as TF
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import datasets
         
     | 
| 16 | 
         
            +
            from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions
         
     | 
| 17 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from glob import glob
         
     | 
| 20 | 
         
            +
            import PIL.Image
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def camNormal2worldNormal(rot_c2w, camNormal):
         
     | 
| 24 | 
         
            +
                H,W,_ = camNormal.shape
         
     | 
| 25 | 
         
            +
                normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                return normal_img
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def worldNormal2camNormal(rot_w2c, worldNormal):
         
     | 
| 30 | 
         
            +
                H,W,_ = worldNormal.shape
         
     | 
| 31 | 
         
            +
                normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3])
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                return normal_img
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def trans_normal(normal, RT_w2c, RT_w2c_target):
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal)
         
     | 
| 38 | 
         
            +
                normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                return normal_target_cam
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            def img2normal(img):
         
     | 
| 43 | 
         
            +
                return (img/255.)*2-1
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            def normal2img(normal):
         
     | 
| 46 | 
         
            +
                return np.uint8((normal*0.5+0.5)*255)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def norm_normalize(normal, dim=-1):
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                return normal
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def RT_opengl2opencv(RT):
         
     | 
| 55 | 
         
            +
                 # Build the coordinate transform matrix from world to computer vision camera
         
     | 
| 56 | 
         
            +
                # R_world2cv = R_bcam2cv@R_world2bcam
         
     | 
| 57 | 
         
            +
                # T_world2cv = R_bcam2cv@T_world2bcam
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                R = RT[:3, :3]
         
     | 
| 60 | 
         
            +
                t = RT[:3, 3]
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                R_world2cv = R_bcam2cv @ R
         
     | 
| 65 | 
         
            +
                t_world2cv = R_bcam2cv @ t
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
                RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1)
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                return RT
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            def normal_opengl2opencv(normal):
         
     | 
| 72 | 
         
            +
                H,W,C = np.shape(normal)
         
     | 
| 73 | 
         
            +
                # normal_img = np.reshape(normal, (H*W,C))
         
     | 
| 74 | 
         
            +
                R_bcam2cv = np.array([1, -1, -1], np.float32)
         
     | 
| 75 | 
         
            +
                normal_cv = normal * R_bcam2cv[None, None, :]
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                print(np.shape(normal_cv))
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                return normal_cv
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            def inv_RT(RT):
         
     | 
| 82 | 
         
            +
                RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0)
         
     | 
| 83 | 
         
            +
                RT_inv = np.linalg.inv(RT_h)
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                return RT_inv[:3, :]
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
            def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None,
         
     | 
| 89 | 
         
            +
                                     normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None):
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                all_images = []
         
     | 
| 92 | 
         
            +
                all_normals = []
         
     | 
| 93 | 
         
            +
                all_normals_world = []
         
     | 
| 94 | 
         
            +
                all_masks = []
         
     | 
| 95 | 
         
            +
                all_color_masks = []
         
     | 
| 96 | 
         
            +
                all_poses = []
         
     | 
| 97 | 
         
            +
                all_w2cs = []
         
     | 
| 98 | 
         
            +
                directions = []
         
     | 
| 99 | 
         
            +
                ray_origins = []
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0])   # world2cam matrix
         
     | 
| 102 | 
         
            +
                RT_front_cv = RT_opengl2opencv(RT_front)   # convert normal from opengl to opencv
         
     | 
| 103 | 
         
            +
                for idx, view in enumerate(view_types):
         
     | 
| 104 | 
         
            +
                    print(os.path.join(root_dir,test_object))
         
     | 
| 105 | 
         
            +
                    normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view))
         
     | 
| 106 | 
         
            +
                    # Load key frame
         
     | 
| 107 | 
         
            +
                    if load_color:  # use bgr
         
     | 
| 108 | 
         
            +
                        image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3]
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    normal = np.array(PIL.Image.open(normal_filepath).resize(imSize))
         
     | 
| 111 | 
         
            +
                    mask = normal[:, :, 3]
         
     | 
| 112 | 
         
            +
                    normal = normal[:, :, :3]
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3]
         
     | 
| 115 | 
         
            +
                    invalid_color_mask = color_mask < 255*0.5
         
     | 
| 116 | 
         
            +
                    threshold =  np.ones_like(image[:, :, 0]) * 250
         
     | 
| 117 | 
         
            +
                    invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold)
         
     | 
| 118 | 
         
            +
                    invalid_color_mask_final = invalid_color_mask & invalid_white_mask
         
     | 
| 119 | 
         
            +
                    color_mask = (1 - invalid_color_mask_final) > 0
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    # if erode_mask:
         
     | 
| 122 | 
         
            +
                    #     kernel = np.ones((3, 3), np.uint8)
         
     | 
| 123 | 
         
            +
                    #     mask = cv2.erode(mask, kernel, iterations=1)
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view)))  # world2cam matrix
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    normal = img2normal(normal)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    normal[mask==0] = [0,0,0]
         
     | 
| 130 | 
         
            +
                    mask = mask> (0.5*255)
         
     | 
| 131 | 
         
            +
                    if load_color:
         
     | 
| 132 | 
         
            +
                        all_images.append(image)
         
     | 
| 133 | 
         
            +
                    
         
     | 
| 134 | 
         
            +
                    all_masks.append(mask)
         
     | 
| 135 | 
         
            +
                    all_color_masks.append(color_mask)
         
     | 
| 136 | 
         
            +
                    RT_cv = RT_opengl2opencv(RT)   # convert normal from opengl to opencv
         
     | 
| 137 | 
         
            +
                    all_poses.append(inv_RT(RT_cv))   # cam2world
         
     | 
| 138 | 
         
            +
                    all_w2cs.append(RT_cv)
         
     | 
| 139 | 
         
            +
             
     | 
| 140 | 
         
            +
                    # whether to 
         
     | 
| 141 | 
         
            +
                    normal_cam_cv = normal_opengl2opencv(normal)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    if normal_system == 'front':
         
     | 
| 144 | 
         
            +
                        print("the loaded normals are defined in the system of front view")
         
     | 
| 145 | 
         
            +
                        normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv)
         
     | 
| 146 | 
         
            +
                    elif normal_system == 'self':
         
     | 
| 147 | 
         
            +
                        print("the loaded normals are in their independent camera systems")
         
     | 
| 148 | 
         
            +
                        normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv)
         
     | 
| 149 | 
         
            +
                    all_normals.append(normal_cam_cv)
         
     | 
| 150 | 
         
            +
                    all_normals_world.append(normal_world)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    if camera_type == 'ortho':
         
     | 
| 153 | 
         
            +
                        origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1])
         
     | 
| 154 | 
         
            +
                    elif camera_type == 'pinhole':
         
     | 
| 155 | 
         
            +
                        dirs = get_ray_directions(W=imSize[0], H=imSize[1],
         
     | 
| 156 | 
         
            +
                                                             fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3])
         
     | 
| 157 | 
         
            +
                        origins = dirs # occupy a position
         
     | 
| 158 | 
         
            +
                    else:
         
     | 
| 159 | 
         
            +
                        raise Exception("not support camera type")
         
     | 
| 160 | 
         
            +
                    ray_origins.append(origins)
         
     | 
| 161 | 
         
            +
                    directions.append(dirs)
         
     | 
| 162 | 
         
            +
                    
         
     | 
| 163 | 
         
            +
                    
         
     | 
| 164 | 
         
            +
                    if not load_color:
         
     | 
| 165 | 
         
            +
                        all_images = [normal2img(x) for x in all_normals_world]
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \
         
     | 
| 169 | 
         
            +
                    np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks)
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
            class OrthoDatasetBase():
         
     | 
| 173 | 
         
            +
                def setup(self, config, split):
         
     | 
| 174 | 
         
            +
                    self.config = config
         
     | 
| 175 | 
         
            +
                    self.split = split
         
     | 
| 176 | 
         
            +
                    self.rank = get_rank()
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    self.data_dir = self.config.root_dir
         
     | 
| 179 | 
         
            +
                    self.object_name = self.config.scene
         
     | 
| 180 | 
         
            +
                    self.scene = self.config.scene
         
     | 
| 181 | 
         
            +
                    self.imSize = self.config.imSize
         
     | 
| 182 | 
         
            +
                    self.load_color = True
         
     | 
| 183 | 
         
            +
                    self.img_wh = [self.imSize[0], self.imSize[1]]
         
     | 
| 184 | 
         
            +
                    self.w = self.img_wh[0]
         
     | 
| 185 | 
         
            +
                    self.h = self.img_wh[1]
         
     | 
| 186 | 
         
            +
                    self.camera_type = self.config.camera_type
         
     | 
| 187 | 
         
            +
                    self.camera_params = self.config.camera_params  # [fx, fy, cx, cy]
         
     | 
| 188 | 
         
            +
                    
         
     | 
| 189 | 
         
            +
                    self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left']
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                    self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1)
         
     | 
| 192 | 
         
            +
                    self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w)
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    if self.config.cam_pose_dir is None:
         
     | 
| 195 | 
         
            +
                        self.cam_pose_dir = "./datasets/fixed_poses"
         
     | 
| 196 | 
         
            +
                    else:
         
     | 
| 197 | 
         
            +
                        self.cam_pose_dir = self.config.cam_pose_dir
         
     | 
| 198 | 
         
            +
                        
         
     | 
| 199 | 
         
            +
                    self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \
         
     | 
| 200 | 
         
            +
                        self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction(
         
     | 
| 201 | 
         
            +
                            self.data_dir, self.object_name, self.imSize, self.view_types,
         
     | 
| 202 | 
         
            +
                            self.load_color, self.cam_pose_dir, normal_system='front', 
         
     | 
| 203 | 
         
            +
                            camera_type=self.camera_type, cam_params=self.camera_params)
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                    self.has_mask = True
         
     | 
| 206 | 
         
            +
                    self.apply_mask = self.config.apply_mask
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    self.all_c2w = torch.from_numpy(self.pose_all_np)
         
     | 
| 209 | 
         
            +
                    self.all_images = torch.from_numpy(self.images_np) / 255.
         
     | 
| 210 | 
         
            +
                    self.all_fg_masks = torch.from_numpy(self.masks_np)
         
     | 
| 211 | 
         
            +
                    self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np)
         
     | 
| 212 | 
         
            +
                    self.all_normals_world = torch.from_numpy(self.normals_world_np)
         
     | 
| 213 | 
         
            +
                    self.origins = torch.from_numpy(self.origins_np)
         
     | 
| 214 | 
         
            +
                    self.directions = torch.from_numpy(self.directions_np)
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    self.directions = self.directions.float().to(self.rank)
         
     | 
| 217 | 
         
            +
                    self.origins = self.origins.float().to(self.rank)
         
     | 
| 218 | 
         
            +
                    self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank)
         
     | 
| 219 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \
         
     | 
| 220 | 
         
            +
                        self.all_c2w.float().to(self.rank), \
         
     | 
| 221 | 
         
            +
                        self.all_images.float().to(self.rank), \
         
     | 
| 222 | 
         
            +
                        self.all_fg_masks.float().to(self.rank), \
         
     | 
| 223 | 
         
            +
                        self.all_normals_world.float().to(self.rank)
         
     | 
| 224 | 
         
            +
                    
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
            class OrthoDataset(Dataset, OrthoDatasetBase):
         
     | 
| 227 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 228 | 
         
            +
                    self.setup(config, split)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                def __len__(self):
         
     | 
| 231 | 
         
            +
                    return len(self.all_images)
         
     | 
| 232 | 
         
            +
                
         
     | 
| 233 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 234 | 
         
            +
                    return {
         
     | 
| 235 | 
         
            +
                        'index': index
         
     | 
| 236 | 
         
            +
                    }
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
             
     | 
| 239 | 
         
            +
            class OrthoIterableDataset(IterableDataset, OrthoDatasetBase):
         
     | 
| 240 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 241 | 
         
            +
                    self.setup(config, split)
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                def __iter__(self):
         
     | 
| 244 | 
         
            +
                    while True:
         
     | 
| 245 | 
         
            +
                        yield {}
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
            @datasets.register('ortho')
         
     | 
| 249 | 
         
            +
            class OrthoDataModule(pl.LightningDataModule):
         
     | 
| 250 | 
         
            +
                def __init__(self, config):
         
     | 
| 251 | 
         
            +
                    super().__init__()
         
     | 
| 252 | 
         
            +
                    self.config = config
         
     | 
| 253 | 
         
            +
                
         
     | 
| 254 | 
         
            +
                def setup(self, stage=None):
         
     | 
| 255 | 
         
            +
                    if stage in [None, 'fit']:
         
     | 
| 256 | 
         
            +
                        self.train_dataset = OrthoIterableDataset(self.config, 'train')
         
     | 
| 257 | 
         
            +
                    if stage in [None, 'fit', 'validate']:
         
     | 
| 258 | 
         
            +
                        self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train'))
         
     | 
| 259 | 
         
            +
                    if stage in [None, 'test']:
         
     | 
| 260 | 
         
            +
                        self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test'))
         
     | 
| 261 | 
         
            +
                    if stage in [None, 'predict']:
         
     | 
| 262 | 
         
            +
                        self.predict_dataset = OrthoDataset(self.config, 'train')    
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                def prepare_data(self):
         
     | 
| 265 | 
         
            +
                    pass
         
     | 
| 266 | 
         
            +
                
         
     | 
| 267 | 
         
            +
                def general_loader(self, dataset, batch_size):
         
     | 
| 268 | 
         
            +
                    sampler = None
         
     | 
| 269 | 
         
            +
                    return DataLoader(
         
     | 
| 270 | 
         
            +
                        dataset, 
         
     | 
| 271 | 
         
            +
                        num_workers=os.cpu_count(), 
         
     | 
| 272 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 273 | 
         
            +
                        pin_memory=True,
         
     | 
| 274 | 
         
            +
                        sampler=sampler
         
     | 
| 275 | 
         
            +
                    )
         
     | 
| 276 | 
         
            +
                
         
     | 
| 277 | 
         
            +
                def train_dataloader(self):
         
     | 
| 278 | 
         
            +
                    return self.general_loader(self.train_dataset, batch_size=1)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                def val_dataloader(self):
         
     | 
| 281 | 
         
            +
                    return self.general_loader(self.val_dataset, batch_size=1)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                def test_dataloader(self):
         
     | 
| 284 | 
         
            +
                    return self.general_loader(self.test_dataset, batch_size=1) 
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
                def predict_dataloader(self):
         
     | 
| 287 | 
         
            +
                    return self.general_loader(self.predict_dataset, batch_size=1)       
         
     | 
    	
        mesh_recon/datasets/utils.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        mesh_recon/datasets/v3d.py
    ADDED
    
    | 
         @@ -0,0 +1,284 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, IterableDataset
         
     | 
| 9 | 
         
            +
            import torchvision.transforms.functional as TF
         
     | 
| 10 | 
         
            +
            from torchvision.utils import make_grid, save_image
         
     | 
| 11 | 
         
            +
            from einops import rearrange
         
     | 
| 12 | 
         
            +
            from mediapy import read_video
         
     | 
| 13 | 
         
            +
            from pathlib import Path
         
     | 
| 14 | 
         
            +
            from rembg import remove, new_session
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            import datasets
         
     | 
| 19 | 
         
            +
            from models.ray_utils import get_ray_directions
         
     | 
| 20 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 21 | 
         
            +
            from datasets.ortho import (
         
     | 
| 22 | 
         
            +
                inv_RT,
         
     | 
| 23 | 
         
            +
                camNormal2worldNormal,
         
     | 
| 24 | 
         
            +
                RT_opengl2opencv,
         
     | 
| 25 | 
         
            +
                normal_opengl2opencv,
         
     | 
| 26 | 
         
            +
            )
         
     | 
| 27 | 
         
            +
            from utils.dpt import DPT
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def get_c2w_from_up_and_look_at(
         
     | 
| 31 | 
         
            +
                up,
         
     | 
| 32 | 
         
            +
                look_at,
         
     | 
| 33 | 
         
            +
                pos,
         
     | 
| 34 | 
         
            +
                opengl=False,
         
     | 
| 35 | 
         
            +
            ):
         
     | 
| 36 | 
         
            +
                up = up / np.linalg.norm(up)
         
     | 
| 37 | 
         
            +
                z = look_at - pos
         
     | 
| 38 | 
         
            +
                z = z / np.linalg.norm(z)
         
     | 
| 39 | 
         
            +
                y = -up
         
     | 
| 40 | 
         
            +
                x = np.cross(y, z)
         
     | 
| 41 | 
         
            +
                x /= np.linalg.norm(x)
         
     | 
| 42 | 
         
            +
                y = np.cross(z, x)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                c2w = np.zeros([4, 4], dtype=np.float32)
         
     | 
| 45 | 
         
            +
                c2w[:3, 0] = x
         
     | 
| 46 | 
         
            +
                c2w[:3, 1] = y
         
     | 
| 47 | 
         
            +
                c2w[:3, 2] = z
         
     | 
| 48 | 
         
            +
                c2w[:3, 3] = pos
         
     | 
| 49 | 
         
            +
                c2w[3, 3] = 1.0
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                # opencv to opengl
         
     | 
| 52 | 
         
            +
                if opengl:
         
     | 
| 53 | 
         
            +
                    c2w[..., 1:3] *= -1
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                return c2w
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def get_uniform_poses(num_frames, radius, elevation, opengl=False):
         
     | 
| 59 | 
         
            +
                T = num_frames
         
     | 
| 60 | 
         
            +
                azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T])
         
     | 
| 61 | 
         
            +
                elevations = np.full_like(azimuths, np.deg2rad(elevation))
         
     | 
| 62 | 
         
            +
                cam_dists = np.full_like(azimuths, radius)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                campos = np.stack(
         
     | 
| 65 | 
         
            +
                    [
         
     | 
| 66 | 
         
            +
                        cam_dists * np.cos(elevations) * np.cos(azimuths),
         
     | 
| 67 | 
         
            +
                        cam_dists * np.cos(elevations) * np.sin(azimuths),
         
     | 
| 68 | 
         
            +
                        cam_dists * np.sin(elevations),
         
     | 
| 69 | 
         
            +
                    ],
         
     | 
| 70 | 
         
            +
                    axis=-1,
         
     | 
| 71 | 
         
            +
                )
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                center = np.array([0, 0, 0], dtype=np.float32)
         
     | 
| 74 | 
         
            +
                up = np.array([0, 0, 1], dtype=np.float32)
         
     | 
| 75 | 
         
            +
                poses = []
         
     | 
| 76 | 
         
            +
                for t in range(T):
         
     | 
| 77 | 
         
            +
                    poses.append(get_c2w_from_up_and_look_at(up, center, campos[t], opengl=opengl))
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                return np.stack(poses, axis=0)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            def blender2midas(img):
         
     | 
| 83 | 
         
            +
                """Blender: rub
         
     | 
| 84 | 
         
            +
                midas: lub
         
     | 
| 85 | 
         
            +
                """
         
     | 
| 86 | 
         
            +
                img[..., 0] = -img[..., 0]
         
     | 
| 87 | 
         
            +
                img[..., 1] = -img[..., 1]
         
     | 
| 88 | 
         
            +
                img[..., -1] = -img[..., -1]
         
     | 
| 89 | 
         
            +
                return img
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            def midas2blender(img):
         
     | 
| 93 | 
         
            +
                """Blender: rub
         
     | 
| 94 | 
         
            +
                midas: lub
         
     | 
| 95 | 
         
            +
                """
         
     | 
| 96 | 
         
            +
                img[..., 0] = -img[..., 0]
         
     | 
| 97 | 
         
            +
                img[..., 1] = -img[..., 1]
         
     | 
| 98 | 
         
            +
                img[..., -1] = -img[..., -1]
         
     | 
| 99 | 
         
            +
                return img
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
            class BlenderDatasetBase:
         
     | 
| 103 | 
         
            +
                def setup(self, config, split):
         
     | 
| 104 | 
         
            +
                    self.config = config
         
     | 
| 105 | 
         
            +
                    self.rank = get_rank()
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    self.has_mask = True
         
     | 
| 108 | 
         
            +
                    self.apply_mask = True
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    dpt = DPT(device=self.rank, mode="normal")
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    # with open(
         
     | 
| 113 | 
         
            +
                    #     os.path.join(
         
     | 
| 114 | 
         
            +
                    #         self.config.root_dir, self.config.scene, f"transforms_train.json"
         
     | 
| 115 | 
         
            +
                    #     ),
         
     | 
| 116 | 
         
            +
                    #     "r",
         
     | 
| 117 | 
         
            +
                    # ) as f:
         
     | 
| 118 | 
         
            +
                    #     meta = json.load(f)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    # if "w" in meta and "h" in meta:
         
     | 
| 121 | 
         
            +
                    #     W, H = int(meta["w"]), int(meta["h"])
         
     | 
| 122 | 
         
            +
                    # else:
         
     | 
| 123 | 
         
            +
                    #     W, H = 800, 800
         
     | 
| 124 | 
         
            +
                    frames = read_video(Path(self.config.root_dir) / f"{self.config.scene}")
         
     | 
| 125 | 
         
            +
                    rembg_session = new_session()
         
     | 
| 126 | 
         
            +
                    num_frames, H, W = frames.shape[:3]
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                    if "img_wh" in self.config:
         
     | 
| 129 | 
         
            +
                        w, h = self.config.img_wh
         
     | 
| 130 | 
         
            +
                        assert round(W / w * h) == H
         
     | 
| 131 | 
         
            +
                    elif "img_downscale" in self.config:
         
     | 
| 132 | 
         
            +
                        w, h = W // self.config.img_downscale, H // self.config.img_downscale
         
     | 
| 133 | 
         
            +
                    else:
         
     | 
| 134 | 
         
            +
                        raise KeyError("Either img_wh or img_downscale should be specified.")
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    self.w, self.h = w, h
         
     | 
| 137 | 
         
            +
                    self.img_wh = (self.w, self.h)
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    # self.near, self.far = self.config.near_plane, self.config.far_plane
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    self.focal = 0.5 * w / math.tan(0.5 * np.deg2rad(60))  # scaled focal length
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    # ray directions for all pixels, same for all images (same H, W, focal)
         
     | 
| 144 | 
         
            +
                    self.directions = get_ray_directions(
         
     | 
| 145 | 
         
            +
                        self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
         
     | 
| 146 | 
         
            +
                    ).to(
         
     | 
| 147 | 
         
            +
                        self.rank
         
     | 
| 148 | 
         
            +
                    )  # (h, w, 3)
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    radius = 2.0
         
     | 
| 153 | 
         
            +
                    elevation = 0.0
         
     | 
| 154 | 
         
            +
                    poses = get_uniform_poses(num_frames, radius, elevation, opengl=True)
         
     | 
| 155 | 
         
            +
                    for i, (c2w, frame) in enumerate(zip(poses, frames)):
         
     | 
| 156 | 
         
            +
                        c2w = torch.from_numpy(np.array(c2w)[:3, :4])
         
     | 
| 157 | 
         
            +
                        self.all_c2w.append(c2w)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                        img = Image.fromarray(frame)
         
     | 
| 160 | 
         
            +
                        img = remove(img, session=rembg_session)
         
     | 
| 161 | 
         
            +
                        img = img.resize(self.img_wh, Image.BICUBIC)
         
     | 
| 162 | 
         
            +
                        img = TF.to_tensor(img).permute(1, 2, 0)  # (4, h, w) => (h, w, 4)
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                        self.all_fg_masks.append(img[..., -1])  # (h, w)
         
     | 
| 165 | 
         
            +
                        self.all_images.append(img[..., :3])
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = (
         
     | 
| 168 | 
         
            +
                        torch.stack(self.all_c2w, dim=0).float().to(self.rank),
         
     | 
| 169 | 
         
            +
                        torch.stack(self.all_images, dim=0).float().to(self.rank),
         
     | 
| 170 | 
         
            +
                        torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
         
     | 
| 171 | 
         
            +
                    )
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    self.normals = dpt(self.all_images)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    self.normals = self.normals * 2.0 - 1.0
         
     | 
| 178 | 
         
            +
                    self.normals = midas2blender(self.normals).cpu().numpy()
         
     | 
| 179 | 
         
            +
                    # self.normals = self.normals.cpu().numpy()
         
     | 
| 180 | 
         
            +
                    self.normals[..., 0] *= -1
         
     | 
| 181 | 
         
            +
                    self.normals[~self.all_masks] = [0, 0, 0]
         
     | 
| 182 | 
         
            +
                    normals = rearrange(self.normals, "b h w c -> b c h w")
         
     | 
| 183 | 
         
            +
                    normals = normals * 0.5 + 0.5
         
     | 
| 184 | 
         
            +
                    normals = torch.from_numpy(normals)
         
     | 
| 185 | 
         
            +
                    # save_image(make_grid(normals, nrow=4), "tmp/normals.png")
         
     | 
| 186 | 
         
            +
                    # exit(0)
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    (
         
     | 
| 189 | 
         
            +
                        self.all_poses,
         
     | 
| 190 | 
         
            +
                        self.all_normals,
         
     | 
| 191 | 
         
            +
                        self.all_normals_world,
         
     | 
| 192 | 
         
            +
                        self.all_w2cs,
         
     | 
| 193 | 
         
            +
                        self.all_color_masks,
         
     | 
| 194 | 
         
            +
                    ) = ([], [], [], [], [])
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
         
     | 
| 197 | 
         
            +
                        RT_opengl = inv_RT(c2w_opengl)
         
     | 
| 198 | 
         
            +
                        RT_opencv = RT_opengl2opencv(RT_opengl)
         
     | 
| 199 | 
         
            +
                        c2w_opencv = inv_RT(RT_opencv)
         
     | 
| 200 | 
         
            +
                        self.all_poses.append(c2w_opencv)
         
     | 
| 201 | 
         
            +
                        self.all_w2cs.append(RT_opencv)
         
     | 
| 202 | 
         
            +
                        normal = normal_opengl2opencv(normal)
         
     | 
| 203 | 
         
            +
                        normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
         
     | 
| 204 | 
         
            +
                        self.all_normals.append(normal)
         
     | 
| 205 | 
         
            +
                        self.all_normals_world.append(normal_world)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    self.directions = torch.stack([self.directions] * len(self.all_images))
         
     | 
| 208 | 
         
            +
                    self.origins = self.directions
         
     | 
| 209 | 
         
            +
                    self.all_poses = np.stack(self.all_poses)
         
     | 
| 210 | 
         
            +
                    self.all_normals = np.stack(self.all_normals)
         
     | 
| 211 | 
         
            +
                    self.all_normals_world = np.stack(self.all_normals_world)
         
     | 
| 212 | 
         
            +
                    self.all_w2cs = np.stack(self.all_w2cs)
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
         
     | 
| 215 | 
         
            +
                    self.all_images = self.all_images.to(self.rank)
         
     | 
| 216 | 
         
            +
                    self.all_fg_masks = self.all_fg_masks.to(self.rank)
         
     | 
| 217 | 
         
            +
                    self.all_rgb_masks = self.all_fg_masks.to(self.rank)
         
     | 
| 218 | 
         
            +
                    self.all_normals_world = (
         
     | 
| 219 | 
         
            +
                        torch.from_numpy(self.all_normals_world).float().to(self.rank)
         
     | 
| 220 | 
         
            +
                    )
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            class BlenderDataset(Dataset, BlenderDatasetBase):
         
     | 
| 224 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 225 | 
         
            +
                    self.setup(config, split)
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                def __len__(self):
         
     | 
| 228 | 
         
            +
                    return len(self.all_images)
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 231 | 
         
            +
                    return {"index": index}
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
            class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
         
     | 
| 235 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 236 | 
         
            +
                    self.setup(config, split)
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                def __iter__(self):
         
     | 
| 239 | 
         
            +
                    while True:
         
     | 
| 240 | 
         
            +
                        yield {}
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
            @datasets.register("v3d")
         
     | 
| 244 | 
         
            +
            class BlenderDataModule(pl.LightningDataModule):
         
     | 
| 245 | 
         
            +
                def __init__(self, config):
         
     | 
| 246 | 
         
            +
                    super().__init__()
         
     | 
| 247 | 
         
            +
                    self.config = config
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                def setup(self, stage=None):
         
     | 
| 250 | 
         
            +
                    if stage in [None, "fit"]:
         
     | 
| 251 | 
         
            +
                        self.train_dataset = BlenderIterableDataset(
         
     | 
| 252 | 
         
            +
                            self.config, self.config.train_split
         
     | 
| 253 | 
         
            +
                        )
         
     | 
| 254 | 
         
            +
                    if stage in [None, "fit", "validate"]:
         
     | 
| 255 | 
         
            +
                        self.val_dataset = BlenderDataset(self.config, self.config.val_split)
         
     | 
| 256 | 
         
            +
                    if stage in [None, "test"]:
         
     | 
| 257 | 
         
            +
                        self.test_dataset = BlenderDataset(self.config, self.config.test_split)
         
     | 
| 258 | 
         
            +
                    if stage in [None, "predict"]:
         
     | 
| 259 | 
         
            +
                        self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                def prepare_data(self):
         
     | 
| 262 | 
         
            +
                    pass
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                def general_loader(self, dataset, batch_size):
         
     | 
| 265 | 
         
            +
                    sampler = None
         
     | 
| 266 | 
         
            +
                    return DataLoader(
         
     | 
| 267 | 
         
            +
                        dataset,
         
     | 
| 268 | 
         
            +
                        num_workers=os.cpu_count(),
         
     | 
| 269 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 270 | 
         
            +
                        pin_memory=True,
         
     | 
| 271 | 
         
            +
                        sampler=sampler,
         
     | 
| 272 | 
         
            +
                    )
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                def train_dataloader(self):
         
     | 
| 275 | 
         
            +
                    return self.general_loader(self.train_dataset, batch_size=1)
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def val_dataloader(self):
         
     | 
| 278 | 
         
            +
                    return self.general_loader(self.val_dataset, batch_size=1)
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                def test_dataloader(self):
         
     | 
| 281 | 
         
            +
                    return self.general_loader(self.test_dataset, batch_size=1)
         
     | 
| 282 | 
         
            +
             
     | 
| 283 | 
         
            +
                def predict_dataloader(self):
         
     | 
| 284 | 
         
            +
                    return self.general_loader(self.predict_dataset, batch_size=1)
         
     | 
    	
        mesh_recon/datasets/videonvs.py
    ADDED
    
    | 
         @@ -0,0 +1,256 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, IterableDataset
         
     | 
| 9 | 
         
            +
            import torchvision.transforms.functional as TF
         
     | 
| 10 | 
         
            +
            from torchvision.utils import make_grid, save_image
         
     | 
| 11 | 
         
            +
            from einops import rearrange
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            import datasets
         
     | 
| 16 | 
         
            +
            from models.ray_utils import get_ray_directions
         
     | 
| 17 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 18 | 
         
            +
            from datasets.ortho import (
         
     | 
| 19 | 
         
            +
                inv_RT,
         
     | 
| 20 | 
         
            +
                camNormal2worldNormal,
         
     | 
| 21 | 
         
            +
                RT_opengl2opencv,
         
     | 
| 22 | 
         
            +
                normal_opengl2opencv,
         
     | 
| 23 | 
         
            +
            )
         
     | 
| 24 | 
         
            +
            from utils.dpt import DPT
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            def blender2midas(img):
         
     | 
| 28 | 
         
            +
                """Blender: rub
         
     | 
| 29 | 
         
            +
                midas: lub
         
     | 
| 30 | 
         
            +
                """
         
     | 
| 31 | 
         
            +
                img[..., 0] = -img[..., 0]
         
     | 
| 32 | 
         
            +
                img[..., 1] = -img[..., 1]
         
     | 
| 33 | 
         
            +
                img[..., -1] = -img[..., -1]
         
     | 
| 34 | 
         
            +
                return img
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def midas2blender(img):
         
     | 
| 38 | 
         
            +
                """Blender: rub
         
     | 
| 39 | 
         
            +
                midas: lub
         
     | 
| 40 | 
         
            +
                """
         
     | 
| 41 | 
         
            +
                img[..., 0] = -img[..., 0]
         
     | 
| 42 | 
         
            +
                img[..., 1] = -img[..., 1]
         
     | 
| 43 | 
         
            +
                img[..., -1] = -img[..., -1]
         
     | 
| 44 | 
         
            +
                return img
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            class BlenderDatasetBase:
         
     | 
| 48 | 
         
            +
                def setup(self, config, split):
         
     | 
| 49 | 
         
            +
                    self.config = config
         
     | 
| 50 | 
         
            +
                    self.rank = get_rank()
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                    self.has_mask = True
         
     | 
| 53 | 
         
            +
                    self.apply_mask = True
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                    dpt = DPT(device=self.rank, mode="normal")
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    with open(
         
     | 
| 58 | 
         
            +
                        os.path.join(
         
     | 
| 59 | 
         
            +
                            self.config.root_dir, self.config.scene, f"transforms_train.json"
         
     | 
| 60 | 
         
            +
                        ),
         
     | 
| 61 | 
         
            +
                        "r",
         
     | 
| 62 | 
         
            +
                    ) as f:
         
     | 
| 63 | 
         
            +
                        meta = json.load(f)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    if "w" in meta and "h" in meta:
         
     | 
| 66 | 
         
            +
                        W, H = int(meta["w"]), int(meta["h"])
         
     | 
| 67 | 
         
            +
                    else:
         
     | 
| 68 | 
         
            +
                        W, H = 800, 800
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    if "img_wh" in self.config:
         
     | 
| 71 | 
         
            +
                        w, h = self.config.img_wh
         
     | 
| 72 | 
         
            +
                        assert round(W / w * h) == H
         
     | 
| 73 | 
         
            +
                    elif "img_downscale" in self.config:
         
     | 
| 74 | 
         
            +
                        w, h = W // self.config.img_downscale, H // self.config.img_downscale
         
     | 
| 75 | 
         
            +
                    else:
         
     | 
| 76 | 
         
            +
                        raise KeyError("Either img_wh or img_downscale should be specified.")
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    self.w, self.h = w, h
         
     | 
| 79 | 
         
            +
                    self.img_wh = (self.w, self.h)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # self.near, self.far = self.config.near_plane, self.config.far_plane
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    self.focal = (
         
     | 
| 84 | 
         
            +
                        0.5 * w / math.tan(0.5 * meta["camera_angle_x"])
         
     | 
| 85 | 
         
            +
                    )  # scaled focal length
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    # ray directions for all pixels, same for all images (same H, W, focal)
         
     | 
| 88 | 
         
            +
                    self.directions = get_ray_directions(
         
     | 
| 89 | 
         
            +
                        self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2
         
     | 
| 90 | 
         
            +
                    ).to(
         
     | 
| 91 | 
         
            +
                        self.rank
         
     | 
| 92 | 
         
            +
                    )  # (h, w, 3)
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
         
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
                    for i, frame in enumerate(meta["frames"]):
         
     | 
| 97 | 
         
            +
                        c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
         
     | 
| 98 | 
         
            +
                        self.all_c2w.append(c2w)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                        img_path = os.path.join(
         
     | 
| 101 | 
         
            +
                            self.config.root_dir,
         
     | 
| 102 | 
         
            +
                            self.config.scene,
         
     | 
| 103 | 
         
            +
                            f"{frame['file_path']}.png",
         
     | 
| 104 | 
         
            +
                        )
         
     | 
| 105 | 
         
            +
                        img = Image.open(img_path)
         
     | 
| 106 | 
         
            +
                        img = img.resize(self.img_wh, Image.BICUBIC)
         
     | 
| 107 | 
         
            +
                        img = TF.to_tensor(img).permute(1, 2, 0)  # (4, h, w) => (h, w, 4)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                        self.all_fg_masks.append(img[..., -1])  # (h, w)
         
     | 
| 110 | 
         
            +
                        self.all_images.append(img[..., :3])
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = (
         
     | 
| 113 | 
         
            +
                        torch.stack(self.all_c2w, dim=0).float().to(self.rank),
         
     | 
| 114 | 
         
            +
                        torch.stack(self.all_images, dim=0).float().to(self.rank),
         
     | 
| 115 | 
         
            +
                        torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
         
     | 
| 116 | 
         
            +
                    )
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    self.normals = dpt(self.all_images)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                    self.normals = self.normals * 2.0 - 1.0
         
     | 
| 123 | 
         
            +
                    self.normals = midas2blender(self.normals).cpu().numpy()
         
     | 
| 124 | 
         
            +
                    # self.normals = self.normals.cpu().numpy()
         
     | 
| 125 | 
         
            +
                    self.normals[..., 0] *= -1
         
     | 
| 126 | 
         
            +
                    self.normals[~self.all_masks] = [0, 0, 0]
         
     | 
| 127 | 
         
            +
                    normals = rearrange(self.normals, "b h w c -> b c h w")
         
     | 
| 128 | 
         
            +
                    normals = normals * 0.5 + 0.5
         
     | 
| 129 | 
         
            +
                    normals = torch.from_numpy(normals)
         
     | 
| 130 | 
         
            +
                    save_image(make_grid(normals, nrow=4), "tmp/normals.png")
         
     | 
| 131 | 
         
            +
                    # exit(0)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    (
         
     | 
| 134 | 
         
            +
                        self.all_poses,
         
     | 
| 135 | 
         
            +
                        self.all_normals,
         
     | 
| 136 | 
         
            +
                        self.all_normals_world,
         
     | 
| 137 | 
         
            +
                        self.all_w2cs,
         
     | 
| 138 | 
         
            +
                        self.all_color_masks,
         
     | 
| 139 | 
         
            +
                    ) = ([], [], [], [], [])
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
         
     | 
| 142 | 
         
            +
                        RT_opengl = inv_RT(c2w_opengl)
         
     | 
| 143 | 
         
            +
                        RT_opencv = RT_opengl2opencv(RT_opengl)
         
     | 
| 144 | 
         
            +
                        c2w_opencv = inv_RT(RT_opencv)
         
     | 
| 145 | 
         
            +
                        self.all_poses.append(c2w_opencv)
         
     | 
| 146 | 
         
            +
                        self.all_w2cs.append(RT_opencv)
         
     | 
| 147 | 
         
            +
                        normal = normal_opengl2opencv(normal)
         
     | 
| 148 | 
         
            +
                        normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
         
     | 
| 149 | 
         
            +
                        self.all_normals.append(normal)
         
     | 
| 150 | 
         
            +
                        self.all_normals_world.append(normal_world)
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                    self.directions = torch.stack([self.directions] * len(self.all_images))
         
     | 
| 153 | 
         
            +
                    self.origins = self.directions
         
     | 
| 154 | 
         
            +
                    self.all_poses = np.stack(self.all_poses)
         
     | 
| 155 | 
         
            +
                    self.all_normals = np.stack(self.all_normals)
         
     | 
| 156 | 
         
            +
                    self.all_normals_world = np.stack(self.all_normals_world)
         
     | 
| 157 | 
         
            +
                    self.all_w2cs = np.stack(self.all_w2cs)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
         
     | 
| 160 | 
         
            +
                    self.all_images = self.all_images.to(self.rank)
         
     | 
| 161 | 
         
            +
                    self.all_fg_masks = self.all_fg_masks.to(self.rank)
         
     | 
| 162 | 
         
            +
                    self.all_rgb_masks = self.all_fg_masks.to(self.rank)
         
     | 
| 163 | 
         
            +
                    self.all_normals_world = (
         
     | 
| 164 | 
         
            +
                        torch.from_numpy(self.all_normals_world).float().to(self.rank)
         
     | 
| 165 | 
         
            +
                    )
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    # normals = rearrange(self.all_normals_world, "b h w c -> b c h w")
         
     | 
| 168 | 
         
            +
                    # normals = normals * 0.5 + 0.5
         
     | 
| 169 | 
         
            +
                    # # normals = torch.from_numpy(normals)
         
     | 
| 170 | 
         
            +
                    # save_image(make_grid(normals, nrow=4), "tmp/normals_world.png")
         
     | 
| 171 | 
         
            +
                    # # exit(0)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    # # normals = (normals + 1) / 2.0
         
     | 
| 174 | 
         
            +
                    # # for debug
         
     | 
| 175 | 
         
            +
                    # index = [0, 9]
         
     | 
| 176 | 
         
            +
                    # self.all_poses = self.all_poses[index]
         
     | 
| 177 | 
         
            +
                    # self.all_c2w = self.all_c2w[index]
         
     | 
| 178 | 
         
            +
                    # self.all_normals_world = self.all_normals_world[index]
         
     | 
| 179 | 
         
            +
                    # self.all_w2cs = self.all_w2cs[index]
         
     | 
| 180 | 
         
            +
                    # self.rgb_masks = self.all_rgb_masks[index]
         
     | 
| 181 | 
         
            +
                    # self.fg_masks = self.all_fg_masks[index]
         
     | 
| 182 | 
         
            +
                    # self.all_images = self.all_images[index]
         
     | 
| 183 | 
         
            +
                    # self.directions = self.directions[index]
         
     | 
| 184 | 
         
            +
                    # self.origins = self.origins[index]
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    # images = rearrange(self.all_images, "b h w c -> b c h w")
         
     | 
| 187 | 
         
            +
                    # normals = rearrange(normals, "b h w c -> b c h w")
         
     | 
| 188 | 
         
            +
                    # save_image(make_grid(images, nrow=4), "tmp/images.png")
         
     | 
| 189 | 
         
            +
                    # save_image(make_grid(normals, nrow=4), "tmp/normals.png")
         
     | 
| 190 | 
         
            +
                    # breakpoint()
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    # self.normals = self.normals * 2.0 - 1.0
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
            class BlenderDataset(Dataset, BlenderDatasetBase):
         
     | 
| 196 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 197 | 
         
            +
                    self.setup(config, split)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                def __len__(self):
         
     | 
| 200 | 
         
            +
                    return len(self.all_images)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 203 | 
         
            +
                    return {"index": index}
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
            class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
         
     | 
| 207 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 208 | 
         
            +
                    self.setup(config, split)
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                def __iter__(self):
         
     | 
| 211 | 
         
            +
                    while True:
         
     | 
| 212 | 
         
            +
                        yield {}
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
            @datasets.register("videonvs")
         
     | 
| 216 | 
         
            +
            class BlenderDataModule(pl.LightningDataModule):
         
     | 
| 217 | 
         
            +
                def __init__(self, config):
         
     | 
| 218 | 
         
            +
                    super().__init__()
         
     | 
| 219 | 
         
            +
                    self.config = config
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                def setup(self, stage=None):
         
     | 
| 222 | 
         
            +
                    if stage in [None, "fit"]:
         
     | 
| 223 | 
         
            +
                        self.train_dataset = BlenderIterableDataset(
         
     | 
| 224 | 
         
            +
                            self.config, self.config.train_split
         
     | 
| 225 | 
         
            +
                        )
         
     | 
| 226 | 
         
            +
                    if stage in [None, "fit", "validate"]:
         
     | 
| 227 | 
         
            +
                        self.val_dataset = BlenderDataset(self.config, self.config.val_split)
         
     | 
| 228 | 
         
            +
                    if stage in [None, "test"]:
         
     | 
| 229 | 
         
            +
                        self.test_dataset = BlenderDataset(self.config, self.config.test_split)
         
     | 
| 230 | 
         
            +
                    if stage in [None, "predict"]:
         
     | 
| 231 | 
         
            +
                        self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                def prepare_data(self):
         
     | 
| 234 | 
         
            +
                    pass
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                def general_loader(self, dataset, batch_size):
         
     | 
| 237 | 
         
            +
                    sampler = None
         
     | 
| 238 | 
         
            +
                    return DataLoader(
         
     | 
| 239 | 
         
            +
                        dataset,
         
     | 
| 240 | 
         
            +
                        num_workers=os.cpu_count(),
         
     | 
| 241 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 242 | 
         
            +
                        pin_memory=True,
         
     | 
| 243 | 
         
            +
                        sampler=sampler,
         
     | 
| 244 | 
         
            +
                    )
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                def train_dataloader(self):
         
     | 
| 247 | 
         
            +
                    return self.general_loader(self.train_dataset, batch_size=1)
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                def val_dataloader(self):
         
     | 
| 250 | 
         
            +
                    return self.general_loader(self.val_dataset, batch_size=1)
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                def test_dataloader(self):
         
     | 
| 253 | 
         
            +
                    return self.general_loader(self.test_dataset, batch_size=1)
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                def predict_dataloader(self):
         
     | 
| 256 | 
         
            +
                    return self.general_loader(self.predict_dataset, batch_size=1)
         
     | 
    	
        mesh_recon/datasets/videonvs_co3d.py
    ADDED
    
    | 
         @@ -0,0 +1,252 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import json
         
     | 
| 3 | 
         
            +
            import math
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from PIL import Image
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            from torch.utils.data import Dataset, DataLoader, IterableDataset
         
     | 
| 9 | 
         
            +
            import torchvision.transforms.functional as TF
         
     | 
| 10 | 
         
            +
            from torchvision.utils import make_grid, save_image
         
     | 
| 11 | 
         
            +
            from einops import rearrange
         
     | 
| 12 | 
         
            +
            from rembg import remove, new_session
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import pytorch_lightning as pl
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            import datasets
         
     | 
| 17 | 
         
            +
            from models.ray_utils import get_ray_directions
         
     | 
| 18 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 19 | 
         
            +
            from datasets.ortho import (
         
     | 
| 20 | 
         
            +
                inv_RT,
         
     | 
| 21 | 
         
            +
                camNormal2worldNormal,
         
     | 
| 22 | 
         
            +
                RT_opengl2opencv,
         
     | 
| 23 | 
         
            +
                normal_opengl2opencv,
         
     | 
| 24 | 
         
            +
            )
         
     | 
| 25 | 
         
            +
            from utils.dpt import DPT
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def blender2midas(img):
         
     | 
| 29 | 
         
            +
                """Blender: rub
         
     | 
| 30 | 
         
            +
                midas: lub
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
                img[..., 0] = -img[..., 0]
         
     | 
| 33 | 
         
            +
                img[..., 1] = -img[..., 1]
         
     | 
| 34 | 
         
            +
                img[..., -1] = -img[..., -1]
         
     | 
| 35 | 
         
            +
                return img
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            def midas2blender(img):
         
     | 
| 39 | 
         
            +
                """Blender: rub
         
     | 
| 40 | 
         
            +
                midas: lub
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
                img[..., 0] = -img[..., 0]
         
     | 
| 43 | 
         
            +
                img[..., 1] = -img[..., 1]
         
     | 
| 44 | 
         
            +
                img[..., -1] = -img[..., -1]
         
     | 
| 45 | 
         
            +
                return img
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            class BlenderDatasetBase:
         
     | 
| 49 | 
         
            +
                def setup(self, config, split):
         
     | 
| 50 | 
         
            +
                    self.config = config
         
     | 
| 51 | 
         
            +
                    self.rank = get_rank()
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
                    self.has_mask = True
         
     | 
| 54 | 
         
            +
                    self.apply_mask = True
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    dpt = DPT(device=self.rank, mode="normal")
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    self.directions = []
         
     | 
| 59 | 
         
            +
                    with open(
         
     | 
| 60 | 
         
            +
                        os.path.join(self.config.root_dir, self.config.scene, f"transforms.json"),
         
     | 
| 61 | 
         
            +
                        "r",
         
     | 
| 62 | 
         
            +
                    ) as f:
         
     | 
| 63 | 
         
            +
                        meta = json.load(f)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    if "w" in meta and "h" in meta:
         
     | 
| 66 | 
         
            +
                        W, H = int(meta["w"]), int(meta["h"])
         
     | 
| 67 | 
         
            +
                    else:
         
     | 
| 68 | 
         
            +
                        W, H = 800, 800
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                    if "img_wh" in self.config:
         
     | 
| 71 | 
         
            +
                        w, h = self.config.img_wh
         
     | 
| 72 | 
         
            +
                        assert round(W / w * h) == H
         
     | 
| 73 | 
         
            +
                    elif "img_downscale" in self.config:
         
     | 
| 74 | 
         
            +
                        w, h = W // self.config.img_downscale, H // self.config.img_downscale
         
     | 
| 75 | 
         
            +
                    else:
         
     | 
| 76 | 
         
            +
                        raise KeyError("Either img_wh or img_downscale should be specified.")
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    self.w, self.h = w, h
         
     | 
| 79 | 
         
            +
                    self.img_wh = (self.w, self.h)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    # self.near, self.far = self.config.near_plane, self.config.far_plane
         
     | 
| 82 | 
         
            +
                    _session = new_session()
         
     | 
| 83 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = [], [], []
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    for i, frame in enumerate(meta["frames"]):
         
     | 
| 86 | 
         
            +
                        c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4])
         
     | 
| 87 | 
         
            +
                        self.all_c2w.append(c2w)
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                        img_path = os.path.join(
         
     | 
| 90 | 
         
            +
                            self.config.root_dir,
         
     | 
| 91 | 
         
            +
                            self.config.scene,
         
     | 
| 92 | 
         
            +
                            f"{frame['file_path']}",
         
     | 
| 93 | 
         
            +
                        )
         
     | 
| 94 | 
         
            +
                        img = Image.open(img_path)
         
     | 
| 95 | 
         
            +
                        img = remove(img, session=_session)
         
     | 
| 96 | 
         
            +
                        img = img.resize(self.img_wh, Image.BICUBIC)
         
     | 
| 97 | 
         
            +
                        img = TF.to_tensor(img).permute(1, 2, 0)  # (4, h, w) => (h, w, 4)
         
     | 
| 98 | 
         
            +
                        fx = frame["fl_x"]
         
     | 
| 99 | 
         
            +
                        fy = frame["fl_y"]
         
     | 
| 100 | 
         
            +
                        cx = frame["cx"]
         
     | 
| 101 | 
         
            +
                        cy = frame["cy"]
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                        self.all_fg_masks.append(img[..., -1])  # (h, w)
         
     | 
| 104 | 
         
            +
                        self.all_images.append(img[..., :3])
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                        self.directions.append(get_ray_directions(self.w, self.h, fx, fy, cx, cy))
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                    self.all_c2w, self.all_images, self.all_fg_masks = (
         
     | 
| 109 | 
         
            +
                        torch.stack(self.all_c2w, dim=0).float().to(self.rank),
         
     | 
| 110 | 
         
            +
                        torch.stack(self.all_images, dim=0).float().to(self.rank),
         
     | 
| 111 | 
         
            +
                        torch.stack(self.all_fg_masks, dim=0).float().to(self.rank),
         
     | 
| 112 | 
         
            +
                    )
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                    self.normals = dpt(self.all_images)
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    self.normals = self.normals * 2.0 - 1.0
         
     | 
| 119 | 
         
            +
                    self.normals = midas2blender(self.normals).cpu().numpy()
         
     | 
| 120 | 
         
            +
                    # self.normals = self.normals.cpu().numpy()
         
     | 
| 121 | 
         
            +
                    self.normals[..., 0] *= -1
         
     | 
| 122 | 
         
            +
                    self.normals[~self.all_masks] = [0, 0, 0]
         
     | 
| 123 | 
         
            +
                    normals = rearrange(self.normals, "b h w c -> b c h w")
         
     | 
| 124 | 
         
            +
                    normals = normals * 0.5 + 0.5
         
     | 
| 125 | 
         
            +
                    normals = torch.from_numpy(normals)
         
     | 
| 126 | 
         
            +
                    save_image(make_grid(normals, nrow=4), "tmp/normals.png")
         
     | 
| 127 | 
         
            +
                    # exit(0)
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    (
         
     | 
| 130 | 
         
            +
                        self.all_poses,
         
     | 
| 131 | 
         
            +
                        self.all_normals,
         
     | 
| 132 | 
         
            +
                        self.all_normals_world,
         
     | 
| 133 | 
         
            +
                        self.all_w2cs,
         
     | 
| 134 | 
         
            +
                        self.all_color_masks,
         
     | 
| 135 | 
         
            +
                    ) = ([], [], [], [], [])
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals):
         
     | 
| 138 | 
         
            +
                        RT_opengl = inv_RT(c2w_opengl)
         
     | 
| 139 | 
         
            +
                        RT_opencv = RT_opengl2opencv(RT_opengl)
         
     | 
| 140 | 
         
            +
                        c2w_opencv = inv_RT(RT_opencv)
         
     | 
| 141 | 
         
            +
                        self.all_poses.append(c2w_opencv)
         
     | 
| 142 | 
         
            +
                        self.all_w2cs.append(RT_opencv)
         
     | 
| 143 | 
         
            +
                        normal = normal_opengl2opencv(normal)
         
     | 
| 144 | 
         
            +
                        normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal)
         
     | 
| 145 | 
         
            +
                        self.all_normals.append(normal)
         
     | 
| 146 | 
         
            +
                        self.all_normals_world.append(normal_world)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    self.directions = torch.stack(self.directions).to(self.rank)
         
     | 
| 149 | 
         
            +
                    self.origins = self.directions
         
     | 
| 150 | 
         
            +
                    self.all_poses = np.stack(self.all_poses)
         
     | 
| 151 | 
         
            +
                    self.all_normals = np.stack(self.all_normals)
         
     | 
| 152 | 
         
            +
                    self.all_normals_world = np.stack(self.all_normals_world)
         
     | 
| 153 | 
         
            +
                    self.all_w2cs = np.stack(self.all_w2cs)
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank)
         
     | 
| 156 | 
         
            +
                    self.all_images = self.all_images.to(self.rank)
         
     | 
| 157 | 
         
            +
                    self.all_fg_masks = self.all_fg_masks.to(self.rank)
         
     | 
| 158 | 
         
            +
                    self.all_rgb_masks = self.all_fg_masks.to(self.rank)
         
     | 
| 159 | 
         
            +
                    self.all_normals_world = (
         
     | 
| 160 | 
         
            +
                        torch.from_numpy(self.all_normals_world).float().to(self.rank)
         
     | 
| 161 | 
         
            +
                    )
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    # normals = rearrange(self.all_normals_world, "b h w c -> b c h w")
         
     | 
| 164 | 
         
            +
                    # normals = normals * 0.5 + 0.5
         
     | 
| 165 | 
         
            +
                    # # normals = torch.from_numpy(normals)
         
     | 
| 166 | 
         
            +
                    # save_image(make_grid(normals, nrow=4), "tmp/normals_world.png")
         
     | 
| 167 | 
         
            +
                    # # exit(0)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    # # normals = (normals + 1) / 2.0
         
     | 
| 170 | 
         
            +
                    # # for debug
         
     | 
| 171 | 
         
            +
                    # index = [0, 9]
         
     | 
| 172 | 
         
            +
                    # self.all_poses = self.all_poses[index]
         
     | 
| 173 | 
         
            +
                    # self.all_c2w = self.all_c2w[index]
         
     | 
| 174 | 
         
            +
                    # self.all_normals_world = self.all_normals_world[index]
         
     | 
| 175 | 
         
            +
                    # self.all_w2cs = self.all_w2cs[index]
         
     | 
| 176 | 
         
            +
                    # self.rgb_masks = self.all_rgb_masks[index]
         
     | 
| 177 | 
         
            +
                    # self.fg_masks = self.all_fg_masks[index]
         
     | 
| 178 | 
         
            +
                    # self.all_images = self.all_images[index]
         
     | 
| 179 | 
         
            +
                    # self.directions = self.directions[index]
         
     | 
| 180 | 
         
            +
                    # self.origins = self.origins[index]
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    # images = rearrange(self.all_images, "b h w c -> b c h w")
         
     | 
| 183 | 
         
            +
                    # normals = rearrange(normals, "b h w c -> b c h w")
         
     | 
| 184 | 
         
            +
                    # save_image(make_grid(images, nrow=4), "tmp/images.png")
         
     | 
| 185 | 
         
            +
                    # save_image(make_grid(normals, nrow=4), "tmp/normals.png")
         
     | 
| 186 | 
         
            +
                    # breakpoint()
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                    # self.normals = self.normals * 2.0 - 1.0
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
            class BlenderDataset(Dataset, BlenderDatasetBase):
         
     | 
| 192 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 193 | 
         
            +
                    self.setup(config, split)
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                def __len__(self):
         
     | 
| 196 | 
         
            +
                    return len(self.all_images)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                def __getitem__(self, index):
         
     | 
| 199 | 
         
            +
                    return {"index": index}
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
            class BlenderIterableDataset(IterableDataset, BlenderDatasetBase):
         
     | 
| 203 | 
         
            +
                def __init__(self, config, split):
         
     | 
| 204 | 
         
            +
                    self.setup(config, split)
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                def __iter__(self):
         
     | 
| 207 | 
         
            +
                    while True:
         
     | 
| 208 | 
         
            +
                        yield {}
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
            @datasets.register("videonvs-scene")
         
     | 
| 212 | 
         
            +
            class VideoNVSScene(pl.LightningDataModule):
         
     | 
| 213 | 
         
            +
                def __init__(self, config):
         
     | 
| 214 | 
         
            +
                    super().__init__()
         
     | 
| 215 | 
         
            +
                    self.config = config
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                def setup(self, stage=None):
         
     | 
| 218 | 
         
            +
                    if stage in [None, "fit"]:
         
     | 
| 219 | 
         
            +
                        self.train_dataset = BlenderIterableDataset(
         
     | 
| 220 | 
         
            +
                            self.config, self.config.train_split
         
     | 
| 221 | 
         
            +
                        )
         
     | 
| 222 | 
         
            +
                    if stage in [None, "fit", "validate"]:
         
     | 
| 223 | 
         
            +
                        self.val_dataset = BlenderDataset(self.config, self.config.val_split)
         
     | 
| 224 | 
         
            +
                    if stage in [None, "test"]:
         
     | 
| 225 | 
         
            +
                        self.test_dataset = BlenderDataset(self.config, self.config.test_split)
         
     | 
| 226 | 
         
            +
                    if stage in [None, "predict"]:
         
     | 
| 227 | 
         
            +
                        self.predict_dataset = BlenderDataset(self.config, self.config.train_split)
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                def prepare_data(self):
         
     | 
| 230 | 
         
            +
                    pass
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                def general_loader(self, dataset, batch_size):
         
     | 
| 233 | 
         
            +
                    sampler = None
         
     | 
| 234 | 
         
            +
                    return DataLoader(
         
     | 
| 235 | 
         
            +
                        dataset,
         
     | 
| 236 | 
         
            +
                        num_workers=os.cpu_count(),
         
     | 
| 237 | 
         
            +
                        batch_size=batch_size,
         
     | 
| 238 | 
         
            +
                        pin_memory=True,
         
     | 
| 239 | 
         
            +
                        sampler=sampler,
         
     | 
| 240 | 
         
            +
                    )
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                def train_dataloader(self):
         
     | 
| 243 | 
         
            +
                    return self.general_loader(self.train_dataset, batch_size=1)
         
     | 
| 244 | 
         
            +
             
     | 
| 245 | 
         
            +
                def val_dataloader(self):
         
     | 
| 246 | 
         
            +
                    return self.general_loader(self.val_dataset, batch_size=1)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                def test_dataloader(self):
         
     | 
| 249 | 
         
            +
                    return self.general_loader(self.test_dataset, batch_size=1)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                def predict_dataloader(self):
         
     | 
| 252 | 
         
            +
                    return self.general_loader(self.predict_dataset, batch_size=1)
         
     | 
    	
        mesh_recon/launch.py
    ADDED
    
    | 
         @@ -0,0 +1,144 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import sys
         
     | 
| 2 | 
         
            +
            import argparse
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import time
         
     | 
| 5 | 
         
            +
            import logging
         
     | 
| 6 | 
         
            +
            from datetime import datetime
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def main():
         
     | 
| 10 | 
         
            +
                parser = argparse.ArgumentParser()
         
     | 
| 11 | 
         
            +
                parser.add_argument("--config", required=True, help="path to config file")
         
     | 
| 12 | 
         
            +
                parser.add_argument("--gpu", default="0", help="GPU(s) to be used")
         
     | 
| 13 | 
         
            +
                parser.add_argument(
         
     | 
| 14 | 
         
            +
                    "--resume", default=None, help="path to the weights to be resumed"
         
     | 
| 15 | 
         
            +
                )
         
     | 
| 16 | 
         
            +
                parser.add_argument(
         
     | 
| 17 | 
         
            +
                    "--resume_weights_only",
         
     | 
| 18 | 
         
            +
                    action="store_true",
         
     | 
| 19 | 
         
            +
                    help="specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only",
         
     | 
| 20 | 
         
            +
                )
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                group = parser.add_mutually_exclusive_group(required=True)
         
     | 
| 23 | 
         
            +
                group.add_argument("--train", action="store_true")
         
     | 
| 24 | 
         
            +
                group.add_argument("--validate", action="store_true")
         
     | 
| 25 | 
         
            +
                group.add_argument("--test", action="store_true")
         
     | 
| 26 | 
         
            +
                group.add_argument("--predict", action="store_true")
         
     | 
| 27 | 
         
            +
                # group.add_argument('--export', action='store_true') # TODO: a separate export action
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                parser.add_argument("--exp_dir", default="./exp")
         
     | 
| 30 | 
         
            +
                parser.add_argument("--runs_dir", default="./runs")
         
     | 
| 31 | 
         
            +
                parser.add_argument(
         
     | 
| 32 | 
         
            +
                    "--verbose", action="store_true", help="if true, set logging level to DEBUG"
         
     | 
| 33 | 
         
            +
                )
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                args, extras = parser.parse_known_args()
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                # set CUDA_VISIBLE_DEVICES then import pytorch-lightning
         
     | 
| 38 | 
         
            +
                os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
         
     | 
| 39 | 
         
            +
                os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
         
     | 
| 40 | 
         
            +
                n_gpus = len(args.gpu.split(","))
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                import datasets
         
     | 
| 43 | 
         
            +
                import systems
         
     | 
| 44 | 
         
            +
                import pytorch_lightning as pl
         
     | 
| 45 | 
         
            +
                from pytorch_lightning import Trainer
         
     | 
| 46 | 
         
            +
                from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
         
     | 
| 47 | 
         
            +
                from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
         
     | 
| 48 | 
         
            +
                from utils.callbacks import (
         
     | 
| 49 | 
         
            +
                    CodeSnapshotCallback,
         
     | 
| 50 | 
         
            +
                    ConfigSnapshotCallback,
         
     | 
| 51 | 
         
            +
                    CustomProgressBar,
         
     | 
| 52 | 
         
            +
                )
         
     | 
| 53 | 
         
            +
                from utils.misc import load_config
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                # parse YAML config to OmegaConf
         
     | 
| 56 | 
         
            +
                config = load_config(args.config, cli_args=extras)
         
     | 
| 57 | 
         
            +
                config.cmd_args = vars(args)
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                config.trial_name = config.get("trial_name") or (
         
     | 
| 60 | 
         
            +
                    config.tag + datetime.now().strftime("@%Y%m%d-%H%M%S")
         
     | 
| 61 | 
         
            +
                )
         
     | 
| 62 | 
         
            +
                config.exp_dir = config.get("exp_dir") or os.path.join(args.exp_dir, config.name)
         
     | 
| 63 | 
         
            +
                config.save_dir = config.get("save_dir") or os.path.join(
         
     | 
| 64 | 
         
            +
                    config.exp_dir, config.trial_name, "save"
         
     | 
| 65 | 
         
            +
                )
         
     | 
| 66 | 
         
            +
                config.ckpt_dir = config.get("ckpt_dir") or os.path.join(
         
     | 
| 67 | 
         
            +
                    config.exp_dir, config.trial_name, "ckpt"
         
     | 
| 68 | 
         
            +
                )
         
     | 
| 69 | 
         
            +
                config.code_dir = config.get("code_dir") or os.path.join(
         
     | 
| 70 | 
         
            +
                    config.exp_dir, config.trial_name, "code"
         
     | 
| 71 | 
         
            +
                )
         
     | 
| 72 | 
         
            +
                config.config_dir = config.get("config_dir") or os.path.join(
         
     | 
| 73 | 
         
            +
                    config.exp_dir, config.trial_name, "config"
         
     | 
| 74 | 
         
            +
                )
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                logger = logging.getLogger("pytorch_lightning")
         
     | 
| 77 | 
         
            +
                if args.verbose:
         
     | 
| 78 | 
         
            +
                    logger.setLevel(logging.DEBUG)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                if "seed" not in config:
         
     | 
| 81 | 
         
            +
                    config.seed = int(time.time() * 1000) % 1000
         
     | 
| 82 | 
         
            +
                pl.seed_everything(config.seed)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                dm = datasets.make(config.dataset.name, config.dataset)
         
     | 
| 85 | 
         
            +
                system = systems.make(
         
     | 
| 86 | 
         
            +
                    config.system.name,
         
     | 
| 87 | 
         
            +
                    config,
         
     | 
| 88 | 
         
            +
                    load_from_checkpoint=None if not args.resume_weights_only else args.resume,
         
     | 
| 89 | 
         
            +
                )
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                callbacks = []
         
     | 
| 92 | 
         
            +
                if args.train:
         
     | 
| 93 | 
         
            +
                    callbacks += [
         
     | 
| 94 | 
         
            +
                        ModelCheckpoint(dirpath=config.ckpt_dir, **config.checkpoint),
         
     | 
| 95 | 
         
            +
                        LearningRateMonitor(logging_interval="step"),
         
     | 
| 96 | 
         
            +
                        # CodeSnapshotCallback(
         
     | 
| 97 | 
         
            +
                        #     config.code_dir, use_version=False
         
     | 
| 98 | 
         
            +
                        # ),
         
     | 
| 99 | 
         
            +
                        ConfigSnapshotCallback(config, config.config_dir, use_version=False),
         
     | 
| 100 | 
         
            +
                        CustomProgressBar(refresh_rate=1),
         
     | 
| 101 | 
         
            +
                    ]
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                loggers = []
         
     | 
| 104 | 
         
            +
                if args.train:
         
     | 
| 105 | 
         
            +
                    loggers += [
         
     | 
| 106 | 
         
            +
                        TensorBoardLogger(
         
     | 
| 107 | 
         
            +
                            args.runs_dir, name=config.name, version=config.trial_name
         
     | 
| 108 | 
         
            +
                        ),
         
     | 
| 109 | 
         
            +
                        CSVLogger(config.exp_dir, name=config.trial_name, version="csv_logs"),
         
     | 
| 110 | 
         
            +
                    ]
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                if sys.platform == "win32":
         
     | 
| 113 | 
         
            +
                    # does not support multi-gpu on windows
         
     | 
| 114 | 
         
            +
                    strategy = "dp"
         
     | 
| 115 | 
         
            +
                    assert n_gpus == 1
         
     | 
| 116 | 
         
            +
                else:
         
     | 
| 117 | 
         
            +
                    strategy = "ddp_find_unused_parameters_false"
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                trainer = Trainer(
         
     | 
| 120 | 
         
            +
                    devices=n_gpus,
         
     | 
| 121 | 
         
            +
                    accelerator="gpu",
         
     | 
| 122 | 
         
            +
                    callbacks=callbacks,
         
     | 
| 123 | 
         
            +
                    logger=loggers,
         
     | 
| 124 | 
         
            +
                    strategy=strategy,
         
     | 
| 125 | 
         
            +
                    **config.trainer
         
     | 
| 126 | 
         
            +
                )
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                if args.train:
         
     | 
| 129 | 
         
            +
                    if args.resume and not args.resume_weights_only:
         
     | 
| 130 | 
         
            +
                        # FIXME: different behavior in pytorch-lighting>1.9 ?
         
     | 
| 131 | 
         
            +
                        trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
         
     | 
| 132 | 
         
            +
                    else:
         
     | 
| 133 | 
         
            +
                        trainer.fit(system, datamodule=dm)
         
     | 
| 134 | 
         
            +
                    trainer.test(system, datamodule=dm)
         
     | 
| 135 | 
         
            +
                elif args.validate:
         
     | 
| 136 | 
         
            +
                    trainer.validate(system, datamodule=dm, ckpt_path=args.resume)
         
     | 
| 137 | 
         
            +
                elif args.test:
         
     | 
| 138 | 
         
            +
                    trainer.test(system, datamodule=dm, ckpt_path=args.resume)
         
     | 
| 139 | 
         
            +
                elif args.predict:
         
     | 
| 140 | 
         
            +
                    trainer.predict(system, datamodule=dm, ckpt_path=args.resume)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 144 | 
         
            +
                main()
         
     | 
    	
        mesh_recon/mesh.py
    ADDED
    
    | 
         @@ -0,0 +1,845 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import os
         
     | 
| 2 | 
         
            +
            import cv2
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import trimesh
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            from kiui.op import safe_normalize, dot
         
     | 
| 8 | 
         
            +
            from kiui.typing import *
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class Mesh:
         
     | 
| 11 | 
         
            +
                """
         
     | 
| 12 | 
         
            +
                A torch-native trimesh class, with support for ``ply/obj/glb`` formats.
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                Note:
         
     | 
| 15 | 
         
            +
                    This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture).
         
     | 
| 16 | 
         
            +
                """
         
     | 
| 17 | 
         
            +
                def __init__(
         
     | 
| 18 | 
         
            +
                    self,
         
     | 
| 19 | 
         
            +
                    v: Optional[Tensor] = None,
         
     | 
| 20 | 
         
            +
                    f: Optional[Tensor] = None,
         
     | 
| 21 | 
         
            +
                    vn: Optional[Tensor] = None,
         
     | 
| 22 | 
         
            +
                    fn: Optional[Tensor] = None,
         
     | 
| 23 | 
         
            +
                    vt: Optional[Tensor] = None,
         
     | 
| 24 | 
         
            +
                    ft: Optional[Tensor] = None,
         
     | 
| 25 | 
         
            +
                    vc: Optional[Tensor] = None, # vertex color
         
     | 
| 26 | 
         
            +
                    albedo: Optional[Tensor] = None,
         
     | 
| 27 | 
         
            +
                    metallicRoughness: Optional[Tensor] = None,
         
     | 
| 28 | 
         
            +
                    device: Optional[torch.device] = None,
         
     | 
| 29 | 
         
            +
                ):
         
     | 
| 30 | 
         
            +
                    """Init a mesh directly using all attributes.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                    Args:
         
     | 
| 33 | 
         
            +
                        v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None.
         
     | 
| 34 | 
         
            +
                        f (Optional[Tensor]): faces, int [M, 3]. Defaults to None.
         
     | 
| 35 | 
         
            +
                        vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None.
         
     | 
| 36 | 
         
            +
                        fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None.
         
     | 
| 37 | 
         
            +
                        vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None.
         
     | 
| 38 | 
         
            +
                        ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None.
         
     | 
| 39 | 
         
            +
                        vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None.
         
     | 
| 40 | 
         
            +
                        albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None.
         
     | 
| 41 | 
         
            +
                        metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None.
         
     | 
| 42 | 
         
            +
                        device (Optional[torch.device]): torch device. Defaults to None.
         
     | 
| 43 | 
         
            +
                    """
         
     | 
| 44 | 
         
            +
                    self.device = device
         
     | 
| 45 | 
         
            +
                    self.v = v
         
     | 
| 46 | 
         
            +
                    self.vn = vn
         
     | 
| 47 | 
         
            +
                    self.vt = vt
         
     | 
| 48 | 
         
            +
                    self.f = f
         
     | 
| 49 | 
         
            +
                    self.fn = fn
         
     | 
| 50 | 
         
            +
                    self.ft = ft
         
     | 
| 51 | 
         
            +
                    # will first see if there is vertex color to use
         
     | 
| 52 | 
         
            +
                    self.vc = vc
         
     | 
| 53 | 
         
            +
                    # only support a single albedo image
         
     | 
| 54 | 
         
            +
                    self.albedo = albedo
         
     | 
| 55 | 
         
            +
                    # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]
         
     | 
| 56 | 
         
            +
                    # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html
         
     | 
| 57 | 
         
            +
                    self.metallicRoughness = metallicRoughness
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    self.ori_center = 0
         
     | 
| 60 | 
         
            +
                    self.ori_scale = 1
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                @classmethod
         
     | 
| 63 | 
         
            +
                def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs):
         
     | 
| 64 | 
         
            +
                    """load mesh from path.
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    Args:
         
     | 
| 67 | 
         
            +
                        path (str): path to mesh file, supports ply, obj, glb.
         
     | 
| 68 | 
         
            +
                        clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False.
         
     | 
| 69 | 
         
            +
                        resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True.
         
     | 
| 70 | 
         
            +
                        renormal (bool, optional): re-calc the vertex normals. Defaults to True.
         
     | 
| 71 | 
         
            +
                        retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False.
         
     | 
| 72 | 
         
            +
                        bound (float, optional): bound to resize. Defaults to 0.9.
         
     | 
| 73 | 
         
            +
                        front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'.
         
     | 
| 74 | 
         
            +
                        device (torch.device, optional): torch device. Defaults to None.
         
     | 
| 75 | 
         
            +
                    
         
     | 
| 76 | 
         
            +
                    Note:
         
     | 
| 77 | 
         
            +
                        a ``device`` keyword argument can be provided to specify the torch device. 
         
     | 
| 78 | 
         
            +
                        If it's not provided, we will try to use ``'cuda'`` as the device if it's available.
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    Returns:
         
     | 
| 81 | 
         
            +
                        Mesh: the loaded Mesh object.
         
     | 
| 82 | 
         
            +
                    """
         
     | 
| 83 | 
         
            +
                    # obj supports face uv
         
     | 
| 84 | 
         
            +
                    if path.endswith(".obj"):
         
     | 
| 85 | 
         
            +
                        mesh = cls.load_obj(path, **kwargs)
         
     | 
| 86 | 
         
            +
                    # trimesh only supports vertex uv, but can load more formats
         
     | 
| 87 | 
         
            +
                    else:
         
     | 
| 88 | 
         
            +
                        mesh = cls.load_trimesh(path, **kwargs)
         
     | 
| 89 | 
         
            +
                    
         
     | 
| 90 | 
         
            +
                    # clean
         
     | 
| 91 | 
         
            +
                    if clean:
         
     | 
| 92 | 
         
            +
                        from kiui.mesh_utils import clean_mesh
         
     | 
| 93 | 
         
            +
                        vertices = mesh.v.detach().cpu().numpy()
         
     | 
| 94 | 
         
            +
                        triangles = mesh.f.detach().cpu().numpy()
         
     | 
| 95 | 
         
            +
                        vertices, triangles = clean_mesh(vertices, triangles, remesh=False)
         
     | 
| 96 | 
         
            +
                        mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device)
         
     | 
| 97 | 
         
            +
                        mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device)
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
         
     | 
| 100 | 
         
            +
                    # auto-normalize
         
     | 
| 101 | 
         
            +
                    if resize:
         
     | 
| 102 | 
         
            +
                        mesh.auto_size(bound=bound)
         
     | 
| 103 | 
         
            +
                    # auto-fix normal
         
     | 
| 104 | 
         
            +
                    if renormal or mesh.vn is None:
         
     | 
| 105 | 
         
            +
                        mesh.auto_normal()
         
     | 
| 106 | 
         
            +
                        print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
         
     | 
| 107 | 
         
            +
                    # auto-fix texcoords
         
     | 
| 108 | 
         
            +
                    if retex or (mesh.albedo is not None and mesh.vt is None):
         
     | 
| 109 | 
         
            +
                        mesh.auto_uv(cache_path=path)
         
     | 
| 110 | 
         
            +
                        print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                    # rotate front dir to +z
         
     | 
| 113 | 
         
            +
                    if front_dir != "+z":
         
     | 
| 114 | 
         
            +
                        # axis switch
         
     | 
| 115 | 
         
            +
                        if "-z" in front_dir:
         
     | 
| 116 | 
         
            +
                            T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32)
         
     | 
| 117 | 
         
            +
                        elif "+x" in front_dir:
         
     | 
| 118 | 
         
            +
                            T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
         
     | 
| 119 | 
         
            +
                        elif "-x" in front_dir:
         
     | 
| 120 | 
         
            +
                            T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32)
         
     | 
| 121 | 
         
            +
                        elif "+y" in front_dir:
         
     | 
| 122 | 
         
            +
                            T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
         
     | 
| 123 | 
         
            +
                        elif "-y" in front_dir:
         
     | 
| 124 | 
         
            +
                            T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32)
         
     | 
| 125 | 
         
            +
                        else:
         
     | 
| 126 | 
         
            +
                            T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32)
         
     | 
| 127 | 
         
            +
                        # rotation (how many 90 degrees)
         
     | 
| 128 | 
         
            +
                        if '1' in front_dir:
         
     | 
| 129 | 
         
            +
                            T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) 
         
     | 
| 130 | 
         
            +
                        elif '2' in front_dir:
         
     | 
| 131 | 
         
            +
                            T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) 
         
     | 
| 132 | 
         
            +
                        elif '3' in front_dir:
         
     | 
| 133 | 
         
            +
                            T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) 
         
     | 
| 134 | 
         
            +
                        mesh.v @= T
         
     | 
| 135 | 
         
            +
                        mesh.vn @= T
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    return mesh
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                # load from obj file
         
     | 
| 140 | 
         
            +
                @classmethod
         
     | 
| 141 | 
         
            +
                def load_obj(cls, path, albedo_path=None, device=None):
         
     | 
| 142 | 
         
            +
                    """load an ``obj`` mesh.
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    Args:
         
     | 
| 145 | 
         
            +
                        path (str): path to mesh.
         
     | 
| 146 | 
         
            +
                        albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None.
         
     | 
| 147 | 
         
            +
                        device (torch.device, optional): torch device. Defaults to None.
         
     | 
| 148 | 
         
            +
                    
         
     | 
| 149 | 
         
            +
                    Note: 
         
     | 
| 150 | 
         
            +
                        We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension.
         
     | 
| 151 | 
         
            +
                        The `usemtl` statement is ignored, and we only use the last material path in `mtl` file.
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    Returns:
         
     | 
| 154 | 
         
            +
                        Mesh: the loaded Mesh object.
         
     | 
| 155 | 
         
            +
                    """
         
     | 
| 156 | 
         
            +
                    assert os.path.splitext(path)[-1] == ".obj"
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                    mesh = cls()
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    # device
         
     | 
| 161 | 
         
            +
                    if device is None:
         
     | 
| 162 | 
         
            +
                        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                    mesh.device = device
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    # load obj
         
     | 
| 167 | 
         
            +
                    with open(path, "r") as f:
         
     | 
| 168 | 
         
            +
                        lines = f.readlines()
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    def parse_f_v(fv):
         
     | 
| 171 | 
         
            +
                        # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
         
     | 
| 172 | 
         
            +
                        # supported forms:
         
     | 
| 173 | 
         
            +
                        # f v1 v2 v3
         
     | 
| 174 | 
         
            +
                        # f v1/vt1 v2/vt2 v3/vt3
         
     | 
| 175 | 
         
            +
                        # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
         
     | 
| 176 | 
         
            +
                        # f v1//vn1 v2//vn2 v3//vn3
         
     | 
| 177 | 
         
            +
                        xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
         
     | 
| 178 | 
         
            +
                        xs.extend([-1] * (3 - len(xs)))
         
     | 
| 179 | 
         
            +
                        return xs[0], xs[1], xs[2]
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                    vertices, texcoords, normals = [], [], []
         
     | 
| 182 | 
         
            +
                    faces, tfaces, nfaces = [], [], []
         
     | 
| 183 | 
         
            +
                    mtl_path = None
         
     | 
| 184 | 
         
            +
             
     | 
| 185 | 
         
            +
                    for line in lines:
         
     | 
| 186 | 
         
            +
                        split_line = line.split()
         
     | 
| 187 | 
         
            +
                        # empty line
         
     | 
| 188 | 
         
            +
                        if len(split_line) == 0:
         
     | 
| 189 | 
         
            +
                            continue
         
     | 
| 190 | 
         
            +
                        prefix = split_line[0].lower()
         
     | 
| 191 | 
         
            +
                        # mtllib
         
     | 
| 192 | 
         
            +
                        if prefix == "mtllib":
         
     | 
| 193 | 
         
            +
                            mtl_path = split_line[1]
         
     | 
| 194 | 
         
            +
                        # usemtl
         
     | 
| 195 | 
         
            +
                        elif prefix == "usemtl":
         
     | 
| 196 | 
         
            +
                            pass # ignored
         
     | 
| 197 | 
         
            +
                        # v/vn/vt
         
     | 
| 198 | 
         
            +
                        elif prefix == "v":
         
     | 
| 199 | 
         
            +
                            vertices.append([float(v) for v in split_line[1:]])
         
     | 
| 200 | 
         
            +
                        elif prefix == "vn":
         
     | 
| 201 | 
         
            +
                            normals.append([float(v) for v in split_line[1:]])
         
     | 
| 202 | 
         
            +
                        elif prefix == "vt":
         
     | 
| 203 | 
         
            +
                            val = [float(v) for v in split_line[1:]]
         
     | 
| 204 | 
         
            +
                            texcoords.append([val[0], 1.0 - val[1]])
         
     | 
| 205 | 
         
            +
                        elif prefix == "f":
         
     | 
| 206 | 
         
            +
                            vs = split_line[1:]
         
     | 
| 207 | 
         
            +
                            nv = len(vs)
         
     | 
| 208 | 
         
            +
                            v0, t0, n0 = parse_f_v(vs[0])
         
     | 
| 209 | 
         
            +
                            for i in range(nv - 2):  # triangulate (assume vertices are ordered)
         
     | 
| 210 | 
         
            +
                                v1, t1, n1 = parse_f_v(vs[i + 1])
         
     | 
| 211 | 
         
            +
                                v2, t2, n2 = parse_f_v(vs[i + 2])
         
     | 
| 212 | 
         
            +
                                faces.append([v0, v1, v2])
         
     | 
| 213 | 
         
            +
                                tfaces.append([t0, t1, t2])
         
     | 
| 214 | 
         
            +
                                nfaces.append([n0, n1, n2])
         
     | 
| 215 | 
         
            +
             
     | 
| 216 | 
         
            +
                    mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
         
     | 
| 217 | 
         
            +
                    mesh.vt = (
         
     | 
| 218 | 
         
            +
                        torch.tensor(texcoords, dtype=torch.float32, device=device)
         
     | 
| 219 | 
         
            +
                        if len(texcoords) > 0
         
     | 
| 220 | 
         
            +
                        else None
         
     | 
| 221 | 
         
            +
                    )
         
     | 
| 222 | 
         
            +
                    mesh.vn = (
         
     | 
| 223 | 
         
            +
                        torch.tensor(normals, dtype=torch.float32, device=device)
         
     | 
| 224 | 
         
            +
                        if len(normals) > 0
         
     | 
| 225 | 
         
            +
                        else None
         
     | 
| 226 | 
         
            +
                    )
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                    mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
         
     | 
| 229 | 
         
            +
                    mesh.ft = (
         
     | 
| 230 | 
         
            +
                        torch.tensor(tfaces, dtype=torch.int32, device=device)
         
     | 
| 231 | 
         
            +
                        if len(texcoords) > 0
         
     | 
| 232 | 
         
            +
                        else None
         
     | 
| 233 | 
         
            +
                    )
         
     | 
| 234 | 
         
            +
                    mesh.fn = (
         
     | 
| 235 | 
         
            +
                        torch.tensor(nfaces, dtype=torch.int32, device=device)
         
     | 
| 236 | 
         
            +
                        if len(normals) > 0
         
     | 
| 237 | 
         
            +
                        else None
         
     | 
| 238 | 
         
            +
                    )
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                    # see if there is vertex color
         
     | 
| 241 | 
         
            +
                    use_vertex_color = False
         
     | 
| 242 | 
         
            +
                    if mesh.v.shape[1] == 6:
         
     | 
| 243 | 
         
            +
                        use_vertex_color = True
         
     | 
| 244 | 
         
            +
                        mesh.vc = mesh.v[:, 3:]
         
     | 
| 245 | 
         
            +
                        mesh.v = mesh.v[:, :3]
         
     | 
| 246 | 
         
            +
                        print(f"[load_obj] use vertex color: {mesh.vc.shape}")
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                    # try to load texture image
         
     | 
| 249 | 
         
            +
                    if not use_vertex_color:
         
     | 
| 250 | 
         
            +
                        # try to retrieve mtl file
         
     | 
| 251 | 
         
            +
                        mtl_path_candidates = []
         
     | 
| 252 | 
         
            +
                        if mtl_path is not None:
         
     | 
| 253 | 
         
            +
                            mtl_path_candidates.append(mtl_path)
         
     | 
| 254 | 
         
            +
                            mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path))
         
     | 
| 255 | 
         
            +
                        mtl_path_candidates.append(path.replace(".obj", ".mtl"))
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                        mtl_path = None
         
     | 
| 258 | 
         
            +
                        for candidate in mtl_path_candidates:
         
     | 
| 259 | 
         
            +
                            if os.path.exists(candidate):
         
     | 
| 260 | 
         
            +
                                mtl_path = candidate
         
     | 
| 261 | 
         
            +
                                break
         
     | 
| 262 | 
         
            +
                        
         
     | 
| 263 | 
         
            +
                        # if albedo_path is not provided, try retrieve it from mtl
         
     | 
| 264 | 
         
            +
                        metallic_path = None
         
     | 
| 265 | 
         
            +
                        roughness_path = None
         
     | 
| 266 | 
         
            +
                        if mtl_path is not None and albedo_path is None:
         
     | 
| 267 | 
         
            +
                            with open(mtl_path, "r") as f:
         
     | 
| 268 | 
         
            +
                                lines = f.readlines()
         
     | 
| 269 | 
         
            +
             
     | 
| 270 | 
         
            +
                            for line in lines:
         
     | 
| 271 | 
         
            +
                                split_line = line.split()
         
     | 
| 272 | 
         
            +
                                # empty line
         
     | 
| 273 | 
         
            +
                                if len(split_line) == 0:
         
     | 
| 274 | 
         
            +
                                    continue
         
     | 
| 275 | 
         
            +
                                prefix = split_line[0]
         
     | 
| 276 | 
         
            +
                                
         
     | 
| 277 | 
         
            +
                                if "map_Kd" in prefix:
         
     | 
| 278 | 
         
            +
                                    # assume relative path!
         
     | 
| 279 | 
         
            +
                                    albedo_path = os.path.join(os.path.dirname(path), split_line[1])
         
     | 
| 280 | 
         
            +
                                    print(f"[load_obj] use texture from: {albedo_path}")
         
     | 
| 281 | 
         
            +
                                elif "map_Pm" in prefix:
         
     | 
| 282 | 
         
            +
                                    metallic_path = os.path.join(os.path.dirname(path), split_line[1])
         
     | 
| 283 | 
         
            +
                                elif "map_Pr" in prefix:
         
     | 
| 284 | 
         
            +
                                    roughness_path = os.path.join(os.path.dirname(path), split_line[1])
         
     | 
| 285 | 
         
            +
                                
         
     | 
| 286 | 
         
            +
                        # still not found albedo_path, or the path doesn't exist
         
     | 
| 287 | 
         
            +
                        if albedo_path is None or not os.path.exists(albedo_path):
         
     | 
| 288 | 
         
            +
                            # init an empty texture
         
     | 
| 289 | 
         
            +
                            print(f"[load_obj] init empty albedo!")
         
     | 
| 290 | 
         
            +
                            # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
         
     | 
| 291 | 
         
            +
                            albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])  # default color
         
     | 
| 292 | 
         
            +
                        else:
         
     | 
| 293 | 
         
            +
                            albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
         
     | 
| 294 | 
         
            +
                            albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
         
     | 
| 295 | 
         
            +
                            albedo = albedo.astype(np.float32) / 255
         
     | 
| 296 | 
         
            +
                            print(f"[load_obj] load texture: {albedo.shape}")
         
     | 
| 297 | 
         
            +
                        
         
     | 
| 298 | 
         
            +
                        mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
         
     | 
| 299 | 
         
            +
                        
         
     | 
| 300 | 
         
            +
                        # try to load metallic and roughness
         
     | 
| 301 | 
         
            +
                        if metallic_path is not None and roughness_path is not None:
         
     | 
| 302 | 
         
            +
                            print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}")
         
     | 
| 303 | 
         
            +
                            metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED)
         
     | 
| 304 | 
         
            +
                            metallic = metallic.astype(np.float32) / 255
         
     | 
| 305 | 
         
            +
                            roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED)
         
     | 
| 306 | 
         
            +
                            roughness = roughness.astype(np.float32) / 255
         
     | 
| 307 | 
         
            +
                            metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1)
         
     | 
| 308 | 
         
            +
             
     | 
| 309 | 
         
            +
                            mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
         
     | 
| 310 | 
         
            +
             
     | 
| 311 | 
         
            +
                    return mesh
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                @classmethod
         
     | 
| 314 | 
         
            +
                def load_trimesh(cls, path, device=None):
         
     | 
| 315 | 
         
            +
                    """load a mesh using ``trimesh.load()``.
         
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
                    Can load various formats like ``glb`` and serves as a fallback.
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    Note:
         
     | 
| 320 | 
         
            +
                        We will try to merge all meshes if the glb contains more than one, 
         
     | 
| 321 | 
         
            +
                        but **this may cause the texture to lose**, since we only support one texture image!
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
                    Args:
         
     | 
| 324 | 
         
            +
                        path (str): path to the mesh file.
         
     | 
| 325 | 
         
            +
                        device (torch.device, optional): torch device. Defaults to None.
         
     | 
| 326 | 
         
            +
             
     | 
| 327 | 
         
            +
                    Returns:
         
     | 
| 328 | 
         
            +
                        Mesh: the loaded Mesh object.
         
     | 
| 329 | 
         
            +
                    """
         
     | 
| 330 | 
         
            +
                    mesh = cls()
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    # device
         
     | 
| 333 | 
         
            +
                    if device is None:
         
     | 
| 334 | 
         
            +
                        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    mesh.device = device
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                    # use trimesh to load ply/glb
         
     | 
| 339 | 
         
            +
                    _data = trimesh.load(path)
         
     | 
| 340 | 
         
            +
                    if isinstance(_data, trimesh.Scene):
         
     | 
| 341 | 
         
            +
                        if len(_data.geometry) == 1:
         
     | 
| 342 | 
         
            +
                            _mesh = list(_data.geometry.values())[0]
         
     | 
| 343 | 
         
            +
                        else:
         
     | 
| 344 | 
         
            +
                            print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.")
         
     | 
| 345 | 
         
            +
                            _concat = []
         
     | 
| 346 | 
         
            +
                            # loop the scene graph and apply transform to each mesh
         
     | 
| 347 | 
         
            +
                            scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}}
         
     | 
| 348 | 
         
            +
                            for k, v in scene_graph.items():
         
     | 
| 349 | 
         
            +
                                name = v['geometry']
         
     | 
| 350 | 
         
            +
                                if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh):
         
     | 
| 351 | 
         
            +
                                    transform = v['transform']
         
     | 
| 352 | 
         
            +
                                    _concat.append(_data.geometry[name].apply_transform(transform))
         
     | 
| 353 | 
         
            +
                            _mesh = trimesh.util.concatenate(_concat)
         
     | 
| 354 | 
         
            +
                    else:
         
     | 
| 355 | 
         
            +
                        _mesh = _data
         
     | 
| 356 | 
         
            +
                    
         
     | 
| 357 | 
         
            +
                    if _mesh.visual.kind == 'vertex':
         
     | 
| 358 | 
         
            +
                        vertex_colors = _mesh.visual.vertex_colors
         
     | 
| 359 | 
         
            +
                        vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255
         
     | 
| 360 | 
         
            +
                        mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device)
         
     | 
| 361 | 
         
            +
                        print(f"[load_trimesh] use vertex color: {mesh.vc.shape}")
         
     | 
| 362 | 
         
            +
                    elif _mesh.visual.kind == 'texture':
         
     | 
| 363 | 
         
            +
                        _material = _mesh.visual.material
         
     | 
| 364 | 
         
            +
                        if isinstance(_material, trimesh.visual.material.PBRMaterial):
         
     | 
| 365 | 
         
            +
                            texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
         
     | 
| 366 | 
         
            +
                            # load metallicRoughness if present
         
     | 
| 367 | 
         
            +
                            if _material.metallicRoughnessTexture is not None:
         
     | 
| 368 | 
         
            +
                                metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255
         
     | 
| 369 | 
         
            +
                                mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous()
         
     | 
| 370 | 
         
            +
                        elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
         
     | 
| 371 | 
         
            +
                            texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
         
     | 
| 372 | 
         
            +
                        else:
         
     | 
| 373 | 
         
            +
                            raise NotImplementedError(f"material type {type(_material)} not supported!")
         
     | 
| 374 | 
         
            +
                        mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous()
         
     | 
| 375 | 
         
            +
                        print(f"[load_trimesh] load texture: {texture.shape}")
         
     | 
| 376 | 
         
            +
                    else:
         
     | 
| 377 | 
         
            +
                        texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5])
         
     | 
| 378 | 
         
            +
                        mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
         
     | 
| 379 | 
         
            +
                        print(f"[load_trimesh] failed to load texture.")
         
     | 
| 380 | 
         
            +
             
     | 
| 381 | 
         
            +
                    vertices = _mesh.vertices
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
                    try:
         
     | 
| 384 | 
         
            +
                        texcoords = _mesh.visual.uv
         
     | 
| 385 | 
         
            +
                        texcoords[:, 1] = 1 - texcoords[:, 1]
         
     | 
| 386 | 
         
            +
                    except Exception as e:
         
     | 
| 387 | 
         
            +
                        texcoords = None
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                    try:
         
     | 
| 390 | 
         
            +
                        normals = _mesh.vertex_normals
         
     | 
| 391 | 
         
            +
                    except Exception as e:
         
     | 
| 392 | 
         
            +
                        normals = None
         
     | 
| 393 | 
         
            +
             
     | 
| 394 | 
         
            +
                    # trimesh only support vertex uv...
         
     | 
| 395 | 
         
            +
                    faces = tfaces = nfaces = _mesh.faces
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                    mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
         
     | 
| 398 | 
         
            +
                    mesh.vt = (
         
     | 
| 399 | 
         
            +
                        torch.tensor(texcoords, dtype=torch.float32, device=device)
         
     | 
| 400 | 
         
            +
                        if texcoords is not None
         
     | 
| 401 | 
         
            +
                        else None
         
     | 
| 402 | 
         
            +
                    )
         
     | 
| 403 | 
         
            +
                    mesh.vn = (
         
     | 
| 404 | 
         
            +
                        torch.tensor(normals, dtype=torch.float32, device=device)
         
     | 
| 405 | 
         
            +
                        if normals is not None
         
     | 
| 406 | 
         
            +
                        else None
         
     | 
| 407 | 
         
            +
                    )
         
     | 
| 408 | 
         
            +
             
     | 
| 409 | 
         
            +
                    mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
         
     | 
| 410 | 
         
            +
                    mesh.ft = (
         
     | 
| 411 | 
         
            +
                        torch.tensor(tfaces, dtype=torch.int32, device=device)
         
     | 
| 412 | 
         
            +
                        if texcoords is not None
         
     | 
| 413 | 
         
            +
                        else None
         
     | 
| 414 | 
         
            +
                    )
         
     | 
| 415 | 
         
            +
                    mesh.fn = (
         
     | 
| 416 | 
         
            +
                        torch.tensor(nfaces, dtype=torch.int32, device=device)
         
     | 
| 417 | 
         
            +
                        if normals is not None
         
     | 
| 418 | 
         
            +
                        else None
         
     | 
| 419 | 
         
            +
                    )
         
     | 
| 420 | 
         
            +
             
     | 
| 421 | 
         
            +
                    return mesh
         
     | 
| 422 | 
         
            +
             
     | 
| 423 | 
         
            +
                # sample surface (using trimesh)
         
     | 
| 424 | 
         
            +
                def sample_surface(self, count: int):
         
     | 
| 425 | 
         
            +
                    """sample points on the surface of the mesh.
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                    Args:
         
     | 
| 428 | 
         
            +
                        count (int): number of points to sample.
         
     | 
| 429 | 
         
            +
             
     | 
| 430 | 
         
            +
                    Returns:
         
     | 
| 431 | 
         
            +
                        torch.Tensor: the sampled points, float [count, 3].
         
     | 
| 432 | 
         
            +
                    """
         
     | 
| 433 | 
         
            +
                    _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy())
         
     | 
| 434 | 
         
            +
                    points, face_idx = trimesh.sample.sample_surface(_mesh, count)
         
     | 
| 435 | 
         
            +
                    points = torch.from_numpy(points).float().to(self.device)
         
     | 
| 436 | 
         
            +
                    return points
         
     | 
| 437 | 
         
            +
             
     | 
| 438 | 
         
            +
                # aabb
         
     | 
| 439 | 
         
            +
                def aabb(self):
         
     | 
| 440 | 
         
            +
                    """get the axis-aligned bounding box of the mesh.
         
     | 
| 441 | 
         
            +
             
     | 
| 442 | 
         
            +
                    Returns:
         
     | 
| 443 | 
         
            +
                        Tuple[torch.Tensor]: the min xyz and max xyz of the mesh.
         
     | 
| 444 | 
         
            +
                    """
         
     | 
| 445 | 
         
            +
                    return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
         
     | 
| 446 | 
         
            +
             
     | 
| 447 | 
         
            +
                # unit size
         
     | 
| 448 | 
         
            +
                @torch.no_grad()
         
     | 
| 449 | 
         
            +
                def auto_size(self, bound=0.9):
         
     | 
| 450 | 
         
            +
                    """auto resize the mesh.
         
     | 
| 451 | 
         
            +
             
     | 
| 452 | 
         
            +
                    Args:
         
     | 
| 453 | 
         
            +
                        bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9.
         
     | 
| 454 | 
         
            +
                    """
         
     | 
| 455 | 
         
            +
                    vmin, vmax = self.aabb()
         
     | 
| 456 | 
         
            +
                    self.ori_center = (vmax + vmin) / 2
         
     | 
| 457 | 
         
            +
                    self.ori_scale = 2 * bound / torch.max(vmax - vmin).item()
         
     | 
| 458 | 
         
            +
                    self.v = (self.v - self.ori_center) * self.ori_scale
         
     | 
| 459 | 
         
            +
             
     | 
| 460 | 
         
            +
                def auto_normal(self):
         
     | 
| 461 | 
         
            +
                    """auto calculate the vertex normals.
         
     | 
| 462 | 
         
            +
                    """
         
     | 
| 463 | 
         
            +
                    i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
         
     | 
| 464 | 
         
            +
                    v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
         
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
                    face_normals = torch.cross(v1 - v0, v2 - v0)
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
                    # Splat face normals to vertices
         
     | 
| 469 | 
         
            +
                    vn = torch.zeros_like(self.v)
         
     | 
| 470 | 
         
            +
                    vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
         
     | 
| 471 | 
         
            +
                    vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
         
     | 
| 472 | 
         
            +
                    vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                    # Normalize, replace zero (degenerated) normals with some default value
         
     | 
| 475 | 
         
            +
                    vn = torch.where(
         
     | 
| 476 | 
         
            +
                        dot(vn, vn) > 1e-20,
         
     | 
| 477 | 
         
            +
                        vn,
         
     | 
| 478 | 
         
            +
                        torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
         
     | 
| 479 | 
         
            +
                    )
         
     | 
| 480 | 
         
            +
                    vn = safe_normalize(vn)
         
     | 
| 481 | 
         
            +
             
     | 
| 482 | 
         
            +
                    self.vn = vn
         
     | 
| 483 | 
         
            +
                    self.fn = self.f
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                def auto_uv(self, cache_path=None, vmap=True):
         
     | 
| 486 | 
         
            +
                    """auto calculate the uv coordinates.
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                    Args:
         
     | 
| 489 | 
         
            +
                        cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None.
         
     | 
| 490 | 
         
            +
                        vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf). 
         
     | 
| 491 | 
         
            +
                            Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True.
         
     | 
| 492 | 
         
            +
                    """
         
     | 
| 493 | 
         
            +
                    # try to load cache
         
     | 
| 494 | 
         
            +
                    if cache_path is not None:
         
     | 
| 495 | 
         
            +
                        cache_path = os.path.splitext(cache_path)[0] + "_uv.npz"
         
     | 
| 496 | 
         
            +
                    if cache_path is not None and os.path.exists(cache_path):
         
     | 
| 497 | 
         
            +
                        data = np.load(cache_path)
         
     | 
| 498 | 
         
            +
                        vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"]
         
     | 
| 499 | 
         
            +
                    else:
         
     | 
| 500 | 
         
            +
                        import xatlas
         
     | 
| 501 | 
         
            +
             
     | 
| 502 | 
         
            +
                        v_np = self.v.detach().cpu().numpy()
         
     | 
| 503 | 
         
            +
                        f_np = self.f.detach().int().cpu().numpy()
         
     | 
| 504 | 
         
            +
                        atlas = xatlas.Atlas()
         
     | 
| 505 | 
         
            +
                        atlas.add_mesh(v_np, f_np)
         
     | 
| 506 | 
         
            +
                        chart_options = xatlas.ChartOptions()
         
     | 
| 507 | 
         
            +
                        # chart_options.max_iterations = 4
         
     | 
| 508 | 
         
            +
                        atlas.generate(chart_options=chart_options)
         
     | 
| 509 | 
         
            +
                        vmapping, ft_np, vt_np = atlas[0]  # [N], [M, 3], [N, 2]
         
     | 
| 510 | 
         
            +
             
     | 
| 511 | 
         
            +
                        # save to cache
         
     | 
| 512 | 
         
            +
                        if cache_path is not None:
         
     | 
| 513 | 
         
            +
                            np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping)
         
     | 
| 514 | 
         
            +
                    
         
     | 
| 515 | 
         
            +
                    vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
         
     | 
| 516 | 
         
            +
                    ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
         
     | 
| 517 | 
         
            +
                    self.vt = vt
         
     | 
| 518 | 
         
            +
                    self.ft = ft
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
                    if vmap:
         
     | 
| 521 | 
         
            +
                        vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device)
         
     | 
| 522 | 
         
            +
                        self.align_v_to_vt(vmapping)
         
     | 
| 523 | 
         
            +
                
         
     | 
| 524 | 
         
            +
                def align_v_to_vt(self, vmapping=None):
         
     | 
| 525 | 
         
            +
                    """ remap v/f and vn/fn to vt/ft.
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
                    Args:
         
     | 
| 528 | 
         
            +
                        vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None.
         
     | 
| 529 | 
         
            +
                    """
         
     | 
| 530 | 
         
            +
                    if vmapping is None:
         
     | 
| 531 | 
         
            +
                        ft = self.ft.view(-1).long()
         
     | 
| 532 | 
         
            +
                        f = self.f.view(-1).long()
         
     | 
| 533 | 
         
            +
                        vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device)
         
     | 
| 534 | 
         
            +
                        vmapping[ft] = f # scatter, randomly choose one if index is not unique
         
     | 
| 535 | 
         
            +
             
     | 
| 536 | 
         
            +
                    self.v = self.v[vmapping]
         
     | 
| 537 | 
         
            +
                    self.f = self.ft
         
     | 
| 538 | 
         
            +
                    
         
     | 
| 539 | 
         
            +
                    if self.vn is not None:
         
     | 
| 540 | 
         
            +
                        self.vn = self.vn[vmapping]
         
     | 
| 541 | 
         
            +
                        self.fn = self.ft
         
     | 
| 542 | 
         
            +
             
     | 
| 543 | 
         
            +
                def to(self, device):
         
     | 
| 544 | 
         
            +
                    """move all tensor attributes to device.
         
     | 
| 545 | 
         
            +
             
     | 
| 546 | 
         
            +
                    Args:
         
     | 
| 547 | 
         
            +
                        device (torch.device): target device.
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
                    Returns:
         
     | 
| 550 | 
         
            +
                        Mesh: self.
         
     | 
| 551 | 
         
            +
                    """
         
     | 
| 552 | 
         
            +
                    self.device = device
         
     | 
| 553 | 
         
            +
                    for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]:
         
     | 
| 554 | 
         
            +
                        tensor = getattr(self, name)
         
     | 
| 555 | 
         
            +
                        if tensor is not None:
         
     | 
| 556 | 
         
            +
                            setattr(self, name, tensor.to(device))
         
     | 
| 557 | 
         
            +
                    return self
         
     | 
| 558 | 
         
            +
                
         
     | 
| 559 | 
         
            +
                def write(self, path):
         
     | 
| 560 | 
         
            +
                    """write the mesh to a path.
         
     | 
| 561 | 
         
            +
             
     | 
| 562 | 
         
            +
                    Args:
         
     | 
| 563 | 
         
            +
                        path (str): path to write, supports ply, obj and glb.
         
     | 
| 564 | 
         
            +
                    """
         
     | 
| 565 | 
         
            +
                    if path.endswith(".ply"):
         
     | 
| 566 | 
         
            +
                        self.write_ply(path)
         
     | 
| 567 | 
         
            +
                    elif path.endswith(".obj"):
         
     | 
| 568 | 
         
            +
                        self.write_obj(path)
         
     | 
| 569 | 
         
            +
                    elif path.endswith(".glb") or path.endswith(".gltf"):
         
     | 
| 570 | 
         
            +
                        self.write_glb(path)
         
     | 
| 571 | 
         
            +
                    else:
         
     | 
| 572 | 
         
            +
                        raise NotImplementedError(f"format {path} not supported!")
         
     | 
| 573 | 
         
            +
                
         
     | 
| 574 | 
         
            +
                def write_ply(self, path):
         
     | 
| 575 | 
         
            +
                    """write the mesh in ply format. Only for geometry!
         
     | 
| 576 | 
         
            +
             
     | 
| 577 | 
         
            +
                    Args:
         
     | 
| 578 | 
         
            +
                        path (str): path to write.
         
     | 
| 579 | 
         
            +
                    """
         
     | 
| 580 | 
         
            +
             
     | 
| 581 | 
         
            +
                    if self.albedo is not None:
         
     | 
| 582 | 
         
            +
                        print(f'[WARN] ply format does not support exporting texture, will ignore!')
         
     | 
| 583 | 
         
            +
             
     | 
| 584 | 
         
            +
                    v_np = self.v.detach().cpu().numpy()
         
     | 
| 585 | 
         
            +
                    f_np = self.f.detach().cpu().numpy()
         
     | 
| 586 | 
         
            +
             
     | 
| 587 | 
         
            +
                    _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
         
     | 
| 588 | 
         
            +
                    _mesh.export(path)
         
     | 
| 589 | 
         
            +
             
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                def write_glb(self, path):
         
     | 
| 592 | 
         
            +
                    """write the mesh in glb/gltf format.
         
     | 
| 593 | 
         
            +
                      This will create a scene with a single mesh.
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
                    Args:
         
     | 
| 596 | 
         
            +
                        path (str): path to write.
         
     | 
| 597 | 
         
            +
                    """
         
     | 
| 598 | 
         
            +
             
     | 
| 599 | 
         
            +
                    # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0]
         
     | 
| 600 | 
         
            +
                    if self.vt is not None and self.v.shape[0] != self.vt.shape[0]:
         
     | 
| 601 | 
         
            +
                        self.align_v_to_vt()
         
     | 
| 602 | 
         
            +
             
     | 
| 603 | 
         
            +
                    import pygltflib
         
     | 
| 604 | 
         
            +
             
     | 
| 605 | 
         
            +
                    f_np = self.f.detach().cpu().numpy().astype(np.uint32)
         
     | 
| 606 | 
         
            +
                    f_np_blob = f_np.flatten().tobytes()
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
                    v_np = self.v.detach().cpu().numpy().astype(np.float32)
         
     | 
| 609 | 
         
            +
                    v_np_blob = v_np.tobytes()
         
     | 
| 610 | 
         
            +
             
     | 
| 611 | 
         
            +
                    blob = f_np_blob + v_np_blob
         
     | 
| 612 | 
         
            +
                    byteOffset = len(blob)
         
     | 
| 613 | 
         
            +
             
     | 
| 614 | 
         
            +
                    # base mesh
         
     | 
| 615 | 
         
            +
                    gltf = pygltflib.GLTF2(
         
     | 
| 616 | 
         
            +
                        scene=0,
         
     | 
| 617 | 
         
            +
                        scenes=[pygltflib.Scene(nodes=[0])],
         
     | 
| 618 | 
         
            +
                        nodes=[pygltflib.Node(mesh=0)],
         
     | 
| 619 | 
         
            +
                        meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive(
         
     | 
| 620 | 
         
            +
                            # indices to accessors (0 is triangles)
         
     | 
| 621 | 
         
            +
                            attributes=pygltflib.Attributes(
         
     | 
| 622 | 
         
            +
                                POSITION=1,
         
     | 
| 623 | 
         
            +
                            ),
         
     | 
| 624 | 
         
            +
                            indices=0,
         
     | 
| 625 | 
         
            +
                        )])],
         
     | 
| 626 | 
         
            +
                        buffers=[
         
     | 
| 627 | 
         
            +
                            pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob))
         
     | 
| 628 | 
         
            +
                        ],
         
     | 
| 629 | 
         
            +
                        # buffer view (based on dtype)
         
     | 
| 630 | 
         
            +
                        bufferViews=[
         
     | 
| 631 | 
         
            +
                            # triangles; as flatten (element) array
         
     | 
| 632 | 
         
            +
                            pygltflib.BufferView(
         
     | 
| 633 | 
         
            +
                                buffer=0,
         
     | 
| 634 | 
         
            +
                                byteLength=len(f_np_blob),
         
     | 
| 635 | 
         
            +
                                target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963)
         
     | 
| 636 | 
         
            +
                            ),
         
     | 
| 637 | 
         
            +
                            # positions; as vec3 array
         
     | 
| 638 | 
         
            +
                            pygltflib.BufferView(
         
     | 
| 639 | 
         
            +
                                buffer=0,
         
     | 
| 640 | 
         
            +
                                byteOffset=len(f_np_blob),
         
     | 
| 641 | 
         
            +
                                byteLength=len(v_np_blob),
         
     | 
| 642 | 
         
            +
                                byteStride=12, # vec3
         
     | 
| 643 | 
         
            +
                                target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962)
         
     | 
| 644 | 
         
            +
                            ),
         
     | 
| 645 | 
         
            +
                        ],
         
     | 
| 646 | 
         
            +
                        accessors=[
         
     | 
| 647 | 
         
            +
                            # 0 = triangles
         
     | 
| 648 | 
         
            +
                            pygltflib.Accessor(
         
     | 
| 649 | 
         
            +
                                bufferView=0,
         
     | 
| 650 | 
         
            +
                                componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125)
         
     | 
| 651 | 
         
            +
                                count=f_np.size,
         
     | 
| 652 | 
         
            +
                                type=pygltflib.SCALAR,
         
     | 
| 653 | 
         
            +
                                max=[int(f_np.max())],
         
     | 
| 654 | 
         
            +
                                min=[int(f_np.min())],
         
     | 
| 655 | 
         
            +
                            ),
         
     | 
| 656 | 
         
            +
                            # 1 = positions
         
     | 
| 657 | 
         
            +
                            pygltflib.Accessor(
         
     | 
| 658 | 
         
            +
                                bufferView=1,
         
     | 
| 659 | 
         
            +
                                componentType=pygltflib.FLOAT, # GL_FLOAT (5126)
         
     | 
| 660 | 
         
            +
                                count=len(v_np),
         
     | 
| 661 | 
         
            +
                                type=pygltflib.VEC3,
         
     | 
| 662 | 
         
            +
                                max=v_np.max(axis=0).tolist(),
         
     | 
| 663 | 
         
            +
                                min=v_np.min(axis=0).tolist(),
         
     | 
| 664 | 
         
            +
                            ),
         
     | 
| 665 | 
         
            +
                        ],
         
     | 
| 666 | 
         
            +
                    )
         
     | 
| 667 | 
         
            +
             
     | 
| 668 | 
         
            +
                    # append texture info
         
     | 
| 669 | 
         
            +
                    if self.vt is not None:
         
     | 
| 670 | 
         
            +
             
     | 
| 671 | 
         
            +
                        vt_np = self.vt.detach().cpu().numpy().astype(np.float32)
         
     | 
| 672 | 
         
            +
                        vt_np_blob = vt_np.tobytes()
         
     | 
| 673 | 
         
            +
             
     | 
| 674 | 
         
            +
                        albedo = self.albedo.detach().cpu().numpy()
         
     | 
| 675 | 
         
            +
                        albedo = (albedo * 255).astype(np.uint8)
         
     | 
| 676 | 
         
            +
                        albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)
         
     | 
| 677 | 
         
            +
                        albedo_blob = cv2.imencode('.png', albedo)[1].tobytes()
         
     | 
| 678 | 
         
            +
             
     | 
| 679 | 
         
            +
                        # update primitive
         
     | 
| 680 | 
         
            +
                        gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2
         
     | 
| 681 | 
         
            +
                        gltf.meshes[0].primitives[0].material = 0
         
     | 
| 682 | 
         
            +
             
     | 
| 683 | 
         
            +
                        # update materials
         
     | 
| 684 | 
         
            +
                        gltf.materials.append(pygltflib.Material(
         
     | 
| 685 | 
         
            +
                            pbrMetallicRoughness=pygltflib.PbrMetallicRoughness(
         
     | 
| 686 | 
         
            +
                                baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0),
         
     | 
| 687 | 
         
            +
                                metallicFactor=0.0,
         
     | 
| 688 | 
         
            +
                                roughnessFactor=1.0,
         
     | 
| 689 | 
         
            +
                            ),
         
     | 
| 690 | 
         
            +
                            alphaMode=pygltflib.OPAQUE,
         
     | 
| 691 | 
         
            +
                            alphaCutoff=None,
         
     | 
| 692 | 
         
            +
                            doubleSided=True,
         
     | 
| 693 | 
         
            +
                        ))
         
     | 
| 694 | 
         
            +
             
     | 
| 695 | 
         
            +
                        gltf.textures.append(pygltflib.Texture(sampler=0, source=0))
         
     | 
| 696 | 
         
            +
                        gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
         
     | 
| 697 | 
         
            +
                        gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png"))
         
     | 
| 698 | 
         
            +
             
     | 
| 699 | 
         
            +
                        # update buffers
         
     | 
| 700 | 
         
            +
                        gltf.bufferViews.append(
         
     | 
| 701 | 
         
            +
                            # index = 2, texcoords; as vec2 array
         
     | 
| 702 | 
         
            +
                            pygltflib.BufferView(
         
     | 
| 703 | 
         
            +
                                buffer=0,
         
     | 
| 704 | 
         
            +
                                byteOffset=byteOffset,
         
     | 
| 705 | 
         
            +
                                byteLength=len(vt_np_blob),
         
     | 
| 706 | 
         
            +
                                byteStride=8, # vec2
         
     | 
| 707 | 
         
            +
                                target=pygltflib.ARRAY_BUFFER,
         
     | 
| 708 | 
         
            +
                            )
         
     | 
| 709 | 
         
            +
                        )
         
     | 
| 710 | 
         
            +
             
     | 
| 711 | 
         
            +
                        gltf.accessors.append(
         
     | 
| 712 | 
         
            +
                            # 2 = texcoords
         
     | 
| 713 | 
         
            +
                            pygltflib.Accessor(
         
     | 
| 714 | 
         
            +
                                bufferView=2,
         
     | 
| 715 | 
         
            +
                                componentType=pygltflib.FLOAT,
         
     | 
| 716 | 
         
            +
                                count=len(vt_np),
         
     | 
| 717 | 
         
            +
                                type=pygltflib.VEC2,
         
     | 
| 718 | 
         
            +
                                max=vt_np.max(axis=0).tolist(),
         
     | 
| 719 | 
         
            +
                                min=vt_np.min(axis=0).tolist(),
         
     | 
| 720 | 
         
            +
                            )
         
     | 
| 721 | 
         
            +
                        )
         
     | 
| 722 | 
         
            +
             
     | 
| 723 | 
         
            +
                        blob += vt_np_blob 
         
     | 
| 724 | 
         
            +
                        byteOffset += len(vt_np_blob)
         
     | 
| 725 | 
         
            +
             
     | 
| 726 | 
         
            +
                        gltf.bufferViews.append(
         
     | 
| 727 | 
         
            +
                            # index = 3, albedo texture; as none target
         
     | 
| 728 | 
         
            +
                            pygltflib.BufferView(
         
     | 
| 729 | 
         
            +
                                buffer=0,
         
     | 
| 730 | 
         
            +
                                byteOffset=byteOffset,
         
     | 
| 731 | 
         
            +
                                byteLength=len(albedo_blob),
         
     | 
| 732 | 
         
            +
                            )
         
     | 
| 733 | 
         
            +
                        )
         
     | 
| 734 | 
         
            +
             
     | 
| 735 | 
         
            +
                        blob += albedo_blob
         
     | 
| 736 | 
         
            +
                        byteOffset += len(albedo_blob)
         
     | 
| 737 | 
         
            +
             
     | 
| 738 | 
         
            +
                        gltf.buffers[0].byteLength = byteOffset
         
     | 
| 739 | 
         
            +
             
     | 
| 740 | 
         
            +
                        # append metllic roughness
         
     | 
| 741 | 
         
            +
                        if self.metallicRoughness is not None:
         
     | 
| 742 | 
         
            +
                            metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
         
     | 
| 743 | 
         
            +
                            metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
         
     | 
| 744 | 
         
            +
                            metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR)
         
     | 
| 745 | 
         
            +
                            metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes()
         
     | 
| 746 | 
         
            +
             
     | 
| 747 | 
         
            +
                            # update texture definition
         
     | 
| 748 | 
         
            +
                            gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0
         
     | 
| 749 | 
         
            +
                            gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0
         
     | 
| 750 | 
         
            +
                            gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0)
         
     | 
| 751 | 
         
            +
             
     | 
| 752 | 
         
            +
                            gltf.textures.append(pygltflib.Texture(sampler=1, source=1))
         
     | 
| 753 | 
         
            +
                            gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT))
         
     | 
| 754 | 
         
            +
                            gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png"))
         
     | 
| 755 | 
         
            +
             
     | 
| 756 | 
         
            +
                            # update buffers
         
     | 
| 757 | 
         
            +
                            gltf.bufferViews.append(
         
     | 
| 758 | 
         
            +
                                # index = 4, metallicRoughness texture; as none target
         
     | 
| 759 | 
         
            +
                                pygltflib.BufferView(
         
     | 
| 760 | 
         
            +
                                    buffer=0,
         
     | 
| 761 | 
         
            +
                                    byteOffset=byteOffset,
         
     | 
| 762 | 
         
            +
                                    byteLength=len(metallicRoughness_blob),
         
     | 
| 763 | 
         
            +
                                )
         
     | 
| 764 | 
         
            +
                            )
         
     | 
| 765 | 
         
            +
             
     | 
| 766 | 
         
            +
                            blob += metallicRoughness_blob
         
     | 
| 767 | 
         
            +
                            byteOffset += len(metallicRoughness_blob)
         
     | 
| 768 | 
         
            +
             
     | 
| 769 | 
         
            +
                            gltf.buffers[0].byteLength = byteOffset
         
     | 
| 770 | 
         
            +
             
     | 
| 771 | 
         
            +
                        
         
     | 
| 772 | 
         
            +
                    # set actual data
         
     | 
| 773 | 
         
            +
                    gltf.set_binary_blob(blob)
         
     | 
| 774 | 
         
            +
             
     | 
| 775 | 
         
            +
                    # glb = b"".join(gltf.save_to_bytes())
         
     | 
| 776 | 
         
            +
                    gltf.save(path)
         
     | 
| 777 | 
         
            +
             
     | 
| 778 | 
         
            +
             
     | 
| 779 | 
         
            +
                def write_obj(self, path):
         
     | 
| 780 | 
         
            +
                    """write the mesh in obj format. Will also write the texture and mtl files.
         
     | 
| 781 | 
         
            +
             
     | 
| 782 | 
         
            +
                    Args:
         
     | 
| 783 | 
         
            +
                        path (str): path to write.
         
     | 
| 784 | 
         
            +
                    """
         
     | 
| 785 | 
         
            +
             
     | 
| 786 | 
         
            +
                    mtl_path = path.replace(".obj", ".mtl")
         
     | 
| 787 | 
         
            +
                    albedo_path = path.replace(".obj", "_albedo.png")
         
     | 
| 788 | 
         
            +
                    metallic_path = path.replace(".obj", "_metallic.png")
         
     | 
| 789 | 
         
            +
                    roughness_path = path.replace(".obj", "_roughness.png")
         
     | 
| 790 | 
         
            +
             
     | 
| 791 | 
         
            +
                    v_np = self.v.detach().cpu().numpy()
         
     | 
| 792 | 
         
            +
                    vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
         
     | 
| 793 | 
         
            +
                    vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
         
     | 
| 794 | 
         
            +
                    f_np = self.f.detach().cpu().numpy()
         
     | 
| 795 | 
         
            +
                    ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
         
     | 
| 796 | 
         
            +
                    fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
         
     | 
| 797 | 
         
            +
             
     | 
| 798 | 
         
            +
                    with open(path, "w") as fp:
         
     | 
| 799 | 
         
            +
                        fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
         
     | 
| 800 | 
         
            +
             
     | 
| 801 | 
         
            +
                        for v in v_np:
         
     | 
| 802 | 
         
            +
                            fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
         
     | 
| 803 | 
         
            +
             
     | 
| 804 | 
         
            +
                        if vt_np is not None:
         
     | 
| 805 | 
         
            +
                            for v in vt_np:
         
     | 
| 806 | 
         
            +
                                fp.write(f"vt {v[0]} {1 - v[1]} \n")
         
     | 
| 807 | 
         
            +
             
     | 
| 808 | 
         
            +
                        if vn_np is not None:
         
     | 
| 809 | 
         
            +
                            for v in vn_np:
         
     | 
| 810 | 
         
            +
                                fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
         
     | 
| 811 | 
         
            +
             
     | 
| 812 | 
         
            +
                        fp.write(f"usemtl defaultMat \n")
         
     | 
| 813 | 
         
            +
                        for i in range(len(f_np)):
         
     | 
| 814 | 
         
            +
                            fp.write(
         
     | 
| 815 | 
         
            +
                                f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
         
     | 
| 816 | 
         
            +
                                         {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
         
     | 
| 817 | 
         
            +
                                         {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
         
     | 
| 818 | 
         
            +
                            )
         
     | 
| 819 | 
         
            +
             
     | 
| 820 | 
         
            +
                    with open(mtl_path, "w") as fp:
         
     | 
| 821 | 
         
            +
                        fp.write(f"newmtl defaultMat \n")
         
     | 
| 822 | 
         
            +
                        fp.write(f"Ka 1 1 1 \n")
         
     | 
| 823 | 
         
            +
                        fp.write(f"Kd 1 1 1 \n")
         
     | 
| 824 | 
         
            +
                        fp.write(f"Ks 0 0 0 \n")
         
     | 
| 825 | 
         
            +
                        fp.write(f"Tr 1 \n")
         
     | 
| 826 | 
         
            +
                        fp.write(f"illum 1 \n")
         
     | 
| 827 | 
         
            +
                        fp.write(f"Ns 0 \n")
         
     | 
| 828 | 
         
            +
                        if self.albedo is not None:
         
     | 
| 829 | 
         
            +
                            fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
         
     | 
| 830 | 
         
            +
                        if self.metallicRoughness is not None:
         
     | 
| 831 | 
         
            +
                            # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering
         
     | 
| 832 | 
         
            +
                            fp.write(f"map_Pm {os.path.basename(metallic_path)} \n")
         
     | 
| 833 | 
         
            +
                            fp.write(f"map_Pr {os.path.basename(roughness_path)} \n")
         
     | 
| 834 | 
         
            +
             
     | 
| 835 | 
         
            +
                    if self.albedo is not None:
         
     | 
| 836 | 
         
            +
                        albedo = self.albedo.detach().cpu().numpy()
         
     | 
| 837 | 
         
            +
                        albedo = (albedo * 255).astype(np.uint8)
         
     | 
| 838 | 
         
            +
                        cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
         
     | 
| 839 | 
         
            +
                    
         
     | 
| 840 | 
         
            +
                    if self.metallicRoughness is not None:
         
     | 
| 841 | 
         
            +
                        metallicRoughness = self.metallicRoughness.detach().cpu().numpy()
         
     | 
| 842 | 
         
            +
                        metallicRoughness = (metallicRoughness * 255).astype(np.uint8)
         
     | 
| 843 | 
         
            +
                        cv2.imwrite(metallic_path, metallicRoughness[..., 2])
         
     | 
| 844 | 
         
            +
                        cv2.imwrite(roughness_path, metallicRoughness[..., 1])
         
     | 
| 845 | 
         
            +
             
     | 
    	
        mesh_recon/models/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,16 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            models = {}
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            def register(name):
         
     | 
| 5 | 
         
            +
                def decorator(cls):
         
     | 
| 6 | 
         
            +
                    models[name] = cls
         
     | 
| 7 | 
         
            +
                    return cls
         
     | 
| 8 | 
         
            +
                return decorator
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def make(name, config):
         
     | 
| 12 | 
         
            +
                model = models[name](config)
         
     | 
| 13 | 
         
            +
                return model
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from . import nerf, neus, geometry, texture
         
     | 
    	
        mesh_recon/models/base.py
    ADDED
    
    | 
         @@ -0,0 +1,32 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            class BaseModel(nn.Module):
         
     | 
| 7 | 
         
            +
                def __init__(self, config):
         
     | 
| 8 | 
         
            +
                    super().__init__()
         
     | 
| 9 | 
         
            +
                    self.config = config
         
     | 
| 10 | 
         
            +
                    self.rank = get_rank()
         
     | 
| 11 | 
         
            +
                    self.setup()
         
     | 
| 12 | 
         
            +
                    if self.config.get('weights', None):
         
     | 
| 13 | 
         
            +
                        self.load_state_dict(torch.load(self.config.weights))
         
     | 
| 14 | 
         
            +
                
         
     | 
| 15 | 
         
            +
                def setup(self):
         
     | 
| 16 | 
         
            +
                    raise NotImplementedError
         
     | 
| 17 | 
         
            +
                
         
     | 
| 18 | 
         
            +
                def update_step(self, epoch, global_step):
         
     | 
| 19 | 
         
            +
                    pass
         
     | 
| 20 | 
         
            +
                
         
     | 
| 21 | 
         
            +
                def train(self, mode=True):
         
     | 
| 22 | 
         
            +
                    return super().train(mode=mode)
         
     | 
| 23 | 
         
            +
                
         
     | 
| 24 | 
         
            +
                def eval(self):
         
     | 
| 25 | 
         
            +
                    return super().eval()
         
     | 
| 26 | 
         
            +
                
         
     | 
| 27 | 
         
            +
                def regularizations(self, out):
         
     | 
| 28 | 
         
            +
                    return {}
         
     | 
| 29 | 
         
            +
                
         
     | 
| 30 | 
         
            +
                @torch.no_grad()
         
     | 
| 31 | 
         
            +
                def export(self, export_config):
         
     | 
| 32 | 
         
            +
                    return {}
         
     | 
    	
        mesh_recon/models/geometry.py
    ADDED
    
    | 
         @@ -0,0 +1,238 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from pytorch_lightning.utilities.rank_zero import rank_zero_info
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            import models
         
     | 
| 9 | 
         
            +
            from models.base import BaseModel
         
     | 
| 10 | 
         
            +
            from models.utils import scale_anything, get_activation, cleanup, chunk_batch
         
     | 
| 11 | 
         
            +
            from models.network_utils import get_encoding, get_mlp, get_encoding_with_network
         
     | 
| 12 | 
         
            +
            from utils.misc import get_rank
         
     | 
| 13 | 
         
            +
            from systems.utils import update_module_step
         
     | 
| 14 | 
         
            +
            from nerfacc import ContractionType
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def contract_to_unisphere(x, radius, contraction_type):
         
     | 
| 18 | 
         
            +
                if contraction_type == ContractionType.AABB:
         
     | 
| 19 | 
         
            +
                    x = scale_anything(x, (-radius, radius), (0, 1))
         
     | 
| 20 | 
         
            +
                elif contraction_type == ContractionType.UN_BOUNDED_SPHERE:
         
     | 
| 21 | 
         
            +
                    x = scale_anything(x, (-radius, radius), (0, 1))
         
     | 
| 22 | 
         
            +
                    x = x * 2 - 1  # aabb is at [-1, 1]
         
     | 
| 23 | 
         
            +
                    mag = x.norm(dim=-1, keepdim=True)
         
     | 
| 24 | 
         
            +
                    mask = mag.squeeze(-1) > 1
         
     | 
| 25 | 
         
            +
                    x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask])
         
     | 
| 26 | 
         
            +
                    x = x / 4 + 0.5  # [-inf, inf] is at [0, 1]
         
     | 
| 27 | 
         
            +
                else:
         
     | 
| 28 | 
         
            +
                    raise NotImplementedError
         
     | 
| 29 | 
         
            +
                return x
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            class MarchingCubeHelper(nn.Module):
         
     | 
| 33 | 
         
            +
                def __init__(self, resolution, use_torch=True):
         
     | 
| 34 | 
         
            +
                    super().__init__()
         
     | 
| 35 | 
         
            +
                    self.resolution = resolution
         
     | 
| 36 | 
         
            +
                    self.use_torch = use_torch
         
     | 
| 37 | 
         
            +
                    self.points_range = (0, 1)
         
     | 
| 38 | 
         
            +
                    if self.use_torch:
         
     | 
| 39 | 
         
            +
                        import torchmcubes
         
     | 
| 40 | 
         
            +
                        self.mc_func = torchmcubes.marching_cubes
         
     | 
| 41 | 
         
            +
                    else:
         
     | 
| 42 | 
         
            +
                        import mcubes
         
     | 
| 43 | 
         
            +
                        self.mc_func = mcubes.marching_cubes
         
     | 
| 44 | 
         
            +
                    self.verts = None
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def grid_vertices(self):
         
     | 
| 47 | 
         
            +
                    if self.verts is None:
         
     | 
| 48 | 
         
            +
                        x, y, z = torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution)
         
     | 
| 49 | 
         
            +
                        x, y, z = torch.meshgrid(x, y, z, indexing='ij')
         
     | 
| 50 | 
         
            +
                        verts = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1).reshape(-1, 3)
         
     | 
| 51 | 
         
            +
                        self.verts = verts
         
     | 
| 52 | 
         
            +
                    return self.verts
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                def forward(self, level, threshold=0.):
         
     | 
| 55 | 
         
            +
                    level = level.float().view(self.resolution, self.resolution, self.resolution)
         
     | 
| 56 | 
         
            +
                    if self.use_torch:
         
     | 
| 57 | 
         
            +
                        verts, faces = self.mc_func(level.to(get_rank()), threshold)
         
     | 
| 58 | 
         
            +
                        verts, faces = verts.cpu(), faces.cpu().long()
         
     | 
| 59 | 
         
            +
                    else:
         
     | 
| 60 | 
         
            +
                        verts, faces = self.mc_func(-level.numpy(), threshold) # transform to numpy
         
     | 
| 61 | 
         
            +
                        verts, faces = torch.from_numpy(verts.astype(np.float32)), torch.from_numpy(faces.astype(np.int64)) # transform back to pytorch
         
     | 
| 62 | 
         
            +
                    verts = verts / (self.resolution - 1.)
         
     | 
| 63 | 
         
            +
                    return {
         
     | 
| 64 | 
         
            +
                        'v_pos': verts,
         
     | 
| 65 | 
         
            +
                        't_pos_idx': faces
         
     | 
| 66 | 
         
            +
                    }
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            class BaseImplicitGeometry(BaseModel):
         
     | 
| 70 | 
         
            +
                def __init__(self, config):
         
     | 
| 71 | 
         
            +
                    super().__init__(config)
         
     | 
| 72 | 
         
            +
                    if self.config.isosurface is not None:
         
     | 
| 73 | 
         
            +
                        assert self.config.isosurface.method in ['mc', 'mc-torch']
         
     | 
| 74 | 
         
            +
                        if self.config.isosurface.method == 'mc-torch':
         
     | 
| 75 | 
         
            +
                            raise NotImplementedError("Please do not use mc-torch. It currently has some scaling issues I haven't fixed yet.")
         
     | 
| 76 | 
         
            +
                        self.helper = MarchingCubeHelper(self.config.isosurface.resolution, use_torch=self.config.isosurface.method=='mc-torch')
         
     | 
| 77 | 
         
            +
                    self.radius = self.config.radius
         
     | 
| 78 | 
         
            +
                    self.contraction_type = None # assigned in system
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                def forward_level(self, points):
         
     | 
| 81 | 
         
            +
                    raise NotImplementedError
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                def isosurface_(self, vmin, vmax):
         
     | 
| 84 | 
         
            +
                    def batch_func(x):
         
     | 
| 85 | 
         
            +
                        x = torch.stack([
         
     | 
| 86 | 
         
            +
                            scale_anything(x[...,0], (0, 1), (vmin[0], vmax[0])),
         
     | 
| 87 | 
         
            +
                            scale_anything(x[...,1], (0, 1), (vmin[1], vmax[1])),
         
     | 
| 88 | 
         
            +
                            scale_anything(x[...,2], (0, 1), (vmin[2], vmax[2])),
         
     | 
| 89 | 
         
            +
                        ], dim=-1).to(self.rank)
         
     | 
| 90 | 
         
            +
                        rv = self.forward_level(x).cpu()
         
     | 
| 91 | 
         
            +
                        cleanup()
         
     | 
| 92 | 
         
            +
                        return rv
         
     | 
| 93 | 
         
            +
                
         
     | 
| 94 | 
         
            +
                    level = chunk_batch(batch_func, self.config.isosurface.chunk, True, self.helper.grid_vertices())
         
     | 
| 95 | 
         
            +
                    mesh = self.helper(level, threshold=self.config.isosurface.threshold)
         
     | 
| 96 | 
         
            +
                    mesh['v_pos'] = torch.stack([
         
     | 
| 97 | 
         
            +
                        scale_anything(mesh['v_pos'][...,0], (0, 1), (vmin[0], vmax[0])),
         
     | 
| 98 | 
         
            +
                        scale_anything(mesh['v_pos'][...,1], (0, 1), (vmin[1], vmax[1])),
         
     | 
| 99 | 
         
            +
                        scale_anything(mesh['v_pos'][...,2], (0, 1), (vmin[2], vmax[2]))
         
     | 
| 100 | 
         
            +
                    ], dim=-1)
         
     | 
| 101 | 
         
            +
                    return mesh
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                @torch.no_grad()
         
     | 
| 104 | 
         
            +
                def isosurface(self):
         
     | 
| 105 | 
         
            +
                    if self.config.isosurface is None:
         
     | 
| 106 | 
         
            +
                        raise NotImplementedError
         
     | 
| 107 | 
         
            +
                    mesh_coarse = self.isosurface_((-self.radius, -self.radius, -self.radius), (self.radius, self.radius, self.radius))
         
     | 
| 108 | 
         
            +
                    vmin, vmax = mesh_coarse['v_pos'].amin(dim=0), mesh_coarse['v_pos'].amax(dim=0)
         
     | 
| 109 | 
         
            +
                    vmin_ = (vmin - (vmax - vmin) * 0.1).clamp(-self.radius, self.radius)
         
     | 
| 110 | 
         
            +
                    vmax_ = (vmax + (vmax - vmin) * 0.1).clamp(-self.radius, self.radius)
         
     | 
| 111 | 
         
            +
                    mesh_fine = self.isosurface_(vmin_, vmax_)
         
     | 
| 112 | 
         
            +
                    return mesh_fine 
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            @models.register('volume-density')
         
     | 
| 116 | 
         
            +
            class VolumeDensity(BaseImplicitGeometry):
         
     | 
| 117 | 
         
            +
                def setup(self):
         
     | 
| 118 | 
         
            +
                    self.n_input_dims = self.config.get('n_input_dims', 3)
         
     | 
| 119 | 
         
            +
                    self.n_output_dims = self.config.feature_dim
         
     | 
| 120 | 
         
            +
                    self.encoding_with_network = get_encoding_with_network(self.n_input_dims, self.n_output_dims, self.config.xyz_encoding_config, self.config.mlp_network_config)
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def forward(self, points):
         
     | 
| 123 | 
         
            +
                    points = contract_to_unisphere(points, self.radius, self.contraction_type)
         
     | 
| 124 | 
         
            +
                    out = self.encoding_with_network(points.view(-1, self.n_input_dims)).view(*points.shape[:-1], self.n_output_dims).float()
         
     | 
| 125 | 
         
            +
                    density, feature = out[...,0], out
         
     | 
| 126 | 
         
            +
                    if 'density_activation' in self.config:
         
     | 
| 127 | 
         
            +
                        density = get_activation(self.config.density_activation)(density + float(self.config.density_bias))
         
     | 
| 128 | 
         
            +
                    if 'feature_activation' in self.config:
         
     | 
| 129 | 
         
            +
                        feature = get_activation(self.config.feature_activation)(feature)
         
     | 
| 130 | 
         
            +
                    return density, feature
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                def forward_level(self, points):
         
     | 
| 133 | 
         
            +
                    points = contract_to_unisphere(points, self.radius, self.contraction_type)
         
     | 
| 134 | 
         
            +
                    density = self.encoding_with_network(points.reshape(-1, self.n_input_dims)).reshape(*points.shape[:-1], self.n_output_dims)[...,0]
         
     | 
| 135 | 
         
            +
                    if 'density_activation' in self.config:
         
     | 
| 136 | 
         
            +
                        density = get_activation(self.config.density_activation)(density + float(self.config.density_bias))
         
     | 
| 137 | 
         
            +
                    return -density      
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                def update_step(self, epoch, global_step):
         
     | 
| 140 | 
         
            +
                    update_module_step(self.encoding_with_network, epoch, global_step)
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
            @models.register('volume-sdf')
         
     | 
| 144 | 
         
            +
            class VolumeSDF(BaseImplicitGeometry):
         
     | 
| 145 | 
         
            +
                def setup(self):
         
     | 
| 146 | 
         
            +
                    self.n_output_dims = self.config.feature_dim
         
     | 
| 147 | 
         
            +
                    encoding = get_encoding(3, self.config.xyz_encoding_config)
         
     | 
| 148 | 
         
            +
                    network = get_mlp(encoding.n_output_dims, self.n_output_dims, self.config.mlp_network_config)
         
     | 
| 149 | 
         
            +
                    self.encoding, self.network = encoding, network
         
     | 
| 150 | 
         
            +
                    self.grad_type = self.config.grad_type
         
     | 
| 151 | 
         
            +
                    self.finite_difference_eps = self.config.get('finite_difference_eps', 1e-3)
         
     | 
| 152 | 
         
            +
                    # the actual value used in training
         
     | 
| 153 | 
         
            +
                    # will update at certain steps if finite_difference_eps="progressive"
         
     | 
| 154 | 
         
            +
                    self._finite_difference_eps = None
         
     | 
| 155 | 
         
            +
                    if self.grad_type == 'finite_difference':
         
     | 
| 156 | 
         
            +
                        rank_zero_info(f"Using finite difference to compute gradients with eps={self.finite_difference_eps}")
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                def forward(self, points, with_grad=True, with_feature=True, with_laplace=False):
         
     | 
| 159 | 
         
            +
                    with torch.inference_mode(torch.is_inference_mode_enabled() and not (with_grad and self.grad_type == 'analytic')):
         
     | 
| 160 | 
         
            +
                        with torch.set_grad_enabled(self.training or (with_grad and self.grad_type == 'analytic')):
         
     | 
| 161 | 
         
            +
                            if with_grad and self.grad_type == 'analytic':
         
     | 
| 162 | 
         
            +
                                if not self.training:
         
     | 
| 163 | 
         
            +
                                    points = points.clone() # points may be in inference mode, get a copy to enable grad
         
     | 
| 164 | 
         
            +
                                points.requires_grad_(True)
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                            points_ = points # points in the original scale
         
     | 
| 167 | 
         
            +
                            points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1)
         
     | 
| 168 | 
         
            +
                            
         
     | 
| 169 | 
         
            +
                            out = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims).float()
         
     | 
| 170 | 
         
            +
                            sdf, feature = out[...,0], out
         
     | 
| 171 | 
         
            +
                            if 'sdf_activation' in self.config:
         
     | 
| 172 | 
         
            +
                                sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias))
         
     | 
| 173 | 
         
            +
                            if 'feature_activation' in self.config:
         
     | 
| 174 | 
         
            +
                                feature = get_activation(self.config.feature_activation)(feature)
         
     | 
| 175 | 
         
            +
                            if with_grad:
         
     | 
| 176 | 
         
            +
                                if self.grad_type == 'analytic':
         
     | 
| 177 | 
         
            +
                                    grad = torch.autograd.grad(
         
     | 
| 178 | 
         
            +
                                        sdf, points_, grad_outputs=torch.ones_like(sdf),
         
     | 
| 179 | 
         
            +
                                        create_graph=True, retain_graph=True, only_inputs=True
         
     | 
| 180 | 
         
            +
                                    )[0]
         
     | 
| 181 | 
         
            +
                                elif self.grad_type == 'finite_difference':
         
     | 
| 182 | 
         
            +
                                    eps = self._finite_difference_eps
         
     | 
| 183 | 
         
            +
                                    offsets = torch.as_tensor(
         
     | 
| 184 | 
         
            +
                                        [
         
     | 
| 185 | 
         
            +
                                            [eps, 0.0, 0.0],
         
     | 
| 186 | 
         
            +
                                            [-eps, 0.0, 0.0],
         
     | 
| 187 | 
         
            +
                                            [0.0, eps, 0.0],
         
     | 
| 188 | 
         
            +
                                            [0.0, -eps, 0.0],
         
     | 
| 189 | 
         
            +
                                            [0.0, 0.0, eps],
         
     | 
| 190 | 
         
            +
                                            [0.0, 0.0, -eps],
         
     | 
| 191 | 
         
            +
                                        ]
         
     | 
| 192 | 
         
            +
                                    ).to(points_)
         
     | 
| 193 | 
         
            +
                                    points_d_ = (points_[...,None,:] + offsets).clamp(-self.radius, self.radius)
         
     | 
| 194 | 
         
            +
                                    points_d = scale_anything(points_d_, (-self.radius, self.radius), (0, 1))
         
     | 
| 195 | 
         
            +
                                    points_d_sdf = self.network(self.encoding(points_d.view(-1, 3)))[...,0].view(*points.shape[:-1], 6).float()
         
     | 
| 196 | 
         
            +
                                    grad = 0.5 * (points_d_sdf[..., 0::2] - points_d_sdf[..., 1::2]) / eps  
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                                    if with_laplace:
         
     | 
| 199 | 
         
            +
                                        laplace = (points_d_sdf[..., 0::2] + points_d_sdf[..., 1::2] - 2 * sdf[..., None]).sum(-1) / (eps ** 2)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    rv = [sdf]
         
     | 
| 202 | 
         
            +
                    if with_grad:
         
     | 
| 203 | 
         
            +
                        rv.append(grad)
         
     | 
| 204 | 
         
            +
                    if with_feature:
         
     | 
| 205 | 
         
            +
                        rv.append(feature)
         
     | 
| 206 | 
         
            +
                    if with_laplace:
         
     | 
| 207 | 
         
            +
                        assert self.config.grad_type == 'finite_difference', "Laplace computation is only supported with grad_type='finite_difference'"
         
     | 
| 208 | 
         
            +
                        rv.append(laplace)
         
     | 
| 209 | 
         
            +
                    rv = [v if self.training else v.detach() for v in rv]
         
     | 
| 210 | 
         
            +
                    return rv[0] if len(rv) == 1 else rv
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                def forward_level(self, points):
         
     | 
| 213 | 
         
            +
                    points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1)
         
     | 
| 214 | 
         
            +
                    sdf = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims)[...,0]
         
     | 
| 215 | 
         
            +
                    if 'sdf_activation' in self.config:
         
     | 
| 216 | 
         
            +
                        sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias))
         
     | 
| 217 | 
         
            +
                    return sdf
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                def update_step(self, epoch, global_step):
         
     | 
| 220 | 
         
            +
                    update_module_step(self.encoding, epoch, global_step)    
         
     | 
| 221 | 
         
            +
                    update_module_step(self.network, epoch, global_step)  
         
     | 
| 222 | 
         
            +
                    if self.grad_type == 'finite_difference':
         
     | 
| 223 | 
         
            +
                        if isinstance(self.finite_difference_eps, float):
         
     | 
| 224 | 
         
            +
                            self._finite_difference_eps = self.finite_difference_eps
         
     | 
| 225 | 
         
            +
                        elif self.finite_difference_eps == 'progressive':
         
     | 
| 226 | 
         
            +
                            hg_conf = self.config.xyz_encoding_config
         
     | 
| 227 | 
         
            +
                            assert hg_conf.otype == "ProgressiveBandHashGrid", "finite_difference_eps='progressive' only works with ProgressiveBandHashGrid"
         
     | 
| 228 | 
         
            +
                            current_level = min(
         
     | 
| 229 | 
         
            +
                                hg_conf.start_level + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
         
     | 
| 230 | 
         
            +
                                hg_conf.n_levels
         
     | 
| 231 | 
         
            +
                            )
         
     | 
| 232 | 
         
            +
                            grid_res = hg_conf.base_resolution * hg_conf.per_level_scale**(current_level - 1)
         
     | 
| 233 | 
         
            +
                            grid_size = 2 * self.config.radius / grid_res
         
     | 
| 234 | 
         
            +
                            if grid_size != self._finite_difference_eps:
         
     | 
| 235 | 
         
            +
                                rank_zero_info(f"Update finite_difference_eps to {grid_size}")
         
     | 
| 236 | 
         
            +
                            self._finite_difference_eps = grid_size
         
     | 
| 237 | 
         
            +
                        else:
         
     | 
| 238 | 
         
            +
                            raise ValueError(f"Unknown finite_difference_eps={self.finite_difference_eps}")
         
     | 
    	
        mesh_recon/models/nerf.py
    ADDED
    
    | 
         @@ -0,0 +1,161 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn as nn
         
     | 
| 5 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import models
         
     | 
| 8 | 
         
            +
            from models.base import BaseModel
         
     | 
| 9 | 
         
            +
            from models.utils import chunk_batch
         
     | 
| 10 | 
         
            +
            from systems.utils import update_module_step
         
     | 
| 11 | 
         
            +
            from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, accumulate_along_rays
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            @models.register('nerf')
         
     | 
| 15 | 
         
            +
            class NeRFModel(BaseModel):
         
     | 
| 16 | 
         
            +
                def setup(self):
         
     | 
| 17 | 
         
            +
                    self.geometry = models.make(self.config.geometry.name, self.config.geometry)
         
     | 
| 18 | 
         
            +
                    self.texture = models.make(self.config.texture.name, self.config.texture)
         
     | 
| 19 | 
         
            +
                    self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32))
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
                    if self.config.learned_background:
         
     | 
| 22 | 
         
            +
                        self.occupancy_grid_res = 256
         
     | 
| 23 | 
         
            +
                        self.near_plane, self.far_plane = 0.2, 1e4
         
     | 
| 24 | 
         
            +
                        self.cone_angle = 10**(math.log10(self.far_plane) / self.config.num_samples_per_ray) - 1. # approximate
         
     | 
| 25 | 
         
            +
                        self.render_step_size = 0.01 # render_step_size = max(distance_to_camera * self.cone_angle, self.render_step_size)
         
     | 
| 26 | 
         
            +
                        self.contraction_type = ContractionType.UN_BOUNDED_SPHERE
         
     | 
| 27 | 
         
            +
                    else:
         
     | 
| 28 | 
         
            +
                        self.occupancy_grid_res = 128
         
     | 
| 29 | 
         
            +
                        self.near_plane, self.far_plane = None, None
         
     | 
| 30 | 
         
            +
                        self.cone_angle = 0.0
         
     | 
| 31 | 
         
            +
                        self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray
         
     | 
| 32 | 
         
            +
                        self.contraction_type = ContractionType.AABB
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    self.geometry.contraction_type = self.contraction_type
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                    if self.config.grid_prune:
         
     | 
| 37 | 
         
            +
                        self.occupancy_grid = OccupancyGrid(
         
     | 
| 38 | 
         
            +
                            roi_aabb=self.scene_aabb,
         
     | 
| 39 | 
         
            +
                            resolution=self.occupancy_grid_res,
         
     | 
| 40 | 
         
            +
                            contraction_type=self.contraction_type
         
     | 
| 41 | 
         
            +
                        )
         
     | 
| 42 | 
         
            +
                    self.randomized = self.config.randomized
         
     | 
| 43 | 
         
            +
                    self.background_color = None
         
     | 
| 44 | 
         
            +
                
         
     | 
| 45 | 
         
            +
                def update_step(self, epoch, global_step):
         
     | 
| 46 | 
         
            +
                    update_module_step(self.geometry, epoch, global_step)
         
     | 
| 47 | 
         
            +
                    update_module_step(self.texture, epoch, global_step)
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                    def occ_eval_fn(x):
         
     | 
| 50 | 
         
            +
                        density, _ = self.geometry(x)
         
     | 
| 51 | 
         
            +
                        # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size) based on taylor series
         
     | 
| 52 | 
         
            +
                        return density[...,None] * self.render_step_size
         
     | 
| 53 | 
         
            +
                    
         
     | 
| 54 | 
         
            +
                    if self.training and self.config.grid_prune:
         
     | 
| 55 | 
         
            +
                        self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                def isosurface(self):
         
     | 
| 58 | 
         
            +
                    mesh = self.geometry.isosurface()
         
     | 
| 59 | 
         
            +
                    return mesh
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                def forward_(self, rays):
         
     | 
| 62 | 
         
            +
                    n_rays = rays.shape[0]
         
     | 
| 63 | 
         
            +
                    rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    def sigma_fn(t_starts, t_ends, ray_indices):
         
     | 
| 66 | 
         
            +
                        ray_indices = ray_indices.long()
         
     | 
| 67 | 
         
            +
                        t_origins = rays_o[ray_indices]
         
     | 
| 68 | 
         
            +
                        t_dirs = rays_d[ray_indices]
         
     | 
| 69 | 
         
            +
                        positions = t_origins + t_dirs * (t_starts + t_ends) / 2.
         
     | 
| 70 | 
         
            +
                        density, _ = self.geometry(positions)
         
     | 
| 71 | 
         
            +
                        return density[...,None]
         
     | 
| 72 | 
         
            +
                    
         
     | 
| 73 | 
         
            +
                    def rgb_sigma_fn(t_starts, t_ends, ray_indices):
         
     | 
| 74 | 
         
            +
                        ray_indices = ray_indices.long()
         
     | 
| 75 | 
         
            +
                        t_origins = rays_o[ray_indices]
         
     | 
| 76 | 
         
            +
                        t_dirs = rays_d[ray_indices]
         
     | 
| 77 | 
         
            +
                        positions = t_origins + t_dirs * (t_starts + t_ends) / 2.
         
     | 
| 78 | 
         
            +
                        density, feature = self.geometry(positions) 
         
     | 
| 79 | 
         
            +
                        rgb = self.texture(feature, t_dirs)
         
     | 
| 80 | 
         
            +
                        return rgb, density[...,None]
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    with torch.no_grad():
         
     | 
| 83 | 
         
            +
                        ray_indices, t_starts, t_ends = ray_marching(
         
     | 
| 84 | 
         
            +
                            rays_o, rays_d,
         
     | 
| 85 | 
         
            +
                            scene_aabb=None if self.config.learned_background else self.scene_aabb,
         
     | 
| 86 | 
         
            +
                            grid=self.occupancy_grid if self.config.grid_prune else None,
         
     | 
| 87 | 
         
            +
                            sigma_fn=sigma_fn,
         
     | 
| 88 | 
         
            +
                            near_plane=self.near_plane, far_plane=self.far_plane,
         
     | 
| 89 | 
         
            +
                            render_step_size=self.render_step_size,
         
     | 
| 90 | 
         
            +
                            stratified=self.randomized,
         
     | 
| 91 | 
         
            +
                            cone_angle=self.cone_angle,
         
     | 
| 92 | 
         
            +
                            alpha_thre=0.0
         
     | 
| 93 | 
         
            +
                        )   
         
     | 
| 94 | 
         
            +
                    
         
     | 
| 95 | 
         
            +
                    ray_indices = ray_indices.long()
         
     | 
| 96 | 
         
            +
                    t_origins = rays_o[ray_indices]
         
     | 
| 97 | 
         
            +
                    t_dirs = rays_d[ray_indices]
         
     | 
| 98 | 
         
            +
                    midpoints = (t_starts + t_ends) / 2.
         
     | 
| 99 | 
         
            +
                    positions = t_origins + t_dirs * midpoints  
         
     | 
| 100 | 
         
            +
                    intervals = t_ends - t_starts
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    density, feature = self.geometry(positions) 
         
     | 
| 103 | 
         
            +
                    rgb = self.texture(feature, t_dirs)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                    weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays)
         
     | 
| 106 | 
         
            +
                    opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays)
         
     | 
| 107 | 
         
            +
                    depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays)
         
     | 
| 108 | 
         
            +
                    comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays)
         
     | 
| 109 | 
         
            +
                    comp_rgb = comp_rgb + self.background_color * (1.0 - opacity)       
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    out = {
         
     | 
| 112 | 
         
            +
                        'comp_rgb': comp_rgb,
         
     | 
| 113 | 
         
            +
                        'opacity': opacity,
         
     | 
| 114 | 
         
            +
                        'depth': depth,
         
     | 
| 115 | 
         
            +
                        'rays_valid': opacity > 0,
         
     | 
| 116 | 
         
            +
                        'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device)
         
     | 
| 117 | 
         
            +
                    }
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    if self.training:
         
     | 
| 120 | 
         
            +
                        out.update({
         
     | 
| 121 | 
         
            +
                            'weights': weights.view(-1),
         
     | 
| 122 | 
         
            +
                            'points': midpoints.view(-1),
         
     | 
| 123 | 
         
            +
                            'intervals': intervals.view(-1),
         
     | 
| 124 | 
         
            +
                            'ray_indices': ray_indices.view(-1)
         
     | 
| 125 | 
         
            +
                        })
         
     | 
| 126 | 
         
            +
                    
         
     | 
| 127 | 
         
            +
                    return out
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                def forward(self, rays):
         
     | 
| 130 | 
         
            +
                    if self.training:
         
     | 
| 131 | 
         
            +
                        out = self.forward_(rays)
         
     | 
| 132 | 
         
            +
                    else:
         
     | 
| 133 | 
         
            +
                        out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays)
         
     | 
| 134 | 
         
            +
                    return {
         
     | 
| 135 | 
         
            +
                        **out,
         
     | 
| 136 | 
         
            +
                    }
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                def train(self, mode=True):
         
     | 
| 139 | 
         
            +
                    self.randomized = mode and self.config.randomized
         
     | 
| 140 | 
         
            +
                    return super().train(mode=mode)
         
     | 
| 141 | 
         
            +
                
         
     | 
| 142 | 
         
            +
                def eval(self):
         
     | 
| 143 | 
         
            +
                    self.randomized = False
         
     | 
| 144 | 
         
            +
                    return super().eval()
         
     | 
| 145 | 
         
            +
                
         
     | 
| 146 | 
         
            +
                def regularizations(self, out):
         
     | 
| 147 | 
         
            +
                    losses = {}
         
     | 
| 148 | 
         
            +
                    losses.update(self.geometry.regularizations(out))
         
     | 
| 149 | 
         
            +
                    losses.update(self.texture.regularizations(out))
         
     | 
| 150 | 
         
            +
                    return losses
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
                @torch.no_grad()
         
     | 
| 153 | 
         
            +
                def export(self, export_config):
         
     | 
| 154 | 
         
            +
                    mesh = self.isosurface()
         
     | 
| 155 | 
         
            +
                    if export_config.export_vertex_color:
         
     | 
| 156 | 
         
            +
                        _, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank))
         
     | 
| 157 | 
         
            +
                        viewdirs = torch.zeros(feature.shape[0], 3).to(feature)
         
     | 
| 158 | 
         
            +
                        viewdirs[...,2] = -1. # set the viewing directions to be -z (looking down)
         
     | 
| 159 | 
         
            +
                        rgb = self.texture(feature, viewdirs).clamp(0,1)
         
     | 
| 160 | 
         
            +
                        mesh['v_rgb'] = rgb.cpu()
         
     | 
| 161 | 
         
            +
                    return mesh
         
     | 
    	
        mesh_recon/models/network_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,215 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            import tinycudann as tcnn
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from utils.misc import config_to_primitive, get_rank
         
     | 
| 11 | 
         
            +
            from models.utils import get_activation
         
     | 
| 12 | 
         
            +
            from systems.utils import update_module_step
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            class VanillaFrequency(nn.Module):
         
     | 
| 15 | 
         
            +
                def __init__(self, in_channels, config):
         
     | 
| 16 | 
         
            +
                    super().__init__()
         
     | 
| 17 | 
         
            +
                    self.N_freqs = config['n_frequencies']
         
     | 
| 18 | 
         
            +
                    self.in_channels, self.n_input_dims = in_channels, in_channels
         
     | 
| 19 | 
         
            +
                    self.funcs = [torch.sin, torch.cos]
         
     | 
| 20 | 
         
            +
                    self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs)
         
     | 
| 21 | 
         
            +
                    self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs)
         
     | 
| 22 | 
         
            +
                    self.n_masking_step = config.get('n_masking_step', 0)
         
     | 
| 23 | 
         
            +
                    self.update_step(None, None) # mask should be updated at the beginning each step
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def forward(self, x):
         
     | 
| 26 | 
         
            +
                    out = []
         
     | 
| 27 | 
         
            +
                    for freq, mask in zip(self.freq_bands, self.mask):
         
     | 
| 28 | 
         
            +
                        for func in self.funcs:
         
     | 
| 29 | 
         
            +
                            out += [func(freq*x) * mask]                
         
     | 
| 30 | 
         
            +
                    return torch.cat(out, -1)          
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def update_step(self, epoch, global_step):
         
     | 
| 33 | 
         
            +
                    if self.n_masking_step <= 0 or global_step is None:
         
     | 
| 34 | 
         
            +
                        self.mask = torch.ones(self.N_freqs, dtype=torch.float32)
         
     | 
| 35 | 
         
            +
                    else:
         
     | 
| 36 | 
         
            +
                        self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2.
         
     | 
| 37 | 
         
            +
                        rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}')
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            class ProgressiveBandHashGrid(nn.Module):
         
     | 
| 41 | 
         
            +
                def __init__(self, in_channels, config):
         
     | 
| 42 | 
         
            +
                    super().__init__()
         
     | 
| 43 | 
         
            +
                    self.n_input_dims = in_channels
         
     | 
| 44 | 
         
            +
                    encoding_config = config.copy()
         
     | 
| 45 | 
         
            +
                    encoding_config['otype'] = 'HashGrid'
         
     | 
| 46 | 
         
            +
                    with torch.cuda.device(get_rank()):
         
     | 
| 47 | 
         
            +
                        self.encoding = tcnn.Encoding(in_channels, encoding_config)
         
     | 
| 48 | 
         
            +
                    self.n_output_dims = self.encoding.n_output_dims
         
     | 
| 49 | 
         
            +
                    self.n_level = config['n_levels']
         
     | 
| 50 | 
         
            +
                    self.n_features_per_level = config['n_features_per_level']
         
     | 
| 51 | 
         
            +
                    self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps']
         
     | 
| 52 | 
         
            +
                    self.current_level = self.start_level
         
     | 
| 53 | 
         
            +
                    self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank())
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
                def forward(self, x):
         
     | 
| 56 | 
         
            +
                    enc = self.encoding(x)
         
     | 
| 57 | 
         
            +
                    enc = enc * self.mask
         
     | 
| 58 | 
         
            +
                    return enc
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def update_step(self, epoch, global_step):
         
     | 
| 61 | 
         
            +
                    current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level)
         
     | 
| 62 | 
         
            +
                    if current_level > self.current_level:
         
     | 
| 63 | 
         
            +
                        rank_zero_info(f'Update grid level to {current_level}')
         
     | 
| 64 | 
         
            +
                    self.current_level = current_level
         
     | 
| 65 | 
         
            +
                    self.mask[:self.current_level * self.n_features_per_level] = 1.
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            class CompositeEncoding(nn.Module):
         
     | 
| 69 | 
         
            +
                def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.):
         
     | 
| 70 | 
         
            +
                    super(CompositeEncoding, self).__init__()
         
     | 
| 71 | 
         
            +
                    self.encoding = encoding
         
     | 
| 72 | 
         
            +
                    self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset
         
     | 
| 73 | 
         
            +
                    self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims
         
     | 
| 74 | 
         
            +
                
         
     | 
| 75 | 
         
            +
                def forward(self, x, *args):
         
     | 
| 76 | 
         
            +
                    return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def update_step(self, epoch, global_step):
         
     | 
| 79 | 
         
            +
                    update_module_step(self.encoding, epoch, global_step)
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
            def get_encoding(n_input_dims, config):
         
     | 
| 83 | 
         
            +
                # input suppose to be range [0, 1]
         
     | 
| 84 | 
         
            +
                if config.otype == 'VanillaFrequency':
         
     | 
| 85 | 
         
            +
                    encoding = VanillaFrequency(n_input_dims, config_to_primitive(config))
         
     | 
| 86 | 
         
            +
                elif config.otype == 'ProgressiveBandHashGrid':
         
     | 
| 87 | 
         
            +
                    encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config))
         
     | 
| 88 | 
         
            +
                else:
         
     | 
| 89 | 
         
            +
                    with torch.cuda.device(get_rank()):
         
     | 
| 90 | 
         
            +
                        encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config))
         
     | 
| 91 | 
         
            +
                encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.)
         
     | 
| 92 | 
         
            +
                return encoding
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            class VanillaMLP(nn.Module):
         
     | 
| 96 | 
         
            +
                def __init__(self, dim_in, dim_out, config):
         
     | 
| 97 | 
         
            +
                    super().__init__()
         
     | 
| 98 | 
         
            +
                    self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers']
         
     | 
| 99 | 
         
            +
                    self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False)
         
     | 
| 100 | 
         
            +
                    self.sphere_init_radius = config.get('sphere_init_radius', 0.5)
         
     | 
| 101 | 
         
            +
                    self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()]
         
     | 
| 102 | 
         
            +
                    for i in range(self.n_hidden_layers - 1):
         
     | 
| 103 | 
         
            +
                        self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()]
         
     | 
| 104 | 
         
            +
                    self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)]
         
     | 
| 105 | 
         
            +
                    self.layers = nn.Sequential(*self.layers)
         
     | 
| 106 | 
         
            +
                    self.output_activation = get_activation(config['output_activation'])
         
     | 
| 107 | 
         
            +
                
         
     | 
| 108 | 
         
            +
                @torch.cuda.amp.autocast(False)
         
     | 
| 109 | 
         
            +
                def forward(self, x):
         
     | 
| 110 | 
         
            +
                    x = self.layers(x.float())
         
     | 
| 111 | 
         
            +
                    x = self.output_activation(x)
         
     | 
| 112 | 
         
            +
                    return x
         
     | 
| 113 | 
         
            +
                
         
     | 
| 114 | 
         
            +
                def make_linear(self, dim_in, dim_out, is_first, is_last):
         
     | 
| 115 | 
         
            +
                    layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality
         
     | 
| 116 | 
         
            +
                    if self.sphere_init:
         
     | 
| 117 | 
         
            +
                        if is_last:
         
     | 
| 118 | 
         
            +
                            torch.nn.init.constant_(layer.bias, -self.sphere_init_radius)
         
     | 
| 119 | 
         
            +
                            torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001)
         
     | 
| 120 | 
         
            +
                        elif is_first:
         
     | 
| 121 | 
         
            +
                            torch.nn.init.constant_(layer.bias, 0.0)
         
     | 
| 122 | 
         
            +
                            torch.nn.init.constant_(layer.weight[:, 3:], 0.0)
         
     | 
| 123 | 
         
            +
                            torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out))
         
     | 
| 124 | 
         
            +
                        else:
         
     | 
| 125 | 
         
            +
                            torch.nn.init.constant_(layer.bias, 0.0)
         
     | 
| 126 | 
         
            +
                            torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out))
         
     | 
| 127 | 
         
            +
                    else:
         
     | 
| 128 | 
         
            +
                        torch.nn.init.constant_(layer.bias, 0.0)
         
     | 
| 129 | 
         
            +
                        torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
         
     | 
| 130 | 
         
            +
                    
         
     | 
| 131 | 
         
            +
                    if self.weight_norm:
         
     | 
| 132 | 
         
            +
                        layer = nn.utils.weight_norm(layer)
         
     | 
| 133 | 
         
            +
                    return layer   
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                def make_activation(self):
         
     | 
| 136 | 
         
            +
                    if self.sphere_init:
         
     | 
| 137 | 
         
            +
                        return nn.Softplus(beta=100)
         
     | 
| 138 | 
         
            +
                    else:
         
     | 
| 139 | 
         
            +
                        return nn.ReLU(inplace=True)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
            def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network):
         
     | 
| 143 | 
         
            +
                rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.')
         
     | 
| 144 | 
         
            +
                """
         
     | 
| 145 | 
         
            +
                from https://github.com/NVlabs/tiny-cuda-nn/issues/96
         
     | 
| 146 | 
         
            +
                It's the weight matrices of each layer laid out in row-major order and then concatenated.
         
     | 
| 147 | 
         
            +
                Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP).
         
     | 
| 148 | 
         
            +
                The padded input dimensions get a constant value of 1.0,
         
     | 
| 149 | 
         
            +
                whereas the padded output dimensions are simply ignored,
         
     | 
| 150 | 
         
            +
                so the weights pertaining to those can have any value.
         
     | 
| 151 | 
         
            +
                """
         
     | 
| 152 | 
         
            +
                padto = 16 if config.otype == 'FullyFusedMLP' else 8
         
     | 
| 153 | 
         
            +
                n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto
         
     | 
| 154 | 
         
            +
                n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto
         
     | 
| 155 | 
         
            +
                data = list(network.parameters())[0].data
         
     | 
| 156 | 
         
            +
                assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2
         
     | 
| 157 | 
         
            +
                new_data = []
         
     | 
| 158 | 
         
            +
                # first layer
         
     | 
| 159 | 
         
            +
                weight = torch.zeros((config.n_neurons, n_input_dims)).to(data)
         
     | 
| 160 | 
         
            +
                torch.nn.init.constant_(weight[:, 3:], 0.0)
         
     | 
| 161 | 
         
            +
                torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons))
         
     | 
| 162 | 
         
            +
                new_data.append(weight.flatten())
         
     | 
| 163 | 
         
            +
                # hidden layers
         
     | 
| 164 | 
         
            +
                for i in range(config.n_hidden_layers - 1):
         
     | 
| 165 | 
         
            +
                    weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data)
         
     | 
| 166 | 
         
            +
                    torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons))
         
     | 
| 167 | 
         
            +
                    new_data.append(weight.flatten())
         
     | 
| 168 | 
         
            +
                # last layer
         
     | 
| 169 | 
         
            +
                weight = torch.zeros((n_output_dims, config.n_neurons)).to(data)
         
     | 
| 170 | 
         
            +
                torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001)
         
     | 
| 171 | 
         
            +
                new_data.append(weight.flatten())
         
     | 
| 172 | 
         
            +
                new_data = torch.cat(new_data)
         
     | 
| 173 | 
         
            +
                data.copy_(new_data)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
             
     | 
| 176 | 
         
            +
            def get_mlp(n_input_dims, n_output_dims, config):
         
     | 
| 177 | 
         
            +
                if config.otype == 'VanillaMLP':
         
     | 
| 178 | 
         
            +
                    network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config))
         
     | 
| 179 | 
         
            +
                else:
         
     | 
| 180 | 
         
            +
                    with torch.cuda.device(get_rank()):
         
     | 
| 181 | 
         
            +
                        network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config))
         
     | 
| 182 | 
         
            +
                        if config.get('sphere_init', False):
         
     | 
| 183 | 
         
            +
                            sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network)
         
     | 
| 184 | 
         
            +
                return network
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
            class EncodingWithNetwork(nn.Module):
         
     | 
| 188 | 
         
            +
                def __init__(self, encoding, network):
         
     | 
| 189 | 
         
            +
                    super().__init__()
         
     | 
| 190 | 
         
            +
                    self.encoding, self.network = encoding, network
         
     | 
| 191 | 
         
            +
                
         
     | 
| 192 | 
         
            +
                def forward(self, x):
         
     | 
| 193 | 
         
            +
                    return self.network(self.encoding(x))
         
     | 
| 194 | 
         
            +
                
         
     | 
| 195 | 
         
            +
                def update_step(self, epoch, global_step):
         
     | 
| 196 | 
         
            +
                    update_module_step(self.encoding, epoch, global_step)
         
     | 
| 197 | 
         
            +
                    update_module_step(self.network, epoch, global_step)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
            def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config):
         
     | 
| 201 | 
         
            +
                # input suppose to be range [0, 1]
         
     | 
| 202 | 
         
            +
                if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \
         
     | 
| 203 | 
         
            +
                    or network_config.otype in ['VanillaMLP']:
         
     | 
| 204 | 
         
            +
                    encoding = get_encoding(n_input_dims, encoding_config)
         
     | 
| 205 | 
         
            +
                    network = get_mlp(encoding.n_output_dims, n_output_dims, network_config)
         
     | 
| 206 | 
         
            +
                    encoding_with_network = EncodingWithNetwork(encoding, network)
         
     | 
| 207 | 
         
            +
                else:
         
     | 
| 208 | 
         
            +
                    with torch.cuda.device(get_rank()):
         
     | 
| 209 | 
         
            +
                        encoding_with_network = tcnn.NetworkWithInputEncoding(
         
     | 
| 210 | 
         
            +
                            n_input_dims=n_input_dims,
         
     | 
| 211 | 
         
            +
                            n_output_dims=n_output_dims,
         
     | 
| 212 | 
         
            +
                            encoding_config=config_to_primitive(encoding_config),
         
     | 
| 213 | 
         
            +
                            network_config=config_to_primitive(network_config)
         
     | 
| 214 | 
         
            +
                        )
         
     | 
| 215 | 
         
            +
                return encoding_with_network
         
     |