diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..c3017c215c14f5737c85ae26e51690d903f47d0c --- /dev/null +++ b/.gitignore @@ -0,0 +1,50 @@ +# extensions +*.egg-info +*.py[cod] + +# envs +.pt13 +.pt2 + +# directories +/checkpoints +/dist +/outputs +/build +/src +logs/ +ckpts/ +tmp/ +lightning_logs/ +images/ +images*/ +kb_configs/ +debug_lvis.log +*.log +.cache/ +redirects/ +submits/ +extern/ +assets/images +output/ +assets/scene +assets/GSO +assets/SD +spirals +*.zip +paper/ +spirals_co3d/ +scene_spirals/ +blenders/ +colmap_results/ +depth_spirals/ +recon/SIBR_viewers/ +recon/assets/ +mesh_recon/exp +mesh_recon/runs +mesh_recon/renders +mesh_recon/refined +*.png +*.pdf +*.npz +*.npy diff --git a/configs/ae/video.yaml b/configs/ae/video.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ecc65942a2203ddb468763cff2fc894616fc47a3 --- /dev/null +++ b/configs/ae/video.yaml @@ -0,0 +1,35 @@ +target: sgm.models.autoencoder.AutoencodingEngine +params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: sgm.modules.diffusionmodules.model.Encoder + params: + attn_type: vanilla + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + decoder_config: + target: sgm.modules.autoencoding.temporal_ae.VideoDecoder + params: + attn_type: vanilla + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + video_kernel_size: [3, 1, 1] \ No newline at end of file diff --git a/configs/embedder/clip_image.yaml b/configs/embedder/clip_image.yaml new file mode 100644 index 0000000000000000000000000000000000000000..54a2a92c162d9c950c16b0f12170d1d73d999212 --- /dev/null +++ b/configs/embedder/clip_image.yaml @@ -0,0 +1,8 @@ +target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder +params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True \ No newline at end of file diff --git a/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml b/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml new file mode 100644 index 0000000000000000000000000000000000000000..731f55930fba00cb9de758c90eefbcd1afd59d47 --- /dev/null +++ b/configs/example_training/autoencoder/kl-f4/imagenet-attnfree-logvar.yaml @@ -0,0 +1,104 @@ +model: + base_learning_rate: 4.5e-6 + target: sgm.models.autoencoder.AutoencodingEngine + params: + input_key: jpg + monitor: val/rec_loss + + loss_config: + target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator + params: + perceptual_weight: 0.25 + disc_start: 20001 + disc_weight: 0.5 + learn_logvar: True + + regularization_weights: + kl_loss: 1.0 + + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + + encoder_config: + target: sgm.modules.diffusionmodules.model.Encoder + params: + attn_type: none + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4] + num_res_blocks: 4 + attn_resolutions: [] + dropout: 0.0 + + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: ${model.params.encoder_config.params} + +data: + target: sgm.data.dataset.StableDataModuleFromConfig + params: + train: + datapipeline: + urls: + - DATA-PATH + pipeline_config: + shardshuffle: 10000 + sample_shuffle: 10000 + + decoders: + - pil + + postprocessors: + - target: sdata.mappers.TorchVisionImageTransforms + params: + key: jpg + transforms: + - target: torchvision.transforms.Resize + params: + size: 256 + interpolation: 3 + - target: torchvision.transforms.ToTensor + - target: sdata.mappers.Rescaler + - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare + params: + h_key: height + w_key: width + + loader: + batch_size: 8 + num_workers: 4 + + +lightning: + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: True + + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 50000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + devices: 0, + limit_val_batches: 50 + benchmark: True + accumulate_grad_batches: 1 + val_check_interval: 10000 \ No newline at end of file diff --git a/configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml b/configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml new file mode 100644 index 0000000000000000000000000000000000000000..39c7c9df5da1c657d2ce72ac8b6269ae86185e91 --- /dev/null +++ b/configs/example_training/autoencoder/kl-f4/imagenet-kl_f8_8chn.yaml @@ -0,0 +1,105 @@ +model: + base_learning_rate: 4.5e-6 + target: sgm.models.autoencoder.AutoencodingEngine + params: + input_key: jpg + monitor: val/loss/rec + disc_start_iter: 0 + + encoder_config: + target: sgm.modules.diffusionmodules.model.Encoder + params: + attn_type: vanilla-xformers + double_z: true + z_channels: 8 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + + decoder_config: + target: sgm.modules.diffusionmodules.model.Decoder + params: ${model.params.encoder_config.params} + + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + + loss_config: + target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator + params: + perceptual_weight: 0.25 + disc_start: 20001 + disc_weight: 0.5 + learn_logvar: True + + regularization_weights: + kl_loss: 1.0 + +data: + target: sgm.data.dataset.StableDataModuleFromConfig + params: + train: + datapipeline: + urls: + - DATA-PATH + pipeline_config: + shardshuffle: 10000 + sample_shuffle: 10000 + + decoders: + - pil + + postprocessors: + - target: sdata.mappers.TorchVisionImageTransforms + params: + key: jpg + transforms: + - target: torchvision.transforms.Resize + params: + size: 256 + interpolation: 3 + - target: torchvision.transforms.ToTensor + - target: sdata.mappers.Rescaler + - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare + params: + h_key: height + w_key: width + + loader: + batch_size: 8 + num_workers: 4 + + +lightning: + strategy: + target: pytorch_lightning.strategies.DDPStrategy + params: + find_unused_parameters: True + + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 50000 + + image_logger: + target: main.ImageLogger + params: + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + + trainer: + devices: 0, + limit_val_batches: 50 + benchmark: True + accumulate_grad_batches: 1 + val_check_interval: 10000 diff --git a/configs/example_training/imagenet-f8_cond.yaml b/configs/example_training/imagenet-f8_cond.yaml new file mode 100644 index 0000000000000000000000000000000000000000..23cded00a72e2883df1a4bf2b639a49cda763a8e --- /dev/null +++ b/configs/example_training/imagenet-f8_cond.yaml @@ -0,0 +1,185 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + log_keys: + - cls + + scheduler_config: + target: sgm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [10000] + cycle_lengths: [10000000000000] + f_start: [1.e-6] + f_max: [1.] + f_min: [1.] + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 256 + attention_resolutions: [1, 2, 4] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + num_classes: sequential + adm_in_channels: 1024 + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: True + input_key: cls + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.ClassEmbedder + params: + add_sequence_dim: True + embed_dim: 1024 + n_classes: 1000 + + - is_trainable: False + ucg_rate: 0.2 + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: crop_coords_top_left + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + ckpt_path: CKPT_PATH + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + num_idx: 1000 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.VanillaCFG + params: + scale: 5.0 + +data: + target: sgm.data.dataset.StableDataModuleFromConfig + params: + train: + datapipeline: + urls: + # USER: adapt this path the root of your custom dataset + - DATA_PATH + pipeline_config: + shardshuffle: 10000 + sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM + + decoders: + - pil + + postprocessors: + - target: sdata.mappers.TorchVisionImageTransforms + params: + key: jpg # USER: you might wanna adapt this for your custom dataset + transforms: + - target: torchvision.transforms.Resize + params: + size: 256 + interpolation: 3 + - target: torchvision.transforms.ToTensor + - target: sdata.mappers.Rescaler + + - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare + params: + h_key: height # USER: you might wanna adapt this for your custom dataset + w_key: width # USER: you might wanna adapt this for your custom dataset + + loader: + batch_size: 64 + num_workers: 6 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 8 + n_rows: 2 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 1000 \ No newline at end of file diff --git a/configs/example_training/toy/cifar10_cond.yaml b/configs/example_training/toy/cifar10_cond.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fca9958464488a66ed2a54d57c59228215690606 --- /dev/null +++ b/configs/example_training/toy/cifar10_cond.yaml @@ -0,0 +1,98 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling + params: + sigma_data: 1.0 + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + in_channels: 3 + out_channels: 3 + model_channels: 32 + attention_resolutions: [] + num_res_blocks: 4 + channel_mult: [1, 2, 2] + num_head_channels: 32 + num_classes: sequential + adm_in_channels: 128 + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: True + input_key: cls + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 128 + n_classes: 10 + + first_stage_config: + target: sgm.models.autoencoder.IdentityFirstStage + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting + params: + sigma_data: 1.0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.VanillaCFG + params: + scale: 3.0 + +data: + target: sgm.data.cifar10.CIFAR10Loader + params: + batch_size: 512 + num_workers: 1 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + batch_frequency: 1000 + max_images: 64 + increase_log_steps: True + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 64 + n_rows: 8 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/toy/mnist.yaml b/configs/example_training/toy/mnist.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a86d05ca1efa537b57646c3923c1f54ac0d6ccf4 --- /dev/null +++ b/configs/example_training/toy/mnist.yaml @@ -0,0 +1,79 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling + params: + sigma_data: 1.0 + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + in_channels: 1 + out_channels: 1 + model_channels: 32 + attention_resolutions: [] + num_res_blocks: 4 + channel_mult: [1, 2, 2] + num_head_channels: 32 + + first_stage_config: + target: sgm.models.autoencoder.IdentityFirstStage + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting + params: + sigma_data: 1.0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + +data: + target: sgm.data.mnist.MNISTLoader + params: + batch_size: 512 + num_workers: 1 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + batch_frequency: 1000 + max_images: 64 + increase_log_steps: False + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 64 + n_rows: 8 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 10 \ No newline at end of file diff --git a/configs/example_training/toy/mnist_cond.yaml b/configs/example_training/toy/mnist_cond.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8378acd7acd4c23039a659789b6e6ff5de1a1058 --- /dev/null +++ b/configs/example_training/toy/mnist_cond.yaml @@ -0,0 +1,98 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling + params: + sigma_data: 1.0 + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + in_channels: 1 + out_channels: 1 + model_channels: 32 + attention_resolutions: [] + num_res_blocks: 4 + channel_mult: [1, 2, 2] + num_head_channels: 32 + num_classes: sequential + adm_in_channels: 128 + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: True + input_key: cls + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 128 + n_classes: 10 + + first_stage_config: + target: sgm.models.autoencoder.IdentityFirstStage + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting + params: + sigma_data: 1.0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.VanillaCFG + params: + scale: 3.0 + +data: + target: sgm.data.mnist.MNISTLoader + params: + batch_size: 512 + num_workers: 1 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + batch_frequency: 1000 + max_images: 16 + increase_log_steps: True + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 16 + n_rows: 4 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/toy/mnist_cond_discrete_eps.yaml b/configs/example_training/toy/mnist_cond_discrete_eps.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e58aae58dd108887d8d2ac06933a31f84ea61509 --- /dev/null +++ b/configs/example_training/toy/mnist_cond_discrete_eps.yaml @@ -0,0 +1,103 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + in_channels: 1 + out_channels: 1 + model_channels: 32 + attention_resolutions: [] + num_res_blocks: 4 + channel_mult: [1, 2, 2] + num_head_channels: 32 + num_classes: sequential + adm_in_channels: 128 + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: True + input_key: cls + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 128 + n_classes: 10 + + first_stage_config: + target: sgm.models.autoencoder.IdentityFirstStage + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + num_idx: 1000 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.VanillaCFG + params: + scale: 5.0 + +data: + target: sgm.data.mnist.MNISTLoader + params: + batch_size: 512 + num_workers: 1 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + batch_frequency: 1000 + max_images: 16 + increase_log_steps: True + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 16 + n_rows: 4 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/toy/mnist_cond_l1_loss.yaml b/configs/example_training/toy/mnist_cond_l1_loss.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ee2f780358b7fe100efa226ae20f6ac58b441632 --- /dev/null +++ b/configs/example_training/toy/mnist_cond_l1_loss.yaml @@ -0,0 +1,99 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling + params: + sigma_data: 1.0 + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + in_channels: 1 + out_channels: 1 + model_channels: 32 + attention_resolutions: [] + num_res_blocks: 4 + channel_mult: [1, 2, 2] + num_head_channels: 32 + num_classes: sequential + adm_in_channels: 128 + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: True + input_key: cls + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 128 + n_classes: 10 + + first_stage_config: + target: sgm.models.autoencoder.IdentityFirstStage + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_type: l1 + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting + params: + sigma_data: 1.0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.VanillaCFG + params: + scale: 3.0 + +data: + target: sgm.data.mnist.MNISTLoader + params: + batch_size: 512 + num_workers: 1 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + batch_frequency: 1000 + max_images: 64 + increase_log_steps: True + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 64 + n_rows: 8 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/toy/mnist_cond_with_ema.yaml b/configs/example_training/toy/mnist_cond_with_ema.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c666e7143b7cb0a920d384f3f6294231b8bb1726 --- /dev/null +++ b/configs/example_training/toy/mnist_cond_with_ema.yaml @@ -0,0 +1,100 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + use_ema: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EDMScaling + params: + sigma_data: 1.0 + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + in_channels: 1 + out_channels: 1 + model_channels: 32 + attention_resolutions: [] + num_res_blocks: 4 + channel_mult: [1, 2, 2] + num_head_channels: 32 + num_classes: sequential + adm_in_channels: 128 + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: True + input_key: cls + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.ClassEmbedder + params: + embed_dim: 128 + n_classes: 10 + + first_stage_config: + target: sgm.models.autoencoder.IdentityFirstStage + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting + params: + sigma_data: 1.0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.VanillaCFG + params: + scale: 3.0 + +data: + target: sgm.data.mnist.MNISTLoader + params: + batch_size: 512 + num_workers: 1 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + batch_frequency: 1000 + max_images: 64 + increase_log_steps: True + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 64 + n_rows: 8 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 20 \ No newline at end of file diff --git a/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml b/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f268c3295bd57888de3efc736d307903ee80a8f --- /dev/null +++ b/configs/example_training/txt2img-clipl-legacy-ucg-training.yaml @@ -0,0 +1,182 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + log_keys: + - txt + + scheduler_config: + target: sgm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [10000] + cycle_lengths: [10000000000000] + f_start: [1.e-6] + f_max: [1.] + f_min: [1.] + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [1, 2, 4] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + num_classes: sequential + adm_in_channels: 1792 + num_heads: 1 + transformer_depth: 1 + context_dim: 768 + spatial_transformer_attn_type: softmax-xformers + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: True + input_key: txt + ucg_rate: 0.1 + legacy_ucg_value: "" + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + always_return_pooled: True + + - is_trainable: False + ucg_rate: 0.1 + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: crop_coords_top_left + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + ckpt_path: CKPT_PATH + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + num_idx: 1000 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.VanillaCFG + params: + scale: 7.5 + +data: + target: sgm.data.dataset.StableDataModuleFromConfig + params: + train: + datapipeline: + urls: + # USER: adapt this path the root of your custom dataset + - DATA_PATH + pipeline_config: + shardshuffle: 10000 + sample_shuffle: 10000 # USER: you might wanna adapt depending on your available RAM + + decoders: + - pil + + postprocessors: + - target: sdata.mappers.TorchVisionImageTransforms + params: + key: jpg # USER: you might wanna adapt this for your custom dataset + transforms: + - target: torchvision.transforms.Resize + params: + size: 256 + interpolation: 3 + - target: torchvision.transforms.ToTensor + - target: sdata.mappers.Rescaler + - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare + # USER: you might wanna use non-default parameters due to your custom dataset + + loader: + batch_size: 64 + num_workers: 6 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 8 + n_rows: 2 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 1000 \ No newline at end of file diff --git a/configs/example_training/txt2img-clipl.yaml b/configs/example_training/txt2img-clipl.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cb66ede901b1aa1acb18d162b88912a2e6eab0ce --- /dev/null +++ b/configs/example_training/txt2img-clipl.yaml @@ -0,0 +1,184 @@ +model: + base_learning_rate: 1.0e-4 + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + log_keys: + - txt + + scheduler_config: + target: sgm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: [10000] + cycle_lengths: [10000000000000] + f_start: [1.e-6] + f_max: [1.] + f_min: [1.] + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [1, 2, 4] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + num_classes: sequential + adm_in_channels: 1792 + num_heads: 1 + transformer_depth: 1 + context_dim: 768 + spatial_transformer_attn_type: softmax-xformers + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: True + input_key: txt + ucg_rate: 0.1 + legacy_ucg_value: "" + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + always_return_pooled: True + + - is_trainable: False + ucg_rate: 0.1 + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: crop_coords_top_left + ucg_rate: 0.1 + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + ckpt_path: CKPT_PATH + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling + params: + num_idx: 1000 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 50 + + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + guider_config: + target: sgm.modules.diffusionmodules.guiders.VanillaCFG + params: + scale: 7.5 + +data: + target: sgm.data.dataset.StableDataModuleFromConfig + params: + train: + datapipeline: + urls: + # USER: adapt this path the root of your custom dataset + - DATA_PATH + pipeline_config: + shardshuffle: 10000 + sample_shuffle: 10000 + + + decoders: + - pil + + postprocessors: + - target: sdata.mappers.TorchVisionImageTransforms + params: + key: jpg # USER: you might wanna adapt this for your custom dataset + transforms: + - target: torchvision.transforms.Resize + params: + size: 256 + interpolation: 3 + - target: torchvision.transforms.ToTensor + - target: sdata.mappers.Rescaler + # USER: you might wanna use non-default parameters due to your custom dataset + - target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare + # USER: you might wanna use non-default parameters due to your custom dataset + + loader: + batch_size: 64 + num_workers: 6 + +lightning: + modelcheckpoint: + params: + every_n_train_steps: 5000 + + callbacks: + metrics_over_trainsteps_checkpoint: + params: + every_n_train_steps: 25000 + + image_logger: + target: main.ImageLogger + params: + disabled: False + enable_autocast: False + batch_frequency: 1000 + max_images: 8 + increase_log_steps: True + log_first_step: False + log_images_kwargs: + use_ema_scope: False + N: 8 + n_rows: 2 + + trainer: + devices: 0, + benchmark: True + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 + max_epochs: 1000 \ No newline at end of file diff --git a/configs/inference/sd_2_1.yaml b/configs/inference/sd_2_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6531c6c49fab2d5d9f21c75e53b0370cb8dad8dc --- /dev/null +++ b/configs/inference/sd_2_1.yaml @@ -0,0 +1,60 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: true + layer: penultimate + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity \ No newline at end of file diff --git a/configs/inference/sd_2_1_768.yaml b/configs/inference/sd_2_1_768.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e2f9910a192745781edb2a8505fae1d3e1916f87 --- /dev/null +++ b/configs/inference/sd_2_1_768.yaml @@ -0,0 +1,60 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: true + layer: penultimate + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity \ No newline at end of file diff --git a/configs/inference/sd_xl_base.yaml b/configs/inference/sd_xl_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6047379753a05224bb5b3f6746130fb7fb9f40aa --- /dev/null +++ b/configs/inference/sd_xl_base.yaml @@ -0,0 +1,93 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenCLIPEmbedder + params: + layer: hidden + layer_idx: 11 + + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: target_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/configs/inference/sd_xl_refiner.yaml b/configs/inference/sd_xl_refiner.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d5ab44e748c55f5f2e34ae5aefdb78a921a8d3f --- /dev/null +++ b/configs/inference/sd_xl_refiner.yaml @@ -0,0 +1,86 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.13025 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser + params: + num_idx: 1000 + + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization + + network_config: + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2560 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 384 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 4 + context_dim: [1280, 1280, 1280, 1280] + spatial_transformer_attn_type: softmax-xformers + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: txt + target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder2 + params: + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + legacy: False + freeze: True + layer: penultimate + always_return_pooled: True + + - is_trainable: False + input_key: original_size_as_tuple + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: crop_coords_top_left + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - is_trainable: False + input_key: aesthetic_score + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity diff --git a/configs/inference/svd.yaml b/configs/inference/svd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a0819ea77f1ed95dfedb2ab6ccdded4e6414e43 --- /dev/null +++ b/configs/inference/svd.yaml @@ -0,0 +1,131 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 768 + num_classes: sequential + use_checkpoint: True + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: True + use_spatial_context: True + merge_strategy: learned_with_images + video_kernel_size: [3, 1, 1] + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: cond_frames_without_noise + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: fps_id + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - input_key: motion_bucket_id + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - input_key: cond_frames + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: True + n_cond_frames: 1 + n_copies: 1 + is_ae: True + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + - input_key: cond_aug + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: sgm.modules.diffusionmodules.model.Encoder + params: + attn_type: vanilla + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + decoder_config: + target: sgm.modules.autoencoding.temporal_ae.VideoDecoder + params: + attn_type: vanilla + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + video_kernel_size: [3, 1, 1] \ No newline at end of file diff --git a/configs/inference/svd_image_decoder.yaml b/configs/inference/svd_image_decoder.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb09177ad77a8154c20fcbb2e2fdfc0ac9b6c491 --- /dev/null +++ b/configs/inference/svd_image_decoder.yaml @@ -0,0 +1,114 @@ +model: + target: sgm.models.diffusion.DiffusionEngine + params: + scale_factor: 0.18215 + disable_first_stage_autocast: True + + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 768 + num_classes: sequential + use_checkpoint: True + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2, 1] + num_res_blocks: 2 + channel_mult: [1, 2, 4, 4] + num_head_channels: 64 + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: True + use_spatial_context: True + merge_strategy: learned_with_images + video_kernel_size: [3, 1, 1] + + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: False + input_key: cond_frames_without_noise + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: True + + - input_key: fps_id + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - input_key: motion_bucket_id + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + - input_key: cond_frames + is_trainable: False + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: True + n_cond_frames: 1 + n_copies: 1 + is_ae: True + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + - input_key: cond_aug + is_trainable: False + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + + first_stage_config: + target: sgm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: True + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity \ No newline at end of file diff --git a/configs/inference/svd_mv.yaml b/configs/inference/svd_mv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a343094433e2a486d730f8c3a4561a371f3ad777 --- /dev/null +++ b/configs/inference/svd_mv.yaml @@ -0,0 +1,202 @@ +model: + base_learning_rate: 1.0e-05 + target: sgm.models.video_diffusion.DiffusionEngine + params: + ckpt_path: ckpts/svd_xt.safetensors + scale_factor: 0.18215 + disable_first_stage_autocast: true + scheduler_config: + target: sgm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: + - 1 + cycle_lengths: + - 10000000000000 + f_start: + - 1.0e-06 + f_max: + - 1.0 + f_min: + - 1.0 + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 768 + num_classes: sequential + use_checkpoint: true + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_head_channels: 64 + use_linear_in_transformer: true + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: true + use_spatial_context: true + merge_strategy: learned_with_images + video_kernel_size: + - 3 + - 1 + - 1 + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + ucg_rate: 0.2 + input_key: cond_frames_without_noise + target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder + params: + n_cond_frames: 1 + n_copies: 1 + open_clip_embedding_config: + target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + params: + freeze: true + - input_key: fps_id + is_trainable: true + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + - input_key: motion_bucket_id + is_trainable: true + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + - input_key: cond_frames + is_trainable: false + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder + params: + disable_encoder_autocast: true + n_cond_frames: 1 + n_copies: 1 + is_ae: true + encoder_config: + target: sgm.models.autoencoder.AutoencoderKLModeOnly + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla-xformers + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + - input_key: cond_aug + is_trainable: true + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: sgm.modules.diffusionmodules.model.Encoder + params: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + decoder_config: + target: sgm.modules.autoencoding.temporal_ae.VideoDecoder + params: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + video_kernel_size: + - 3 + - 1 + - 1 + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 30 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + params: + sigma_max: 700.0 + guider_config: + target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider + params: + max_scale: 2.5 + min_scale: 1.0 + num_frames: 24 + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + batch2model_keys: + - num_video_frames + - image_only_indicator + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting + params: + sigma_data: 1.0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling + params: + p_mean: 0.3 + p_std: 1.2 +data: + target: sgm.data.objaverse.ObjaverseSpiralDataset + params: + root_dir: /mnt/mfs/zilong.chen/Downloads/objaverse-ndd-samples + random_front: true + batch_size: 2 + num_workers: 16 + cond_aug_mean: -0.0 diff --git a/mesh_recon/configs/neuralangelo-ortho-wmask.yaml b/mesh_recon/configs/neuralangelo-ortho-wmask.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4411ac8173d87e17717ec7381165b03c5b464d7f --- /dev/null +++ b/mesh_recon/configs/neuralangelo-ortho-wmask.yaml @@ -0,0 +1,145 @@ +name: ${basename:${dataset.scene}} +tag: "" +seed: 42 + +dataset: + name: ortho + root_dir: /home/xiaoxiao/Workplace/wonder3Dplus/outputs/joint-twice/aigc/cropsize-224-cfg1.0 + cam_pose_dir: null + scene: scene_name + imSize: [1024, 1024] # should use larger res, otherwise the exported mesh has wrong colors + camera_type: ortho + apply_mask: true + camera_params: null + view_weights: [1.0, 0.8, 0.2, 1.0, 0.4, 0.7] #['front', 'front_right', 'right', 'back', 'left', 'front_left'] + # view_weights: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + +model: + name: neus + radius: 1.0 + num_samples_per_ray: 1024 + train_num_rays: 256 + max_train_num_rays: 8192 + grid_prune: true + grid_prune_occ_thre: 0.001 + dynamic_ray_sampling: true + batch_image_sampling: true + randomized: true + ray_chunk: 2048 + cos_anneal_end: 20000 + learned_background: false + background_color: black + variance: + init_val: 0.3 + modulate: false + geometry: + name: volume-sdf + radius: ${model.radius} + feature_dim: 13 + grad_type: finite_difference + finite_difference_eps: progressive + isosurface: + method: mc + resolution: 192 + chunk: 2097152 + threshold: 0. + xyz_encoding_config: + otype: ProgressiveBandHashGrid + n_levels: 10 # 12 modify + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 32 + per_level_scale: 1.3195079107728942 + include_xyz: true + start_level: 4 + start_step: 0 + update_steps: 1000 + mlp_network_config: + otype: VanillaMLP + activation: ReLU + output_activation: none + n_neurons: 64 + n_hidden_layers: 1 + sphere_init: true + sphere_init_radius: 0.5 + weight_norm: true + texture: + name: volume-radiance + input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input + dir_encoding_config: + otype: SphericalHarmonics + degree: 4 + mlp_network_config: + otype: VanillaMLP + activation: ReLU + output_activation: none + n_neurons: 64 + n_hidden_layers: 2 + color_activation: sigmoid + +system: + name: ortho-neus-system + loss: + lambda_rgb_mse: 0.5 + lambda_rgb_l1: 0. + lambda_mask: 1.0 + lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects + lambda_normal: 1.0 # cannot be too large + lambda_3d_normal_smooth: 1.0 + # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup + lambda_curvature: 0. + lambda_sparsity: 0.5 + lambda_distortion: 0.0 + lambda_distortion_bg: 0.0 + lambda_opaque: 0.0 + sparsity_scale: 100.0 + geo_aware: true + rgb_p_ratio: 0.8 + normal_p_ratio: 0.8 + mask_p_ratio: 0.9 + optimizer: + name: AdamW + args: + lr: 0.01 + betas: [0.9, 0.99] + eps: 1.e-15 + params: + geometry: + lr: 0.001 + texture: + lr: 0.01 + variance: + lr: 0.001 + constant_steps: 500 + scheduler: + name: SequentialLR + interval: step + milestones: + - ${system.constant_steps} + schedulers: + - name: ConstantLR + args: + factor: 1.0 + total_iters: ${system.constant_steps} + - name: ExponentialLR + args: + gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}} + +checkpoint: + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} + +export: + chunk_size: 2097152 + export_vertex_color: True + ortho_scale: 1.35 #modify + +trainer: + max_steps: 3000 + log_every_n_steps: 100 + num_sanity_val_steps: 0 + val_check_interval: 4000 + limit_train_batches: 1.0 + limit_val_batches: 2 + enable_progress_bar: true + precision: 16 diff --git a/mesh_recon/configs/v3d.yaml b/mesh_recon/configs/v3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3037c6453f2bf39e3bbd5fc82f9d1cf088bd1201 --- /dev/null +++ b/mesh_recon/configs/v3d.yaml @@ -0,0 +1,144 @@ +name: ${basename:${dataset.scene}} +tag: "" +seed: 42 + +dataset: + name: v3d + root_dir: ./spirals + cam_pose_dir: null + scene: pizza_man + apply_mask: true + train_split: train + test_split: train + val_split: train + img_wh: [1024, 1024] + +model: + name: neus + radius: 1.0 ## check this + num_samples_per_ray: 1024 + train_num_rays: 256 + max_train_num_rays: 8192 + grid_prune: true + grid_prune_occ_thre: 0.001 + dynamic_ray_sampling: true + batch_image_sampling: true + randomized: true + ray_chunk: 2048 + cos_anneal_end: 20000 + learned_background: false + background_color: black + variance: + init_val: 0.3 + modulate: false + geometry: + name: volume-sdf + radius: ${model.radius} + feature_dim: 13 + grad_type: finite_difference + finite_difference_eps: progressive + isosurface: + method: mc + resolution: 384 + chunk: 2097152 + threshold: 0. + xyz_encoding_config: + otype: ProgressiveBandHashGrid + n_levels: 10 # 12 modify + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 32 + per_level_scale: 1.3195079107728942 + include_xyz: true + start_level: 4 + start_step: 0 + update_steps: 1000 + mlp_network_config: + otype: VanillaMLP + activation: ReLU + output_activation: none + n_neurons: 64 + n_hidden_layers: 1 + sphere_init: true + sphere_init_radius: 0.5 + weight_norm: true + texture: + name: volume-radiance + input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input + dir_encoding_config: + otype: SphericalHarmonics + degree: 4 + mlp_network_config: + otype: VanillaMLP + activation: ReLU + output_activation: none + n_neurons: 64 + n_hidden_layers: 2 + color_activation: sigmoid + +system: + name: videonvs-neus-system + loss: + lambda_rgb_mse: 0.5 + lambda_rgb_l1: 0. + lambda_mask: 1.0 + lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects + lambda_normal: 0.0 # cannot be too large + lambda_3d_normal_smooth: 1.0 + # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup + lambda_curvature: 0. + lambda_sparsity: 0.5 + lambda_distortion: 0.0 + lambda_distortion_bg: 0.0 + lambda_opaque: 0.0 + sparsity_scale: 100.0 + geo_aware: true + rgb_p_ratio: 0.8 + normal_p_ratio: 0.8 + mask_p_ratio: 0.9 + optimizer: + name: AdamW + args: + lr: 0.01 + betas: [0.9, 0.99] + eps: 1.e-15 + params: + geometry: + lr: 0.001 + texture: + lr: 0.01 + variance: + lr: 0.001 + constant_steps: 500 + scheduler: + name: SequentialLR + interval: step + milestones: + - ${system.constant_steps} + schedulers: + - name: ConstantLR + args: + factor: 1.0 + total_iters: ${system.constant_steps} + - name: ExponentialLR + args: + gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}} + +checkpoint: + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} + +export: + chunk_size: 2097152 + export_vertex_color: True + ortho_scale: null #modify + +trainer: + max_steps: 3000 + log_every_n_steps: 100 + num_sanity_val_steps: 0 + val_check_interval: 3000 + limit_train_batches: 1.0 + limit_val_batches: 2 + enable_progress_bar: true + precision: 16 diff --git a/mesh_recon/configs/videonvs.yaml b/mesh_recon/configs/videonvs.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c296a5cc659f612a5700f8535b2f3f1e21c640d1 --- /dev/null +++ b/mesh_recon/configs/videonvs.yaml @@ -0,0 +1,144 @@ +name: ${basename:${dataset.scene}} +tag: "" +seed: 42 + +dataset: + name: videonvs + root_dir: ./spirals + cam_pose_dir: null + scene: pizza_man + apply_mask: true + train_split: train + test_split: train + val_split: train + img_wh: [1024, 1024] + +model: + name: neus + radius: 1.0 ## check this + num_samples_per_ray: 1024 + train_num_rays: 256 + max_train_num_rays: 8192 + grid_prune: true + grid_prune_occ_thre: 0.001 + dynamic_ray_sampling: true + batch_image_sampling: true + randomized: true + ray_chunk: 2048 + cos_anneal_end: 20000 + learned_background: false + background_color: black + variance: + init_val: 0.3 + modulate: false + geometry: + name: volume-sdf + radius: ${model.radius} + feature_dim: 13 + grad_type: finite_difference + finite_difference_eps: progressive + isosurface: + method: mc + resolution: 384 + chunk: 2097152 + threshold: 0. + xyz_encoding_config: + otype: ProgressiveBandHashGrid + n_levels: 10 # 12 modify + n_features_per_level: 2 + log2_hashmap_size: 19 + base_resolution: 32 + per_level_scale: 1.3195079107728942 + include_xyz: true + start_level: 4 + start_step: 0 + update_steps: 1000 + mlp_network_config: + otype: VanillaMLP + activation: ReLU + output_activation: none + n_neurons: 64 + n_hidden_layers: 1 + sphere_init: true + sphere_init_radius: 0.5 + weight_norm: true + texture: + name: volume-radiance + input_feature_dim: ${add:${model.geometry.feature_dim},3} # surface normal as additional input + dir_encoding_config: + otype: SphericalHarmonics + degree: 4 + mlp_network_config: + otype: VanillaMLP + activation: ReLU + output_activation: none + n_neurons: 64 + n_hidden_layers: 2 + color_activation: sigmoid + +system: + name: videonvs-neus-system + loss: + lambda_rgb_mse: 0.5 + lambda_rgb_l1: 0. + lambda_mask: 1.0 + lambda_eikonal: 0.2 # cannot be too large, will cause holes to thin objects + lambda_normal: 1.0 # cannot be too large + lambda_3d_normal_smooth: 1.0 + # lambda_curvature: [0, 0.0, 1.e-4, 1000] # topology warmup + lambda_curvature: 0. + lambda_sparsity: 0.5 + lambda_distortion: 0.0 + lambda_distortion_bg: 0.0 + lambda_opaque: 0.0 + sparsity_scale: 100.0 + geo_aware: true + rgb_p_ratio: 0.8 + normal_p_ratio: 0.8 + mask_p_ratio: 0.9 + optimizer: + name: AdamW + args: + lr: 0.01 + betas: [0.9, 0.99] + eps: 1.e-15 + params: + geometry: + lr: 0.001 + texture: + lr: 0.01 + variance: + lr: 0.001 + constant_steps: 500 + scheduler: + name: SequentialLR + interval: step + milestones: + - ${system.constant_steps} + schedulers: + - name: ConstantLR + args: + factor: 1.0 + total_iters: ${system.constant_steps} + - name: ExponentialLR + args: + gamma: ${calc_exp_lr_decay_rate:0.1,${sub:${trainer.max_steps},${system.constant_steps}}} + +checkpoint: + save_top_k: -1 + every_n_train_steps: ${trainer.max_steps} + +export: + chunk_size: 2097152 + export_vertex_color: True + ortho_scale: null #modify + +trainer: + max_steps: 3000 + log_every_n_steps: 100 + num_sanity_val_steps: 0 + val_check_interval: 3000 + limit_train_batches: 1.0 + limit_val_batches: 2 + enable_progress_bar: true + precision: 16 diff --git a/mesh_recon/datasets/__init__.py b/mesh_recon/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..adc86bd681a104460169f57a2e18ac29c9e24c33 --- /dev/null +++ b/mesh_recon/datasets/__init__.py @@ -0,0 +1,17 @@ +datasets = {} + + +def register(name): + def decorator(cls): + datasets[name] = cls + return cls + + return decorator + + +def make(name, config): + dataset = datasets[name](config) + return dataset + + +from . import blender, colmap, dtu, ortho, videonvs, videonvs_co3d, v3d diff --git a/mesh_recon/datasets/blender.py b/mesh_recon/datasets/blender.py new file mode 100644 index 0000000000000000000000000000000000000000..4bc643389da75f6e7a9c1331e2dc5f6e6c1dbf9a --- /dev/null +++ b/mesh_recon/datasets/blender.py @@ -0,0 +1,143 @@ +import os +import json +import math +import numpy as np +from PIL import Image + +import torch +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ray_directions +from utils.misc import get_rank + + +class BlenderDatasetBase: + def setup(self, config, split): + self.config = config + self.split = split + self.rank = get_rank() + + self.has_mask = True + self.apply_mask = True + + with open( + os.path.join(self.config.root_dir, f"transforms_{self.split}.json"), "r" + ) as f: + meta = json.load(f) + + if "w" in meta and "h" in meta: + W, H = int(meta["w"]), int(meta["h"]) + else: + W, H = 800, 800 + + if "img_wh" in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif "img_downscale" in self.config: + w, h = W // self.config.img_downscale, H // self.config.img_downscale + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + self.w, self.h = w, h + self.img_wh = (self.w, self.h) + + self.near, self.far = self.config.near_plane, self.config.far_plane + + self.focal = ( + 0.5 * w / math.tan(0.5 * meta["camera_angle_x"]) + ) # scaled focal length + + # ray directions for all pixels, same for all images (same H, W, focal) + self.directions = get_ray_directions( + self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2 + ).to( + self.rank + ) # (h, w, 3) + + self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] + + for i, frame in enumerate(meta["frames"]): + c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4]) + self.all_c2w.append(c2w) + + img_path = os.path.join(self.config.root_dir, f"{frame['file_path']}.png") + img = Image.open(img_path) + img = img.resize(self.img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) + + self.all_fg_masks.append(img[..., -1]) # (h, w) + self.all_images.append(img[..., :3]) + + self.all_c2w, self.all_images, self.all_fg_masks = ( + torch.stack(self.all_c2w, dim=0).float().to(self.rank), + torch.stack(self.all_images, dim=0).float().to(self.rank), + torch.stack(self.all_fg_masks, dim=0).float().to(self.rank), + ) + + +class BlenderDataset(Dataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return {"index": index} + + +class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register("blender") +class VideoNVSDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, "fit"]: + self.train_dataset = BlenderIterableDataset( + self.config, self.config.train_split + ) + if stage in [None, "fit", "validate"]: + self.val_dataset = BlenderDataset(self.config, self.config.val_split) + if stage in [None, "test"]: + self.test_dataset = BlenderDataset(self.config, self.config.test_split) + if stage in [None, "predict"]: + self.predict_dataset = BlenderDataset(self.config, self.config.train_split) + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler, + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/mesh_recon/datasets/colmap.py b/mesh_recon/datasets/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..b6b389ebb09b8169019046ca8afbcce872e5d30a --- /dev/null +++ b/mesh_recon/datasets/colmap.py @@ -0,0 +1,332 @@ +import os +import math +import numpy as np +from PIL import Image + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF + +import pytorch_lightning as pl + +import datasets +from datasets.colmap_utils import \ + read_cameras_binary, read_images_binary, read_points3d_binary +from models.ray_utils import get_ray_directions +from utils.misc import get_rank + + +def get_center(pts): + center = pts.mean(0) + dis = (pts - center[None,:]).norm(p=2, dim=-1) + mean, std = dis.mean(), dis.std() + q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75) + valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5) + center = pts[valid].mean(0) + return center + +def normalize_poses(poses, pts, up_est_method, center_est_method): + if center_est_method == 'camera': + # estimation scene center as the average of all camera positions + center = poses[...,3].mean(0) + elif center_est_method == 'lookat': + # estimation scene center as the average of the intersection of selected pairs of camera rays + cams_ori = poses[...,3] + cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.]) + cams_dir = F.normalize(cams_dir, dim=-1) + A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1) + b = -cams_ori + cams_ori.roll(1,0) + t = torch.linalg.lstsq(A, b).solution + center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2)) + elif center_est_method == 'point': + # first estimation scene center as the average of all camera positions + # later we'll use the center of all points bounded by the cameras as the final scene center + center = poses[...,3].mean(0) + else: + raise NotImplementedError(f'Unknown center estimation method: {center_est_method}') + + if up_est_method == 'ground': + # estimate up direction as the normal of the estimated ground plane + # use RANSAC to estimate the ground plane in the point cloud + import pyransac3d as pyrsc + ground = pyrsc.Plane() + plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale + plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0 + z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction + signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1) + if signed_distance.mean() < 0: + z = -z # flip the direction if points lie under the plane + elif up_est_method == 'camera': + # estimate up direction as the average of all camera up directions + z = F.normalize((poses[...,3] - center).mean(0), dim=0) + else: + raise NotImplementedError(f'Unknown up estimation method: {up_est_method}') + + # new axis + y_ = torch.as_tensor([z[1], -z[0], 0.]) + x = F.normalize(y_.cross(z), dim=0) + y = z.cross(x) + + if center_est_method == 'point': + # rotation + Rc = torch.stack([x, y, z], dim=1) + R = Rc.T + poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) + inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) + poses_norm = (inv_trans @ poses_homo)[:,:3] + pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] + + # translation and scaling + poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0] + pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])] + center = get_center(pts_fg) + tc = center.reshape(3, 1) + t = -tc + poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1) + inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) + poses_norm = (inv_trans @ poses_homo)[:,:3] + scale = poses_norm[...,3].norm(p=2, dim=-1).min() + poses_norm[...,3] /= scale + pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] + pts = pts / scale + else: + # rotation and translation + Rc = torch.stack([x, y, z], dim=1) + tc = center.reshape(3, 1) + R, t = Rc.T, -Rc.T @ tc + poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) + inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) + poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4) + + # scaling + scale = poses_norm[...,3].norm(p=2, dim=-1).min() + poses_norm[...,3] /= scale + + # apply the transformation to the point cloud + pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] + pts = pts / scale + + return poses_norm, pts + +def create_spheric_poses(cameras, n_steps=120): + center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) + mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean() + mean_h = cameras[:,2].mean() + r = (mean_d**2 - mean_h**2).sqrt() + up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device) + + all_c2w = [] + for theta in torch.linspace(0, 2 * math.pi, n_steps): + cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h]) + l = F.normalize(center - cam_pos, p=2, dim=0) + s = F.normalize(l.cross(up), p=2, dim=0) + u = F.normalize(s.cross(l), p=2, dim=0) + c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) + all_c2w.append(c2w) + + all_c2w = torch.stack(all_c2w, dim=0) + + return all_c2w + +class ColmapDatasetBase(): + # the data only has to be processed once + initialized = False + properties = {} + + def setup(self, config, split): + self.config = config + self.split = split + self.rank = get_rank() + + if not ColmapDatasetBase.initialized: + camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin')) + + H = int(camdata[1].height) + W = int(camdata[1].width) + + if 'img_wh' in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif 'img_downscale' in self.config: + w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + img_wh = (w, h) + factor = w / W + + if camdata[1].model == 'SIMPLE_RADIAL': + fx = fy = camdata[1].params[0] * factor + cx = camdata[1].params[1] * factor + cy = camdata[1].params[2] * factor + elif camdata[1].model in ['PINHOLE', 'OPENCV']: + fx = camdata[1].params[0] * factor + fy = camdata[1].params[1] * factor + cx = camdata[1].params[2] * factor + cy = camdata[1].params[3] * factor + else: + raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") + + directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank) + + imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin')) + + mask_dir = os.path.join(self.config.root_dir, 'masks') + has_mask = os.path.exists(mask_dir) # TODO: support partial masks + apply_mask = has_mask and self.config.apply_mask + + all_c2w, all_images, all_fg_masks = [], [], [] + + for i, d in enumerate(imdata.values()): + R = d.qvec2rotmat() + t = d.tvec.reshape(3, 1) + c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float() + c2w[:,1:3] *= -1. # COLMAP => OpenGL + all_c2w.append(c2w) + if self.split in ['train', 'val']: + img_path = os.path.join(self.config.root_dir, 'images', d.name) + img = Image.open(img_path) + img = img.resize(img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] + img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu() + if has_mask: + mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])] + mask_paths = list(filter(os.path.exists, mask_paths)) + assert len(mask_paths) == 1 + mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1) + mask = mask.resize(img_wh, Image.BICUBIC) + mask = TF.to_tensor(mask)[0] + else: + mask = torch.ones_like(img[...,0], device=img.device) + all_fg_masks.append(mask) # (h, w) + all_images.append(img) + + all_c2w = torch.stack(all_c2w, dim=0) + + pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin')) + pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float() + all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method) + + ColmapDatasetBase.properties = { + 'w': w, + 'h': h, + 'img_wh': img_wh, + 'factor': factor, + 'has_mask': has_mask, + 'apply_mask': apply_mask, + 'directions': directions, + 'pts3d': pts3d, + 'all_c2w': all_c2w, + 'all_images': all_images, + 'all_fg_masks': all_fg_masks + } + + ColmapDatasetBase.initialized = True + + for k, v in ColmapDatasetBase.properties.items(): + setattr(self, k, v) + + if self.split == 'test': + self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) + self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) + self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) + else: + self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float() + + """ + # for debug use + from models.ray_utils import get_rays + rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True) + pts_out = [] + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()])) + + t_vals = torch.linspace(0, 1, 8) + z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals + + ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :]) + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()])) + + ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :]) + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) + + ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :]) + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) + + ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :]) + pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) + + open('cameras.txt', 'w').write('\n'.join(pts_out)) + open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()])) + + exit(1) + """ + + self.all_c2w = self.all_c2w.float().to(self.rank) + if self.config.load_data_on_gpu: + self.all_images = self.all_images.to(self.rank) + self.all_fg_masks = self.all_fg_masks.to(self.rank) + + +class ColmapDataset(Dataset, ColmapDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return { + 'index': index + } + + +class ColmapIterableDataset(IterableDataset, ColmapDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register('colmap') +class ColmapDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, 'fit']: + self.train_dataset = ColmapIterableDataset(self.config, 'train') + if stage in [None, 'fit', 'validate']: + self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train')) + if stage in [None, 'test']: + self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test')) + if stage in [None, 'predict']: + self.predict_dataset = ColmapDataset(self.config, 'train') + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/mesh_recon/datasets/colmap_utils.py b/mesh_recon/datasets/colmap_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5064d53fc4b3a738fc8ab6e52c7a5fee853d16 --- /dev/null +++ b/mesh_recon/datasets/colmap_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Author: Johannes L. Schoenberger (jsch at inf.ethz.ch) + +import os +import collections +import numpy as np +import struct + + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \ + for camera_model in CAMERA_MODELS]) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def read_cameras_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for camera_line_index in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def read_images_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for image_index in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_points3d_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for point_line_index in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, xyz=xyz, rgb=rgb, + error=error, image_ids=image_ids, + point2D_idxs=point2D_idxs) + return points3D + + +def read_model(path, ext): + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec diff --git a/mesh_recon/datasets/dtu.py b/mesh_recon/datasets/dtu.py new file mode 100644 index 0000000000000000000000000000000000000000..39e3a36c54e95ca436ca99cc1e4d94d291c52b11 --- /dev/null +++ b/mesh_recon/datasets/dtu.py @@ -0,0 +1,201 @@ +import os +import json +import math +import numpy as np +from PIL import Image +import cv2 + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ray_directions +from utils.misc import get_rank + + +def load_K_Rt_from_P(P=None): + out = cv2.decomposeProjectionMatrix(P) + K = out[0] + R = out[1] + t = out[2] + + K = K / K[2, 2] + intrinsics = np.eye(4) + intrinsics[:3, :3] = K + + pose = np.eye(4, dtype=np.float32) + pose[:3, :3] = R.transpose() + pose[:3, 3] = (t[:3] / t[3])[:, 0] + + return intrinsics, pose + +def create_spheric_poses(cameras, n_steps=120): + center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) + cam_center = F.normalize(cameras.mean(0), p=2, dim=-1) * cameras.mean(0).norm(2) + eigvecs = torch.linalg.eig(cameras.T @ cameras).eigenvectors + rot_axis = F.normalize(eigvecs[:,1].real.float(), p=2, dim=-1) + up = rot_axis + rot_dir = torch.cross(rot_axis, cam_center) + max_angle = (F.normalize(cameras, p=2, dim=-1) * F.normalize(cam_center, p=2, dim=-1)).sum(-1).acos().max() + + all_c2w = [] + for theta in torch.linspace(-max_angle, max_angle, n_steps): + cam_pos = cam_center * math.cos(theta) + rot_dir * math.sin(theta) + l = F.normalize(center - cam_pos, p=2, dim=0) + s = F.normalize(l.cross(up), p=2, dim=0) + u = F.normalize(s.cross(l), p=2, dim=0) + c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) + all_c2w.append(c2w) + + all_c2w = torch.stack(all_c2w, dim=0) + + return all_c2w + +class DTUDatasetBase(): + def setup(self, config, split): + self.config = config + self.split = split + self.rank = get_rank() + + cams = np.load(os.path.join(self.config.root_dir, self.config.cameras_file)) + + img_sample = cv2.imread(os.path.join(self.config.root_dir, 'image', '000000.png')) + H, W = img_sample.shape[0], img_sample.shape[1] + + if 'img_wh' in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif 'img_downscale' in self.config: + w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + self.w, self.h = w, h + self.img_wh = (w, h) + self.factor = w / W + + mask_dir = os.path.join(self.config.root_dir, 'mask') + self.has_mask = True + self.apply_mask = self.config.apply_mask + + self.directions = [] + self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] + + n_images = max([int(k.split('_')[-1]) for k in cams.keys()]) + 1 + + for i in range(n_images): + world_mat, scale_mat = cams[f'world_mat_{i}'], cams[f'scale_mat_{i}'] + P = (world_mat @ scale_mat)[:3,:4] + K, c2w = load_K_Rt_from_P(P) + fx, fy, cx, cy = K[0,0] * self.factor, K[1,1] * self.factor, K[0,2] * self.factor, K[1,2] * self.factor + directions = get_ray_directions(w, h, fx, fy, cx, cy) + self.directions.append(directions) + + c2w = torch.from_numpy(c2w).float() + + # blender follows opengl camera coordinates (right up back) + # NeuS DTU data coordinate system (right down front) is different from blender + # https://github.com/Totoro97/NeuS/issues/9 + # for c2w, flip the sign of input camera coordinate yz + c2w_ = c2w.clone() + c2w_[:3,1:3] *= -1. # flip input sign + self.all_c2w.append(c2w_[:3,:4]) + + if self.split in ['train', 'val']: + img_path = os.path.join(self.config.root_dir, 'image', f'{i:06d}.png') + img = Image.open(img_path) + img = img.resize(self.img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] + + mask_path = os.path.join(mask_dir, f'{i:03d}.png') + mask = Image.open(mask_path).convert('L') # (H, W, 1) + mask = mask.resize(self.img_wh, Image.BICUBIC) + mask = TF.to_tensor(mask)[0] + + self.all_fg_masks.append(mask) # (h, w) + self.all_images.append(img) + + self.all_c2w = torch.stack(self.all_c2w, dim=0) + + if self.split == 'test': + self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) + self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) + self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) + self.directions = self.directions[0] + else: + self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0), torch.stack(self.all_fg_masks, dim=0) + self.directions = torch.stack(self.directions, dim=0) + + self.directions = self.directions.float().to(self.rank) + self.all_c2w, self.all_images, self.all_fg_masks = \ + self.all_c2w.float().to(self.rank), \ + self.all_images.float().to(self.rank), \ + self.all_fg_masks.float().to(self.rank) + + +class DTUDataset(Dataset, DTUDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return { + 'index': index + } + + +class DTUIterableDataset(IterableDataset, DTUDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register('dtu') +class DTUDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, 'fit']: + self.train_dataset = DTUIterableDataset(self.config, 'train') + if stage in [None, 'fit', 'validate']: + self.val_dataset = DTUDataset(self.config, self.config.get('val_split', 'train')) + if stage in [None, 'test']: + self.test_dataset = DTUDataset(self.config, self.config.get('test_split', 'test')) + if stage in [None, 'predict']: + self.predict_dataset = DTUDataset(self.config, 'train') + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/mesh_recon/datasets/fixed_poses/000_back_RT.txt b/mesh_recon/datasets/fixed_poses/000_back_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..0b839ed2505438786e2d33bd779b77ed1eedb778 --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_back_RT.txt @@ -0,0 +1,3 @@ +-1.000000238418579102e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 +0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 1.746665105883948854e-07 +0.000000000000000000e+00 1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00 diff --git a/mesh_recon/datasets/fixed_poses/000_back_left_RT.txt b/mesh_recon/datasets/fixed_poses/000_back_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..97b10e711b1a86782cb69798051df209e8943b19 --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_back_left_RT.txt @@ -0,0 +1,3 @@ +-7.071069478988647461e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07 +0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08 +-7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 diff --git a/mesh_recon/datasets/fixed_poses/000_back_right_RT.txt b/mesh_recon/datasets/fixed_poses/000_back_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..0c7ce665f9ee958fe56e1589f52e4e772f3069e1 --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_back_right_RT.txt @@ -0,0 +1,3 @@ +-7.071069478988647461e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07 +0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 9.863901340168013121e-08 +7.071068286895751953e-01 7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 diff --git a/mesh_recon/datasets/fixed_poses/000_front_RT.txt b/mesh_recon/datasets/fixed_poses/000_front_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..67db8bce2207aabc0b8fcf9db25a0af8b9dd9e7b --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_front_RT.txt @@ -0,0 +1,3 @@ +1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 +0.000000000000000000e+00 -1.343588564850506373e-07 1.000000119209289551e+00 -1.746665105883948854e-07 +0.000000000000000000e+00 -1.000000119209289551e+00 -1.343588564850506373e-07 -1.300000071525573730e+00 diff --git a/mesh_recon/datasets/fixed_poses/000_front_left_RT.txt b/mesh_recon/datasets/fixed_poses/000_front_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..bed4b8cf8913b5fbf1ec092bceea4da0e4014133 --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_front_left_RT.txt @@ -0,0 +1,3 @@ +7.071067690849304199e-01 -7.071068286895751953e-01 0.000000000000000000e+00 -1.192092895507812500e-07 +0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08 +-7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 diff --git a/mesh_recon/datasets/fixed_poses/000_front_right_RT.txt b/mesh_recon/datasets/fixed_poses/000_front_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..56064b9ddb2afa5ae1db28cd70a93018c1f59c33 --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_front_right_RT.txt @@ -0,0 +1,3 @@ +7.071067690849304199e-01 7.071068286895751953e-01 0.000000000000000000e+00 1.192092895507812500e-07 +0.000000000000000000e+00 -7.587616579485256807e-08 1.000000119209289551e+00 -9.863901340168013121e-08 +7.071068286895751953e-01 -7.071068286895751953e-01 -7.587616579485256807e-08 -1.838477730751037598e+00 diff --git a/mesh_recon/datasets/fixed_poses/000_left_RT.txt b/mesh_recon/datasets/fixed_poses/000_left_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..465ebaee41f28ba09c6e44451a9c200d4c23bf95 --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_left_RT.txt @@ -0,0 +1,3 @@ +-2.220446049250313081e-16 -1.000000000000000000e+00 0.000000000000000000e+00 -2.886579758146288598e-16 +0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 +-1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00 diff --git a/mesh_recon/datasets/fixed_poses/000_right_RT.txt b/mesh_recon/datasets/fixed_poses/000_right_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..2a0c740f885267b285a6585ad4058536205181c5 --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_right_RT.txt @@ -0,0 +1,3 @@ +-2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 2.886579758146288598e-16 +0.000000000000000000e+00 -2.220446049250313081e-16 1.000000000000000000e+00 0.000000000000000000e+00 +1.000000000000000000e+00 0.000000000000000000e+00 -2.220446049250313081e-16 -1.299999952316284180e+00 diff --git a/mesh_recon/datasets/fixed_poses/000_top_RT.txt b/mesh_recon/datasets/fixed_poses/000_top_RT.txt new file mode 100644 index 0000000000000000000000000000000000000000..eba7ea36b7d091f390bae16d1428b52b5287bef0 --- /dev/null +++ b/mesh_recon/datasets/fixed_poses/000_top_RT.txt @@ -0,0 +1,3 @@ +1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 +0.000000000000000000e+00 1.000000000000000000e+00 0.000000000000000000e+00 0.000000000000000000e+00 +0.000000000000000000e+00 0.000000000000000000e+00 1.000000000000000000e+00 -1.299999952316284180e+00 diff --git a/mesh_recon/datasets/ortho.py b/mesh_recon/datasets/ortho.py new file mode 100644 index 0000000000000000000000000000000000000000..b29664e1ebda5baf64e57d56e21250cf4a7692ba --- /dev/null +++ b/mesh_recon/datasets/ortho.py @@ -0,0 +1,287 @@ +import os +import json +import math +import numpy as np +from PIL import Image +import cv2 + +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ortho_ray_directions_origins, get_ortho_rays, get_ray_directions +from utils.misc import get_rank + +from glob import glob +import PIL.Image + + +def camNormal2worldNormal(rot_c2w, camNormal): + H,W,_ = camNormal.shape + normal_img = np.matmul(rot_c2w[None, :, :], camNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + + return normal_img + +def worldNormal2camNormal(rot_w2c, worldNormal): + H,W,_ = worldNormal.shape + normal_img = np.matmul(rot_w2c[None, :, :], worldNormal.reshape(-1,3)[:, :, None]).reshape([H, W, 3]) + + return normal_img + +def trans_normal(normal, RT_w2c, RT_w2c_target): + + normal_world = camNormal2worldNormal(np.linalg.inv(RT_w2c[:3,:3]), normal) + normal_target_cam = worldNormal2camNormal(RT_w2c_target[:3,:3], normal_world) + + return normal_target_cam + +def img2normal(img): + return (img/255.)*2-1 + +def normal2img(normal): + return np.uint8((normal*0.5+0.5)*255) + +def norm_normalize(normal, dim=-1): + + normal = normal/(np.linalg.norm(normal, axis=dim, keepdims=True)+1e-6) + + return normal + +def RT_opengl2opencv(RT): + # Build the coordinate transform matrix from world to computer vision camera + # R_world2cv = R_bcam2cv@R_world2bcam + # T_world2cv = R_bcam2cv@T_world2bcam + + R = RT[:3, :3] + t = RT[:3, 3] + + R_bcam2cv = np.asarray([[1, 0, 0], [0, -1, 0], [0, 0, -1]], np.float32) + + R_world2cv = R_bcam2cv @ R + t_world2cv = R_bcam2cv @ t + + RT = np.concatenate([R_world2cv,t_world2cv[:,None]],1) + + return RT + +def normal_opengl2opencv(normal): + H,W,C = np.shape(normal) + # normal_img = np.reshape(normal, (H*W,C)) + R_bcam2cv = np.array([1, -1, -1], np.float32) + normal_cv = normal * R_bcam2cv[None, None, :] + + print(np.shape(normal_cv)) + + return normal_cv + +def inv_RT(RT): + RT_h = np.concatenate([RT, np.array([[0,0,0,1]])], axis=0) + RT_inv = np.linalg.inv(RT_h) + + return RT_inv[:3, :] + + +def load_a_prediction(root_dir, test_object, imSize, view_types, load_color=False, cam_pose_dir=None, + normal_system='front', erode_mask=True, camera_type='ortho', cam_params=None): + + all_images = [] + all_normals = [] + all_normals_world = [] + all_masks = [] + all_color_masks = [] + all_poses = [] + all_w2cs = [] + directions = [] + ray_origins = [] + + RT_front = np.loadtxt(glob(os.path.join(cam_pose_dir, '*_%s_RT.txt'%( 'front')))[0]) # world2cam matrix + RT_front_cv = RT_opengl2opencv(RT_front) # convert normal from opengl to opencv + for idx, view in enumerate(view_types): + print(os.path.join(root_dir,test_object)) + normal_filepath = os.path.join(root_dir, test_object, 'normals_000_%s.png'%( view)) + # Load key frame + if load_color: # use bgr + image =np.array(PIL.Image.open(normal_filepath.replace("normals", "rgb")).resize(imSize))[:, :, :3] + + normal = np.array(PIL.Image.open(normal_filepath).resize(imSize)) + mask = normal[:, :, 3] + normal = normal[:, :, :3] + + color_mask = np.array(PIL.Image.open(os.path.join(root_dir,test_object, 'masked_colors/rgb_000_%s.png'%( view))).resize(imSize))[:, :, 3] + invalid_color_mask = color_mask < 255*0.5 + threshold = np.ones_like(image[:, :, 0]) * 250 + invalid_white_mask = (image[:, :, 0] > threshold) & (image[:, :, 1] > threshold) & (image[:, :, 2] > threshold) + invalid_color_mask_final = invalid_color_mask & invalid_white_mask + color_mask = (1 - invalid_color_mask_final) > 0 + + # if erode_mask: + # kernel = np.ones((3, 3), np.uint8) + # mask = cv2.erode(mask, kernel, iterations=1) + + RT = np.loadtxt(os.path.join(cam_pose_dir, '000_%s_RT.txt'%( view))) # world2cam matrix + + normal = img2normal(normal) + + normal[mask==0] = [0,0,0] + mask = mask> (0.5*255) + if load_color: + all_images.append(image) + + all_masks.append(mask) + all_color_masks.append(color_mask) + RT_cv = RT_opengl2opencv(RT) # convert normal from opengl to opencv + all_poses.append(inv_RT(RT_cv)) # cam2world + all_w2cs.append(RT_cv) + + # whether to + normal_cam_cv = normal_opengl2opencv(normal) + + if normal_system == 'front': + print("the loaded normals are defined in the system of front view") + normal_world = camNormal2worldNormal(inv_RT(RT_front_cv)[:3, :3], normal_cam_cv) + elif normal_system == 'self': + print("the loaded normals are in their independent camera systems") + normal_world = camNormal2worldNormal(inv_RT(RT_cv)[:3, :3], normal_cam_cv) + all_normals.append(normal_cam_cv) + all_normals_world.append(normal_world) + + if camera_type == 'ortho': + origins, dirs = get_ortho_ray_directions_origins(W=imSize[0], H=imSize[1]) + elif camera_type == 'pinhole': + dirs = get_ray_directions(W=imSize[0], H=imSize[1], + fx=cam_params[0], fy=cam_params[1], cx=cam_params[2], cy=cam_params[3]) + origins = dirs # occupy a position + else: + raise Exception("not support camera type") + ray_origins.append(origins) + directions.append(dirs) + + + if not load_color: + all_images = [normal2img(x) for x in all_normals_world] + + + return np.stack(all_images), np.stack(all_masks), np.stack(all_normals), \ + np.stack(all_normals_world), np.stack(all_poses), np.stack(all_w2cs), np.stack(ray_origins), np.stack(directions), np.stack(all_color_masks) + + +class OrthoDatasetBase(): + def setup(self, config, split): + self.config = config + self.split = split + self.rank = get_rank() + + self.data_dir = self.config.root_dir + self.object_name = self.config.scene + self.scene = self.config.scene + self.imSize = self.config.imSize + self.load_color = True + self.img_wh = [self.imSize[0], self.imSize[1]] + self.w = self.img_wh[0] + self.h = self.img_wh[1] + self.camera_type = self.config.camera_type + self.camera_params = self.config.camera_params # [fx, fy, cx, cy] + + self.view_types = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] + + self.view_weights = torch.from_numpy(np.array(self.config.view_weights)).float().to(self.rank).view(-1) + self.view_weights = self.view_weights.view(-1,1,1).repeat(1, self.h, self.w) + + if self.config.cam_pose_dir is None: + self.cam_pose_dir = "./datasets/fixed_poses" + else: + self.cam_pose_dir = self.config.cam_pose_dir + + self.images_np, self.masks_np, self.normals_cam_np, self.normals_world_np, \ + self.pose_all_np, self.w2c_all_np, self.origins_np, self.directions_np, self.rgb_masks_np = load_a_prediction( + self.data_dir, self.object_name, self.imSize, self.view_types, + self.load_color, self.cam_pose_dir, normal_system='front', + camera_type=self.camera_type, cam_params=self.camera_params) + + self.has_mask = True + self.apply_mask = self.config.apply_mask + + self.all_c2w = torch.from_numpy(self.pose_all_np) + self.all_images = torch.from_numpy(self.images_np) / 255. + self.all_fg_masks = torch.from_numpy(self.masks_np) + self.all_rgb_masks = torch.from_numpy(self.rgb_masks_np) + self.all_normals_world = torch.from_numpy(self.normals_world_np) + self.origins = torch.from_numpy(self.origins_np) + self.directions = torch.from_numpy(self.directions_np) + + self.directions = self.directions.float().to(self.rank) + self.origins = self.origins.float().to(self.rank) + self.all_rgb_masks = self.all_rgb_masks.float().to(self.rank) + self.all_c2w, self.all_images, self.all_fg_masks, self.all_normals_world = \ + self.all_c2w.float().to(self.rank), \ + self.all_images.float().to(self.rank), \ + self.all_fg_masks.float().to(self.rank), \ + self.all_normals_world.float().to(self.rank) + + +class OrthoDataset(Dataset, OrthoDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return { + 'index': index + } + + +class OrthoIterableDataset(IterableDataset, OrthoDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register('ortho') +class OrthoDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, 'fit']: + self.train_dataset = OrthoIterableDataset(self.config, 'train') + if stage in [None, 'fit', 'validate']: + self.val_dataset = OrthoDataset(self.config, self.config.get('val_split', 'train')) + if stage in [None, 'test']: + self.test_dataset = OrthoDataset(self.config, self.config.get('test_split', 'test')) + if stage in [None, 'predict']: + self.predict_dataset = OrthoDataset(self.config, 'train') + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/mesh_recon/datasets/utils.py b/mesh_recon/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mesh_recon/datasets/v3d.py b/mesh_recon/datasets/v3d.py new file mode 100644 index 0000000000000000000000000000000000000000..5532605f786f7d14181b6c9d2b704af10d1c4396 --- /dev/null +++ b/mesh_recon/datasets/v3d.py @@ -0,0 +1,284 @@ +import os +import json +import math +import numpy as np +from PIL import Image + +import torch +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF +from torchvision.utils import make_grid, save_image +from einops import rearrange +from mediapy import read_video +from pathlib import Path +from rembg import remove, new_session + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ray_directions +from utils.misc import get_rank +from datasets.ortho import ( + inv_RT, + camNormal2worldNormal, + RT_opengl2opencv, + normal_opengl2opencv, +) +from utils.dpt import DPT + + +def get_c2w_from_up_and_look_at( + up, + look_at, + pos, + opengl=False, +): + up = up / np.linalg.norm(up) + z = look_at - pos + z = z / np.linalg.norm(z) + y = -up + x = np.cross(y, z) + x /= np.linalg.norm(x) + y = np.cross(z, x) + + c2w = np.zeros([4, 4], dtype=np.float32) + c2w[:3, 0] = x + c2w[:3, 1] = y + c2w[:3, 2] = z + c2w[:3, 3] = pos + c2w[3, 3] = 1.0 + + # opencv to opengl + if opengl: + c2w[..., 1:3] *= -1 + + return c2w + + +def get_uniform_poses(num_frames, radius, elevation, opengl=False): + T = num_frames + azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T]) + elevations = np.full_like(azimuths, np.deg2rad(elevation)) + cam_dists = np.full_like(azimuths, radius) + + campos = np.stack( + [ + cam_dists * np.cos(elevations) * np.cos(azimuths), + cam_dists * np.cos(elevations) * np.sin(azimuths), + cam_dists * np.sin(elevations), + ], + axis=-1, + ) + + center = np.array([0, 0, 0], dtype=np.float32) + up = np.array([0, 0, 1], dtype=np.float32) + poses = [] + for t in range(T): + poses.append(get_c2w_from_up_and_look_at(up, center, campos[t], opengl=opengl)) + + return np.stack(poses, axis=0) + + +def blender2midas(img): + """Blender: rub + midas: lub + """ + img[..., 0] = -img[..., 0] + img[..., 1] = -img[..., 1] + img[..., -1] = -img[..., -1] + return img + + +def midas2blender(img): + """Blender: rub + midas: lub + """ + img[..., 0] = -img[..., 0] + img[..., 1] = -img[..., 1] + img[..., -1] = -img[..., -1] + return img + + +class BlenderDatasetBase: + def setup(self, config, split): + self.config = config + self.rank = get_rank() + + self.has_mask = True + self.apply_mask = True + + dpt = DPT(device=self.rank, mode="normal") + + # with open( + # os.path.join( + # self.config.root_dir, self.config.scene, f"transforms_train.json" + # ), + # "r", + # ) as f: + # meta = json.load(f) + + # if "w" in meta and "h" in meta: + # W, H = int(meta["w"]), int(meta["h"]) + # else: + # W, H = 800, 800 + frames = read_video(Path(self.config.root_dir) / f"{self.config.scene}") + rembg_session = new_session() + num_frames, H, W = frames.shape[:3] + + if "img_wh" in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif "img_downscale" in self.config: + w, h = W // self.config.img_downscale, H // self.config.img_downscale + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + self.w, self.h = w, h + self.img_wh = (self.w, self.h) + + # self.near, self.far = self.config.near_plane, self.config.far_plane + + self.focal = 0.5 * w / math.tan(0.5 * np.deg2rad(60)) # scaled focal length + + # ray directions for all pixels, same for all images (same H, W, focal) + self.directions = get_ray_directions( + self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2 + ).to( + self.rank + ) # (h, w, 3) + + self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] + + radius = 2.0 + elevation = 0.0 + poses = get_uniform_poses(num_frames, radius, elevation, opengl=True) + for i, (c2w, frame) in enumerate(zip(poses, frames)): + c2w = torch.from_numpy(np.array(c2w)[:3, :4]) + self.all_c2w.append(c2w) + + img = Image.fromarray(frame) + img = remove(img, session=rembg_session) + img = img.resize(self.img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) + + self.all_fg_masks.append(img[..., -1]) # (h, w) + self.all_images.append(img[..., :3]) + + self.all_c2w, self.all_images, self.all_fg_masks = ( + torch.stack(self.all_c2w, dim=0).float().to(self.rank), + torch.stack(self.all_images, dim=0).float().to(self.rank), + torch.stack(self.all_fg_masks, dim=0).float().to(self.rank), + ) + + self.normals = dpt(self.all_images) + + self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1 + + self.normals = self.normals * 2.0 - 1.0 + self.normals = midas2blender(self.normals).cpu().numpy() + # self.normals = self.normals.cpu().numpy() + self.normals[..., 0] *= -1 + self.normals[~self.all_masks] = [0, 0, 0] + normals = rearrange(self.normals, "b h w c -> b c h w") + normals = normals * 0.5 + 0.5 + normals = torch.from_numpy(normals) + # save_image(make_grid(normals, nrow=4), "tmp/normals.png") + # exit(0) + + ( + self.all_poses, + self.all_normals, + self.all_normals_world, + self.all_w2cs, + self.all_color_masks, + ) = ([], [], [], [], []) + + for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals): + RT_opengl = inv_RT(c2w_opengl) + RT_opencv = RT_opengl2opencv(RT_opengl) + c2w_opencv = inv_RT(RT_opencv) + self.all_poses.append(c2w_opencv) + self.all_w2cs.append(RT_opencv) + normal = normal_opengl2opencv(normal) + normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal) + self.all_normals.append(normal) + self.all_normals_world.append(normal_world) + + self.directions = torch.stack([self.directions] * len(self.all_images)) + self.origins = self.directions + self.all_poses = np.stack(self.all_poses) + self.all_normals = np.stack(self.all_normals) + self.all_normals_world = np.stack(self.all_normals_world) + self.all_w2cs = np.stack(self.all_w2cs) + + self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank) + self.all_images = self.all_images.to(self.rank) + self.all_fg_masks = self.all_fg_masks.to(self.rank) + self.all_rgb_masks = self.all_fg_masks.to(self.rank) + self.all_normals_world = ( + torch.from_numpy(self.all_normals_world).float().to(self.rank) + ) + + +class BlenderDataset(Dataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return {"index": index} + + +class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register("v3d") +class BlenderDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, "fit"]: + self.train_dataset = BlenderIterableDataset( + self.config, self.config.train_split + ) + if stage in [None, "fit", "validate"]: + self.val_dataset = BlenderDataset(self.config, self.config.val_split) + if stage in [None, "test"]: + self.test_dataset = BlenderDataset(self.config, self.config.test_split) + if stage in [None, "predict"]: + self.predict_dataset = BlenderDataset(self.config, self.config.train_split) + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler, + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/mesh_recon/datasets/videonvs.py b/mesh_recon/datasets/videonvs.py new file mode 100644 index 0000000000000000000000000000000000000000..6db218cde6308e56e868dac488a5bb962fff0eb2 --- /dev/null +++ b/mesh_recon/datasets/videonvs.py @@ -0,0 +1,256 @@ +import os +import json +import math +import numpy as np +from PIL import Image + +import torch +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF +from torchvision.utils import make_grid, save_image +from einops import rearrange + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ray_directions +from utils.misc import get_rank +from datasets.ortho import ( + inv_RT, + camNormal2worldNormal, + RT_opengl2opencv, + normal_opengl2opencv, +) +from utils.dpt import DPT + + +def blender2midas(img): + """Blender: rub + midas: lub + """ + img[..., 0] = -img[..., 0] + img[..., 1] = -img[..., 1] + img[..., -1] = -img[..., -1] + return img + + +def midas2blender(img): + """Blender: rub + midas: lub + """ + img[..., 0] = -img[..., 0] + img[..., 1] = -img[..., 1] + img[..., -1] = -img[..., -1] + return img + + +class BlenderDatasetBase: + def setup(self, config, split): + self.config = config + self.rank = get_rank() + + self.has_mask = True + self.apply_mask = True + + dpt = DPT(device=self.rank, mode="normal") + + with open( + os.path.join( + self.config.root_dir, self.config.scene, f"transforms_train.json" + ), + "r", + ) as f: + meta = json.load(f) + + if "w" in meta and "h" in meta: + W, H = int(meta["w"]), int(meta["h"]) + else: + W, H = 800, 800 + + if "img_wh" in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif "img_downscale" in self.config: + w, h = W // self.config.img_downscale, H // self.config.img_downscale + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + self.w, self.h = w, h + self.img_wh = (self.w, self.h) + + # self.near, self.far = self.config.near_plane, self.config.far_plane + + self.focal = ( + 0.5 * w / math.tan(0.5 * meta["camera_angle_x"]) + ) # scaled focal length + + # ray directions for all pixels, same for all images (same H, W, focal) + self.directions = get_ray_directions( + self.w, self.h, self.focal, self.focal, self.w // 2, self.h // 2 + ).to( + self.rank + ) # (h, w, 3) + + self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] + + for i, frame in enumerate(meta["frames"]): + c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4]) + self.all_c2w.append(c2w) + + img_path = os.path.join( + self.config.root_dir, + self.config.scene, + f"{frame['file_path']}.png", + ) + img = Image.open(img_path) + img = img.resize(self.img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) + + self.all_fg_masks.append(img[..., -1]) # (h, w) + self.all_images.append(img[..., :3]) + + self.all_c2w, self.all_images, self.all_fg_masks = ( + torch.stack(self.all_c2w, dim=0).float().to(self.rank), + torch.stack(self.all_images, dim=0).float().to(self.rank), + torch.stack(self.all_fg_masks, dim=0).float().to(self.rank), + ) + + self.normals = dpt(self.all_images) + + self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1 + + self.normals = self.normals * 2.0 - 1.0 + self.normals = midas2blender(self.normals).cpu().numpy() + # self.normals = self.normals.cpu().numpy() + self.normals[..., 0] *= -1 + self.normals[~self.all_masks] = [0, 0, 0] + normals = rearrange(self.normals, "b h w c -> b c h w") + normals = normals * 0.5 + 0.5 + normals = torch.from_numpy(normals) + save_image(make_grid(normals, nrow=4), "tmp/normals.png") + # exit(0) + + ( + self.all_poses, + self.all_normals, + self.all_normals_world, + self.all_w2cs, + self.all_color_masks, + ) = ([], [], [], [], []) + + for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals): + RT_opengl = inv_RT(c2w_opengl) + RT_opencv = RT_opengl2opencv(RT_opengl) + c2w_opencv = inv_RT(RT_opencv) + self.all_poses.append(c2w_opencv) + self.all_w2cs.append(RT_opencv) + normal = normal_opengl2opencv(normal) + normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal) + self.all_normals.append(normal) + self.all_normals_world.append(normal_world) + + self.directions = torch.stack([self.directions] * len(self.all_images)) + self.origins = self.directions + self.all_poses = np.stack(self.all_poses) + self.all_normals = np.stack(self.all_normals) + self.all_normals_world = np.stack(self.all_normals_world) + self.all_w2cs = np.stack(self.all_w2cs) + + self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank) + self.all_images = self.all_images.to(self.rank) + self.all_fg_masks = self.all_fg_masks.to(self.rank) + self.all_rgb_masks = self.all_fg_masks.to(self.rank) + self.all_normals_world = ( + torch.from_numpy(self.all_normals_world).float().to(self.rank) + ) + + # normals = rearrange(self.all_normals_world, "b h w c -> b c h w") + # normals = normals * 0.5 + 0.5 + # # normals = torch.from_numpy(normals) + # save_image(make_grid(normals, nrow=4), "tmp/normals_world.png") + # # exit(0) + + # # normals = (normals + 1) / 2.0 + # # for debug + # index = [0, 9] + # self.all_poses = self.all_poses[index] + # self.all_c2w = self.all_c2w[index] + # self.all_normals_world = self.all_normals_world[index] + # self.all_w2cs = self.all_w2cs[index] + # self.rgb_masks = self.all_rgb_masks[index] + # self.fg_masks = self.all_fg_masks[index] + # self.all_images = self.all_images[index] + # self.directions = self.directions[index] + # self.origins = self.origins[index] + + # images = rearrange(self.all_images, "b h w c -> b c h w") + # normals = rearrange(normals, "b h w c -> b c h w") + # save_image(make_grid(images, nrow=4), "tmp/images.png") + # save_image(make_grid(normals, nrow=4), "tmp/normals.png") + # breakpoint() + + # self.normals = self.normals * 2.0 - 1.0 + + +class BlenderDataset(Dataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return {"index": index} + + +class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register("videonvs") +class BlenderDataModule(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, "fit"]: + self.train_dataset = BlenderIterableDataset( + self.config, self.config.train_split + ) + if stage in [None, "fit", "validate"]: + self.val_dataset = BlenderDataset(self.config, self.config.val_split) + if stage in [None, "test"]: + self.test_dataset = BlenderDataset(self.config, self.config.test_split) + if stage in [None, "predict"]: + self.predict_dataset = BlenderDataset(self.config, self.config.train_split) + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler, + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/mesh_recon/datasets/videonvs_co3d.py b/mesh_recon/datasets/videonvs_co3d.py new file mode 100644 index 0000000000000000000000000000000000000000..98521cdeae6341e7422cdf6188d00abfb569d558 --- /dev/null +++ b/mesh_recon/datasets/videonvs_co3d.py @@ -0,0 +1,252 @@ +import os +import json +import math +import numpy as np +from PIL import Image + +import torch +from torch.utils.data import Dataset, DataLoader, IterableDataset +import torchvision.transforms.functional as TF +from torchvision.utils import make_grid, save_image +from einops import rearrange +from rembg import remove, new_session + +import pytorch_lightning as pl + +import datasets +from models.ray_utils import get_ray_directions +from utils.misc import get_rank +from datasets.ortho import ( + inv_RT, + camNormal2worldNormal, + RT_opengl2opencv, + normal_opengl2opencv, +) +from utils.dpt import DPT + + +def blender2midas(img): + """Blender: rub + midas: lub + """ + img[..., 0] = -img[..., 0] + img[..., 1] = -img[..., 1] + img[..., -1] = -img[..., -1] + return img + + +def midas2blender(img): + """Blender: rub + midas: lub + """ + img[..., 0] = -img[..., 0] + img[..., 1] = -img[..., 1] + img[..., -1] = -img[..., -1] + return img + + +class BlenderDatasetBase: + def setup(self, config, split): + self.config = config + self.rank = get_rank() + + self.has_mask = True + self.apply_mask = True + + dpt = DPT(device=self.rank, mode="normal") + + self.directions = [] + with open( + os.path.join(self.config.root_dir, self.config.scene, f"transforms.json"), + "r", + ) as f: + meta = json.load(f) + + if "w" in meta and "h" in meta: + W, H = int(meta["w"]), int(meta["h"]) + else: + W, H = 800, 800 + + if "img_wh" in self.config: + w, h = self.config.img_wh + assert round(W / w * h) == H + elif "img_downscale" in self.config: + w, h = W // self.config.img_downscale, H // self.config.img_downscale + else: + raise KeyError("Either img_wh or img_downscale should be specified.") + + self.w, self.h = w, h + self.img_wh = (self.w, self.h) + + # self.near, self.far = self.config.near_plane, self.config.far_plane + _session = new_session() + self.all_c2w, self.all_images, self.all_fg_masks = [], [], [] + + for i, frame in enumerate(meta["frames"]): + c2w = torch.from_numpy(np.array(frame["transform_matrix"])[:3, :4]) + self.all_c2w.append(c2w) + + img_path = os.path.join( + self.config.root_dir, + self.config.scene, + f"{frame['file_path']}", + ) + img = Image.open(img_path) + img = remove(img, session=_session) + img = img.resize(self.img_wh, Image.BICUBIC) + img = TF.to_tensor(img).permute(1, 2, 0) # (4, h, w) => (h, w, 4) + fx = frame["fl_x"] + fy = frame["fl_y"] + cx = frame["cx"] + cy = frame["cy"] + + self.all_fg_masks.append(img[..., -1]) # (h, w) + self.all_images.append(img[..., :3]) + + self.directions.append(get_ray_directions(self.w, self.h, fx, fy, cx, cy)) + + self.all_c2w, self.all_images, self.all_fg_masks = ( + torch.stack(self.all_c2w, dim=0).float().to(self.rank), + torch.stack(self.all_images, dim=0).float().to(self.rank), + torch.stack(self.all_fg_masks, dim=0).float().to(self.rank), + ) + + self.normals = dpt(self.all_images) + + self.all_masks = self.all_fg_masks.cpu().numpy() > 0.1 + + self.normals = self.normals * 2.0 - 1.0 + self.normals = midas2blender(self.normals).cpu().numpy() + # self.normals = self.normals.cpu().numpy() + self.normals[..., 0] *= -1 + self.normals[~self.all_masks] = [0, 0, 0] + normals = rearrange(self.normals, "b h w c -> b c h w") + normals = normals * 0.5 + 0.5 + normals = torch.from_numpy(normals) + save_image(make_grid(normals, nrow=4), "tmp/normals.png") + # exit(0) + + ( + self.all_poses, + self.all_normals, + self.all_normals_world, + self.all_w2cs, + self.all_color_masks, + ) = ([], [], [], [], []) + + for c2w_opengl, normal in zip(self.all_c2w.cpu().numpy(), self.normals): + RT_opengl = inv_RT(c2w_opengl) + RT_opencv = RT_opengl2opencv(RT_opengl) + c2w_opencv = inv_RT(RT_opencv) + self.all_poses.append(c2w_opencv) + self.all_w2cs.append(RT_opencv) + normal = normal_opengl2opencv(normal) + normal_world = camNormal2worldNormal(inv_RT(RT_opencv)[:3, :3], normal) + self.all_normals.append(normal) + self.all_normals_world.append(normal_world) + + self.directions = torch.stack(self.directions).to(self.rank) + self.origins = self.directions + self.all_poses = np.stack(self.all_poses) + self.all_normals = np.stack(self.all_normals) + self.all_normals_world = np.stack(self.all_normals_world) + self.all_w2cs = np.stack(self.all_w2cs) + + self.all_c2w = torch.from_numpy(self.all_poses).float().to(self.rank) + self.all_images = self.all_images.to(self.rank) + self.all_fg_masks = self.all_fg_masks.to(self.rank) + self.all_rgb_masks = self.all_fg_masks.to(self.rank) + self.all_normals_world = ( + torch.from_numpy(self.all_normals_world).float().to(self.rank) + ) + + # normals = rearrange(self.all_normals_world, "b h w c -> b c h w") + # normals = normals * 0.5 + 0.5 + # # normals = torch.from_numpy(normals) + # save_image(make_grid(normals, nrow=4), "tmp/normals_world.png") + # # exit(0) + + # # normals = (normals + 1) / 2.0 + # # for debug + # index = [0, 9] + # self.all_poses = self.all_poses[index] + # self.all_c2w = self.all_c2w[index] + # self.all_normals_world = self.all_normals_world[index] + # self.all_w2cs = self.all_w2cs[index] + # self.rgb_masks = self.all_rgb_masks[index] + # self.fg_masks = self.all_fg_masks[index] + # self.all_images = self.all_images[index] + # self.directions = self.directions[index] + # self.origins = self.origins[index] + + # images = rearrange(self.all_images, "b h w c -> b c h w") + # normals = rearrange(normals, "b h w c -> b c h w") + # save_image(make_grid(images, nrow=4), "tmp/images.png") + # save_image(make_grid(normals, nrow=4), "tmp/normals.png") + # breakpoint() + + # self.normals = self.normals * 2.0 - 1.0 + + +class BlenderDataset(Dataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __len__(self): + return len(self.all_images) + + def __getitem__(self, index): + return {"index": index} + + +class BlenderIterableDataset(IterableDataset, BlenderDatasetBase): + def __init__(self, config, split): + self.setup(config, split) + + def __iter__(self): + while True: + yield {} + + +@datasets.register("videonvs-scene") +class VideoNVSScene(pl.LightningDataModule): + def __init__(self, config): + super().__init__() + self.config = config + + def setup(self, stage=None): + if stage in [None, "fit"]: + self.train_dataset = BlenderIterableDataset( + self.config, self.config.train_split + ) + if stage in [None, "fit", "validate"]: + self.val_dataset = BlenderDataset(self.config, self.config.val_split) + if stage in [None, "test"]: + self.test_dataset = BlenderDataset(self.config, self.config.test_split) + if stage in [None, "predict"]: + self.predict_dataset = BlenderDataset(self.config, self.config.train_split) + + def prepare_data(self): + pass + + def general_loader(self, dataset, batch_size): + sampler = None + return DataLoader( + dataset, + num_workers=os.cpu_count(), + batch_size=batch_size, + pin_memory=True, + sampler=sampler, + ) + + def train_dataloader(self): + return self.general_loader(self.train_dataset, batch_size=1) + + def val_dataloader(self): + return self.general_loader(self.val_dataset, batch_size=1) + + def test_dataloader(self): + return self.general_loader(self.test_dataset, batch_size=1) + + def predict_dataloader(self): + return self.general_loader(self.predict_dataset, batch_size=1) diff --git a/mesh_recon/launch.py b/mesh_recon/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..c09238df7f03d898f661c893c7bd5444b2520505 --- /dev/null +++ b/mesh_recon/launch.py @@ -0,0 +1,144 @@ +import sys +import argparse +import os +import time +import logging +from datetime import datetime + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True, help="path to config file") + parser.add_argument("--gpu", default="0", help="GPU(s) to be used") + parser.add_argument( + "--resume", default=None, help="path to the weights to be resumed" + ) + parser.add_argument( + "--resume_weights_only", + action="store_true", + help="specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only", + ) + + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--train", action="store_true") + group.add_argument("--validate", action="store_true") + group.add_argument("--test", action="store_true") + group.add_argument("--predict", action="store_true") + # group.add_argument('--export', action='store_true') # TODO: a separate export action + + parser.add_argument("--exp_dir", default="./exp") + parser.add_argument("--runs_dir", default="./runs") + parser.add_argument( + "--verbose", action="store_true", help="if true, set logging level to DEBUG" + ) + + args, extras = parser.parse_known_args() + + # set CUDA_VISIBLE_DEVICES then import pytorch-lightning + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu + n_gpus = len(args.gpu.split(",")) + + import datasets + import systems + import pytorch_lightning as pl + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor + from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger + from utils.callbacks import ( + CodeSnapshotCallback, + ConfigSnapshotCallback, + CustomProgressBar, + ) + from utils.misc import load_config + + # parse YAML config to OmegaConf + config = load_config(args.config, cli_args=extras) + config.cmd_args = vars(args) + + config.trial_name = config.get("trial_name") or ( + config.tag + datetime.now().strftime("@%Y%m%d-%H%M%S") + ) + config.exp_dir = config.get("exp_dir") or os.path.join(args.exp_dir, config.name) + config.save_dir = config.get("save_dir") or os.path.join( + config.exp_dir, config.trial_name, "save" + ) + config.ckpt_dir = config.get("ckpt_dir") or os.path.join( + config.exp_dir, config.trial_name, "ckpt" + ) + config.code_dir = config.get("code_dir") or os.path.join( + config.exp_dir, config.trial_name, "code" + ) + config.config_dir = config.get("config_dir") or os.path.join( + config.exp_dir, config.trial_name, "config" + ) + + logger = logging.getLogger("pytorch_lightning") + if args.verbose: + logger.setLevel(logging.DEBUG) + + if "seed" not in config: + config.seed = int(time.time() * 1000) % 1000 + pl.seed_everything(config.seed) + + dm = datasets.make(config.dataset.name, config.dataset) + system = systems.make( + config.system.name, + config, + load_from_checkpoint=None if not args.resume_weights_only else args.resume, + ) + + callbacks = [] + if args.train: + callbacks += [ + ModelCheckpoint(dirpath=config.ckpt_dir, **config.checkpoint), + LearningRateMonitor(logging_interval="step"), + # CodeSnapshotCallback( + # config.code_dir, use_version=False + # ), + ConfigSnapshotCallback(config, config.config_dir, use_version=False), + CustomProgressBar(refresh_rate=1), + ] + + loggers = [] + if args.train: + loggers += [ + TensorBoardLogger( + args.runs_dir, name=config.name, version=config.trial_name + ), + CSVLogger(config.exp_dir, name=config.trial_name, version="csv_logs"), + ] + + if sys.platform == "win32": + # does not support multi-gpu on windows + strategy = "dp" + assert n_gpus == 1 + else: + strategy = "ddp_find_unused_parameters_false" + + trainer = Trainer( + devices=n_gpus, + accelerator="gpu", + callbacks=callbacks, + logger=loggers, + strategy=strategy, + **config.trainer + ) + + if args.train: + if args.resume and not args.resume_weights_only: + # FIXME: different behavior in pytorch-lighting>1.9 ? + trainer.fit(system, datamodule=dm, ckpt_path=args.resume) + else: + trainer.fit(system, datamodule=dm) + trainer.test(system, datamodule=dm) + elif args.validate: + trainer.validate(system, datamodule=dm, ckpt_path=args.resume) + elif args.test: + trainer.test(system, datamodule=dm, ckpt_path=args.resume) + elif args.predict: + trainer.predict(system, datamodule=dm, ckpt_path=args.resume) + + +if __name__ == "__main__": + main() diff --git a/mesh_recon/mesh.py b/mesh_recon/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..225a271023da1f5fc3e6a709ce465a6fa91afb5d --- /dev/null +++ b/mesh_recon/mesh.py @@ -0,0 +1,845 @@ +import os +import cv2 +import torch +import trimesh +import numpy as np + +from kiui.op import safe_normalize, dot +from kiui.typing import * + +class Mesh: + """ + A torch-native trimesh class, with support for ``ply/obj/glb`` formats. + + Note: + This class only supports one mesh with a single texture image (an albedo texture and a metallic-roughness texture). + """ + def __init__( + self, + v: Optional[Tensor] = None, + f: Optional[Tensor] = None, + vn: Optional[Tensor] = None, + fn: Optional[Tensor] = None, + vt: Optional[Tensor] = None, + ft: Optional[Tensor] = None, + vc: Optional[Tensor] = None, # vertex color + albedo: Optional[Tensor] = None, + metallicRoughness: Optional[Tensor] = None, + device: Optional[torch.device] = None, + ): + """Init a mesh directly using all attributes. + + Args: + v (Optional[Tensor]): vertices, float [N, 3]. Defaults to None. + f (Optional[Tensor]): faces, int [M, 3]. Defaults to None. + vn (Optional[Tensor]): vertex normals, float [N, 3]. Defaults to None. + fn (Optional[Tensor]): faces for normals, int [M, 3]. Defaults to None. + vt (Optional[Tensor]): vertex uv coordinates, float [N, 2]. Defaults to None. + ft (Optional[Tensor]): faces for uvs, int [M, 3]. Defaults to None. + vc (Optional[Tensor]): vertex colors, float [N, 3]. Defaults to None. + albedo (Optional[Tensor]): albedo texture, float [H, W, 3], RGB format. Defaults to None. + metallicRoughness (Optional[Tensor]): metallic-roughness texture, float [H, W, 3], metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1]. Defaults to None. + device (Optional[torch.device]): torch device. Defaults to None. + """ + self.device = device + self.v = v + self.vn = vn + self.vt = vt + self.f = f + self.fn = fn + self.ft = ft + # will first see if there is vertex color to use + self.vc = vc + # only support a single albedo image + self.albedo = albedo + # pbr extension, metallic(Blue) = metallicRoughness[..., 2], roughness(Green) = metallicRoughness[..., 1] + # ref: https://registry.khronos.org/glTF/specs/2.0/glTF-2.0.html + self.metallicRoughness = metallicRoughness + + self.ori_center = 0 + self.ori_scale = 1 + + @classmethod + def load(cls, path, resize=True, clean=False, renormal=True, retex=False, bound=0.9, front_dir='+z', **kwargs): + """load mesh from path. + + Args: + path (str): path to mesh file, supports ply, obj, glb. + clean (bool, optional): perform mesh cleaning at load (e.g., merge close vertices). Defaults to False. + resize (bool, optional): auto resize the mesh using ``bound`` into [-bound, bound]^3. Defaults to True. + renormal (bool, optional): re-calc the vertex normals. Defaults to True. + retex (bool, optional): re-calc the uv coordinates, will overwrite the existing uv coordinates. Defaults to False. + bound (float, optional): bound to resize. Defaults to 0.9. + front_dir (str, optional): front-view direction of the mesh, should be [+-][xyz][ 123]. Defaults to '+z'. + device (torch.device, optional): torch device. Defaults to None. + + Note: + a ``device`` keyword argument can be provided to specify the torch device. + If it's not provided, we will try to use ``'cuda'`` as the device if it's available. + + Returns: + Mesh: the loaded Mesh object. + """ + # obj supports face uv + if path.endswith(".obj"): + mesh = cls.load_obj(path, **kwargs) + # trimesh only supports vertex uv, but can load more formats + else: + mesh = cls.load_trimesh(path, **kwargs) + + # clean + if clean: + from kiui.mesh_utils import clean_mesh + vertices = mesh.v.detach().cpu().numpy() + triangles = mesh.f.detach().cpu().numpy() + vertices, triangles = clean_mesh(vertices, triangles, remesh=False) + mesh.v = torch.from_numpy(vertices).contiguous().float().to(mesh.device) + mesh.f = torch.from_numpy(triangles).contiguous().int().to(mesh.device) + + print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}") + # auto-normalize + if resize: + mesh.auto_size(bound=bound) + # auto-fix normal + if renormal or mesh.vn is None: + mesh.auto_normal() + print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}") + # auto-fix texcoords + if retex or (mesh.albedo is not None and mesh.vt is None): + mesh.auto_uv(cache_path=path) + print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}") + + # rotate front dir to +z + if front_dir != "+z": + # axis switch + if "-z" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, -1]], device=mesh.device, dtype=torch.float32) + elif "+x" in front_dir: + T = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) + elif "-x" in front_dir: + T = torch.tensor([[0, 0, -1], [0, 1, 0], [1, 0, 0]], device=mesh.device, dtype=torch.float32) + elif "+y" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) + elif "-y" in front_dir: + T = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=mesh.device, dtype=torch.float32) + else: + T = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + # rotation (how many 90 degrees) + if '1' in front_dir: + T @= torch.tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + elif '2' in front_dir: + T @= torch.tensor([[1, 0, 0], [0, -1, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + elif '3' in front_dir: + T @= torch.tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]], device=mesh.device, dtype=torch.float32) + mesh.v @= T + mesh.vn @= T + + return mesh + + # load from obj file + @classmethod + def load_obj(cls, path, albedo_path=None, device=None): + """load an ``obj`` mesh. + + Args: + path (str): path to mesh. + albedo_path (str, optional): path to the albedo texture image, will overwrite the existing texture path if specified in mtl. Defaults to None. + device (torch.device, optional): torch device. Defaults to None. + + Note: + We will try to read `mtl` path from `obj`, else we assume the file name is the same as `obj` but with `mtl` extension. + The `usemtl` statement is ignored, and we only use the last material path in `mtl` file. + + Returns: + Mesh: the loaded Mesh object. + """ + assert os.path.splitext(path)[-1] == ".obj" + + mesh = cls() + + # device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + mesh.device = device + + # load obj + with open(path, "r") as f: + lines = f.readlines() + + def parse_f_v(fv): + # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided) + # supported forms: + # f v1 v2 v3 + # f v1/vt1 v2/vt2 v3/vt3 + # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3 + # f v1//vn1 v2//vn2 v3//vn3 + xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")] + xs.extend([-1] * (3 - len(xs))) + return xs[0], xs[1], xs[2] + + vertices, texcoords, normals = [], [], [] + faces, tfaces, nfaces = [], [], [] + mtl_path = None + + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: + continue + prefix = split_line[0].lower() + # mtllib + if prefix == "mtllib": + mtl_path = split_line[1] + # usemtl + elif prefix == "usemtl": + pass # ignored + # v/vn/vt + elif prefix == "v": + vertices.append([float(v) for v in split_line[1:]]) + elif prefix == "vn": + normals.append([float(v) for v in split_line[1:]]) + elif prefix == "vt": + val = [float(v) for v in split_line[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + elif prefix == "f": + vs = split_line[1:] + nv = len(vs) + v0, t0, n0 = parse_f_v(vs[0]) + for i in range(nv - 2): # triangulate (assume vertices are ordered) + v1, t1, n1 = parse_f_v(vs[i + 1]) + v2, t2, n2 = parse_f_v(vs[i + 2]) + faces.append([v0, v1, v2]) + tfaces.append([t0, t1, t2]) + nfaces.append([n0, n1, n2]) + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if len(texcoords) > 0 + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if len(normals) > 0 + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if len(texcoords) > 0 + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if len(normals) > 0 + else None + ) + + # see if there is vertex color + use_vertex_color = False + if mesh.v.shape[1] == 6: + use_vertex_color = True + mesh.vc = mesh.v[:, 3:] + mesh.v = mesh.v[:, :3] + print(f"[load_obj] use vertex color: {mesh.vc.shape}") + + # try to load texture image + if not use_vertex_color: + # try to retrieve mtl file + mtl_path_candidates = [] + if mtl_path is not None: + mtl_path_candidates.append(mtl_path) + mtl_path_candidates.append(os.path.join(os.path.dirname(path), mtl_path)) + mtl_path_candidates.append(path.replace(".obj", ".mtl")) + + mtl_path = None + for candidate in mtl_path_candidates: + if os.path.exists(candidate): + mtl_path = candidate + break + + # if albedo_path is not provided, try retrieve it from mtl + metallic_path = None + roughness_path = None + if mtl_path is not None and albedo_path is None: + with open(mtl_path, "r") as f: + lines = f.readlines() + + for line in lines: + split_line = line.split() + # empty line + if len(split_line) == 0: + continue + prefix = split_line[0] + + if "map_Kd" in prefix: + # assume relative path! + albedo_path = os.path.join(os.path.dirname(path), split_line[1]) + print(f"[load_obj] use texture from: {albedo_path}") + elif "map_Pm" in prefix: + metallic_path = os.path.join(os.path.dirname(path), split_line[1]) + elif "map_Pr" in prefix: + roughness_path = os.path.join(os.path.dirname(path), split_line[1]) + + # still not found albedo_path, or the path doesn't exist + if albedo_path is None or not os.path.exists(albedo_path): + # init an empty texture + print(f"[load_obj] init empty albedo!") + # albedo = np.random.rand(1024, 1024, 3).astype(np.float32) + albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) # default color + else: + albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED) + albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB) + albedo = albedo.astype(np.float32) / 255 + print(f"[load_obj] load texture: {albedo.shape}") + + mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device) + + # try to load metallic and roughness + if metallic_path is not None and roughness_path is not None: + print(f"[load_obj] load metallicRoughness from: {metallic_path}, {roughness_path}") + metallic = cv2.imread(metallic_path, cv2.IMREAD_UNCHANGED) + metallic = metallic.astype(np.float32) / 255 + roughness = cv2.imread(roughness_path, cv2.IMREAD_UNCHANGED) + roughness = roughness.astype(np.float32) / 255 + metallicRoughness = np.stack([np.zeros_like(metallic), roughness, metallic], axis=-1) + + mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() + + return mesh + + @classmethod + def load_trimesh(cls, path, device=None): + """load a mesh using ``trimesh.load()``. + + Can load various formats like ``glb`` and serves as a fallback. + + Note: + We will try to merge all meshes if the glb contains more than one, + but **this may cause the texture to lose**, since we only support one texture image! + + Args: + path (str): path to the mesh file. + device (torch.device, optional): torch device. Defaults to None. + + Returns: + Mesh: the loaded Mesh object. + """ + mesh = cls() + + # device + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + mesh.device = device + + # use trimesh to load ply/glb + _data = trimesh.load(path) + if isinstance(_data, trimesh.Scene): + if len(_data.geometry) == 1: + _mesh = list(_data.geometry.values())[0] + else: + print(f"[load_trimesh] concatenating {len(_data.geometry)} meshes.") + _concat = [] + # loop the scene graph and apply transform to each mesh + scene_graph = _data.graph.to_flattened() # dict {name: {transform: 4x4 mat, geometry: str}} + for k, v in scene_graph.items(): + name = v['geometry'] + if name in _data.geometry and isinstance(_data.geometry[name], trimesh.Trimesh): + transform = v['transform'] + _concat.append(_data.geometry[name].apply_transform(transform)) + _mesh = trimesh.util.concatenate(_concat) + else: + _mesh = _data + + if _mesh.visual.kind == 'vertex': + vertex_colors = _mesh.visual.vertex_colors + vertex_colors = np.array(vertex_colors[..., :3]).astype(np.float32) / 255 + mesh.vc = torch.tensor(vertex_colors, dtype=torch.float32, device=device) + print(f"[load_trimesh] use vertex color: {mesh.vc.shape}") + elif _mesh.visual.kind == 'texture': + _material = _mesh.visual.material + if isinstance(_material, trimesh.visual.material.PBRMaterial): + texture = np.array(_material.baseColorTexture).astype(np.float32) / 255 + # load metallicRoughness if present + if _material.metallicRoughnessTexture is not None: + metallicRoughness = np.array(_material.metallicRoughnessTexture).astype(np.float32) / 255 + mesh.metallicRoughness = torch.tensor(metallicRoughness, dtype=torch.float32, device=device).contiguous() + elif isinstance(_material, trimesh.visual.material.SimpleMaterial): + texture = np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255 + else: + raise NotImplementedError(f"material type {type(_material)} not supported!") + mesh.albedo = torch.tensor(texture[..., :3], dtype=torch.float32, device=device).contiguous() + print(f"[load_trimesh] load texture: {texture.shape}") + else: + texture = np.ones((1024, 1024, 3), dtype=np.float32) * np.array([0.5, 0.5, 0.5]) + mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device) + print(f"[load_trimesh] failed to load texture.") + + vertices = _mesh.vertices + + try: + texcoords = _mesh.visual.uv + texcoords[:, 1] = 1 - texcoords[:, 1] + except Exception as e: + texcoords = None + + try: + normals = _mesh.vertex_normals + except Exception as e: + normals = None + + # trimesh only support vertex uv... + faces = tfaces = nfaces = _mesh.faces + + mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device) + mesh.vt = ( + torch.tensor(texcoords, dtype=torch.float32, device=device) + if texcoords is not None + else None + ) + mesh.vn = ( + torch.tensor(normals, dtype=torch.float32, device=device) + if normals is not None + else None + ) + + mesh.f = torch.tensor(faces, dtype=torch.int32, device=device) + mesh.ft = ( + torch.tensor(tfaces, dtype=torch.int32, device=device) + if texcoords is not None + else None + ) + mesh.fn = ( + torch.tensor(nfaces, dtype=torch.int32, device=device) + if normals is not None + else None + ) + + return mesh + + # sample surface (using trimesh) + def sample_surface(self, count: int): + """sample points on the surface of the mesh. + + Args: + count (int): number of points to sample. + + Returns: + torch.Tensor: the sampled points, float [count, 3]. + """ + _mesh = trimesh.Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.f.detach().cpu().numpy()) + points, face_idx = trimesh.sample.sample_surface(_mesh, count) + points = torch.from_numpy(points).float().to(self.device) + return points + + # aabb + def aabb(self): + """get the axis-aligned bounding box of the mesh. + + Returns: + Tuple[torch.Tensor]: the min xyz and max xyz of the mesh. + """ + return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values + + # unit size + @torch.no_grad() + def auto_size(self, bound=0.9): + """auto resize the mesh. + + Args: + bound (float, optional): resizing into ``[-bound, bound]^3``. Defaults to 0.9. + """ + vmin, vmax = self.aabb() + self.ori_center = (vmax + vmin) / 2 + self.ori_scale = 2 * bound / torch.max(vmax - vmin).item() + self.v = (self.v - self.ori_center) * self.ori_scale + + def auto_normal(self): + """auto calculate the vertex normals. + """ + i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long() + v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :] + + face_normals = torch.cross(v1 - v0, v2 - v0) + + # Splat face normals to vertices + vn = torch.zeros_like(self.v) + vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) + vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) + + # Normalize, replace zero (degenerated) normals with some default value + vn = torch.where( + dot(vn, vn) > 1e-20, + vn, + torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device), + ) + vn = safe_normalize(vn) + + self.vn = vn + self.fn = self.f + + def auto_uv(self, cache_path=None, vmap=True): + """auto calculate the uv coordinates. + + Args: + cache_path (str, optional): path to save/load the uv cache as a npz file, this can avoid calculating uv every time when loading the same mesh, which is time-consuming. Defaults to None. + vmap (bool, optional): remap vertices based on uv coordinates, so each v correspond to a unique vt (necessary for formats like gltf). + Usually this will duplicate the vertices on the edge of uv atlas. Defaults to True. + """ + # try to load cache + if cache_path is not None: + cache_path = os.path.splitext(cache_path)[0] + "_uv.npz" + if cache_path is not None and os.path.exists(cache_path): + data = np.load(cache_path) + vt_np, ft_np, vmapping = data["vt"], data["ft"], data["vmapping"] + else: + import xatlas + + v_np = self.v.detach().cpu().numpy() + f_np = self.f.detach().int().cpu().numpy() + atlas = xatlas.Atlas() + atlas.add_mesh(v_np, f_np) + chart_options = xatlas.ChartOptions() + # chart_options.max_iterations = 4 + atlas.generate(chart_options=chart_options) + vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] + + # save to cache + if cache_path is not None: + np.savez(cache_path, vt=vt_np, ft=ft_np, vmapping=vmapping) + + vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device) + ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device) + self.vt = vt + self.ft = ft + + if vmap: + vmapping = torch.from_numpy(vmapping.astype(np.int64)).long().to(self.device) + self.align_v_to_vt(vmapping) + + def align_v_to_vt(self, vmapping=None): + """ remap v/f and vn/fn to vt/ft. + + Args: + vmapping (np.ndarray, optional): the mapping relationship from f to ft. Defaults to None. + """ + if vmapping is None: + ft = self.ft.view(-1).long() + f = self.f.view(-1).long() + vmapping = torch.zeros(self.vt.shape[0], dtype=torch.long, device=self.device) + vmapping[ft] = f # scatter, randomly choose one if index is not unique + + self.v = self.v[vmapping] + self.f = self.ft + + if self.vn is not None: + self.vn = self.vn[vmapping] + self.fn = self.ft + + def to(self, device): + """move all tensor attributes to device. + + Args: + device (torch.device): target device. + + Returns: + Mesh: self. + """ + self.device = device + for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo", "vc", "metallicRoughness"]: + tensor = getattr(self, name) + if tensor is not None: + setattr(self, name, tensor.to(device)) + return self + + def write(self, path): + """write the mesh to a path. + + Args: + path (str): path to write, supports ply, obj and glb. + """ + if path.endswith(".ply"): + self.write_ply(path) + elif path.endswith(".obj"): + self.write_obj(path) + elif path.endswith(".glb") or path.endswith(".gltf"): + self.write_glb(path) + else: + raise NotImplementedError(f"format {path} not supported!") + + def write_ply(self, path): + """write the mesh in ply format. Only for geometry! + + Args: + path (str): path to write. + """ + + if self.albedo is not None: + print(f'[WARN] ply format does not support exporting texture, will ignore!') + + v_np = self.v.detach().cpu().numpy() + f_np = self.f.detach().cpu().numpy() + + _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np) + _mesh.export(path) + + + def write_glb(self, path): + """write the mesh in glb/gltf format. + This will create a scene with a single mesh. + + Args: + path (str): path to write. + """ + + # assert self.v.shape[0] == self.vn.shape[0] and self.v.shape[0] == self.vt.shape[0] + if self.vt is not None and self.v.shape[0] != self.vt.shape[0]: + self.align_v_to_vt() + + import pygltflib + + f_np = self.f.detach().cpu().numpy().astype(np.uint32) + f_np_blob = f_np.flatten().tobytes() + + v_np = self.v.detach().cpu().numpy().astype(np.float32) + v_np_blob = v_np.tobytes() + + blob = f_np_blob + v_np_blob + byteOffset = len(blob) + + # base mesh + gltf = pygltflib.GLTF2( + scene=0, + scenes=[pygltflib.Scene(nodes=[0])], + nodes=[pygltflib.Node(mesh=0)], + meshes=[pygltflib.Mesh(primitives=[pygltflib.Primitive( + # indices to accessors (0 is triangles) + attributes=pygltflib.Attributes( + POSITION=1, + ), + indices=0, + )])], + buffers=[ + pygltflib.Buffer(byteLength=len(f_np_blob) + len(v_np_blob)) + ], + # buffer view (based on dtype) + bufferViews=[ + # triangles; as flatten (element) array + pygltflib.BufferView( + buffer=0, + byteLength=len(f_np_blob), + target=pygltflib.ELEMENT_ARRAY_BUFFER, # GL_ELEMENT_ARRAY_BUFFER (34963) + ), + # positions; as vec3 array + pygltflib.BufferView( + buffer=0, + byteOffset=len(f_np_blob), + byteLength=len(v_np_blob), + byteStride=12, # vec3 + target=pygltflib.ARRAY_BUFFER, # GL_ARRAY_BUFFER (34962) + ), + ], + accessors=[ + # 0 = triangles + pygltflib.Accessor( + bufferView=0, + componentType=pygltflib.UNSIGNED_INT, # GL_UNSIGNED_INT (5125) + count=f_np.size, + type=pygltflib.SCALAR, + max=[int(f_np.max())], + min=[int(f_np.min())], + ), + # 1 = positions + pygltflib.Accessor( + bufferView=1, + componentType=pygltflib.FLOAT, # GL_FLOAT (5126) + count=len(v_np), + type=pygltflib.VEC3, + max=v_np.max(axis=0).tolist(), + min=v_np.min(axis=0).tolist(), + ), + ], + ) + + # append texture info + if self.vt is not None: + + vt_np = self.vt.detach().cpu().numpy().astype(np.float32) + vt_np_blob = vt_np.tobytes() + + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + albedo = cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR) + albedo_blob = cv2.imencode('.png', albedo)[1].tobytes() + + # update primitive + gltf.meshes[0].primitives[0].attributes.TEXCOORD_0 = 2 + gltf.meshes[0].primitives[0].material = 0 + + # update materials + gltf.materials.append(pygltflib.Material( + pbrMetallicRoughness=pygltflib.PbrMetallicRoughness( + baseColorTexture=pygltflib.TextureInfo(index=0, texCoord=0), + metallicFactor=0.0, + roughnessFactor=1.0, + ), + alphaMode=pygltflib.OPAQUE, + alphaCutoff=None, + doubleSided=True, + )) + + gltf.textures.append(pygltflib.Texture(sampler=0, source=0)) + gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) + gltf.images.append(pygltflib.Image(bufferView=3, mimeType="image/png")) + + # update buffers + gltf.bufferViews.append( + # index = 2, texcoords; as vec2 array + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(vt_np_blob), + byteStride=8, # vec2 + target=pygltflib.ARRAY_BUFFER, + ) + ) + + gltf.accessors.append( + # 2 = texcoords + pygltflib.Accessor( + bufferView=2, + componentType=pygltflib.FLOAT, + count=len(vt_np), + type=pygltflib.VEC2, + max=vt_np.max(axis=0).tolist(), + min=vt_np.min(axis=0).tolist(), + ) + ) + + blob += vt_np_blob + byteOffset += len(vt_np_blob) + + gltf.bufferViews.append( + # index = 3, albedo texture; as none target + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(albedo_blob), + ) + ) + + blob += albedo_blob + byteOffset += len(albedo_blob) + + gltf.buffers[0].byteLength = byteOffset + + # append metllic roughness + if self.metallicRoughness is not None: + metallicRoughness = self.metallicRoughness.detach().cpu().numpy() + metallicRoughness = (metallicRoughness * 255).astype(np.uint8) + metallicRoughness = cv2.cvtColor(metallicRoughness, cv2.COLOR_RGB2BGR) + metallicRoughness_blob = cv2.imencode('.png', metallicRoughness)[1].tobytes() + + # update texture definition + gltf.materials[0].pbrMetallicRoughness.metallicFactor = 1.0 + gltf.materials[0].pbrMetallicRoughness.roughnessFactor = 1.0 + gltf.materials[0].pbrMetallicRoughness.metallicRoughnessTexture = pygltflib.TextureInfo(index=1, texCoord=0) + + gltf.textures.append(pygltflib.Texture(sampler=1, source=1)) + gltf.samplers.append(pygltflib.Sampler(magFilter=pygltflib.LINEAR, minFilter=pygltflib.LINEAR_MIPMAP_LINEAR, wrapS=pygltflib.REPEAT, wrapT=pygltflib.REPEAT)) + gltf.images.append(pygltflib.Image(bufferView=4, mimeType="image/png")) + + # update buffers + gltf.bufferViews.append( + # index = 4, metallicRoughness texture; as none target + pygltflib.BufferView( + buffer=0, + byteOffset=byteOffset, + byteLength=len(metallicRoughness_blob), + ) + ) + + blob += metallicRoughness_blob + byteOffset += len(metallicRoughness_blob) + + gltf.buffers[0].byteLength = byteOffset + + + # set actual data + gltf.set_binary_blob(blob) + + # glb = b"".join(gltf.save_to_bytes()) + gltf.save(path) + + + def write_obj(self, path): + """write the mesh in obj format. Will also write the texture and mtl files. + + Args: + path (str): path to write. + """ + + mtl_path = path.replace(".obj", ".mtl") + albedo_path = path.replace(".obj", "_albedo.png") + metallic_path = path.replace(".obj", "_metallic.png") + roughness_path = path.replace(".obj", "_roughness.png") + + v_np = self.v.detach().cpu().numpy() + vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None + vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None + f_np = self.f.detach().cpu().numpy() + ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None + fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None + + with open(path, "w") as fp: + fp.write(f"mtllib {os.path.basename(mtl_path)} \n") + + for v in v_np: + fp.write(f"v {v[0]} {v[1]} {v[2]} \n") + + if vt_np is not None: + for v in vt_np: + fp.write(f"vt {v[0]} {1 - v[1]} \n") + + if vn_np is not None: + for v in vn_np: + fp.write(f"vn {v[0]} {v[1]} {v[2]} \n") + + fp.write(f"usemtl defaultMat \n") + for i in range(len(f_np)): + fp.write( + f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \ + {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \ + {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n' + ) + + with open(mtl_path, "w") as fp: + fp.write(f"newmtl defaultMat \n") + fp.write(f"Ka 1 1 1 \n") + fp.write(f"Kd 1 1 1 \n") + fp.write(f"Ks 0 0 0 \n") + fp.write(f"Tr 1 \n") + fp.write(f"illum 1 \n") + fp.write(f"Ns 0 \n") + if self.albedo is not None: + fp.write(f"map_Kd {os.path.basename(albedo_path)} \n") + if self.metallicRoughness is not None: + # ref: https://en.wikipedia.org/wiki/Wavefront_.obj_file#Physically-based_Rendering + fp.write(f"map_Pm {os.path.basename(metallic_path)} \n") + fp.write(f"map_Pr {os.path.basename(roughness_path)} \n") + + if self.albedo is not None: + albedo = self.albedo.detach().cpu().numpy() + albedo = (albedo * 255).astype(np.uint8) + cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR)) + + if self.metallicRoughness is not None: + metallicRoughness = self.metallicRoughness.detach().cpu().numpy() + metallicRoughness = (metallicRoughness * 255).astype(np.uint8) + cv2.imwrite(metallic_path, metallicRoughness[..., 2]) + cv2.imwrite(roughness_path, metallicRoughness[..., 1]) + diff --git a/mesh_recon/models/__init__.py b/mesh_recon/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2fa8c5785d3371ad963aca212ae0da32644c175 --- /dev/null +++ b/mesh_recon/models/__init__.py @@ -0,0 +1,16 @@ +models = {} + + +def register(name): + def decorator(cls): + models[name] = cls + return cls + return decorator + + +def make(name, config): + model = models[name](config) + return model + + +from . import nerf, neus, geometry, texture diff --git a/mesh_recon/models/base.py b/mesh_recon/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..47b853bc9502ff2581639b5a6bc7313ffe0ec9ec --- /dev/null +++ b/mesh_recon/models/base.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +from utils.misc import get_rank + +class BaseModel(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.rank = get_rank() + self.setup() + if self.config.get('weights', None): + self.load_state_dict(torch.load(self.config.weights)) + + def setup(self): + raise NotImplementedError + + def update_step(self, epoch, global_step): + pass + + def train(self, mode=True): + return super().train(mode=mode) + + def eval(self): + return super().eval() + + def regularizations(self, out): + return {} + + @torch.no_grad() + def export(self, export_config): + return {} diff --git a/mesh_recon/models/geometry.py b/mesh_recon/models/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..861edbe2726bb19e7c705c419837f369de170f28 --- /dev/null +++ b/mesh_recon/models/geometry.py @@ -0,0 +1,238 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pytorch_lightning.utilities.rank_zero import rank_zero_info + +import models +from models.base import BaseModel +from models.utils import scale_anything, get_activation, cleanup, chunk_batch +from models.network_utils import get_encoding, get_mlp, get_encoding_with_network +from utils.misc import get_rank +from systems.utils import update_module_step +from nerfacc import ContractionType + + +def contract_to_unisphere(x, radius, contraction_type): + if contraction_type == ContractionType.AABB: + x = scale_anything(x, (-radius, radius), (0, 1)) + elif contraction_type == ContractionType.UN_BOUNDED_SPHERE: + x = scale_anything(x, (-radius, radius), (0, 1)) + x = x * 2 - 1 # aabb is at [-1, 1] + mag = x.norm(dim=-1, keepdim=True) + mask = mag.squeeze(-1) > 1 + x[mask] = (2 - 1 / mag[mask]) * (x[mask] / mag[mask]) + x = x / 4 + 0.5 # [-inf, inf] is at [0, 1] + else: + raise NotImplementedError + return x + + +class MarchingCubeHelper(nn.Module): + def __init__(self, resolution, use_torch=True): + super().__init__() + self.resolution = resolution + self.use_torch = use_torch + self.points_range = (0, 1) + if self.use_torch: + import torchmcubes + self.mc_func = torchmcubes.marching_cubes + else: + import mcubes + self.mc_func = mcubes.marching_cubes + self.verts = None + + def grid_vertices(self): + if self.verts is None: + x, y, z = torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution), torch.linspace(*self.points_range, self.resolution) + x, y, z = torch.meshgrid(x, y, z, indexing='ij') + verts = torch.cat([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], dim=-1).reshape(-1, 3) + self.verts = verts + return self.verts + + def forward(self, level, threshold=0.): + level = level.float().view(self.resolution, self.resolution, self.resolution) + if self.use_torch: + verts, faces = self.mc_func(level.to(get_rank()), threshold) + verts, faces = verts.cpu(), faces.cpu().long() + else: + verts, faces = self.mc_func(-level.numpy(), threshold) # transform to numpy + verts, faces = torch.from_numpy(verts.astype(np.float32)), torch.from_numpy(faces.astype(np.int64)) # transform back to pytorch + verts = verts / (self.resolution - 1.) + return { + 'v_pos': verts, + 't_pos_idx': faces + } + + +class BaseImplicitGeometry(BaseModel): + def __init__(self, config): + super().__init__(config) + if self.config.isosurface is not None: + assert self.config.isosurface.method in ['mc', 'mc-torch'] + if self.config.isosurface.method == 'mc-torch': + raise NotImplementedError("Please do not use mc-torch. It currently has some scaling issues I haven't fixed yet.") + self.helper = MarchingCubeHelper(self.config.isosurface.resolution, use_torch=self.config.isosurface.method=='mc-torch') + self.radius = self.config.radius + self.contraction_type = None # assigned in system + + def forward_level(self, points): + raise NotImplementedError + + def isosurface_(self, vmin, vmax): + def batch_func(x): + x = torch.stack([ + scale_anything(x[...,0], (0, 1), (vmin[0], vmax[0])), + scale_anything(x[...,1], (0, 1), (vmin[1], vmax[1])), + scale_anything(x[...,2], (0, 1), (vmin[2], vmax[2])), + ], dim=-1).to(self.rank) + rv = self.forward_level(x).cpu() + cleanup() + return rv + + level = chunk_batch(batch_func, self.config.isosurface.chunk, True, self.helper.grid_vertices()) + mesh = self.helper(level, threshold=self.config.isosurface.threshold) + mesh['v_pos'] = torch.stack([ + scale_anything(mesh['v_pos'][...,0], (0, 1), (vmin[0], vmax[0])), + scale_anything(mesh['v_pos'][...,1], (0, 1), (vmin[1], vmax[1])), + scale_anything(mesh['v_pos'][...,2], (0, 1), (vmin[2], vmax[2])) + ], dim=-1) + return mesh + + @torch.no_grad() + def isosurface(self): + if self.config.isosurface is None: + raise NotImplementedError + mesh_coarse = self.isosurface_((-self.radius, -self.radius, -self.radius), (self.radius, self.radius, self.radius)) + vmin, vmax = mesh_coarse['v_pos'].amin(dim=0), mesh_coarse['v_pos'].amax(dim=0) + vmin_ = (vmin - (vmax - vmin) * 0.1).clamp(-self.radius, self.radius) + vmax_ = (vmax + (vmax - vmin) * 0.1).clamp(-self.radius, self.radius) + mesh_fine = self.isosurface_(vmin_, vmax_) + return mesh_fine + + +@models.register('volume-density') +class VolumeDensity(BaseImplicitGeometry): + def setup(self): + self.n_input_dims = self.config.get('n_input_dims', 3) + self.n_output_dims = self.config.feature_dim + self.encoding_with_network = get_encoding_with_network(self.n_input_dims, self.n_output_dims, self.config.xyz_encoding_config, self.config.mlp_network_config) + + def forward(self, points): + points = contract_to_unisphere(points, self.radius, self.contraction_type) + out = self.encoding_with_network(points.view(-1, self.n_input_dims)).view(*points.shape[:-1], self.n_output_dims).float() + density, feature = out[...,0], out + if 'density_activation' in self.config: + density = get_activation(self.config.density_activation)(density + float(self.config.density_bias)) + if 'feature_activation' in self.config: + feature = get_activation(self.config.feature_activation)(feature) + return density, feature + + def forward_level(self, points): + points = contract_to_unisphere(points, self.radius, self.contraction_type) + density = self.encoding_with_network(points.reshape(-1, self.n_input_dims)).reshape(*points.shape[:-1], self.n_output_dims)[...,0] + if 'density_activation' in self.config: + density = get_activation(self.config.density_activation)(density + float(self.config.density_bias)) + return -density + + def update_step(self, epoch, global_step): + update_module_step(self.encoding_with_network, epoch, global_step) + + +@models.register('volume-sdf') +class VolumeSDF(BaseImplicitGeometry): + def setup(self): + self.n_output_dims = self.config.feature_dim + encoding = get_encoding(3, self.config.xyz_encoding_config) + network = get_mlp(encoding.n_output_dims, self.n_output_dims, self.config.mlp_network_config) + self.encoding, self.network = encoding, network + self.grad_type = self.config.grad_type + self.finite_difference_eps = self.config.get('finite_difference_eps', 1e-3) + # the actual value used in training + # will update at certain steps if finite_difference_eps="progressive" + self._finite_difference_eps = None + if self.grad_type == 'finite_difference': + rank_zero_info(f"Using finite difference to compute gradients with eps={self.finite_difference_eps}") + + def forward(self, points, with_grad=True, with_feature=True, with_laplace=False): + with torch.inference_mode(torch.is_inference_mode_enabled() and not (with_grad and self.grad_type == 'analytic')): + with torch.set_grad_enabled(self.training or (with_grad and self.grad_type == 'analytic')): + if with_grad and self.grad_type == 'analytic': + if not self.training: + points = points.clone() # points may be in inference mode, get a copy to enable grad + points.requires_grad_(True) + + points_ = points # points in the original scale + points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1) + + out = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims).float() + sdf, feature = out[...,0], out + if 'sdf_activation' in self.config: + sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias)) + if 'feature_activation' in self.config: + feature = get_activation(self.config.feature_activation)(feature) + if with_grad: + if self.grad_type == 'analytic': + grad = torch.autograd.grad( + sdf, points_, grad_outputs=torch.ones_like(sdf), + create_graph=True, retain_graph=True, only_inputs=True + )[0] + elif self.grad_type == 'finite_difference': + eps = self._finite_difference_eps + offsets = torch.as_tensor( + [ + [eps, 0.0, 0.0], + [-eps, 0.0, 0.0], + [0.0, eps, 0.0], + [0.0, -eps, 0.0], + [0.0, 0.0, eps], + [0.0, 0.0, -eps], + ] + ).to(points_) + points_d_ = (points_[...,None,:] + offsets).clamp(-self.radius, self.radius) + points_d = scale_anything(points_d_, (-self.radius, self.radius), (0, 1)) + points_d_sdf = self.network(self.encoding(points_d.view(-1, 3)))[...,0].view(*points.shape[:-1], 6).float() + grad = 0.5 * (points_d_sdf[..., 0::2] - points_d_sdf[..., 1::2]) / eps + + if with_laplace: + laplace = (points_d_sdf[..., 0::2] + points_d_sdf[..., 1::2] - 2 * sdf[..., None]).sum(-1) / (eps ** 2) + + rv = [sdf] + if with_grad: + rv.append(grad) + if with_feature: + rv.append(feature) + if with_laplace: + assert self.config.grad_type == 'finite_difference', "Laplace computation is only supported with grad_type='finite_difference'" + rv.append(laplace) + rv = [v if self.training else v.detach() for v in rv] + return rv[0] if len(rv) == 1 else rv + + def forward_level(self, points): + points = contract_to_unisphere(points, self.radius, self.contraction_type) # points normalized to (0, 1) + sdf = self.network(self.encoding(points.view(-1, 3))).view(*points.shape[:-1], self.n_output_dims)[...,0] + if 'sdf_activation' in self.config: + sdf = get_activation(self.config.sdf_activation)(sdf + float(self.config.sdf_bias)) + return sdf + + def update_step(self, epoch, global_step): + update_module_step(self.encoding, epoch, global_step) + update_module_step(self.network, epoch, global_step) + if self.grad_type == 'finite_difference': + if isinstance(self.finite_difference_eps, float): + self._finite_difference_eps = self.finite_difference_eps + elif self.finite_difference_eps == 'progressive': + hg_conf = self.config.xyz_encoding_config + assert hg_conf.otype == "ProgressiveBandHashGrid", "finite_difference_eps='progressive' only works with ProgressiveBandHashGrid" + current_level = min( + hg_conf.start_level + max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps, + hg_conf.n_levels + ) + grid_res = hg_conf.base_resolution * hg_conf.per_level_scale**(current_level - 1) + grid_size = 2 * self.config.radius / grid_res + if grid_size != self._finite_difference_eps: + rank_zero_info(f"Update finite_difference_eps to {grid_size}") + self._finite_difference_eps = grid_size + else: + raise ValueError(f"Unknown finite_difference_eps={self.finite_difference_eps}") diff --git a/mesh_recon/models/nerf.py b/mesh_recon/models/nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..64ce5f9b839828eb02292faa4828108694f5f6d1 --- /dev/null +++ b/mesh_recon/models/nerf.py @@ -0,0 +1,161 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import models +from models.base import BaseModel +from models.utils import chunk_batch +from systems.utils import update_module_step +from nerfacc import ContractionType, OccupancyGrid, ray_marching, render_weight_from_density, accumulate_along_rays + + +@models.register('nerf') +class NeRFModel(BaseModel): + def setup(self): + self.geometry = models.make(self.config.geometry.name, self.config.geometry) + self.texture = models.make(self.config.texture.name, self.config.texture) + self.register_buffer('scene_aabb', torch.as_tensor([-self.config.radius, -self.config.radius, -self.config.radius, self.config.radius, self.config.radius, self.config.radius], dtype=torch.float32)) + + if self.config.learned_background: + self.occupancy_grid_res = 256 + self.near_plane, self.far_plane = 0.2, 1e4 + self.cone_angle = 10**(math.log10(self.far_plane) / self.config.num_samples_per_ray) - 1. # approximate + self.render_step_size = 0.01 # render_step_size = max(distance_to_camera * self.cone_angle, self.render_step_size) + self.contraction_type = ContractionType.UN_BOUNDED_SPHERE + else: + self.occupancy_grid_res = 128 + self.near_plane, self.far_plane = None, None + self.cone_angle = 0.0 + self.render_step_size = 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray + self.contraction_type = ContractionType.AABB + + self.geometry.contraction_type = self.contraction_type + + if self.config.grid_prune: + self.occupancy_grid = OccupancyGrid( + roi_aabb=self.scene_aabb, + resolution=self.occupancy_grid_res, + contraction_type=self.contraction_type + ) + self.randomized = self.config.randomized + self.background_color = None + + def update_step(self, epoch, global_step): + update_module_step(self.geometry, epoch, global_step) + update_module_step(self.texture, epoch, global_step) + + def occ_eval_fn(x): + density, _ = self.geometry(x) + # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size) based on taylor series + return density[...,None] * self.render_step_size + + if self.training and self.config.grid_prune: + self.occupancy_grid.every_n_step(step=global_step, occ_eval_fn=occ_eval_fn) + + def isosurface(self): + mesh = self.geometry.isosurface() + return mesh + + def forward_(self, rays): + n_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + + def sigma_fn(t_starts, t_ends, ray_indices): + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends) / 2. + density, _ = self.geometry(positions) + return density[...,None] + + def rgb_sigma_fn(t_starts, t_ends, ray_indices): + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends) / 2. + density, feature = self.geometry(positions) + rgb = self.texture(feature, t_dirs) + return rgb, density[...,None] + + with torch.no_grad(): + ray_indices, t_starts, t_ends = ray_marching( + rays_o, rays_d, + scene_aabb=None if self.config.learned_background else self.scene_aabb, + grid=self.occupancy_grid if self.config.grid_prune else None, + sigma_fn=sigma_fn, + near_plane=self.near_plane, far_plane=self.far_plane, + render_step_size=self.render_step_size, + stratified=self.randomized, + cone_angle=self.cone_angle, + alpha_thre=0.0 + ) + + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + midpoints = (t_starts + t_ends) / 2. + positions = t_origins + t_dirs * midpoints + intervals = t_ends - t_starts + + density, feature = self.geometry(positions) + rgb = self.texture(feature, t_dirs) + + weights = render_weight_from_density(t_starts, t_ends, density[...,None], ray_indices=ray_indices, n_rays=n_rays) + opacity = accumulate_along_rays(weights, ray_indices, values=None, n_rays=n_rays) + depth = accumulate_along_rays(weights, ray_indices, values=midpoints, n_rays=n_rays) + comp_rgb = accumulate_along_rays(weights, ray_indices, values=rgb, n_rays=n_rays) + comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) + + out = { + 'comp_rgb': comp_rgb, + 'opacity': opacity, + 'depth': depth, + 'rays_valid': opacity > 0, + 'num_samples': torch.as_tensor([len(t_starts)], dtype=torch.int32, device=rays.device) + } + + if self.training: + out.update({ + 'weights': weights.view(-1), + 'points': midpoints.view(-1), + 'intervals': intervals.view(-1), + 'ray_indices': ray_indices.view(-1) + }) + + return out + + def forward(self, rays): + if self.training: + out = self.forward_(rays) + else: + out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) + return { + **out, + } + + def train(self, mode=True): + self.randomized = mode and self.config.randomized + return super().train(mode=mode) + + def eval(self): + self.randomized = False + return super().eval() + + def regularizations(self, out): + losses = {} + losses.update(self.geometry.regularizations(out)) + losses.update(self.texture.regularizations(out)) + return losses + + @torch.no_grad() + def export(self, export_config): + mesh = self.isosurface() + if export_config.export_vertex_color: + _, feature = chunk_batch(self.geometry, export_config.chunk_size, False, mesh['v_pos'].to(self.rank)) + viewdirs = torch.zeros(feature.shape[0], 3).to(feature) + viewdirs[...,2] = -1. # set the viewing directions to be -z (looking down) + rgb = self.texture(feature, viewdirs).clamp(0,1) + mesh['v_rgb'] = rgb.cpu() + return mesh diff --git a/mesh_recon/models/network_utils.py b/mesh_recon/models/network_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1c4ab64487b68118e62cbc834dc2f1ff908ad7 --- /dev/null +++ b/mesh_recon/models/network_utils.py @@ -0,0 +1,215 @@ +import math +import numpy as np + +import torch +import torch.nn as nn +import tinycudann as tcnn + +from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_info + +from utils.misc import config_to_primitive, get_rank +from models.utils import get_activation +from systems.utils import update_module_step + +class VanillaFrequency(nn.Module): + def __init__(self, in_channels, config): + super().__init__() + self.N_freqs = config['n_frequencies'] + self.in_channels, self.n_input_dims = in_channels, in_channels + self.funcs = [torch.sin, torch.cos] + self.freq_bands = 2**torch.linspace(0, self.N_freqs-1, self.N_freqs) + self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs) + self.n_masking_step = config.get('n_masking_step', 0) + self.update_step(None, None) # mask should be updated at the beginning each step + + def forward(self, x): + out = [] + for freq, mask in zip(self.freq_bands, self.mask): + for func in self.funcs: + out += [func(freq*x) * mask] + return torch.cat(out, -1) + + def update_step(self, epoch, global_step): + if self.n_masking_step <= 0 or global_step is None: + self.mask = torch.ones(self.N_freqs, dtype=torch.float32) + else: + self.mask = (1. - torch.cos(math.pi * (global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs)).clamp(0, 1))) / 2. + rank_zero_debug(f'Update mask: {global_step}/{self.n_masking_step} {self.mask}') + + +class ProgressiveBandHashGrid(nn.Module): + def __init__(self, in_channels, config): + super().__init__() + self.n_input_dims = in_channels + encoding_config = config.copy() + encoding_config['otype'] = 'HashGrid' + with torch.cuda.device(get_rank()): + self.encoding = tcnn.Encoding(in_channels, encoding_config) + self.n_output_dims = self.encoding.n_output_dims + self.n_level = config['n_levels'] + self.n_features_per_level = config['n_features_per_level'] + self.start_level, self.start_step, self.update_steps = config['start_level'], config['start_step'], config['update_steps'] + self.current_level = self.start_level + self.mask = torch.zeros(self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank()) + + def forward(self, x): + enc = self.encoding(x) + enc = enc * self.mask + return enc + + def update_step(self, epoch, global_step): + current_level = min(self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level) + if current_level > self.current_level: + rank_zero_info(f'Update grid level to {current_level}') + self.current_level = current_level + self.mask[:self.current_level * self.n_features_per_level] = 1. + + +class CompositeEncoding(nn.Module): + def __init__(self, encoding, include_xyz=False, xyz_scale=1., xyz_offset=0.): + super(CompositeEncoding, self).__init__() + self.encoding = encoding + self.include_xyz, self.xyz_scale, self.xyz_offset = include_xyz, xyz_scale, xyz_offset + self.n_output_dims = int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims + + def forward(self, x, *args): + return self.encoding(x, *args) if not self.include_xyz else torch.cat([x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1) + + def update_step(self, epoch, global_step): + update_module_step(self.encoding, epoch, global_step) + + +def get_encoding(n_input_dims, config): + # input suppose to be range [0, 1] + if config.otype == 'VanillaFrequency': + encoding = VanillaFrequency(n_input_dims, config_to_primitive(config)) + elif config.otype == 'ProgressiveBandHashGrid': + encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config)) + else: + with torch.cuda.device(get_rank()): + encoding = tcnn.Encoding(n_input_dims, config_to_primitive(config)) + encoding = CompositeEncoding(encoding, include_xyz=config.get('include_xyz', False), xyz_scale=2., xyz_offset=-1.) + return encoding + + +class VanillaMLP(nn.Module): + def __init__(self, dim_in, dim_out, config): + super().__init__() + self.n_neurons, self.n_hidden_layers = config['n_neurons'], config['n_hidden_layers'] + self.sphere_init, self.weight_norm = config.get('sphere_init', False), config.get('weight_norm', False) + self.sphere_init_radius = config.get('sphere_init_radius', 0.5) + self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] + for i in range(self.n_hidden_layers - 1): + self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), self.make_activation()] + self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] + self.layers = nn.Sequential(*self.layers) + self.output_activation = get_activation(config['output_activation']) + + @torch.cuda.amp.autocast(False) + def forward(self, x): + x = self.layers(x.float()) + x = self.output_activation(x) + return x + + def make_linear(self, dim_in, dim_out, is_first, is_last): + layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality + if self.sphere_init: + if is_last: + torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) + torch.nn.init.normal_(layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001) + elif is_first: + torch.nn.init.constant_(layer.bias, 0.0) + torch.nn.init.constant_(layer.weight[:, 3:], 0.0) + torch.nn.init.normal_(layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out)) + else: + torch.nn.init.constant_(layer.bias, 0.0) + torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) + else: + torch.nn.init.constant_(layer.bias, 0.0) + torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') + + if self.weight_norm: + layer = nn.utils.weight_norm(layer) + return layer + + def make_activation(self): + if self.sphere_init: + return nn.Softplus(beta=100) + else: + return nn.ReLU(inplace=True) + + +def sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network): + rank_zero_debug('Initialize tcnn MLP to approximately represent a sphere.') + """ + from https://github.com/NVlabs/tiny-cuda-nn/issues/96 + It's the weight matrices of each layer laid out in row-major order and then concatenated. + Notably: inputs and output dimensions are padded to multiples of 8 (CutlassMLP) or 16 (FullyFusedMLP). + The padded input dimensions get a constant value of 1.0, + whereas the padded output dimensions are simply ignored, + so the weights pertaining to those can have any value. + """ + padto = 16 if config.otype == 'FullyFusedMLP' else 8 + n_input_dims = n_input_dims + (padto - n_input_dims % padto) % padto + n_output_dims = n_output_dims + (padto - n_output_dims % padto) % padto + data = list(network.parameters())[0].data + assert data.shape[0] == (n_input_dims + n_output_dims) * config.n_neurons + (config.n_hidden_layers - 1) * config.n_neurons**2 + new_data = [] + # first layer + weight = torch.zeros((config.n_neurons, n_input_dims)).to(data) + torch.nn.init.constant_(weight[:, 3:], 0.0) + torch.nn.init.normal_(weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) + new_data.append(weight.flatten()) + # hidden layers + for i in range(config.n_hidden_layers - 1): + weight = torch.zeros((config.n_neurons, config.n_neurons)).to(data) + torch.nn.init.normal_(weight, 0.0, math.sqrt(2) / math.sqrt(config.n_neurons)) + new_data.append(weight.flatten()) + # last layer + weight = torch.zeros((n_output_dims, config.n_neurons)).to(data) + torch.nn.init.normal_(weight, mean=math.sqrt(math.pi) / math.sqrt(config.n_neurons), std=0.0001) + new_data.append(weight.flatten()) + new_data = torch.cat(new_data) + data.copy_(new_data) + + +def get_mlp(n_input_dims, n_output_dims, config): + if config.otype == 'VanillaMLP': + network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config)) + else: + with torch.cuda.device(get_rank()): + network = tcnn.Network(n_input_dims, n_output_dims, config_to_primitive(config)) + if config.get('sphere_init', False): + sphere_init_tcnn_network(n_input_dims, n_output_dims, config, network) + return network + + +class EncodingWithNetwork(nn.Module): + def __init__(self, encoding, network): + super().__init__() + self.encoding, self.network = encoding, network + + def forward(self, x): + return self.network(self.encoding(x)) + + def update_step(self, epoch, global_step): + update_module_step(self.encoding, epoch, global_step) + update_module_step(self.network, epoch, global_step) + + +def get_encoding_with_network(n_input_dims, n_output_dims, encoding_config, network_config): + # input suppose to be range [0, 1] + if encoding_config.otype in ['VanillaFrequency', 'ProgressiveBandHashGrid'] \ + or network_config.otype in ['VanillaMLP']: + encoding = get_encoding(n_input_dims, encoding_config) + network = get_mlp(encoding.n_output_dims, n_output_dims, network_config) + encoding_with_network = EncodingWithNetwork(encoding, network) + else: + with torch.cuda.device(get_rank()): + encoding_with_network = tcnn.NetworkWithInputEncoding( + n_input_dims=n_input_dims, + n_output_dims=n_output_dims, + encoding_config=config_to_primitive(encoding_config), + network_config=config_to_primitive(network_config) + ) + return encoding_with_network diff --git a/mesh_recon/models/neus.py b/mesh_recon/models/neus.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9e9c781efd4dd49b6edc1e0f48a25f63c0c01d --- /dev/null +++ b/mesh_recon/models/neus.py @@ -0,0 +1,441 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import models +from models.base import BaseModel +from models.utils import chunk_batch +from systems.utils import update_module_step +from nerfacc import ( + ContractionType, + OccupancyGrid, + ray_marching, + render_weight_from_density, + render_weight_from_alpha, + accumulate_along_rays, +) +from nerfacc.intersection import ray_aabb_intersect + +import pdb + + +class VarianceNetwork(nn.Module): + def __init__(self, config): + super(VarianceNetwork, self).__init__() + self.config = config + self.init_val = self.config.init_val + self.register_parameter( + "variance", nn.Parameter(torch.tensor(self.config.init_val)) + ) + self.modulate = self.config.get("modulate", False) + if self.modulate: + self.mod_start_steps = self.config.mod_start_steps + self.reach_max_steps = self.config.reach_max_steps + self.max_inv_s = self.config.max_inv_s + + @property + def inv_s(self): + val = torch.exp(self.variance * 10.0) + if self.modulate and self.do_mod: + val = val.clamp_max(self.mod_val) + return val + + def forward(self, x): + return torch.ones([len(x), 1], device=self.variance.device) * self.inv_s + + def update_step(self, epoch, global_step): + if self.modulate: + self.do_mod = global_step > self.mod_start_steps + if not self.do_mod: + self.prev_inv_s = self.inv_s.item() + else: + self.mod_val = min( + (global_step / self.reach_max_steps) + * (self.max_inv_s - self.prev_inv_s) + + self.prev_inv_s, + self.max_inv_s, + ) + + +@models.register("neus") +class NeuSModel(BaseModel): + def setup(self): + self.geometry = models.make(self.config.geometry.name, self.config.geometry) + self.texture = models.make(self.config.texture.name, self.config.texture) + self.geometry.contraction_type = ContractionType.AABB + + if self.config.learned_background: + self.geometry_bg = models.make( + self.config.geometry_bg.name, self.config.geometry_bg + ) + self.texture_bg = models.make( + self.config.texture_bg.name, self.config.texture_bg + ) + self.geometry_bg.contraction_type = ContractionType.UN_BOUNDED_SPHERE + self.near_plane_bg, self.far_plane_bg = 0.1, 1e3 + self.cone_angle_bg = ( + 10 + ** (math.log10(self.far_plane_bg) / self.config.num_samples_per_ray_bg) + - 1.0 + ) + self.render_step_size_bg = 0.01 + + self.variance = VarianceNetwork(self.config.variance) + self.register_buffer( + "scene_aabb", + torch.as_tensor( + [ + -self.config.radius, + -self.config.radius, + -self.config.radius, + self.config.radius, + self.config.radius, + self.config.radius, + ], + dtype=torch.float32, + ), + ) + if self.config.grid_prune: + self.occupancy_grid = OccupancyGrid( + roi_aabb=self.scene_aabb, + resolution=128, + contraction_type=ContractionType.AABB, + ) + if self.config.learned_background: + self.occupancy_grid_bg = OccupancyGrid( + roi_aabb=self.scene_aabb, + resolution=256, + contraction_type=ContractionType.UN_BOUNDED_SPHERE, + ) + self.randomized = self.config.randomized + self.background_color = None + self.render_step_size = ( + 1.732 * 2 * self.config.radius / self.config.num_samples_per_ray + ) + + def update_step(self, epoch, global_step): + update_module_step(self.geometry, epoch, global_step) + update_module_step(self.texture, epoch, global_step) + if self.config.learned_background: + update_module_step(self.geometry_bg, epoch, global_step) + update_module_step(self.texture_bg, epoch, global_step) + update_module_step(self.variance, epoch, global_step) + + cos_anneal_end = self.config.get("cos_anneal_end", 0) + self.cos_anneal_ratio = ( + 1.0 if cos_anneal_end == 0 else min(1.0, global_step / cos_anneal_end) + ) + + def occ_eval_fn(x): + sdf = self.geometry(x, with_grad=False, with_feature=False) + inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip(1e-6, 1e6) + inv_s = inv_s.expand(sdf.shape[0], 1) + estimated_next_sdf = sdf[..., None] - self.render_step_size * 0.5 + estimated_prev_sdf = sdf[..., None] + self.render_step_size * 0.5 + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) + p = prev_cdf - next_cdf + c = prev_cdf + alpha = ((p + 1e-5) / (c + 1e-5)).view(-1, 1).clip(0.0, 1.0) + return alpha + + def occ_eval_fn_bg(x): + density, _ = self.geometry_bg(x) + # approximate for 1 - torch.exp(-density[...,None] * self.render_step_size_bg) based on taylor series + return density[..., None] * self.render_step_size_bg + + if self.training and self.config.grid_prune: + self.occupancy_grid.every_n_step( + step=global_step, + occ_eval_fn=occ_eval_fn, + occ_thre=self.config.get("grid_prune_occ_thre", 0.01), + ) + if self.config.learned_background: + self.occupancy_grid_bg.every_n_step( + step=global_step, + occ_eval_fn=occ_eval_fn_bg, + occ_thre=self.config.get("grid_prune_occ_thre_bg", 0.01), + ) + + def isosurface(self): + mesh = self.geometry.isosurface() + return mesh + + def get_alpha(self, sdf, normal, dirs, dists): + inv_s = self.variance(torch.zeros([1, 3]))[:, :1].clip( + 1e-6, 1e6 + ) # Single parameter + inv_s = inv_s.expand(sdf.shape[0], 1) + + true_cos = (dirs * normal).sum(-1, keepdim=True) + + # "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes + # the cos value "not dead" at the beginning training iterations, for better convergence. + iter_cos = -( + F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) + + F.relu(-true_cos) * self.cos_anneal_ratio + ) # always non-positive + + # Estimate signed distances at section points + estimated_next_sdf = sdf[..., None] + iter_cos * dists.reshape(-1, 1) * 0.5 + estimated_prev_sdf = sdf[..., None] - iter_cos * dists.reshape(-1, 1) * 0.5 + + prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_s) + next_cdf = torch.sigmoid(estimated_next_sdf * inv_s) + + p = prev_cdf - next_cdf + c = prev_cdf + + alpha = ((p + 1e-5) / (c + 1e-5)).view(-1).clip(0.0, 1.0) + return alpha + + def forward_bg_(self, rays): + n_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + + def sigma_fn(t_starts, t_ends, ray_indices): + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0 + density, _ = self.geometry_bg(positions) + return density[..., None] + + _, t_max = ray_aabb_intersect(rays_o, rays_d, self.scene_aabb) + # if the ray intersects with the bounding box, start from the farther intersection point + # otherwise start from self.far_plane_bg + # note that in nerfacc t_max is set to 1e10 if there is no intersection + near_plane = torch.where(t_max > 1e9, self.near_plane_bg, t_max) + with torch.no_grad(): + ray_indices, t_starts, t_ends = ray_marching( + rays_o, + rays_d, + scene_aabb=None, + grid=self.occupancy_grid_bg if self.config.grid_prune else None, + sigma_fn=sigma_fn, + near_plane=near_plane, + far_plane=self.far_plane_bg, + render_step_size=self.render_step_size_bg, + stratified=self.randomized, + cone_angle=self.cone_angle_bg, + alpha_thre=0.0, + ) + + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + midpoints = (t_starts + t_ends) / 2.0 + positions = t_origins + t_dirs * midpoints + intervals = t_ends - t_starts + + density, feature = self.geometry_bg(positions) + rgb = self.texture_bg(feature, t_dirs) + + weights = render_weight_from_density( + t_starts, t_ends, density[..., None], ray_indices=ray_indices, n_rays=n_rays + ) + opacity = accumulate_along_rays( + weights, ray_indices, values=None, n_rays=n_rays + ) + depth = accumulate_along_rays( + weights, ray_indices, values=midpoints, n_rays=n_rays + ) + comp_rgb = accumulate_along_rays( + weights, ray_indices, values=rgb, n_rays=n_rays + ) + comp_rgb = comp_rgb + self.background_color * (1.0 - opacity) + + out = { + "comp_rgb": comp_rgb, + "opacity": opacity, + "depth": depth, + "rays_valid": opacity > 0, + "num_samples": torch.as_tensor( + [len(t_starts)], dtype=torch.int32, device=rays.device + ), + } + + if self.training: + out.update( + { + "weights": weights.view(-1), + "points": midpoints.view(-1), + "intervals": intervals.view(-1), + "ray_indices": ray_indices.view(-1), + } + ) + + return out + + def forward_(self, rays): + n_rays = rays.shape[0] + rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3) + + with torch.no_grad(): + ray_indices, t_starts, t_ends = ray_marching( + rays_o, + rays_d, + scene_aabb=self.scene_aabb, + grid=self.occupancy_grid if self.config.grid_prune else None, + alpha_fn=None, + near_plane=None, + far_plane=None, + render_step_size=self.render_step_size, + stratified=self.randomized, + cone_angle=0.0, + alpha_thre=0.0, + ) + + ray_indices = ray_indices.long() + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + midpoints = (t_starts + t_ends) / 2.0 + positions = t_origins + t_dirs * midpoints + dists = t_ends - t_starts + + if self.config.geometry.grad_type == "finite_difference": + sdf, sdf_grad, feature, sdf_laplace = self.geometry( + positions, with_grad=True, with_feature=True, with_laplace=True + ) + else: + sdf, sdf_grad, feature = self.geometry( + positions, with_grad=True, with_feature=True + ) + + normal = F.normalize(sdf_grad, p=2, dim=-1) + alpha = self.get_alpha(sdf, normal, t_dirs, dists)[..., None] + rgb = self.texture(feature, t_dirs, normal) + + weights = render_weight_from_alpha( + alpha, ray_indices=ray_indices, n_rays=n_rays + ) + opacity = accumulate_along_rays( + weights, ray_indices, values=None, n_rays=n_rays + ) + depth = accumulate_along_rays( + weights, ray_indices, values=midpoints, n_rays=n_rays + ) + comp_rgb = accumulate_along_rays( + weights, ray_indices, values=rgb, n_rays=n_rays + ) + + comp_normal = accumulate_along_rays( + weights, ray_indices, values=normal, n_rays=n_rays + ) + comp_normal = F.normalize(comp_normal, p=2, dim=-1) + + pts_random = ( + torch.rand([1024 * 2, 3]).to(sdf.dtype).to(sdf.device) * 2 - 1 + ) # normalized to (-1, 1) + + if self.config.geometry.grad_type == "finite_difference": + random_sdf, random_sdf_grad, _ = self.geometry( + pts_random, with_grad=True, with_feature=False, with_laplace=True + ) + _, normal_perturb, _ = self.geometry( + pts_random + torch.randn_like(pts_random) * 1e-2, + with_grad=True, + with_feature=False, + with_laplace=True, + ) + else: + random_sdf, random_sdf_grad = self.geometry( + pts_random, with_grad=True, with_feature=False + ) + _, normal_perturb = self.geometry( + positions + torch.randn_like(positions) * 1e-2, + with_grad=True, + with_feature=False, + ) + + # pdb.set_trace() + out = { + "comp_rgb": comp_rgb, + "comp_normal": comp_normal, + "opacity": opacity, + "depth": depth, + "rays_valid": opacity > 0, + "num_samples": torch.as_tensor( + [len(t_starts)], dtype=torch.int32, device=rays.device + ), + } + + if self.training: + out.update( + { + "sdf_samples": sdf, + "sdf_grad_samples": sdf_grad, + "random_sdf": random_sdf, + "random_sdf_grad": random_sdf_grad, + "normal_perturb": normal_perturb, + "weights": weights.view(-1), + "points": midpoints.view(-1), + "intervals": dists.view(-1), + "ray_indices": ray_indices.view(-1), + } + ) + if self.config.geometry.grad_type == "finite_difference": + out.update({"sdf_laplace_samples": sdf_laplace}) + + if self.config.learned_background: + out_bg = self.forward_bg_(rays) + else: + out_bg = { + "comp_rgb": self.background_color[None, :].expand(*comp_rgb.shape), + "num_samples": torch.zeros_like(out["num_samples"]), + "rays_valid": torch.zeros_like(out["rays_valid"]), + } + + out_full = { + "comp_rgb": out["comp_rgb"] + out_bg["comp_rgb"] * (1.0 - out["opacity"]), + "num_samples": out["num_samples"] + out_bg["num_samples"], + "rays_valid": out["rays_valid"] | out_bg["rays_valid"], + } + + return { + **out, + **{k + "_bg": v for k, v in out_bg.items()}, + **{k + "_full": v for k, v in out_full.items()}, + } + + def forward(self, rays): + if self.training: + out = self.forward_(rays) + else: + out = chunk_batch(self.forward_, self.config.ray_chunk, True, rays) + return {**out, "inv_s": self.variance.inv_s} + + def train(self, mode=True): + self.randomized = mode and self.config.randomized + return super().train(mode=mode) + + def eval(self): + self.randomized = False + return super().eval() + + def regularizations(self, out): + losses = {} + losses.update(self.geometry.regularizations(out)) + losses.update(self.texture.regularizations(out)) + return losses + + @torch.no_grad() + def export(self, export_config): + mesh = self.isosurface() + if export_config.export_vertex_color: + _, sdf_grad, feature = chunk_batch( + self.geometry, + export_config.chunk_size, + False, + mesh["v_pos"].to(self.rank), + with_grad=True, + with_feature=True, + ) + normal = F.normalize(sdf_grad, p=2, dim=-1) + rgb = self.texture( + feature, -normal, normal + ) # set the viewing directions to the normal to get "albedo" + mesh["v_rgb"] = rgb.cpu() + return mesh diff --git a/mesh_recon/models/ray_utils.py b/mesh_recon/models/ray_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ca1866fa43aedb83e111233af1c5d0e37dbedf75 --- /dev/null +++ b/mesh_recon/models/ray_utils.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +def cast_rays(ori, dir, z_vals): + return ori[..., None, :] + z_vals[..., None] * dir[..., None, :] + + +def get_ray_directions(W, H, fx, fy, cx, cy, use_pixel_centers=True): + pixel_center = 0.5 if use_pixel_centers else 0 + i, j = np.meshgrid( + np.arange(W, dtype=np.float32) + pixel_center, + np.arange(H, dtype=np.float32) + pixel_center, + indexing='xy' + ) + i, j = torch.from_numpy(i), torch.from_numpy(j) + + # directions = torch.stack([(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1) # (H, W, 3) + # opencv system + directions = torch.stack([(i - cx) / fx, (j - cy) / fy, torch.ones_like(i)], -1) # (H, W, 3) + + return directions + + +def get_ortho_ray_directions_origins(W, H, use_pixel_centers=True): + pixel_center = 0.5 if use_pixel_centers else 0 + i, j = np.meshgrid( + np.arange(W, dtype=np.float32) + pixel_center, + np.arange(H, dtype=np.float32) + pixel_center, + indexing='xy' + ) + i, j = torch.from_numpy(i), torch.from_numpy(j) + + origins = torch.stack([(i/W-0.5)*2, (j/H-0.5)*2, torch.zeros_like(i)], dim=-1) # W, H, 3 + directions = torch.stack([torch.zeros_like(i), torch.zeros_like(j), torch.ones_like(i)], dim=-1) # W, H, 3 + + return origins, directions + + +def get_rays(directions, c2w, keepdim=False): + # Rotate ray directions from camera coordinate to the world coordinate + # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow? + assert directions.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4) + rays_d = (directions[:,None,:] * c2w[:,:3,:3]).sum(-1) # (N_rays, 3) + rays_o = c2w[:,:,3].expand(rays_d.shape) + elif directions.ndim == 3: # (H, W, 3) + if c2w.ndim == 2: # (4, 4) + rays_d = (directions[:,:,None,:] * c2w[None,None,:3,:3]).sum(-1) # (H, W, 3) + rays_o = c2w[None,None,:,3].expand(rays_d.shape) + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = (directions[None,:,:,None,:] * c2w[:,None,None,:3,:3]).sum(-1) # (B, H, W, 3) + rays_o = c2w[:,None,None,:,3].expand(rays_d.shape) + + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d + + +# rays_v = torch.matmul(self.pose_all[img_idx, None, None, :3, :3].cuda(), rays_v[:, :, :, None].cuda()).squeeze() # W, H, 3 + +# rays_o = torch.matmul(self.pose_all[img_idx, None, None, :3, :3].cuda(), q[:, :, :, None].cuda()).squeeze() # W, H, 3 +# rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(rays_v.shape).cuda() + rays_o # W, H, 3 + +def get_ortho_rays(origins, directions, c2w, keepdim=False): + # Rotate ray directions from camera coordinate to the world coordinate + # rays_d = directions @ c2w[:, :3].T # (H, W, 3) # slow? + assert directions.shape[-1] == 3 + assert origins.shape[-1] == 3 + + if directions.ndim == 2: # (N_rays, 3) + assert c2w.ndim == 3 # (N_rays, 4, 4) / (1, 4, 4) + rays_d = torch.matmul(c2w[:, :3, :3], directions[:, :, None]).squeeze() # (N_rays, 3) + rays_o = torch.matmul(c2w[:, :3, :3], origins[:, :, None]).squeeze() # (N_rays, 3) + rays_o = c2w[:,:3,3].expand(rays_d.shape) + rays_o + elif directions.ndim == 3: # (H, W, 3) + if c2w.ndim == 2: # (4, 4) + rays_d = torch.matmul(c2w[None, None, :3, :3], directions[:, :, :, None]).squeeze() # (H, W, 3) + rays_o = torch.matmul(c2w[None, None, :3, :3], origins[:, :, :, None]).squeeze() # (H, W, 3) + rays_o = c2w[None, None,:3,3].expand(rays_d.shape) + rays_o + elif c2w.ndim == 3: # (B, 4, 4) + rays_d = torch.matmul(c2w[:,None, None, :3, :3], directions[None, :, :, :, None]).squeeze() # # (B, H, W, 3) + rays_o = torch.matmul(c2w[:,None, None, :3, :3], origins[None, :, :, :, None]).squeeze() # # (B, H, W, 3) + rays_o = c2w[:,None, None, :3,3].expand(rays_d.shape) + rays_o + + if not keepdim: + rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3) + + return rays_o, rays_d diff --git a/mesh_recon/models/texture.py b/mesh_recon/models/texture.py new file mode 100644 index 0000000000000000000000000000000000000000..4a83c9775c89d812cf6009155a414771c5462ebf --- /dev/null +++ b/mesh_recon/models/texture.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn + +import models +from models.utils import get_activation +from models.network_utils import get_encoding, get_mlp +from systems.utils import update_module_step + + +@models.register('volume-radiance') +class VolumeRadiance(nn.Module): + def __init__(self, config): + super(VolumeRadiance, self).__init__() + self.config = config + self.with_viewdir = False #self.config.get('wo_viewdir', False) + self.n_dir_dims = self.config.get('n_dir_dims', 3) + self.n_output_dims = 3 + + if self.with_viewdir: + encoding = get_encoding(self.n_dir_dims, self.config.dir_encoding_config) + self.n_input_dims = self.config.input_feature_dim + encoding.n_output_dims + # self.network_base = get_mlp(self.config.input_feature_dim, self.n_output_dims, self.config.mlp_network_config) + else: + encoding = None + self.n_input_dims = self.config.input_feature_dim + + network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) + self.encoding = encoding + self.network = network + + def forward(self, features, dirs, *args): + + # features = features.detach() + if self.with_viewdir: + dirs = (dirs + 1.) / 2. # (-1, 1) => (0, 1) + dirs_embd = self.encoding(dirs.view(-1, self.n_dir_dims)) + network_inp = torch.cat([features.view(-1, features.shape[-1]), dirs_embd] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) + # network_inp_base = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) + color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() + # color_base = self.network_base(network_inp_base).view(*features.shape[:-1], self.n_output_dims).float() + # color = color + color_base + else: + network_inp = torch.cat([features.view(-1, features.shape[-1])] + [arg.view(-1, arg.shape[-1]) for arg in args], dim=-1) + color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() + + if 'color_activation' in self.config: + color = get_activation(self.config.color_activation)(color) + return color + + def update_step(self, epoch, global_step): + update_module_step(self.encoding, epoch, global_step) + + def regularizations(self, out): + return {} + + +@models.register('volume-color') +class VolumeColor(nn.Module): + def __init__(self, config): + super(VolumeColor, self).__init__() + self.config = config + self.n_output_dims = 3 + self.n_input_dims = self.config.input_feature_dim + network = get_mlp(self.n_input_dims, self.n_output_dims, self.config.mlp_network_config) + self.network = network + + def forward(self, features, *args): + network_inp = features.view(-1, features.shape[-1]) + color = self.network(network_inp).view(*features.shape[:-1], self.n_output_dims).float() + if 'color_activation' in self.config: + color = get_activation(self.config.color_activation)(color) + return color + + def regularizations(self, out): + return {} diff --git a/mesh_recon/models/utils.py b/mesh_recon/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5c3cf19dd3e8f277783db68f1435c8f9755e96 --- /dev/null +++ b/mesh_recon/models/utils.py @@ -0,0 +1,119 @@ +import gc +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + +import tinycudann as tcnn + + +def chunk_batch(func, chunk_size, move_to_cpu, *args, **kwargs): + B = None + for arg in args: + if isinstance(arg, torch.Tensor): + B = arg.shape[0] + break + out = defaultdict(list) + out_type = None + for i in range(0, B, chunk_size): + out_chunk = func(*[arg[i:i+chunk_size] if isinstance(arg, torch.Tensor) else arg for arg in args], **kwargs) + if out_chunk is None: + continue + out_type = type(out_chunk) + if isinstance(out_chunk, torch.Tensor): + out_chunk = {0: out_chunk} + elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): + chunk_length = len(out_chunk) + out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} + elif isinstance(out_chunk, dict): + pass + else: + print(f'Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}.') + exit(1) + for k, v in out_chunk.items(): + v = v if torch.is_grad_enabled() else v.detach() + v = v.cpu() if move_to_cpu else v + out[k].append(v) + + if out_type is None: + return + + out = {k: torch.cat(v, dim=0) for k, v in out.items()} + if out_type is torch.Tensor: + return out[0] + elif out_type in [tuple, list]: + return out_type([out[i] for i in range(chunk_length)]) + elif out_type is dict: + return out + + +class _TruncExp(Function): # pylint: disable=abstract-method + # Implementation from torch-ngp: + # https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py + @staticmethod + @custom_fwd(cast_inputs=torch.float32) + def forward(ctx, x): # pylint: disable=arguments-differ + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): # pylint: disable=arguments-differ + x = ctx.saved_tensors[0] + return g * torch.exp(torch.clamp(x, max=15)) + +trunc_exp = _TruncExp.apply + + +def get_activation(name): + if name is None: + return lambda x: x + name = name.lower() + if name == 'none': + return lambda x: x + elif name.startswith('scale'): + scale_factor = float(name[5:]) + return lambda x: x.clamp(0., scale_factor) / scale_factor + elif name.startswith('clamp'): + clamp_max = float(name[5:]) + return lambda x: x.clamp(0., clamp_max) + elif name.startswith('mul'): + mul_factor = float(name[3:]) + return lambda x: x * mul_factor + elif name == 'lin2srgb': + return lambda x: torch.where(x > 0.0031308, torch.pow(torch.clamp(x, min=0.0031308), 1.0/2.4)*1.055 - 0.055, 12.92*x).clamp(0., 1.) + elif name == 'trunc_exp': + return trunc_exp + elif name.startswith('+') or name.startswith('-'): + return lambda x: x + float(name) + elif name == 'sigmoid': + return lambda x: torch.sigmoid(x) + elif name == 'tanh': + return lambda x: torch.tanh(x) + else: + return getattr(F, name) + + +def dot(x, y): + return torch.sum(x*y, -1, keepdim=True) + + +def reflect(x, n): + return 2 * dot(x, n) * n - x + + +def scale_anything(dat, inp_scale, tgt_scale): + if inp_scale is None: + inp_scale = [dat.min(), dat.max()] + dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) + dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] + return dat + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + tcnn.free_temporary_memory() diff --git a/mesh_recon/refine.py b/mesh_recon/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..f8109f4f7efe86845afd72c4709e62b124c0fc29 --- /dev/null +++ b/mesh_recon/refine.py @@ -0,0 +1,288 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import nvdiffrast.torch as dr +import kiui +from kiui.mesh import Mesh +import json +from pathlib import Path +import tqdm +from PIL import Image +from torchvision.transforms.functional import to_tensor +from torchvision.utils import save_image +import trimesh +from mediapy import write_image, write_video +from einops import rearrange + +from kiui.op import uv_padding, safe_normalize, inverse_sigmoid +from kiui.cam import orbit_camera, get_perspective + +from torchmetrics.image import LearnedPerceptualImagePatchSimilarity + +from mesh import Mesh +from mediapy import read_video +import tyro + +from datasets.v3d import get_uniform_poses + + +class Refiner(nn.Module): + def __init__(self, mesh_filename, video, num_opt=4, lpips: float = 0.0) -> None: + super().__init__() + self.output_size = 512 + znear = 0.1 + zfar = 10 + self.mesh = Mesh.load_obj(mesh_filename) + # self.mesh.v[..., 1], self.mesh.v[..., 2] = ( + # self.mesh.v[..., 2], + # self.mesh.v[..., 1], + # ) + self.glctx = dr.RasterizeGLContext() + + self.device = torch.device("cuda") + self.lpips_meter = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=True + ).to(self.device) + self.lpips = lpips + + fov = 60 + + frames = read_video(video) + self.name = Path(video).stem + frames = frames.astype(np.float32) / 255.0 + frames = np.moveaxis(frames, -1, 1) + num_frames, h, w, c = frames.shape + self.poses = get_uniform_poses(num_frames, 2.0, 0.0, opengl=True) + frames = frames.astype(np.float32) / 255.0 + + self.image_gt = torch.from_numpy(frames).to(self.device) + + self.n_frames = len(self.poses) + self.opt_frames = np.linspace(0, self.n_frames, num_opt + 1)[:num_opt].astype( + int + ) + print(self.opt_frames) + + # gs renderer + self.tan_half_fov = np.tan(0.5 * np.deg2rad(fov)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (zfar + znear) / (zfar - znear) + self.proj_matrix[3, 2] = -(zfar * znear) / (zfar - znear) + self.proj_matrix[2, 3] = 1 + + self.glctx = dr.RasterizeGLContext() + + self.proj = torch.from_numpy(get_perspective(fov)).float().to(self.device) + + self.v = self.mesh.v.contiguous().float().to(self.device) + self.f = self.mesh.f.contiguous().int().to(self.device) + self.vc = self.mesh.vc.contiguous().float().to(self.device) + # self.vt = self.mesh.vt + # self.ft = self.mesh.ft + + def render_normal(self, pose): + h = w = self.output_size + + v = self.v + f = self.f + + if not hasattr(self.mesh, "vn") or self.mesh.vn is None: + self.mesh.auto_normal() + vc = self.mesh.vn.to(self.device) + + pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) + + vc = torch.einsum("ij, kj -> ki", pose[:3, :3].T, vc).contiguous() + + # get v_clip and render rgb + v_cam = ( + torch.matmul( + F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T + ) + .float() + .unsqueeze(0) + ) + v_clip = v_cam @ self.proj.T + + rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) + + alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] + alpha = ( + dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) + ) # [H, W] important to enable gradients! + + # color, texc_db = dr.interpolate( + # self.vc.unsqueeze(0), rast, f, rast_db=rast_db, diff_attrs="all" + # ) + color, texc_db = dr.interpolate(vc.unsqueeze(0), rast, f) + color = dr.antialias(color, rast, v_clip, f) + # image = torch.sigmoid( + # dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) + # ) # [1, H, W, 3] + + image = color.view(1, h, w, 3) + # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) + image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] + image = (image + 1) / 2.0 + image = alpha * image + (1 - alpha) + + return image, alpha + + def render_mesh(self, pose, use_sigmoid=True): + h = w = self.output_size + + v = self.v + f = self.f + if use_sigmoid: + vc = torch.sigmoid(self.vc) + else: + vc = self.vc + + pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) + + # get v_clip and render rgb + v_cam = ( + torch.matmul( + F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T + ) + .float() + .unsqueeze(0) + ) + v_clip = v_cam @ self.proj.T + + rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) + + alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] + alpha = ( + dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) + ) # [H, W] important to enable gradients! + + # color, texc_db = dr.interpolate( + # self.vc.unsqueeze(0), rast, f, rast_db=rast_db, diff_attrs="all" + # ) + color, texc_db = dr.interpolate(vc.unsqueeze(0), rast, f) + color = dr.antialias(color, rast, v_clip, f) + # image = torch.sigmoid( + # dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) + # ) # [1, H, W, 3] + + image = color.view(1, h, w, 3) + # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) + image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] + image = alpha * image + (1 - alpha) + + return image, alpha + + def refine_texture(self, texture_resolution: int = 512, iters: int = 5000): + h = w = texture_resolution + albedo = torch.ones(h * w, 3, device=self.device, dtype=torch.float32) * 0.5 + albedo = albedo.view(h, w, -1) + vc_original = self.vc.clone() + self.vc = nn.Parameter(inverse_sigmoid(vc_original)).to(self.device) + + optimizer = torch.optim.Adam( + [ + {"params": self.vc, "lr": 1e-3}, + ] + ) + + pbar = tqdm.trange(iters) + for i in pbar: + index = np.random.choice(self.opt_frames) + pose = self.poses[index] + image_gt = self.image_gt[index] + + image_pred, _ = self.render_mesh(pose) + + # if i % 1000 == 0: + # save_image(image_pred, f"tmp/image_pred_{i}.png") + # save_image(image_gt, f"tmp/image_gt_{i}.png") + + loss = F.mse_loss(image_pred, image_gt) + if self.lpips > 0.0: + loss += ( + self.lpips_meter( + image_gt.clamp(0, 1)[None], image_pred.clamp(0, 1)[None] + ) + * self.lpips + ) + # * 10.0 + + loss.backward() + optimizer.step() + optimizer.zero_grad() + + pbar.set_description(f"MSE = {loss.item():.6f}") + + @torch.no_grad() + def render_spiral(self): + images = [] + for i, pose in enumerate(self.poses): + image, _ = self.render_mesh(pose, use_sigmoid=False) + images.append(image) + + images = torch.stack(images) + images = images.cpu().numpy() + images = rearrange(images, "b c h w -> b h w c") + if not Path("renders").exists(): + Path("renders").mkdir(parents=True, exist_ok=True) + write_video(f"renders/{self.name}.mp4", images, fps=3) + + @torch.no_grad() + def render_normal_spiral(self): + images = [] + for i, pose in enumerate(self.poses): + image, _ = self.render_normal(pose) + images.append(image) + + images = torch.stack(images) + images = images.cpu().numpy() + images = rearrange(images, "b c h w -> b h w c") + Path("renders").mkdir(exist_ok=True, parents=True) + write_video(f"renders/{self.name}_normal.mp4", images, fps=3) + + def export(self, filename): + mesh = trimesh.Trimesh( + vertices=self.mesh.v.cpu().numpy(), + faces=self.mesh.f.cpu().numpy(), + vertex_colors=torch.sigmoid(self.vc.detach()).cpu().numpy(), + ) + self.vc.data = torch.sigmoid(self.vc.detach()) + trimesh.repair.fix_inversion(mesh) + mesh.export(filename) + + +def do_refine( + mesh: str, + scene: str, + num_opt: int = 4, + iters: int = 2000, + skip_refine: bool = False, + render_normal: bool = True, + lpips: float = 1.0, +): + refiner = Refiner( + # "tmp/corgi_size_1.obj", + mesh, + scene, + num_opt=num_opt, + lpips=lpips, + ) + if not skip_refine: + refiner.refine_texture(512, iters) + save_path = Path("refined") / f"{Path(scene).stem}.obj" + if not save_path.parent.exists(): + save_path.parent.mkdir(exist_ok=True, parents=True) + refiner.export(str(save_path)) + + refiner.render_spiral() + if render_normal: + refiner.render_normal_spiral() + + +if __name__ == "__main__": + tyro.cli(do_refine) diff --git a/mesh_recon/run.sh b/mesh_recon/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..617143ccecba268e77b2aeb48cb3ec266d098c40 --- /dev/null +++ b/mesh_recon/run.sh @@ -0,0 +1 @@ +python launch.py --config configs/neuralangelo-ortho-wmask.yaml --gpu 0 --train dataset.root_dir=$1 dataset.scene=$2 \ No newline at end of file diff --git a/mesh_recon/scripts/imgs2poses.py b/mesh_recon/scripts/imgs2poses.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b6e0b19c7192fceee0518b2cde691bfabd4ff4 --- /dev/null +++ b/mesh_recon/scripts/imgs2poses.py @@ -0,0 +1,85 @@ + +""" +This file is adapted from https://github.com/Fyusion/LLFF. +""" + +import os +import sys +import argparse +import subprocess + + +def run_colmap(basedir, match_type): + logfile_name = os.path.join(basedir, 'colmap_output.txt') + logfile = open(logfile_name, 'w') + + feature_extractor_args = [ + 'colmap', 'feature_extractor', + '--database_path', os.path.join(basedir, 'database.db'), + '--image_path', os.path.join(basedir, 'images'), + '--ImageReader.single_camera', '1' + ] + feat_output = ( subprocess.check_output(feature_extractor_args, universal_newlines=True) ) + logfile.write(feat_output) + print('Features extracted') + + exhaustive_matcher_args = [ + 'colmap', match_type, + '--database_path', os.path.join(basedir, 'database.db'), + ] + + match_output = ( subprocess.check_output(exhaustive_matcher_args, universal_newlines=True) ) + logfile.write(match_output) + print('Features matched') + + p = os.path.join(basedir, 'sparse') + if not os.path.exists(p): + os.makedirs(p) + + mapper_args = [ + 'colmap', 'mapper', + '--database_path', os.path.join(basedir, 'database.db'), + '--image_path', os.path.join(basedir, 'images'), + '--output_path', os.path.join(basedir, 'sparse'), # --export_path changed to --output_path in colmap 3.6 + '--Mapper.num_threads', '16', + '--Mapper.init_min_tri_angle', '4', + '--Mapper.multiple_models', '0', + '--Mapper.extract_colors', '0', + ] + + map_output = ( subprocess.check_output(mapper_args, universal_newlines=True) ) + logfile.write(map_output) + logfile.close() + print('Sparse map created') + + print( 'Finished running COLMAP, see {} for logs'.format(logfile_name) ) + + +def gen_poses(basedir, match_type): + files_needed = ['{}.bin'.format(f) for f in ['cameras', 'images', 'points3D']] + if os.path.exists(os.path.join(basedir, 'sparse/0')): + files_had = os.listdir(os.path.join(basedir, 'sparse/0')) + else: + files_had = [] + if not all([f in files_had for f in files_needed]): + print( 'Need to run COLMAP' ) + run_colmap(basedir, match_type) + else: + print('Don\'t need to run COLMAP') + + return True + + +if __name__=='__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--match_type', type=str, + default='exhaustive_matcher', help='type of matcher used. Valid options: \ + exhaustive_matcher sequential_matcher. Other matchers not supported at this time') + parser.add_argument('scenedir', type=str, + help='input scene directory') + args = parser.parse_args() + + if args.match_type != 'exhaustive_matcher' and args.match_type != 'sequential_matcher': + print('ERROR: matcher type ' + args.match_type + ' is not valid. Aborting') + sys.exit() + gen_poses(args.scenedir, args.match_type) diff --git a/mesh_recon/systems/__init__.py b/mesh_recon/systems/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c20369c0f9b733366b6b54cfe79529774c8c8131 --- /dev/null +++ b/mesh_recon/systems/__init__.py @@ -0,0 +1,22 @@ +systems = {} + + +def register(name): + def decorator(cls): + systems[name] = cls + return cls + + return decorator + + +def make(name, config, load_from_checkpoint=None): + if load_from_checkpoint is None: + system = systems[name](config) + else: + system = systems[name].load_from_checkpoint( + load_from_checkpoint, strict=False, config=config + ) + return system + + +from . import neus, neus_ortho, neus_pinhole, neus_videonvs diff --git a/mesh_recon/systems/base.py b/mesh_recon/systems/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bcdbdc76d810548f85ebbaf64870a33f5ddaf1 --- /dev/null +++ b/mesh_recon/systems/base.py @@ -0,0 +1,128 @@ +import pytorch_lightning as pl + +import models +from systems.utils import parse_optimizer, parse_scheduler, update_module_step +from utils.mixins import SaverMixin +from utils.misc import config_to_primitive, get_rank + + +class BaseSystem(pl.LightningModule, SaverMixin): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def __init__(self, config): + super().__init__() + self.config = config + self.rank = get_rank() + self.prepare() + self.model = models.make(self.config.model.name, self.config.model) + + def prepare(self): + pass + + def forward(self, batch): + raise NotImplementedError + + def C(self, value): + if isinstance(value, int) or isinstance(value, float): + pass + else: + value = config_to_primitive(value) + if not isinstance(value, list): + raise TypeError('Scalar specification only supports list, got', type(value)) + if len(value) == 3: + value = [0] + value + assert len(value) == 4 + start_step, start_value, end_value, end_step = value + if isinstance(end_step, int): + current_step = self.global_step + value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) + elif isinstance(end_step, float): + current_step = self.current_epoch + value = start_value + (end_value - start_value) * max(min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0) + return value + + def preprocess_data(self, batch, stage): + pass + + """ + Implementing on_after_batch_transfer of DataModule does the same. + But on_after_batch_transfer does not support DP. + """ + def on_train_batch_start(self, batch, batch_idx, unused=0): + self.dataset = self.trainer.datamodule.train_dataloader().dataset + self.preprocess_data(batch, 'train') + update_module_step(self.model, self.current_epoch, self.global_step) + + def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): + self.dataset = self.trainer.datamodule.val_dataloader().dataset + self.preprocess_data(batch, 'validation') + update_module_step(self.model, self.current_epoch, self.global_step) + + def on_test_batch_start(self, batch, batch_idx, dataloader_idx): + self.dataset = self.trainer.datamodule.test_dataloader().dataset + self.preprocess_data(batch, 'test') + update_module_step(self.model, self.current_epoch, self.global_step) + + def on_predict_batch_start(self, batch, batch_idx, dataloader_idx): + self.dataset = self.trainer.datamodule.predict_dataloader().dataset + self.preprocess_data(batch, 'predict') + update_module_step(self.model, self.current_epoch, self.global_step) + + def training_step(self, batch, batch_idx): + raise NotImplementedError + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + raise NotImplementedError + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + """ + Gather metrics from all devices, compute mean. + Purge repeated results using data index. + """ + raise NotImplementedError + + def test_step(self, batch, batch_idx): + raise NotImplementedError + + def test_epoch_end(self, out): + """ + Gather metrics from all devices, compute mean. + Purge repeated results using data index. + """ + raise NotImplementedError + + def export(self): + raise NotImplementedError + + def configure_optimizers(self): + optim = parse_optimizer(self.config.system.optimizer, self.model) + ret = { + 'optimizer': optim, + } + if 'scheduler' in self.config.system: + ret.update({ + 'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim), + }) + return ret + diff --git a/mesh_recon/systems/criterions.py b/mesh_recon/systems/criterions.py new file mode 100644 index 0000000000000000000000000000000000000000..b101032ec7bc8d9943dd5df47557c4b6d3aa465b --- /dev/null +++ b/mesh_recon/systems/criterions.py @@ -0,0 +1,164 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class WeightedLoss(nn.Module): + @property + def func(self): + raise NotImplementedError + + def forward(self, inputs, targets, weight=None, reduction='mean'): + assert reduction in ['none', 'sum', 'mean', 'valid_mean'] + loss = self.func(inputs, targets, reduction='none') + if weight is not None: + while weight.ndim < inputs.ndim: + weight = weight[..., None] + loss *= weight.float() + if reduction == 'none': + return loss + elif reduction == 'sum': + return loss.sum() + elif reduction == 'mean': + return loss.mean() + elif reduction == 'valid_mean': + return loss.sum() / weight.float().sum() + + +class MSELoss(WeightedLoss): + @property + def func(self): + return F.mse_loss + + +class L1Loss(WeightedLoss): + @property + def func(self): + return F.l1_loss + + +class PSNR(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, inputs, targets, valid_mask=None, reduction='mean'): + assert reduction in ['mean', 'none'] + value = (inputs - targets)**2 + if valid_mask is not None: + value = value[valid_mask] + if reduction == 'mean': + return -10 * torch.log10(torch.mean(value)) + elif reduction == 'none': + return -10 * torch.log10(torch.mean(value, dim=tuple(range(value.ndim)[1:]))) + + +class SSIM(): + def __init__(self, data_range=(0, 1), kernel_size=(11, 11), sigma=(1.5, 1.5), k1=0.01, k2=0.03, gaussian=True): + self.kernel_size = kernel_size + self.sigma = sigma + self.gaussian = gaussian + + if any(x % 2 == 0 or x <= 0 for x in self.kernel_size): + raise ValueError(f"Expected kernel_size to have odd positive number. Got {kernel_size}.") + if any(y <= 0 for y in self.sigma): + raise ValueError(f"Expected sigma to have positive number. Got {sigma}.") + + data_scale = data_range[1] - data_range[0] + self.c1 = (k1 * data_scale)**2 + self.c2 = (k2 * data_scale)**2 + self.pad_h = (self.kernel_size[0] - 1) // 2 + self.pad_w = (self.kernel_size[1] - 1) // 2 + self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma) + + def _uniform(self, kernel_size): + max, min = 2.5, -2.5 + ksize_half = (kernel_size - 1) * 0.5 + kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + for i, j in enumerate(kernel): + if min <= j <= max: + kernel[i] = 1 / (max - min) + else: + kernel[i] = 0 + + return kernel.unsqueeze(dim=0) # (1, kernel_size) + + def _gaussian(self, kernel_size, sigma): + ksize_half = (kernel_size - 1) * 0.5 + kernel = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + gauss = torch.exp(-0.5 * (kernel / sigma).pow(2)) + return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) + + def _gaussian_or_uniform_kernel(self, kernel_size, sigma): + if self.gaussian: + kernel_x = self._gaussian(kernel_size[0], sigma[0]) + kernel_y = self._gaussian(kernel_size[1], sigma[1]) + else: + kernel_x = self._uniform(kernel_size[0]) + kernel_y = self._uniform(kernel_size[1]) + + return torch.matmul(kernel_x.t(), kernel_y) # (kernel_size, 1) * (1, kernel_size) + + def __call__(self, output, target, reduction='mean'): + if output.dtype != target.dtype: + raise TypeError( + f"Expected output and target to have the same data type. Got output: {output.dtype} and y: {target.dtype}." + ) + + if output.shape != target.shape: + raise ValueError( + f"Expected output and target to have the same shape. Got output: {output.shape} and y: {target.shape}." + ) + + if len(output.shape) != 4 or len(target.shape) != 4: + raise ValueError( + f"Expected output and target to have BxCxHxW shape. Got output: {output.shape} and y: {target.shape}." + ) + + assert reduction in ['mean', 'sum', 'none'] + + channel = output.size(1) + if len(self._kernel.shape) < 4: + self._kernel = self._kernel.expand(channel, 1, -1, -1) + + output = F.pad(output, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") + target = F.pad(target, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect") + + input_list = torch.cat([output, target, output * output, target * target, output * target]) + outputs = F.conv2d(input_list, self._kernel, groups=channel) + + output_list = [outputs[x * output.size(0) : (x + 1) * output.size(0)] for x in range(len(outputs))] + + mu_pred_sq = output_list[0].pow(2) + mu_target_sq = output_list[1].pow(2) + mu_pred_target = output_list[0] * output_list[1] + + sigma_pred_sq = output_list[2] - mu_pred_sq + sigma_target_sq = output_list[3] - mu_target_sq + sigma_pred_target = output_list[4] - mu_pred_target + + a1 = 2 * mu_pred_target + self.c1 + a2 = 2 * sigma_pred_target + self.c2 + b1 = mu_pred_sq + mu_target_sq + self.c1 + b2 = sigma_pred_sq + sigma_target_sq + self.c2 + + ssim_idx = (a1 * a2) / (b1 * b2) + _ssim = torch.mean(ssim_idx, (1, 2, 3)) + + if reduction == 'none': + return _ssim + elif reduction == 'sum': + return _ssim.sum() + elif reduction == 'mean': + return _ssim.mean() + + +def binary_cross_entropy(input, target, reduction='mean'): + """ + F.binary_cross_entropy is not numerically stable in mixed-precision training. + """ + loss = -(target * torch.log(input) + (1 - target) * torch.log(1 - input)) + + if reduction == 'mean': + return loss.mean() + elif reduction == 'none': + return loss diff --git a/mesh_recon/systems/nerf.py b/mesh_recon/systems/nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..c5fc821f430aeee62880b7240e75a185ca9b15f2 --- /dev/null +++ b/mesh_recon/systems/nerf.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.ray_utils import get_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR + + +@systems.register('nerf-system') +class NeRFSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def prepare(self): + self.criterions = { + 'psnr': PSNR() + } + self.train_num_samples = self.config.model.train_num_rays * self.config.model.num_samples_per_ray + self.train_num_rays = self.config.model.train_num_rays + + def forward(self, batch): + return self.model(batch['rays']) + + def preprocess_data(self, batch, stage): + if 'index' in batch: # validation / testing + index = batch['index'] + else: + if self.config.model.batch_image_sampling: + index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) + else: + index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) + if stage in ['train']: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + y = torch.randint( + 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ['train']: + if self.config.model.background_color == 'white': + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'random': + self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) + + batch.update({ + 'rays': rays, + 'rgb': rgb, + 'fg_mask': fg_mask + }) + + def training_step(self, batch, batch_idx): + out = self(batch) + + loss = 0. + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples'].sum().item())) + self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) + + loss_rgb = F.smooth_l1_loss(out['comp_rgb'][out['rays_valid'][...,0]], batch['rgb'][out['rays_valid'][...,0]]) + self.log('train/loss_rgb', loss_rgb) + loss += loss_rgb * self.C(self.config.system.loss.lambda_rgb) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss, but still slows down training by ~30% + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) + self.log('train/loss_distortion', loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f'train/loss_{name}', value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + for name, value in self.config.system.loss.items(): + if name.startswith('lambda'): + self.log(f'train_params/{name}', self.C(value)) + + self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) + + return { + 'loss': loss + } + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) + + def test_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'grayscale', 'img': out['opacity'].view(H, W), 'kwargs': {'cmap': None, 'data_range': (0, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + def test_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) + + self.save_img_sequence( + f"it{self.global_step}-test", + f"it{self.global_step}-test", + '(\d+)\.png', + save_format='mp4', + fps=30 + ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + **mesh + ) diff --git a/mesh_recon/systems/neus.py b/mesh_recon/systems/neus.py new file mode 100644 index 0000000000000000000000000000000000000000..ce273d0790a1fbea4d795b07e285c1318573a562 --- /dev/null +++ b/mesh_recon/systems/neus.py @@ -0,0 +1,265 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.utils import cleanup +from models.ray_utils import get_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR, binary_cross_entropy + + +@systems.register('neus-system') +class NeuSSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def prepare(self): + self.criterions = { + 'psnr': PSNR() + } + self.train_num_samples = self.config.model.train_num_rays * (self.config.model.num_samples_per_ray + self.config.model.get('num_samples_per_ray_bg', 0)) + self.train_num_rays = self.config.model.train_num_rays + + def forward(self, batch): + return self.model(batch['rays']) + + def preprocess_data(self, batch, stage): + if 'index' in batch: # validation / testing + index = batch['index'] + else: + if self.config.model.batch_image_sampling: + index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) + else: + index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) + if stage in ['train']: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + y = torch.randint( + 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + rays_o, rays_d = get_rays(directions, c2w) + rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ['train']: + if self.config.model.background_color == 'white': + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'random': + self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) + + batch.update({ + 'rays': rays, + 'rgb': rgb, + 'fg_mask': fg_mask + }) + + def training_step(self, batch, batch_idx): + out = self(batch) + + loss = 0. + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples_full'].sum().item())) + self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) + + loss_rgb_mse = F.mse_loss(out['comp_rgb_full'][out['rays_valid_full'][...,0]], batch['rgb'][out['rays_valid_full'][...,0]]) + self.log('train/loss_rgb_mse', loss_rgb_mse) + loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) + + loss_rgb_l1 = F.l1_loss(out['comp_rgb_full'][out['rays_valid_full'][...,0]], batch['rgb'][out['rays_valid_full'][...,0]]) + self.log('train/loss_rgb', loss_rgb_l1) + loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) + + loss_eikonal = ((torch.linalg.norm(out['sdf_grad_samples'], ord=2, dim=-1) - 1.)**2).mean() + self.log('train/loss_eikonal', loss_eikonal) + loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) + + opacity = torch.clamp(out['opacity'].squeeze(-1), 1.e-3, 1.-1.e-3) + loss_mask = binary_cross_entropy(opacity, batch['fg_mask'].float()) + self.log('train/loss_mask', loss_mask) + loss += loss_mask * (self.C(self.config.system.loss.lambda_mask) if self.dataset.has_mask else 0.0) + + loss_opaque = binary_cross_entropy(opacity, opacity) + self.log('train/loss_opaque', loss_opaque) + loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) + + loss_sparsity = torch.exp(-self.config.system.loss.sparsity_scale * out['sdf_samples'].abs()).mean() + self.log('train/loss_sparsity', loss_sparsity) + loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) + + if self.C(self.config.system.loss.lambda_curvature) > 0: + assert 'sdf_laplace_samples' in out, "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" + loss_curvature = out['sdf_laplace_samples'].abs().mean() + self.log('train/loss_curvature', loss_curvature) + loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) + self.log('train/loss_distortion', loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + if self.config.model.learned_background and self.C(self.config.system.loss.lambda_distortion_bg) > 0: + loss_distortion_bg = flatten_eff_distloss(out['weights_bg'], out['points_bg'], out['intervals_bg'], out['ray_indices_bg']) + self.log('train/loss_distortion_bg', loss_distortion_bg) + loss += loss_distortion_bg * self.C(self.config.system.loss.lambda_distortion_bg) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f'train/loss_{name}', value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + self.log('train/inv_s', out['inv_s'], prog_bar=True) + + for name, value in self.config.system.loss.items(): + if name.startswith('lambda'): + self.log(f'train_params/{name}', self.C(value)) + + self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) + + return { + 'loss': loss + } + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + ] + ([ + {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + ] if self.config.model.learned_background else []) + [ + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) + + def test_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + ] + ([ + {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + ] if self.config.model.learned_background else []) + [ + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + def test_epoch_end(self, out): + """ + Synchronize devices. + Generate image sequence using test outputs. + """ + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) + + self.save_img_sequence( + f"it{self.global_step}-test", + f"it{self.global_step}-test", + '(\d+)\.png', + save_format='mp4', + fps=30 + ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + **mesh + ) diff --git a/mesh_recon/systems/neus_ortho.py b/mesh_recon/systems/neus_ortho.py new file mode 100644 index 0000000000000000000000000000000000000000..803b2a84564e491e16883ee0177979c4280e8b3d --- /dev/null +++ b/mesh_recon/systems/neus_ortho.py @@ -0,0 +1,358 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.utils import cleanup +from models.ray_utils import get_ortho_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR, binary_cross_entropy + +import pdb + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None , type='mean'): + error, indices = torch.sort(error) + # only sum relatively small errors + s_error = torch.index_select(error, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + if extra_weights is not None: + weights = torch.index_select(extra_weights, 0, index=indices[:int(penalize_ratio * indices.shape[0])]) + s_error = s_error * weights + + if type == 'mean': + return torch.mean(s_error) + elif type == 'sum': + return torch.sum(s_error) + +@systems.register('ortho-neus-system') +class OrthoNeuSSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + def prepare(self): + self.criterions = { + 'psnr': PSNR() + } + self.train_num_samples = self.config.model.train_num_rays * (self.config.model.num_samples_per_ray + self.config.model.get('num_samples_per_ray_bg', 0)) + self.train_num_rays = self.config.model.train_num_rays + self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + + def forward(self, batch): + return self.model(batch['rays']) + + def preprocess_data(self, batch, stage): + if 'index' in batch: # validation / testing + index = batch['index'] + else: + if self.config.model.batch_image_sampling: + index = torch.randint(0, len(self.dataset.all_images), size=(self.train_num_rays,), device=self.dataset.all_images.device) + else: + index = torch.randint(0, len(self.dataset.all_images), size=(1,), device=self.dataset.all_images.device) + if stage in ['train']: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, self.dataset.w, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + y = torch.randint( + 0, self.dataset.h, size=(self.train_num_rays,), device=self.dataset.all_images.device + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + origins = self.dataset.origins[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + origins = self.dataset.origins[index, y, x] + rays_o, rays_d = get_ortho_rays(origins, directions, c2w) + rgb = self.dataset.all_images[index, y, x].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + normal = self.dataset.all_normals_world[index, y, x].view(-1, self.dataset.all_normals_world.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index, y, x].view(-1).to(self.rank) + view_weights = self.dataset.view_weights[index, y, x].view(-1).to(self.rank) + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + origins = self.dataset.origins + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + origins = self.dataset.origins[index][0] + rays_o, rays_d = get_ortho_rays(origins, directions, c2w) + rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + normal = self.dataset.all_normals_world[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index].view(-1).to(self.rank) + view_weights = None + + cosines = self.cos(rays_d, normal) + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ['train']: + if self.config.model.background_color == 'white': + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'black': + self.model.background_color = torch.zeros((3,), dtype=torch.float32, device=self.rank) + elif self.config.model.background_color == 'random': + self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None]) + + batch.update({ + 'rays': rays, + 'rgb': rgb, + 'normal': normal, + 'fg_mask': fg_mask, + 'rgb_mask': rgb_mask, + 'cosines': cosines, + 'view_weights': view_weights + }) + + def training_step(self, batch, batch_idx): + out = self(batch) + + cosines = batch['cosines'] + fg_mask = batch['fg_mask'] + rgb_mask = batch['rgb_mask'] + view_weights = batch['view_weights'] + + cosines[cosines > -0.1] = 0 + mask = ((fg_mask > 0) & (cosines < -0.1)) + rgb_mask = out['rays_valid_full'][...,0] & (rgb_mask > 0) + + grad_cosines = self.cos(batch['rays'][...,3:], out['comp_normal']).detach() + # grad_cosines = cosines + + loss = 0. + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int(self.train_num_rays * (self.train_num_samples / out['num_samples_full'].sum().item())) + self.train_num_rays = min(int(self.train_num_rays * 0.9 + train_num_rays * 0.1), self.config.model.max_train_num_rays) + + erros_rgb_mse = F.mse_loss(out['comp_rgb_full'][rgb_mask], batch['rgb'][rgb_mask], reduction='none') + # erros_rgb_mse = erros_rgb_mse * torch.exp(grad_cosines.abs())[:, None][rgb_mask] / torch.exp(grad_cosines.abs()[rgb_mask]).sum() + # loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='sum') + loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), + penalize_ratio=self.config.system.loss.rgb_p_ratio, type='mean') + self.log('train/loss_rgb_mse', loss_rgb_mse, prog_bar=True, rank_zero_only=True) + loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) + + loss_rgb_l1 = F.l1_loss(out['comp_rgb_full'][rgb_mask], batch['rgb'][rgb_mask], reduction='none') + loss_rgb_l1 = ranking_loss(loss_rgb_l1.sum(dim=1), + # extra_weights=view_weights[rgb_mask], + penalize_ratio=0.8) + self.log('train/loss_rgb', loss_rgb_l1) + loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) + + normal_errors = 1 - F.cosine_similarity(out['comp_normal'], batch['normal'], dim=1) + # normal_errors = normal_errors * cosines.abs() / cosines.abs().sum() + if self.config.system.loss.geo_aware: + normal_errors = normal_errors * torch.exp(cosines.abs()) / torch.exp(cosines.abs()).sum() + loss_normal = ranking_loss(normal_errors[mask], + penalize_ratio=self.config.system.loss.normal_p_ratio, + extra_weights=view_weights[mask], + type='sum') + else: + loss_normal = ranking_loss(normal_errors[mask], + penalize_ratio=self.config.system.loss.normal_p_ratio, + extra_weights=view_weights[mask], + type='mean') + + self.log('train/loss_normal', loss_normal, prog_bar=True, rank_zero_only=True) + loss += loss_normal * self.C(self.config.system.loss.lambda_normal) + + loss_eikonal = ((torch.linalg.norm(out['sdf_grad_samples'], ord=2, dim=-1) - 1.)**2).mean() + self.log('train/loss_eikonal', loss_eikonal, prog_bar=True, rank_zero_only=True) + loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) + + opacity = torch.clamp(out['opacity'].squeeze(-1), 1.e-3, 1.-1.e-3) + loss_mask = binary_cross_entropy(opacity, batch['fg_mask'].float(), reduction='none') + loss_mask = ranking_loss(loss_mask, + penalize_ratio=self.config.system.loss.mask_p_ratio, + extra_weights=view_weights) + self.log('train/loss_mask', loss_mask, prog_bar=True, rank_zero_only=True) + loss += loss_mask * (self.C(self.config.system.loss.lambda_mask) if self.dataset.has_mask else 0.0) + + loss_opaque = binary_cross_entropy(opacity, opacity) + self.log('train/loss_opaque', loss_opaque) + loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) + + loss_sparsity = torch.exp(-self.config.system.loss.sparsity_scale * out['random_sdf'].abs()).mean() + self.log('train/loss_sparsity', loss_sparsity, prog_bar=True, rank_zero_only=True) + loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) + + if self.C(self.config.system.loss.lambda_curvature) > 0: + assert 'sdf_laplace_samples' in out, "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" + loss_curvature = out['sdf_laplace_samples'].abs().mean() + self.log('train/loss_curvature', loss_curvature) + loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss(out['weights'], out['points'], out['intervals'], out['ray_indices']) + self.log('train/loss_distortion', loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + if self.config.model.learned_background and self.C(self.config.system.loss.lambda_distortion_bg) > 0: + loss_distortion_bg = flatten_eff_distloss(out['weights_bg'], out['points_bg'], out['intervals_bg'], out['ray_indices_bg']) + self.log('train/loss_distortion_bg', loss_distortion_bg) + loss += loss_distortion_bg * self.C(self.config.system.loss.lambda_distortion_bg) + + if self.C(self.config.system.loss.lambda_3d_normal_smooth) > 0: + if "random_sdf_grad" not in out: + raise ValueError( + "random_sdf_grad is required for normal smooth loss, no normal is found in the output." + ) + if "normal_perturb" not in out: + raise ValueError( + "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." + ) + normals_3d = out["random_sdf_grad"] + normals_perturb_3d = out["normal_perturb"] + loss_3d_normal_smooth = (normals_3d - normals_perturb_3d).abs().mean() + self.log('train/loss_3d_normal_smooth', loss_3d_normal_smooth, prog_bar=True ) + + loss += loss_3d_normal_smooth * self.C(self.config.system.loss.lambda_3d_normal_smooth) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f'train/loss_{name}', value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + self.log('train/inv_s', out['inv_s'], prog_bar=True) + + for name, value in self.config.system.loss.items(): + if name.startswith('lambda'): + self.log(f'train_params/{name}', self.C(value)) + + self.log('train/num_rays', float(self.train_num_rays), prog_bar=True) + + return { + 'loss': loss + } + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + W, H = self.dataset.img_wh + self.save_image_grid(f"it{self.global_step}-{batch['index'][0].item()}.png", [ + {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + ] + ([ + {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + ] if self.config.model.learned_background else []) + [ + {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + ]) + return { + 'psnr': psnr, + 'index': batch['index'] + } + + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out['index'].ndim == 1: + out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # DDP + else: + for oi, index in enumerate(step_out['index']): + out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + self.log('val/psnr', psnr, prog_bar=True, rank_zero_only=True) + self.export() + + # def test_step(self, batch, batch_idx): + # out = self(batch) + # psnr = self.criterions['psnr'](out['comp_rgb_full'].to(batch['rgb']), batch['rgb']) + # W, H = self.dataset.img_wh + # self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [ + # {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + # {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}} + # ] + ([ + # {'type': 'rgb', 'img': out['comp_rgb_bg'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + # {'type': 'rgb', 'img': out['comp_rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}, + # ] if self.config.model.learned_background else []) + [ + # {'type': 'grayscale', 'img': out['depth'].view(H, W), 'kwargs': {}}, + # {'type': 'rgb', 'img': out['comp_normal'].view(H, W, 3), 'kwargs': {'data_format': 'HWC', 'data_range': (-1, 1)}} + # ]) + # return { + # 'psnr': psnr, + # 'index': batch['index'] + # } + + def test_step(self, batch, batch_idx): + pass + + def test_epoch_end(self, out): + """ + Synchronize devices. + Generate image sequence using test outputs. + """ + # out = self.all_gather(out) + if self.trainer.is_global_zero: + # out_set = {} + # for step_out in out: + # # DP + # if step_out['index'].ndim == 1: + # out_set[step_out['index'].item()] = {'psnr': step_out['psnr']} + # # DDP + # else: + # for oi, index in enumerate(step_out['index']): + # out_set[index[0].item()] = {'psnr': step_out['psnr'][oi]} + # psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()])) + # self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True) + + # self.save_img_sequence( + # f"it{self.global_step}-test", + # f"it{self.global_step}-test", + # '(\d+)\.png', + # save_format='mp4', + # fps=30 + # ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + # pdb.set_trace() + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + ortho_scale=self.config.export.ortho_scale, + **mesh + ) diff --git a/mesh_recon/systems/neus_pinhole.py b/mesh_recon/systems/neus_pinhole.py new file mode 100644 index 0000000000000000000000000000000000000000..12e8abb7df7e439468dc1b932d471693d92728f0 --- /dev/null +++ b/mesh_recon/systems/neus_pinhole.py @@ -0,0 +1,501 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.utils import cleanup +from models.ray_utils import get_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR, binary_cross_entropy + +import pdb + + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None, type="mean"): + error, indices = torch.sort(error) + # only sum relatively small errors + s_error = torch.index_select( + error, 0, index=indices[: int(penalize_ratio * indices.shape[0])] + ) + if extra_weights is not None: + weights = torch.index_select( + extra_weights, 0, index=indices[: int(penalize_ratio * indices.shape[0])] + ) + s_error = s_error * weights + + if type == "mean": + return torch.mean(s_error) + elif type == "sum": + return torch.sum(s_error) + + +@systems.register("pinhole-neus-system") +class PinholeNeuSSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + + def prepare(self): + self.criterions = {"psnr": PSNR()} + self.train_num_samples = self.config.model.train_num_rays * ( + self.config.model.num_samples_per_ray + + self.config.model.get("num_samples_per_ray_bg", 0) + ) + self.train_num_rays = self.config.model.train_num_rays + self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + + def forward(self, batch): + return self.model(batch["rays"]) + + def preprocess_data(self, batch, stage): + if "index" in batch: # validation / testing + index = batch["index"] + else: + if self.config.model.batch_image_sampling: + index = torch.randint( + 0, + len(self.dataset.all_images), + size=(self.train_num_rays,), + device=self.dataset.all_images.device, + ) + else: + index = torch.randint( + 0, + len(self.dataset.all_images), + size=(1,), + device=self.dataset.all_images.device, + ) + if stage in ["train"]: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, + self.dataset.w, + size=(self.train_num_rays,), + device=self.dataset.all_images.device, + ) + y = torch.randint( + 0, + self.dataset.h, + size=(self.train_num_rays,), + device=self.dataset.all_images.device, + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + # origins = self.dataset.origins[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + # origins = self.dataset.origins[index, y, x] + rays_o, rays_d = get_rays(directions, c2w) + rgb = ( + self.dataset.all_images[index, y, x] + .view(-1, self.dataset.all_images.shape[-1]) + .to(self.rank) + ) + normal = ( + self.dataset.all_normals_world[index, y, x] + .view(-1, self.dataset.all_normals_world.shape[-1]) + .to(self.rank) + ) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index, y, x].view(-1).to(self.rank) + view_weights = self.dataset.view_weights[index, y, x].view(-1).to(self.rank) + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + # origins = self.dataset.origins + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + # origins = self.dataset.origins[index][0] + rays_o, rays_d = get_rays(directions, c2w) + rgb = ( + self.dataset.all_images[index] + .view(-1, self.dataset.all_images.shape[-1]) + .to(self.rank) + ) + normal = ( + self.dataset.all_normals_world[index] + .view(-1, self.dataset.all_images.shape[-1]) + .to(self.rank) + ) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index].view(-1).to(self.rank) + view_weights = None + + cosines = self.cos(rays_d, normal) + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ["train"]: + if self.config.model.background_color == "white": + self.model.background_color = torch.ones( + (3,), dtype=torch.float32, device=self.rank + ) + elif self.config.model.background_color == "black": + self.model.background_color = torch.zeros( + (3,), dtype=torch.float32, device=self.rank + ) + elif self.config.model.background_color == "random": + self.model.background_color = torch.rand( + (3,), dtype=torch.float32, device=self.rank + ) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones( + (3,), dtype=torch.float32, device=self.rank + ) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[..., None] + self.model.background_color * ( + 1 - fg_mask[..., None] + ) + + batch.update( + { + "rays": rays, + "rgb": rgb, + "normal": normal, + "fg_mask": fg_mask, + "rgb_mask": rgb_mask, + "cosines": cosines, + "view_weights": view_weights, + } + ) + + def training_step(self, batch, batch_idx): + out = self(batch) + + cosines = batch["cosines"] + fg_mask = batch["fg_mask"] + rgb_mask = batch["rgb_mask"] + view_weights = batch["view_weights"] + + cosines[cosines > -0.1] = 0 + mask = (fg_mask > 0) & (cosines < -0.1) + rgb_mask = out["rays_valid_full"][..., 0] & (rgb_mask > 0) + + grad_cosines = self.cos(batch["rays"][..., 3:], out["comp_normal"]).detach() + # grad_cosines = cosines + + loss = 0.0 + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int( + self.train_num_rays + * (self.train_num_samples / out["num_samples_full"].sum().item()) + ) + self.train_num_rays = min( + int(self.train_num_rays * 0.9 + train_num_rays * 0.1), + self.config.model.max_train_num_rays, + ) + + erros_rgb_mse = F.mse_loss( + out["comp_rgb_full"][rgb_mask], batch["rgb"][rgb_mask], reduction="none" + ) + # erros_rgb_mse = erros_rgb_mse * torch.exp(grad_cosines.abs())[:, None][rgb_mask] / torch.exp(grad_cosines.abs()[rgb_mask]).sum() + # loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='sum') + loss_rgb_mse = ranking_loss( + erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type="mean" + ) + self.log("train/loss_rgb_mse", loss_rgb_mse, prog_bar=True, rank_zero_only=True) + loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) + + loss_rgb_l1 = F.l1_loss( + out["comp_rgb_full"][rgb_mask], batch["rgb"][rgb_mask], reduction="none" + ) + loss_rgb_l1 = ranking_loss( + loss_rgb_l1.sum(dim=1), + extra_weights=view_weights[rgb_mask], + penalize_ratio=0.8, + ) + self.log("train/loss_rgb", loss_rgb_l1) + loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) + + normal_errors = 1 - F.cosine_similarity( + out["comp_normal"], batch["normal"], dim=1 + ) + # normal_errors = normal_errors * cosines.abs() / cosines.abs().sum() + normal_errors = ( + normal_errors * torch.exp(cosines.abs()) / torch.exp(cosines.abs()).sum() + ) + loss_normal = ranking_loss( + normal_errors[mask], + penalize_ratio=0.8, + # extra_weights=view_weights[mask], + type="sum", + ) + self.log("train/loss_normal", loss_normal, prog_bar=True, rank_zero_only=True) + loss += loss_normal * self.C(self.config.system.loss.lambda_normal) + + loss_eikonal = ( + (torch.linalg.norm(out["sdf_grad_samples"], ord=2, dim=-1) - 1.0) ** 2 + ).mean() + self.log("train/loss_eikonal", loss_eikonal, prog_bar=True, rank_zero_only=True) + loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) + + opacity = torch.clamp(out["opacity"].squeeze(-1), 1.0e-3, 1.0 - 1.0e-3) + loss_mask = binary_cross_entropy( + opacity, batch["fg_mask"].float(), reduction="none" + ) + loss_mask = ranking_loss( + loss_mask, penalize_ratio=0.9, extra_weights=view_weights + ) + self.log("train/loss_mask", loss_mask, prog_bar=True, rank_zero_only=True) + loss += loss_mask * ( + self.C(self.config.system.loss.lambda_mask) + if self.dataset.has_mask + else 0.0 + ) + + loss_opaque = binary_cross_entropy(opacity, opacity) + self.log("train/loss_opaque", loss_opaque) + loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) + + loss_sparsity = torch.exp( + -self.config.system.loss.sparsity_scale * out["random_sdf"].abs() + ).mean() + self.log( + "train/loss_sparsity", loss_sparsity, prog_bar=True, rank_zero_only=True + ) + loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) + + if self.C(self.config.system.loss.lambda_curvature) > 0: + assert ( + "sdf_laplace_samples" in out + ), "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" + loss_curvature = out["sdf_laplace_samples"].abs().mean() + self.log("train/loss_curvature", loss_curvature) + loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss( + out["weights"], out["points"], out["intervals"], out["ray_indices"] + ) + self.log("train/loss_distortion", loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + if ( + self.config.model.learned_background + and self.C(self.config.system.loss.lambda_distortion_bg) > 0 + ): + loss_distortion_bg = flatten_eff_distloss( + out["weights_bg"], + out["points_bg"], + out["intervals_bg"], + out["ray_indices_bg"], + ) + self.log("train/loss_distortion_bg", loss_distortion_bg) + loss += loss_distortion_bg * self.C( + self.config.system.loss.lambda_distortion_bg + ) + + if self.C(self.config.system.loss.lambda_3d_normal_smooth) > 0: + if "random_sdf_grad" not in out: + raise ValueError( + "random_sdf_grad is required for normal smooth loss, no normal is found in the output." + ) + if "normal_perturb" not in out: + raise ValueError( + "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." + ) + normals_3d = out["random_sdf_grad"] + normals_perturb_3d = out["normal_perturb"] + loss_3d_normal_smooth = (normals_3d - normals_perturb_3d).abs().mean() + self.log( + "train/loss_3d_normal_smooth", loss_3d_normal_smooth, prog_bar=True + ) + + loss += loss_3d_normal_smooth * self.C( + self.config.system.loss.lambda_3d_normal_smooth + ) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f"train/loss_{name}", value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + self.log("train/inv_s", out["inv_s"], prog_bar=True) + + for name, value in self.config.system.loss.items(): + if name.startswith("lambda"): + self.log(f"train_params/{name}", self.C(value)) + + self.log("train/num_rays", float(self.train_num_rays), prog_bar=True) + + return {"loss": loss} + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions["psnr"]( + out["comp_rgb_full"].to(batch["rgb"]), batch["rgb"] + ) + W, H = self.dataset.img_wh + self.save_image_grid( + f"it{self.global_step}-{batch['index'][0].item()}.png", + [ + { + "type": "rgb", + "img": batch["rgb"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": out["comp_rgb_full"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + ] + + ( + [ + { + "type": "rgb", + "img": out["comp_rgb_bg"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": out["comp_rgb"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + ] + if self.config.model.learned_background + else [] + ) + + [ + {"type": "grayscale", "img": out["depth"].view(H, W), "kwargs": {}}, + { + "type": "rgb", + "img": out["comp_normal"].view(H, W, 3), + "kwargs": {"data_format": "HWC", "data_range": (-1, 1)}, + }, + ], + ) + return {"psnr": psnr, "index": batch["index"]} + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out["index"].ndim == 1: + out_set[step_out["index"].item()] = {"psnr": step_out["psnr"]} + # DDP + else: + for oi, index in enumerate(step_out["index"]): + out_set[index[0].item()] = {"psnr": step_out["psnr"][oi]} + psnr = torch.mean(torch.stack([o["psnr"] for o in out_set.values()])) + self.log("val/psnr", psnr, prog_bar=True, rank_zero_only=True) + self.export() + + def test_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions["psnr"]( + out["comp_rgb_full"].to(batch["rgb"]), batch["rgb"] + ) + W, H = self.dataset.img_wh + self.save_image_grid( + f"it{self.global_step}-test/{batch['index'][0].item()}.png", + [ + { + "type": "rgb", + "img": batch["rgb"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": out["comp_rgb_full"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + ] + + ( + [ + { + "type": "rgb", + "img": out["comp_rgb_bg"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": out["comp_rgb"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + ] + if self.config.model.learned_background + else [] + ) + + [ + {"type": "grayscale", "img": out["depth"].view(H, W), "kwargs": {}}, + { + "type": "rgb", + "img": out["comp_normal"].view(H, W, 3), + "kwargs": {"data_format": "HWC", "data_range": (-1, 1)}, + }, + ], + ) + return {"psnr": psnr, "index": batch["index"]} + + def test_epoch_end(self, out): + """ + Synchronize devices. + Generate image sequence using test outputs. + """ + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out["index"].ndim == 1: + out_set[step_out["index"].item()] = {"psnr": step_out["psnr"]} + # DDP + else: + for oi, index in enumerate(step_out["index"]): + out_set[index[0].item()] = {"psnr": step_out["psnr"][oi]} + psnr = torch.mean(torch.stack([o["psnr"] for o in out_set.values()])) + self.log("test/psnr", psnr, prog_bar=True, rank_zero_only=True) + + self.save_img_sequence( + f"it{self.global_step}-test", + f"it{self.global_step}-test", + "(\d+)\.png", + save_format="mp4", + fps=30, + ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + ortho_scale=self.config.export.ortho_scale, + **mesh, + ) diff --git a/mesh_recon/systems/neus_videonvs.py b/mesh_recon/systems/neus_videonvs.py new file mode 100644 index 0000000000000000000000000000000000000000..6cc7b9a1574d6b93da8e4c0718f848427cf63415 --- /dev/null +++ b/mesh_recon/systems/neus_videonvs.py @@ -0,0 +1,503 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_efficient_distloss import flatten_eff_distloss + +import pytorch_lightning as pl +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug + +import models +from models.utils import cleanup +from models.ray_utils import get_rays +import systems +from systems.base import BaseSystem +from systems.criterions import PSNR, binary_cross_entropy + +import pdb + + +def ranking_loss(error, penalize_ratio=0.7, extra_weights=None, type="mean"): + # error, indices = torch.sort(error) + # # only sum relatively small errors + # s_error = torch.index_select( + # error, 0, index=indices[: int(penalize_ratio * indices.shape[0])] + # ) + # if extra_weights is not None: + # weights = torch.index_select( + # extra_weights, 0, index=indices[: int(penalize_ratio * indices.shape[0])] + # ) + # s_error = s_error * weights + + if type == "mean": + return torch.mean(error) + elif type == "sum": + return torch.sum(error) + + +@systems.register("videonvs-neus-system") +class PinholeNeuSSystem(BaseSystem): + """ + Two ways to print to console: + 1. self.print: correctly handle progress bar + 2. rank_zero_info: use the logging module + """ + + def prepare(self): + self.criterions = {"psnr": PSNR()} + self.train_num_samples = self.config.model.train_num_rays * ( + self.config.model.num_samples_per_ray + + self.config.model.get("num_samples_per_ray_bg", 0) + ) + self.train_num_rays = self.config.model.train_num_rays + self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) + + def forward(self, batch): + return self.model(batch["rays"]) + + def preprocess_data(self, batch, stage): + if "index" in batch: # validation / testing + index = batch["index"] + else: + if self.config.model.batch_image_sampling: + index = torch.randint( + 0, + len(self.dataset.all_images), + size=(self.train_num_rays,), + device=self.dataset.all_images.device, + ) + else: + index = torch.randint( + 0, + len(self.dataset.all_images), + size=(1,), + device=self.dataset.all_images.device, + ) + if stage in ["train"]: + c2w = self.dataset.all_c2w[index] + x = torch.randint( + 0, + self.dataset.w, + size=(self.train_num_rays,), + device=self.dataset.all_images.device, + ) + y = torch.randint( + 0, + self.dataset.h, + size=(self.train_num_rays,), + device=self.dataset.all_images.device, + ) + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions[y, x] + # origins = self.dataset.origins[y, x] + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index, y, x] + # origins = self.dataset.origins[index, y, x] + rays_o, rays_d = get_rays(directions, c2w) + rgb = ( + self.dataset.all_images[index, y, x] + .view(-1, self.dataset.all_images.shape[-1]) + .to(self.rank) + ) + normal = ( + self.dataset.all_normals_world[index, y, x] + .view(-1, self.dataset.all_normals_world.shape[-1]) + .to(self.rank) + ) + fg_mask = self.dataset.all_fg_masks[index, y, x].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index, y, x].view(-1).to(self.rank) + # view_weights = self.dataset.view_weights[index, y, x].view(-1).to(self.rank) + view_weights = None + else: + c2w = self.dataset.all_c2w[index][0] + if self.dataset.directions.ndim == 3: # (H, W, 3) + directions = self.dataset.directions + # origins = self.dataset.origins + elif self.dataset.directions.ndim == 4: # (N, H, W, 3) + directions = self.dataset.directions[index][0] + # origins = self.dataset.origins[index][0] + rays_o, rays_d = get_rays(directions, c2w) + rgb = ( + self.dataset.all_images[index] + .view(-1, self.dataset.all_images.shape[-1]) + .to(self.rank) + ) + normal = ( + self.dataset.all_normals_world[index] + .view(-1, self.dataset.all_images.shape[-1]) + .to(self.rank) + ) + fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank) + rgb_mask = self.dataset.all_rgb_masks[index].view(-1).to(self.rank) + view_weights = None + + cosines = self.cos(rays_d, normal) + rays = torch.cat([rays_o, F.normalize(rays_d, p=2, dim=-1)], dim=-1) + + if stage in ["train"]: + if self.config.model.background_color == "white": + self.model.background_color = torch.ones( + (3,), dtype=torch.float32, device=self.rank + ) + elif self.config.model.background_color == "black": + self.model.background_color = torch.zeros( + (3,), dtype=torch.float32, device=self.rank + ) + elif self.config.model.background_color == "random": + self.model.background_color = torch.rand( + (3,), dtype=torch.float32, device=self.rank + ) + else: + raise NotImplementedError + else: + self.model.background_color = torch.ones( + (3,), dtype=torch.float32, device=self.rank + ) + + if self.dataset.apply_mask: + rgb = rgb * fg_mask[..., None] + self.model.background_color * ( + 1 - fg_mask[..., None] + ) + + batch.update( + { + "rays": rays, + "rgb": rgb, + "normal": normal, + "fg_mask": fg_mask, + "rgb_mask": rgb_mask, + "cosines": cosines, + "view_weights": view_weights, + } + ) + + def training_step(self, batch, batch_idx): + out = self(batch) + + cosines = batch["cosines"] + fg_mask = batch["fg_mask"] + rgb_mask = batch["rgb_mask"] + view_weights = batch["view_weights"] + + cosines[cosines > -0.1] = 0 + mask = (fg_mask > 0) & (cosines < -0.1) + rgb_mask = out["rays_valid_full"][..., 0] & (rgb_mask > 0) + + grad_cosines = self.cos(batch["rays"][..., 3:], out["comp_normal"]).detach() + # grad_cosines = cosines + + loss = 0.0 + + # update train_num_rays + if self.config.model.dynamic_ray_sampling: + train_num_rays = int( + self.train_num_rays + * (self.train_num_samples / out["num_samples_full"].sum().item()) + ) + self.train_num_rays = min( + int(self.train_num_rays * 0.9 + train_num_rays * 0.1), + self.config.model.max_train_num_rays, + ) + + erros_rgb_mse = F.mse_loss( + out["comp_rgb_full"][rgb_mask], batch["rgb"][rgb_mask], reduction="none" + ) + # erros_rgb_mse = erros_rgb_mse * torch.exp(grad_cosines.abs())[:, None][rgb_mask] / torch.exp(grad_cosines.abs()[rgb_mask]).sum() + # loss_rgb_mse = ranking_loss(erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type='sum') + # loss_rgb_mse = ranking_loss( + # erros_rgb_mse.sum(dim=1), penalize_ratio=0.7, type="mean" + # ) + loss_rgb_mse = ranking_loss( + erros_rgb_mse.sum(dim=1), penalize_ratio=1.0, type="mean" + ) + self.log("train/loss_rgb_mse", loss_rgb_mse, prog_bar=True, rank_zero_only=True) + loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse) + + loss_rgb_l1 = F.l1_loss( + out["comp_rgb_full"][rgb_mask], batch["rgb"][rgb_mask], reduction="none" + ) + loss_rgb_l1 = ranking_loss( + loss_rgb_l1.sum(dim=1), + extra_weights=1.0, + penalize_ratio=1.0, + ) + self.log("train/loss_rgb", loss_rgb_l1) + loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1) + + normal_errors = 1 - F.cosine_similarity( + out["comp_normal"], batch["normal"], dim=1 + ) + # normal_errors = normal_errors * cosines.abs() / cosines.abs().sum() + normal_errors = ( + normal_errors * torch.exp(cosines.abs()) / torch.exp(cosines.abs()).sum() + ) + loss_normal = ranking_loss( + normal_errors[mask], + penalize_ratio=0.7, + # extra_weights=view_weights[mask], + type="sum", + ) + self.log("train/loss_normal", loss_normal, prog_bar=True, rank_zero_only=True) + loss += loss_normal * self.C(self.config.system.loss.lambda_normal) + + loss_eikonal = ( + (torch.linalg.norm(out["sdf_grad_samples"], ord=2, dim=-1) - 1.0) ** 2 + ).mean() + self.log("train/loss_eikonal", loss_eikonal, prog_bar=True, rank_zero_only=True) + loss += loss_eikonal * self.C(self.config.system.loss.lambda_eikonal) + + opacity = torch.clamp(out["opacity"].squeeze(-1), 1.0e-3, 1.0 - 1.0e-3) + loss_mask = binary_cross_entropy( + opacity, batch["fg_mask"].float(), reduction="none" + ) + loss_mask = ranking_loss(loss_mask, penalize_ratio=1.0, extra_weights=1.0) + self.log("train/loss_mask", loss_mask, prog_bar=True, rank_zero_only=True) + loss += loss_mask * ( + self.C(self.config.system.loss.lambda_mask) + if self.dataset.has_mask + else 0.0 + ) + + loss_opaque = binary_cross_entropy(opacity, opacity) + self.log("train/loss_opaque", loss_opaque) + loss += loss_opaque * self.C(self.config.system.loss.lambda_opaque) + + loss_sparsity = torch.exp( + -self.config.system.loss.sparsity_scale * out["random_sdf"].abs() + ).mean() + self.log( + "train/loss_sparsity", loss_sparsity, prog_bar=True, rank_zero_only=True + ) + loss += loss_sparsity * self.C(self.config.system.loss.lambda_sparsity) + + if self.C(self.config.system.loss.lambda_curvature) > 0: + assert ( + "sdf_laplace_samples" in out + ), "Need geometry.grad_type='finite_difference' to get SDF Laplace samples" + loss_curvature = out["sdf_laplace_samples"].abs().mean() + self.log("train/loss_curvature", loss_curvature) + loss += loss_curvature * self.C(self.config.system.loss.lambda_curvature) + + # distortion loss proposed in MipNeRF360 + # an efficient implementation from https://github.com/sunset1995/torch_efficient_distloss + if self.C(self.config.system.loss.lambda_distortion) > 0: + loss_distortion = flatten_eff_distloss( + out["weights"], out["points"], out["intervals"], out["ray_indices"] + ) + self.log("train/loss_distortion", loss_distortion) + loss += loss_distortion * self.C(self.config.system.loss.lambda_distortion) + + if ( + self.config.model.learned_background + and self.C(self.config.system.loss.lambda_distortion_bg) > 0 + ): + loss_distortion_bg = flatten_eff_distloss( + out["weights_bg"], + out["points_bg"], + out["intervals_bg"], + out["ray_indices_bg"], + ) + self.log("train/loss_distortion_bg", loss_distortion_bg) + loss += loss_distortion_bg * self.C( + self.config.system.loss.lambda_distortion_bg + ) + + if self.C(self.config.system.loss.lambda_3d_normal_smooth) > 0: + if "random_sdf_grad" not in out: + raise ValueError( + "random_sdf_grad is required for normal smooth loss, no normal is found in the output." + ) + if "normal_perturb" not in out: + raise ValueError( + "normal_perturb is required for normal smooth loss, no normal_perturb is found in the output." + ) + normals_3d = out["random_sdf_grad"] + normals_perturb_3d = out["normal_perturb"] + loss_3d_normal_smooth = (normals_3d - normals_perturb_3d).abs().mean() + self.log( + "train/loss_3d_normal_smooth", loss_3d_normal_smooth, prog_bar=True + ) + + loss += loss_3d_normal_smooth * self.C( + self.config.system.loss.lambda_3d_normal_smooth + ) + + losses_model_reg = self.model.regularizations(out) + for name, value in losses_model_reg.items(): + self.log(f"train/loss_{name}", value) + loss_ = value * self.C(self.config.system.loss[f"lambda_{name}"]) + loss += loss_ + + self.log("train/inv_s", out["inv_s"], prog_bar=True) + + for name, value in self.config.system.loss.items(): + if name.startswith("lambda"): + self.log(f"train_params/{name}", self.C(value)) + + self.log("train/num_rays", float(self.train_num_rays), prog_bar=True) + + return {"loss": loss} + + """ + # aggregate outputs from different devices (DP) + def training_step_end(self, out): + pass + """ + + """ + # aggregate outputs from different iterations + def training_epoch_end(self, out): + pass + """ + + def validation_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions["psnr"]( + out["comp_rgb_full"].to(batch["rgb"]), batch["rgb"] + ) + W, H = self.dataset.img_wh + self.save_image_grid( + f"it{self.global_step}-{batch['index'][0].item()}.png", + [ + { + "type": "rgb", + "img": batch["rgb"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": out["comp_rgb_full"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + ] + + ( + [ + { + "type": "rgb", + "img": out["comp_rgb_bg"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": out["comp_rgb"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + ] + if self.config.model.learned_background + else [] + ) + + [ + {"type": "grayscale", "img": out["depth"].view(H, W), "kwargs": {}}, + { + "type": "rgb", + "img": out["comp_normal"].view(H, W, 3), + "kwargs": {"data_format": "HWC", "data_range": (-1, 1)}, + }, + ], + ) + return {"psnr": psnr, "index": batch["index"]} + + """ + # aggregate outputs from different devices when using DP + def validation_step_end(self, out): + pass + """ + + def validation_epoch_end(self, out): + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out["index"].ndim == 1: + out_set[step_out["index"].item()] = {"psnr": step_out["psnr"]} + # DDP + else: + for oi, index in enumerate(step_out["index"]): + out_set[index[0].item()] = {"psnr": step_out["psnr"][oi]} + psnr = torch.mean(torch.stack([o["psnr"] for o in out_set.values()])) + self.log("val/psnr", psnr, prog_bar=True, rank_zero_only=True) + self.export() + + def test_step(self, batch, batch_idx): + out = self(batch) + psnr = self.criterions["psnr"]( + out["comp_rgb_full"].to(batch["rgb"]), batch["rgb"] + ) + W, H = self.dataset.img_wh + self.save_image_grid( + f"it{self.global_step}-test/{batch['index'][0].item()}.png", + [ + { + "type": "rgb", + "img": batch["rgb"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": out["comp_rgb_full"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + ] + + ( + [ + { + "type": "rgb", + "img": out["comp_rgb_bg"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + { + "type": "rgb", + "img": out["comp_rgb"].view(H, W, 3), + "kwargs": {"data_format": "HWC"}, + }, + ] + if self.config.model.learned_background + else [] + ) + + [ + {"type": "grayscale", "img": out["depth"].view(H, W), "kwargs": {}}, + { + "type": "rgb", + "img": out["comp_normal"].view(H, W, 3), + "kwargs": {"data_format": "HWC", "data_range": (-1, 1)}, + }, + ], + ) + return {"psnr": psnr, "index": batch["index"]} + + def test_epoch_end(self, out): + """ + Synchronize devices. + Generate image sequence using test outputs. + """ + out = self.all_gather(out) + if self.trainer.is_global_zero: + out_set = {} + for step_out in out: + # DP + if step_out["index"].ndim == 1: + out_set[step_out["index"].item()] = {"psnr": step_out["psnr"]} + # DDP + else: + for oi, index in enumerate(step_out["index"]): + out_set[index[0].item()] = {"psnr": step_out["psnr"][oi]} + psnr = torch.mean(torch.stack([o["psnr"] for o in out_set.values()])) + self.log("test/psnr", psnr, prog_bar=True, rank_zero_only=True) + + self.save_img_sequence( + f"it{self.global_step}-test", + f"it{self.global_step}-test", + "(\d+)\.png", + save_format="mp4", + fps=30, + ) + + self.export() + + def export(self): + mesh = self.model.export(self.config.export) + self.save_mesh( + f"it{self.global_step}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj", + ortho_scale=self.config.export.ortho_scale, + **mesh, + ) diff --git a/mesh_recon/systems/utils.py b/mesh_recon/systems/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dafae78295305113fd1854e9104bf44be24f4727 --- /dev/null +++ b/mesh_recon/systems/utils.py @@ -0,0 +1,351 @@ +import sys +import warnings +from bisect import bisect_right + +import torch +import torch.nn as nn +from torch.optim import lr_scheduler + +from pytorch_lightning.utilities.rank_zero import rank_zero_debug + + +class ChainedScheduler(lr_scheduler._LRScheduler): + """Chains list of learning rate schedulers. It takes a list of chainable learning + rate schedulers and performs consecutive step() functions belong to them by just + one call. + + Args: + schedulers (list): List of chained schedulers. + + Example: + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.09 if epoch == 0 + >>> # lr = 0.081 if epoch == 1 + >>> # lr = 0.729 if epoch == 2 + >>> # lr = 0.6561 if epoch == 3 + >>> # lr = 0.59049 if epoch >= 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, schedulers): + for scheduler_idx in range(1, len(schedulers)): + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "ChainedScheduler expects all schedulers to belong to the same optimizer, but " + "got schedulers at index {} and {} to be different".format(0, scheduler_idx) + ) + self._schedulers = list(schedulers) + self.optimizer = optimizer + + def step(self): + for scheduler in self._schedulers: + scheduler.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} + state_dict['_schedulers'] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict['_schedulers'][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop('_schedulers') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['_schedulers'] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class SequentialLR(lr_scheduler._LRScheduler): + """Receives the list of schedulers that is expected to be called sequentially during + optimization process and milestone points that provides exact intervals to reflect + which scheduler is supposed to be called at a given epoch. + + Args: + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + + Example: + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.1 if epoch == 0 + >>> # lr = 0.1 if epoch == 1 + >>> # lr = 0.9 if epoch == 2 + >>> # lr = 0.81 if epoch == 3 + >>> # lr = 0.729 if epoch == 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): + for scheduler_idx in range(1, len(schedulers)): + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " + "got schedulers at index {} and {} to be different".format(0, scheduler_idx) + ) + if (len(milestones) != len(schedulers) - 1): + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + "than the number of milestone points, but got number of schedulers {} and the " + "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) + ) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + self.optimizer = optimizer + + def step(self): + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + self._schedulers[idx].step(0) + else: + self._schedulers[idx].step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} + state_dict['_schedulers'] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict['_schedulers'][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop('_schedulers') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['_schedulers'] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class ConstantLR(lr_scheduler._LRScheduler): + """Decays the learning rate of each parameter group by a small constant factor until the + number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler decays the learning rate. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): + if factor > 1.0 or factor < 0: + raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') + + self.factor = factor + self.total_iters = total_iters + super(ConstantLR, self).__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.factor for group in self.optimizer.param_groups] + + if (self.last_epoch > self.total_iters or + (self.last_epoch != self.total_iters)): + return [group['lr'] for group in self.optimizer.param_groups] + + if (self.last_epoch == self.total_iters): + return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs] + + +class LinearLR(lr_scheduler._LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small + multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, + verbose=False): + if start_factor > 1.0 or start_factor < 0: + raise ValueError('Starting multiplicative factor expected to be between 0 and 1.') + + if end_factor > 1.0 or end_factor < 0: + raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super(LinearLR, self).__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] + + if (self.last_epoch > self.total_iters): + return [group['lr'] for group in self.optimizer.param_groups] + + return [group['lr'] * (1. + (self.end_factor - self.start_factor) / + (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.start_factor + + (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) + for base_lr in self.base_lrs] + + +custom_schedulers = ['ConstantLR', 'LinearLR'] +def get_scheduler(name): + if hasattr(lr_scheduler, name): + return getattr(lr_scheduler, name) + elif name in custom_schedulers: + return getattr(sys.modules[__name__], name) + else: + raise NotImplementedError + + +def getattr_recursive(m, attr): + for name in attr.split('.'): + m = getattr(m, name) + return m + + +def get_parameters(model, name): + module = getattr_recursive(model, name) + if isinstance(module, nn.Module): + return module.parameters() + elif isinstance(module, nn.Parameter): + return module + return [] + + +def parse_optimizer(config, model): + if hasattr(config, 'params'): + params = [{'params': get_parameters(model, name), 'name': name, **args} for name, args in config.params.items()] + rank_zero_debug('Specify optimizer params:', config.params) + else: + params = model.parameters() + if config.name in ['FusedAdam']: + import apex + optim = getattr(apex.optimizers, config.name)(params, **config.args) + else: + optim = getattr(torch.optim, config.name)(params, **config.args) + return optim + + +def parse_scheduler(config, optimizer): + interval = config.get('interval', 'epoch') + assert interval in ['epoch', 'step'] + if config.name == 'SequentialLR': + scheduler = { + 'scheduler': SequentialLR(optimizer, [parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers], milestones=config.milestones), + 'interval': interval + } + elif config.name == 'Chained': + scheduler = { + 'scheduler': ChainedScheduler([parse_scheduler(conf, optimizer)['scheduler'] for conf in config.schedulers]), + 'interval': interval + } + else: + scheduler = { + 'scheduler': get_scheduler(config.name)(optimizer, **config.args), + 'interval': interval + } + return scheduler + + +def update_module_step(m, epoch, global_step): + if hasattr(m, 'update_step'): + m.update_step(epoch, global_step) diff --git a/mesh_recon/utils/__init__.py b/mesh_recon/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/mesh_recon/utils/callbacks.py b/mesh_recon/utils/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..22f39efdb2f381ff677f5311c0586fbad88ae34f --- /dev/null +++ b/mesh_recon/utils/callbacks.py @@ -0,0 +1,99 @@ +import os +import subprocess +import shutil +from utils.misc import dump_config, parse_version + + +import pytorch_lightning +if parse_version(pytorch_lightning.__version__) > parse_version('1.8'): + from pytorch_lightning.callbacks import Callback +else: + from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn +from pytorch_lightning.callbacks.progress import TQDMProgressBar + + +class VersionedCallback(Callback): + def __init__(self, save_root, version=None, use_version=True): + self.save_root = save_root + self._version = version + self.use_version = use_version + + @property + def version(self) -> int: + """Get the experiment version. + + Returns: + The experiment version if specified else the next version. + """ + if self._version is None: + self._version = self._get_next_version() + return self._version + + def _get_next_version(self): + existing_versions = [] + if os.path.isdir(self.save_root): + for f in os.listdir(self.save_root): + bn = os.path.basename(f) + if bn.startswith("version_"): + dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") + existing_versions.append(int(dir_ver)) + if len(existing_versions) == 0: + return 0 + return max(existing_versions) + 1 + + @property + def savedir(self): + if not self.use_version: + return self.save_root + return os.path.join(self.save_root, self.version if isinstance(self.version, str) else f"version_{self.version}") + + +class CodeSnapshotCallback(VersionedCallback): + def __init__(self, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + + def get_file_list(self): + return [ + b.decode() for b in + set(subprocess.check_output('git ls-files', shell=True).splitlines()) | + set(subprocess.check_output('git ls-files --others --exclude-standard', shell=True).splitlines()) + ] + + @rank_zero_only + def save_code_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + for f in self.get_file_list(): + if not os.path.exists(f) or os.path.isdir(f): + continue + os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) + shutil.copyfile(f, os.path.join(self.savedir, f)) + + def on_fit_start(self, trainer, pl_module): + try: + self.save_code_snapshot() + except: + rank_zero_warn("Code snapshot is not saved. Please make sure you have git installed and are in a git repository.") + + +class ConfigSnapshotCallback(VersionedCallback): + def __init__(self, config, save_root, version=None, use_version=True): + super().__init__(save_root, version, use_version) + self.config = config + + @rank_zero_only + def save_config_snapshot(self): + os.makedirs(self.savedir, exist_ok=True) + dump_config(os.path.join(self.savedir, 'parsed.yaml'), self.config) + shutil.copyfile(self.config.cmd_args['config'], os.path.join(self.savedir, 'raw.yaml')) + + def on_fit_start(self, trainer, pl_module): + self.save_config_snapshot() + + +class CustomProgressBar(TQDMProgressBar): + def get_metrics(self, *args, **kwargs): + # don't show the version number + items = super().get_metrics(*args, **kwargs) + items.pop("v_num", None) + return items diff --git a/mesh_recon/utils/dpt.py b/mesh_recon/utils/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..84cfd76f485cf0848f9e95ce81da386386ad3551 --- /dev/null +++ b/mesh_recon/utils/dpt.py @@ -0,0 +1,1071 @@ +""" DPT Model for monocular depth estimation, adopted from https://github1s.com/ashawkey/stable-dreamfusion/blob/HEAD/preprocess_image.py""" + +import math +import types +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +from pathlib import Path + +import timm + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device("cpu")) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) + + +def unflatten_with_named_tensor(input, dim, sizes): + """Workaround for unflattening with named tensor.""" + # tracer acts up with unflatten. See https://github.com/pytorch/pytorch/issues/49538 + new_shape = list(input.shape)[:dim] + list(sizes) + list(input.shape)[dim + 1 :] + return input.view(*new_shape) + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflattened_dim = 2 + unflattened_size = ( + int(torch.div(h, pretrained.model.patch_size[1], rounding_mode="floor")), + int(torch.div(w, pretrained.model.patch_size[0], rounding_mode="floor")), + ) + unflatten = nn.Sequential(nn.Unflatten(unflattened_dim, unflattened_size)) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten_with_named_tensor( + layer_3, unflattened_dim, unflattened_size + ) + if layer_4.ndim == 3: + layer_4 = unflatten_with_named_tensor( + layer_4, unflattened_dim, unflattened_size + ) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(posemb_grid.shape[0])) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, + torch.div(h, self.patch_size[1], rounding_mode="floor"), + torch.div(w, self.patch_size[0], rounding_mode="floor"), + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + + +def _make_encoder( + backbone, + features, + use_pretrained, + groups=1, + expand=False, + exportable=True, + hooks=None, + use_vit_only=False, + use_readout="ignore", +): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch( + [256, 512, 1024, 2048], features, groups=groups, expand=expand + ) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3( + use_pretrained, exportable=exportable + ) + scratch = _make_scratch( + [32, 48, 136, 384], features, groups=groups, expand=expand + ) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand == True: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], + out_shape1, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], + out_shape2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], + out_shape3, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], + out_shape4, + kernel_size=3, + stride=1, + padding=1, + bias=False, + groups=groups, + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable, + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + +class Interpolate(nn.Module): + """Interpolation module.""" + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, + scale_factor=self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module.""" + + def __init__(self, features, activation, bn): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups = 1 + + self.conv1 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + + self.conv2 = nn.Conv2d( + features, + features, + kernel_size=3, + stride=1, + padding=1, + bias=True, + groups=self.groups, + ) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block.""" + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + ): + """Init. + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups = 1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d( + features, + out_features, + kernel_size=1, + stride=1, + padding=0, + bias=True, + groups=1, + ) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT_(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + super(DPT_, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + True, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT_): + def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + + +def download_if_need(path, url): + if Path(path).exists(): + return + import wget + + path.parent.mkdir(parents=True, exist_ok=True) + wget.download(url, out=str(path)) + + +class DPT: + def __init__(self, device, mode="depth"): + self.mode = mode + self.device = device + + if self.mode == "depth": + path = ".cache/dpt/omnidata_dpt_depth_v2.ckpt" + self.model = DPTDepthModel(backbone="vitb_rn50_384") + self.aug = transforms.Compose( + [ + transforms.Resize((384, 384)), + transforms.Normalize(mean=0.5, std=0.5), + ] + ) + elif self.mode == "normal": + path = "../ckpts/omnidata_dpt_normal_v2.ckpt" + download_if_need( + path, + "https://huggingface.co/clay3d/omnidata/resolve/main/omnidata_dpt_normal_v2.ckpt", + ) + self.model = DPTDepthModel(backbone="vitb_rn50_384", num_channels=3) + self.aug = transforms.Compose( + [ + transforms.Resize((384, 384)), + ] + ) + else: + raise ValueError(f"Unknown mode {mode} for DPT") + + checkpoint = torch.load(path, map_location="cpu") + if "state_dict" in checkpoint: + state_dict = {} + for k, v in checkpoint["state_dict"].items(): + state_dict[k[6:]] = v + else: + state_dict = checkpoint + self.model.load_state_dict(state_dict) + self.model.eval().to(self.device) + + @torch.no_grad() + def __call__(self, x): + # x.shape: [B H W 3] + x = x.to(self.device) + H, W = x.shape[1], x.shape[2] + x = x.moveaxis(-1, 1) # [B 3 H W] + x = self.aug(x) + + if self.mode == "depth": + depth = self.model(x).clamp(0, 1) + depth = F.interpolate( + depth.unsqueeze(1), size=(H, W), mode="bicubic", align_corners=False + ) + # depth = depth.cpu().numpy() + return depth.moveaxis(1, -1) + elif self.mode == "normal": + normal = self.model(x).clamp(0, 1) + normal = F.interpolate( + normal, size=(H, W), mode="bicubic", align_corners=False + ) + # normal = normal.cpu().numpy() + return normal.moveaxis(1, -1) + else: + assert False diff --git a/mesh_recon/utils/loggers.py b/mesh_recon/utils/loggers.py new file mode 100644 index 0000000000000000000000000000000000000000..2d1a92302a431a75e0c920327208ab11e9559ec8 --- /dev/null +++ b/mesh_recon/utils/loggers.py @@ -0,0 +1,41 @@ +import re +import pprint +import logging + +from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment +from pytorch_lightning.utilities.rank_zero import rank_zero_only + + +class ConsoleLogger(LightningLoggerBase): + def __init__(self, log_keys=[]): + super().__init__() + self.log_keys = [re.compile(k) for k in log_keys] + self.dict_printer = pprint.PrettyPrinter(indent=2, compact=False).pformat + + def match_log_keys(self, s): + return True if not self.log_keys else any(r.search(s) for r in self.log_keys) + + @property + def name(self): + return 'console' + + @property + def version(self): + return '0' + + @property + @rank_zero_experiment + def experiment(self): + return logging.getLogger('pytorch_lightning') + + @rank_zero_only + def log_hyperparams(self, params): + pass + + @rank_zero_only + def log_metrics(self, metrics, step): + metrics_ = {k: v for k, v in metrics.items() if self.match_log_keys(k)} + if not metrics_: + return + self.experiment.info(f"\nEpoch{metrics['epoch']} Step{step}\n{self.dict_printer(metrics_)}") + diff --git a/mesh_recon/utils/misc.py b/mesh_recon/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..c16fafa2ab8e7b934be711c41aed6e12001444fd --- /dev/null +++ b/mesh_recon/utils/misc.py @@ -0,0 +1,54 @@ +import os +from omegaconf import OmegaConf +from packaging import version + + +# ============ Register OmegaConf Recolvers ============= # +OmegaConf.register_new_resolver('calc_exp_lr_decay_rate', lambda factor, n: factor**(1./n)) +OmegaConf.register_new_resolver('add', lambda a, b: a + b) +OmegaConf.register_new_resolver('sub', lambda a, b: a - b) +OmegaConf.register_new_resolver('mul', lambda a, b: a * b) +OmegaConf.register_new_resolver('div', lambda a, b: a / b) +OmegaConf.register_new_resolver('idiv', lambda a, b: a // b) +OmegaConf.register_new_resolver('basename', lambda p: os.path.basename(p)) +# ======================================================= # + + +def prompt(question): + inp = input(f"{question} (y/n)").lower().strip() + if inp and inp == 'y': + return True + if inp and inp == 'n': + return False + return prompt(question) + + +def load_config(*yaml_files, cli_args=[]): + yaml_confs = [OmegaConf.load(f) for f in yaml_files] + cli_conf = OmegaConf.from_cli(cli_args) + conf = OmegaConf.merge(*yaml_confs, cli_conf) + OmegaConf.resolve(conf) + return conf + + +def config_to_primitive(config, resolve=True): + return OmegaConf.to_container(config, resolve=resolve) + + +def dump_config(path, config): + with open(path, 'w') as fp: + OmegaConf.save(config=config, f=fp) + +def get_rank(): + # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, + # therefore LOCAL_RANK needs to be checked first + rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") + for key in rank_keys: + rank = os.environ.get(key) + if rank is not None: + return int(rank) + return 0 + + +def parse_version(ver): + return version.parse(ver) diff --git a/mesh_recon/utils/mixins.py b/mesh_recon/utils/mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb639ac25f0e495c89917f5ad77c69f787c0afc --- /dev/null +++ b/mesh_recon/utils/mixins.py @@ -0,0 +1,331 @@ +import os +import re +import shutil +import numpy as np +import cv2 +import imageio +from matplotlib import cm +from matplotlib.colors import LinearSegmentedColormap +import json + +import torch + +from utils.obj import write_obj + + +class SaverMixin: + @property + def save_dir(self): + return self.config.save_dir + + def convert_data(self, data): + if isinstance(data, np.ndarray): + return data + elif isinstance(data, torch.Tensor): + return data.cpu().numpy() + elif isinstance(data, list): + return [self.convert_data(d) for d in data] + elif isinstance(data, dict): + return {k: self.convert_data(v) for k, v in data.items()} + else: + raise TypeError( + "Data must be in type numpy.ndarray, torch.Tensor, list or dict, getting", + type(data), + ) + + def get_save_path(self, filename): + save_path = os.path.join(self.save_dir, filename) + os.makedirs(os.path.dirname(save_path), exist_ok=True) + return save_path + + DEFAULT_RGB_KWARGS = {"data_format": "CHW", "data_range": (0, 1)} + DEFAULT_UV_KWARGS = { + "data_format": "CHW", + "data_range": (0, 1), + "cmap": "checkerboard", + } + DEFAULT_GRAYSCALE_KWARGS = {"data_range": None, "cmap": "jet"} + + def get_rgb_image_(self, img, data_format, data_range): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = ((img - data_range[0]) / (data_range[1] - data_range[0]) * 255.0).astype( + np.uint8 + ) + imgs = [img[..., start : start + 3] for start in range(0, img.shape[-1], 3)] + imgs = [ + ( + img_ + if img_.shape[-1] == 3 + else np.concatenate( + [ + img_, + np.zeros( + (img_.shape[0], img_.shape[1], 3 - img_.shape[2]), + dtype=img_.dtype, + ), + ], + axis=-1, + ) + ) + for img_ in imgs + ] + img = np.concatenate(imgs, axis=1) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + return img + + def save_rgb_image( + self, + filename, + img, + data_format=DEFAULT_RGB_KWARGS["data_format"], + data_range=DEFAULT_RGB_KWARGS["data_range"], + ): + img = self.get_rgb_image_(img, data_format, data_range) + cv2.imwrite(self.get_save_path(filename), img) + + def get_uv_image_(self, img, data_format, data_range, cmap): + img = self.convert_data(img) + assert data_format in ["CHW", "HWC"] + if data_format == "CHW": + img = img.transpose(1, 2, 0) + img = img.clip(min=data_range[0], max=data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in ["checkerboard", "color"] + if cmap == "checkerboard": + n_grid = 64 + mask = (img * n_grid).astype(int) + mask = (mask[..., 0] + mask[..., 1]) % 2 == 0 + img = np.ones((img.shape[0], img.shape[1], 3), dtype=np.uint8) * 255 + img[mask] = np.array([255, 0, 255], dtype=np.uint8) + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif cmap == "color": + img_ = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + img_[..., 0] = (img[..., 0] * 255).astype(np.uint8) + img_[..., 1] = (img[..., 1] * 255).astype(np.uint8) + img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR) + img = img_ + return img + + def save_uv_image( + self, + filename, + img, + data_format=DEFAULT_UV_KWARGS["data_format"], + data_range=DEFAULT_UV_KWARGS["data_range"], + cmap=DEFAULT_UV_KWARGS["cmap"], + ): + img = self.get_uv_image_(img, data_format, data_range, cmap) + cv2.imwrite(self.get_save_path(filename), img) + + def get_grayscale_image_(self, img, data_range, cmap): + img = self.convert_data(img) + img = np.nan_to_num(img) + if data_range is None: + img = (img - img.min()) / (img.max() - img.min()) + else: + img = img.clip(data_range[0], data_range[1]) + img = (img - data_range[0]) / (data_range[1] - data_range[0]) + assert cmap in [None, "jet", "magma"] + if cmap == None: + img = (img * 255.0).astype(np.uint8) + img = np.repeat(img[..., None], 3, axis=2) + elif cmap == "jet": + img = (img * 255.0).astype(np.uint8) + img = cv2.applyColorMap(img, cv2.COLORMAP_JET) + elif cmap == "magma": + img = 1.0 - img + base = cm.get_cmap("magma") + num_bins = 256 + colormap = LinearSegmentedColormap.from_list( + f"{base.name}{num_bins}", base(np.linspace(0, 1, num_bins)), num_bins + )(np.linspace(0, 1, num_bins))[:, :3] + a = np.floor(img * 255.0) + b = (a + 1).clip(max=255.0) + f = img * 255.0 - a + a = a.astype(np.uint16).clip(0, 255) + b = b.astype(np.uint16).clip(0, 255) + img = colormap[a] + (colormap[b] - colormap[a]) * f[..., None] + img = (img * 255.0).astype(np.uint8) + return img + + def save_grayscale_image( + self, + filename, + img, + data_range=DEFAULT_GRAYSCALE_KWARGS["data_range"], + cmap=DEFAULT_GRAYSCALE_KWARGS["cmap"], + ): + img = self.get_grayscale_image_(img, data_range, cmap) + cv2.imwrite(self.get_save_path(filename), img) + + def get_image_grid_(self, imgs): + if isinstance(imgs[0], list): + return np.concatenate([self.get_image_grid_(row) for row in imgs], axis=0) + cols = [] + for col in imgs: + assert col["type"] in ["rgb", "uv", "grayscale"] + if col["type"] == "rgb": + rgb_kwargs = self.DEFAULT_RGB_KWARGS.copy() + rgb_kwargs.update(col["kwargs"]) + cols.append(self.get_rgb_image_(col["img"], **rgb_kwargs)) + elif col["type"] == "uv": + uv_kwargs = self.DEFAULT_UV_KWARGS.copy() + uv_kwargs.update(col["kwargs"]) + cols.append(self.get_uv_image_(col["img"], **uv_kwargs)) + elif col["type"] == "grayscale": + grayscale_kwargs = self.DEFAULT_GRAYSCALE_KWARGS.copy() + grayscale_kwargs.update(col["kwargs"]) + cols.append(self.get_grayscale_image_(col["img"], **grayscale_kwargs)) + return np.concatenate(cols, axis=1) + + def save_image_grid(self, filename, imgs): + img = self.get_image_grid_(imgs) + cv2.imwrite(self.get_save_path(filename), img) + + def save_image(self, filename, img): + img = self.convert_data(img) + assert img.dtype == np.uint8 + if img.shape[-1] == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + elif img.shape[-1] == 4: + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA) + cv2.imwrite(self.get_save_path(filename), img) + + def save_cubemap(self, filename, img, data_range=(0, 1)): + img = self.convert_data(img) + assert img.ndim == 4 and img.shape[0] == 6 and img.shape[1] == img.shape[2] + + imgs_full = [] + for start in range(0, img.shape[-1], 3): + img_ = img[..., start : start + 3] + img_ = np.stack( + [ + self.get_rgb_image_(img_[i], "HWC", data_range) + for i in range(img_.shape[0]) + ], + axis=0, + ) + size = img_.shape[1] + placeholder = np.zeros((size, size, 3), dtype=np.float32) + img_full = np.concatenate( + [ + np.concatenate( + [placeholder, img_[2], placeholder, placeholder], axis=1 + ), + np.concatenate([img_[1], img_[4], img_[0], img_[5]], axis=1), + np.concatenate( + [placeholder, img_[3], placeholder, placeholder], axis=1 + ), + ], + axis=0, + ) + img_full = cv2.cvtColor(img_full, cv2.COLOR_RGB2BGR) + imgs_full.append(img_full) + + imgs_full = np.concatenate(imgs_full, axis=1) + cv2.imwrite(self.get_save_path(filename), imgs_full) + + def save_data(self, filename, data): + data = self.convert_data(data) + if isinstance(data, dict): + if not filename.endswith(".npz"): + filename += ".npz" + np.savez(self.get_save_path(filename), **data) + else: + if not filename.endswith(".npy"): + filename += ".npy" + np.save(self.get_save_path(filename), data) + + def save_state_dict(self, filename, data): + torch.save(data, self.get_save_path(filename)) + + def save_img_sequence(self, filename, img_dir, matcher, save_format="gif", fps=30): + assert save_format in ["gif", "mp4"] + if not filename.endswith(save_format): + filename += f".{save_format}" + matcher = re.compile(matcher) + img_dir = os.path.join(self.save_dir, img_dir) + imgs = [] + for f in os.listdir(img_dir): + if matcher.search(f): + imgs.append(f) + imgs = sorted(imgs, key=lambda f: int(matcher.search(f).groups()[0])) + imgs = [cv2.imread(os.path.join(img_dir, f)) for f in imgs] + + if save_format == "gif": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave( + self.get_save_path(filename), imgs, fps=fps, palettesize=256 + ) + elif save_format == "mp4": + imgs = [cv2.cvtColor(i, cv2.COLOR_BGR2RGB) for i in imgs] + imageio.mimsave(self.get_save_path(filename), imgs, fps=fps) + + def save_mesh( + self, + filename, + v_pos, + t_pos_idx, + v_tex=None, + t_tex_idx=None, + v_rgb=None, + ortho_scale=1, + ): + v_pos, t_pos_idx = self.convert_data(v_pos), self.convert_data(t_pos_idx) + if v_rgb is not None: + v_rgb = self.convert_data(v_rgb) + + if ortho_scale is not None: + print("ortho scale is: ", ortho_scale) + v_pos = v_pos * ortho_scale * 0.5 + + # change to front-facing + v_pos_copy = np.zeros_like(v_pos) + # v_pos_copy[:, 0] = v_pos[:, 0] + # v_pos_copy[:, 1] = v_pos[:, 2] + # v_pos_copy[:, 2] = v_pos[:, 1] + v_pos_copy[:, 0] = v_pos[:, 0] + v_pos_copy[:, 1] = v_pos[:, 1] + v_pos_copy[:, 2] = v_pos[:, 2] + + import trimesh + + mesh = trimesh.Trimesh( + vertices=v_pos_copy, faces=t_pos_idx, vertex_colors=v_rgb + ) + trimesh.repair.fix_inversion(mesh) + mesh.export(self.get_save_path(filename)) + # mesh.export(self.get_save_path(filename.replace(".obj", "-meshlab.obj"))) + + # v_pos_copy[:, 0] = v_pos[:, 1] * -1 + # v_pos_copy[:, 1] = v_pos[:, 0] + # v_pos_copy[:, 2] = v_pos[:, 2] + + # mesh = trimesh.Trimesh( + # vertices=v_pos_copy, + # faces=t_pos_idx, + # vertex_colors=v_rgb + # ) + # mesh.export(self.get_save_path(filename.replace(".obj", "-blender.obj"))) + + # v_pos_copy[:, 0] = v_pos[:, 0] + # v_pos_copy[:, 1] = v_pos[:, 1] * -1 + # v_pos_copy[:, 2] = v_pos[:, 2] * -1 + + # mesh = trimesh.Trimesh( + # vertices=v_pos_copy, + # faces=t_pos_idx, + # vertex_colors=v_rgb + # ) + # mesh.export(self.get_save_path(filename.replace(".obj", "-opengl.obj"))) + + def save_file(self, filename, src_path): + shutil.copyfile(src_path, self.get_save_path(filename)) + + def save_json(self, filename, payload): + with open(self.get_save_path(filename), "w") as f: + f.write(json.dumps(payload)) diff --git a/mesh_recon/utils/obj.py b/mesh_recon/utils/obj.py new file mode 100644 index 0000000000000000000000000000000000000000..da6d11938c244a6b982ec8c3f36a90a7a5fd2831 --- /dev/null +++ b/mesh_recon/utils/obj.py @@ -0,0 +1,74 @@ +import numpy as np + + +def load_obj(filename): + # Read entire file + with open(filename, 'r') as f: + lines = f.readlines() + + # load vertices + vertices, texcoords = [], [] + for line in lines: + if len(line.split()) == 0: + continue + + prefix = line.split()[0].lower() + if prefix == 'v': + vertices.append([float(v) for v in line.split()[1:]]) + elif prefix == 'vt': + val = [float(v) for v in line.split()[1:]] + texcoords.append([val[0], 1.0 - val[1]]) + + uv = len(texcoords) > 0 + faces, tfaces = [], [] + for line in lines: + if len(line.split()) == 0: + continue + prefix = line.split()[0].lower() + if prefix == 'usemtl': # Track used materials + pass + elif prefix == 'f': # Parse face + vs = line.split()[1:] + nv = len(vs) + vv = vs[0].split('/') + v0 = int(vv[0]) - 1 + if uv: + t0 = int(vv[1]) - 1 if vv[1] != "" else -1 + for i in range(nv - 2): # Triangulate polygons + vv1 = vs[i + 1].split('/') + v1 = int(vv1[0]) - 1 + vv2 = vs[i + 2].split('/') + v2 = int(vv2[0]) - 1 + faces.append([v0, v1, v2]) + if uv: + t1 = int(vv1[1]) - 1 if vv1[1] != "" else -1 + t2 = int(vv2[1]) - 1 if vv2[1] != "" else -1 + tfaces.append([t0, t1, t2]) + vertices = np.array(vertices, dtype=np.float32) + faces = np.array(faces, dtype=np.int64) + if uv: + assert len(tfaces) == len(faces) + texcoords = np.array(texcoords, dtype=np.float32) + tfaces = np.array(tfaces, dtype=np.int64) + else: + texcoords, tfaces = None, None + + return vertices, faces, texcoords, tfaces + + +def write_obj(filename, v_pos, t_pos_idx, v_tex, t_tex_idx): + with open(filename, "w") as f: + for v in v_pos: + f.write('v {} {} {} \n'.format(v[0], v[1], v[2])) + + if v_tex is not None: + assert(len(t_pos_idx) == len(t_tex_idx)) + for v in v_tex: + f.write('vt {} {} \n'.format(v[0], 1.0 - v[1])) + + # Write faces + for i in range(len(t_pos_idx)): + f.write("f ") + for j in range(3): + f.write(' %s/%s' % (str(t_pos_idx[i][j]+1), '' if v_tex is None else str(t_tex_idx[i][j]+1))) + f.write("\n") diff --git a/readme.md b/readme.md new file mode 100644 index 0000000000000000000000000000000000000000..ceb1be97c0df3de99d8596ac51916139355f715a --- /dev/null +++ b/readme.md @@ -0,0 +1,110 @@ +# V3D: Video Diffusion Models are Effective 3D Generators +Zilong Chen1,2, Yikai Wang1, Feng Wang1, Zhengyi Wang1,2, Huaping Liu1 + +1Tsinghua University, 2ShengShu + +This repository contains the official implementation of [V3D: Video Diffusion Models are Effective 3D Generators](https://arxiv.org/abs/2403.06738). + +### [Work in Progress] + +We are currently working on making this completely publicly available (including refactoring code, uploading weights, etc.), so please be patient. + +### [arXiv](https://arxiv.org/abs/2403.06738) | [Paper](assets/pdf/V3D.pdf) | [Project Page](https://heheyas.github.io/V3D) | [HF Demo](TBD) + +### Video results +Single Image to 3D + +Generated Multi-views + +https://github.com/heheyas/V3D/assets/44675551/bb724ed1-b9a6-4aa7-9a49-f1a8c8756c2f + + +https://github.com/heheyas/V3D/assets/44675551/4bfaea91-6c5b-40da-8682-30286a916979 + +Reconstructed 3D Gaussian Splats + + +https://github.com/heheyas/V3D/assets/44675551/894444eb-a454-4bc9-921b-cd0d5764a14d + + + +https://github.com/heheyas/V3D/assets/44675551/eda05891-e2c7-4f44-af12-9ccd0bce61d1 + + + +https://github.com/heheyas/V3D/assets/44675551/27d61245-b416-4289-ba98-97219ad199a3 + + + +https://github.com/heheyas/V3D/assets/44675551/e94d71ff-b8bc-410c-ad2c-3cfb1fbef7fa + + + +https://github.com/heheyas/V3D/assets/44675551/a0d1e971-0f8f-4f05-a73e-45271e37a31f + + + +https://github.com/heheyas/V3D/assets/44675551/0dac3189-fc59-4e9b-8151-10ebe2711d71 + + +Sparse view scene generation (On CO3D `hydrant` category) + + +https://github.com/heheyas/V3D/assets/44675551/33c87468-b6c0-4fa2-a9bf-6f396b3fa089 + + +https://github.com/heheyas/V3D/assets/44675551/3c03d015-2e56-44de-8210-e33e7ec810bb + + + +https://github.com/heheyas/V3D/assets/44675551/1e73958b-04b2-4faa-bbc3-675399f21956 + + + +https://github.com/heheyas/V3D/assets/44675551/f70cc259-7d50-4bf9-9c1b-0d4143ae8958 + + + +https://github.com/heheyas/V3D/assets/44675551/f6407b02-5ee7-4f8f-8559-4a893e6fd912 + + + + + +### Instructions: +1. Install the requirements: +``` +pip install -r requirements.txt +``` +2. Download our weights for V3D +``` +wget https://huggingface.co/heheyas/V3D/resolve/main/V3D.ckpt -O ckpts/V3D_512.ckpt +wget https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/resolve/main/svd_xt.safetensors -O ckpts/svd_xt.safetensors +``` +3. Run the V3D Video diffusion to generate dense multi-views +``` +PYTHONPATH="." python scripts/pub/V3D_512.py --input_path --save --border_ratio 0.3 --min_guidance_scale 4.5 --max_guidance_scale 4.5 --output-folder +``` +4. Reconstruct 3D assets from generated multi-views +Using 3D Gaussian Splatting +``` +PYTHONPATH="." python recon/train_from_vid.py -w --sh_degree 0 --iterations 4000 --lambda_dssim 1.0 --lambda_lpips 2.0 --save_iterations 4000 --num_pts 100_000 --video +``` +Or using (NeuS) instant-nsr-pl: +``` +cd mesh_recon +PYTHONPATH="." python launch.py --config configs/videonvs.yaml --gpu --train system.loss.lambda_normal=0.1 dataset.scene= dataset.root_dir= dataset.img_wh='[512, 512]' +``` +Refine texture +``` +python refine.py --mesh --scene --num-opt 16 --lpips 1.0 --iters 500 +``` + +## Acknowledgement +This code base is built upon the following awesome open-source projects: +- [Stable Video Diffusion](https://github.com/Stability-AI/generative-models) +- [3D Gaussian Splatting](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) +- [kiuikit](https://github.com/ashawkey/kiuikit) +- [Instant-nsr-pl](https://github.com/bennyguo/instant-nsr-pl) + +Thank the authors for their remarkable job ! diff --git a/recon/.gitignore b/recon/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..114376106621fc03ff8f923c03536e9dafd0d3f5 --- /dev/null +++ b/recon/.gitignore @@ -0,0 +1,8 @@ +*.pyc +.vscode +output +build +diff_rasterization/diff_rast.egg-info +diff_rasterization/dist +tensorboard_3d +screenshots \ No newline at end of file diff --git a/recon/.gitmodules b/recon/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..d20bef20b74670643e1f7849848528cc41000160 --- /dev/null +++ b/recon/.gitmodules @@ -0,0 +1,9 @@ +[submodule "submodules/simple-knn"] + path = submodules/simple-knn + url = https://gitlab.inria.fr/bkerbl/simple-knn.git +[submodule "submodules/diff-gaussian-rasterization"] + path = submodules/diff-gaussian-rasterization + url = https://github.com/graphdeco-inria/diff-gaussian-rasterization +[submodule "SIBR_viewers"] + path = SIBR_viewers + url = https://gitlab.inria.fr/sibr/sibr_core.git diff --git a/recon/LICENSE.md b/recon/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..c869e695fa63bfde6f887d63a24a2a71f03480ac --- /dev/null +++ b/recon/LICENSE.md @@ -0,0 +1,83 @@ +Gaussian-Splatting License +=========================== + +**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. +The *Software* is in the process of being registered with the Agence pour la Protection des +Programmes (APP). + +The *Software* is still being developed by the *Licensor*. + +*Licensor*'s goal is to allow the research community to use, test and evaluate +the *Software*. + +## 1. Definitions + +*Licensee* means any person or entity that uses the *Software* and distributes +its *Work*. + +*Licensor* means the owners of the *Software*, i.e Inria and MPII + +*Software* means the original work of authorship made available under this +License ie gaussian-splatting. + +*Work* means the *Software* and any additions to or derivative works of the +*Software* that are made available under this License. + + +## 2. Purpose +This license is intended to define the rights granted to the *Licensee* by +Licensors under the *Software*. + +## 3. Rights granted + +For the above reasons Licensors have decided to distribute the *Software*. +Licensors grant non-exclusive rights to use the *Software* for research purposes +to research users (both academic and industrial), free of charge, without right +to sublicense.. The *Software* may be used "non-commercially", i.e., for research +and/or evaluation purposes only. + +Subject to the terms and conditions of this License, you are granted a +non-exclusive, royalty-free, license to reproduce, prepare derivative works of, +publicly display, publicly perform and distribute its *Work* and any resulting +derivative works in any form. + +## 4. Limitations + +**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do +so under this License, (b) you include a complete copy of this License with +your distribution, and (c) you retain without modification any copyright, +patent, trademark, or attribution notices that are present in the *Work*. + +**4.2 Derivative Works.** You may specify that additional or different terms apply +to the use, reproduction, and distribution of your derivative works of the *Work* +("Your Terms") only if (a) Your Terms provide that the use limitation in +Section 2 applies to your derivative works, and (b) you identify the specific +derivative works that are subject to Your Terms. Notwithstanding Your Terms, +this License (including the redistribution requirements in Section 3.1) will +continue to apply to the *Work* itself. + +**4.3** Any other use without of prior consent of Licensors is prohibited. Research +users explicitly acknowledge having received from Licensors all information +allowing to appreciate the adequacy between of the *Software* and their needs and +to undertake all necessary precautions for its execution and use. + +**4.4** The *Software* is provided both as a compiled library file and as source +code. In case of using the *Software* for a publication or other results obtained +through the use of the *Software*, users are strongly encouraged to cite the +corresponding publications as explained in the documentation of the *Software*. + +## 5. Disclaimer + +THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES +WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY +UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL +CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES +OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL +USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR +ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE +AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR +IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. diff --git a/recon/README.md b/recon/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a0e39a33faa39779076d7e8fffeb55dccfae1423 --- /dev/null +++ b/recon/README.md @@ -0,0 +1,513 @@ +# 3D Gaussian Splatting for Real-Time Radiance Field Rendering +Bernhard Kerbl*, Georgios Kopanas*, Thomas Leimkühler, George Drettakis (* indicates equal contribution)
+| [Webpage](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/) | [Full Paper](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/3d_gaussian_splatting_high.pdf) | [Video](https://youtu.be/T_kXY43VZnk) | [Other GRAPHDECO Publications](http://www-sop.inria.fr/reves/publis/gdindex.php) | [FUNGRAPH project page](https://fungraph.inria.fr) |
+| [T&T+DB COLMAP (650MB)](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip) | [Pre-trained Models (14 GB)](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/pretrained/models.zip) | [Viewers for Windows (60MB)](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/binaries/viewers.zip) | [Evaluation Images (7 GB)](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/evaluation/images.zip) |
+![Teaser image](assets/teaser.png) + +This repository contains the official authors implementation associated with the paper "3D Gaussian Splatting for Real-Time Radiance Field Rendering", which can be found [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/). We further provide the reference images used to create the error metrics reported in the paper, as well as recently created, pre-trained models. + + + + + + +Abstract: *Radiance Field methods have recently revolutionized novel-view synthesis of scenes captured with multiple photos or videos. However, achieving high visual quality still requires neural networks that are costly to train and render, while recent faster methods inevitably trade off speed for quality. For unbounded and complete scenes (rather than isolated objects) and 1080p resolution rendering, no current method can achieve real-time display rates. We introduce three key elements that allow us to achieve state-of-the-art visual quality while maintaining competitive training times and importantly allow high-quality real-time (≥ 30 fps) novel-view synthesis at 1080p resolution. First, starting from sparse points produced during camera calibration, we represent the scene with 3D Gaussians that preserve desirable properties of continuous volumetric radiance fields for scene optimization while avoiding unnecessary computation in empty space; Second, we perform interleaved optimization/density control of the 3D Gaussians, notably optimizing anisotropic covariance to achieve an accurate representation of the scene; Third, we develop a fast visibility-aware rendering algorithm that supports anisotropic splatting and both accelerates training and allows realtime rendering. We demonstrate state-of-the-art visual quality and real-time rendering on several established datasets.* + +
+
+

BibTeX

+
@Article{kerbl3Dgaussians,
+      author       = {Kerbl, Bernhard and Kopanas, Georgios and Leimk{\"u}hler, Thomas and Drettakis, George},
+      title        = {3D Gaussian Splatting for Real-Time Radiance Field Rendering},
+      journal      = {ACM Transactions on Graphics},
+      number       = {4},
+      volume       = {42},
+      month        = {July},
+      year         = {2023},
+      url          = {https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/}
+}
+
+
+ + +## Funding and Acknowledgments + +This research was funded by the ERC Advanced grant FUNGRAPH No 788065. The authors are grateful to Adobe for generous donations, the OPAL infrastructure from Université Côte d’Azur and for the HPC resources from GENCI–IDRIS (Grant 2022-AD011013409). The authors thank the anonymous reviewers for their valuable feedback, P. Hedman and A. Tewari for proofreading earlier drafts also T. Müller, A. Yu and S. Fridovich-Keil for helping with the comparisons. + +## Step-by-step Tutorial + +Jonathan Stephens made a fantastic step-by-step tutorial for setting up Gaussian Splatting on your machine, along with instructions for creating usable datasets from videos. If the instructions below are too dry for you, go ahead and check it out [here](https://www.youtube.com/watch?v=UXtuigy_wYc). + +## Colab + +User [camenduru](https://github.com/camenduru) was kind enough to provide a Colab template that uses this repo's source (status: August 2023!) for quick and easy access to the method. Please check it out [here](https://github.com/camenduru/gaussian-splatting-colab). + +## Cloning the Repository + +The repository contains submodules, thus please check it out with +```shell +# SSH +git clone git@github.com:graphdeco-inria/gaussian-splatting.git --recursive +``` +or +```shell +# HTTPS +git clone https://github.com/graphdeco-inria/gaussian-splatting --recursive +``` + +## Overview + +The codebase has 4 main components: +- A PyTorch-based optimizer to produce a 3D Gaussian model from SfM inputs +- A network viewer that allows to connect to and visualize the optimization process +- An OpenGL-based real-time viewer to render trained models in real-time. +- A script to help you turn your own images into optimization-ready SfM data sets + +The components have different requirements w.r.t. both hardware and software. They have been tested on Windows 10 and Ubuntu Linux 22.04. Instructions for setting up and running each of them are found in the sections below. + +## Optimizer + +The optimizer uses PyTorch and CUDA extensions in a Python environment to produce trained models. + +### Hardware Requirements + +- CUDA-ready GPU with Compute Capability 7.0+ +- 24 GB VRAM (to train to paper evaluation quality) +- Please see FAQ for smaller VRAM configurations + +### Software Requirements +- Conda (recommended for easy setup) +- C++ Compiler for PyTorch extensions (we used Visual Studio 2019 for Windows) +- CUDA SDK 11 for PyTorch extensions, install *after* Visual Studio (we used 11.8, **known issues with 11.6**) +- C++ Compiler and CUDA SDK must be compatible + +### Setup + +#### Local Setup + +Our default, provided install method is based on Conda package and environment management: +```shell +SET DISTUTILS_USE_SDK=1 # Windows only +conda env create --file environment.yml +conda activate gaussian_splatting +``` +Please note that this process assumes that you have CUDA SDK **11** installed, not **12**. For modifications, see below. + +Tip: Downloading packages and creating a new environment with Conda can require a significant amount of disk space. By default, Conda will use the main system hard drive. You can avoid this by specifying a different package download location and an environment on a different drive: + +```shell +conda config --add pkgs_dirs / +conda env create --file environment.yml --prefix //gaussian_splatting +conda activate //gaussian_splatting +``` + +#### Modifications + +If you can afford the disk space, we recommend using our environment files for setting up a training environment identical to ours. If you want to make modifications, please note that major version changes might affect the results of our method. However, our (limited) experiments suggest that the codebase works just fine inside a more up-to-date environment (Python 3.8, PyTorch 2.0.0, CUDA 12). Make sure to create an environment where PyTorch and its CUDA runtime version match and the installed CUDA SDK has no major version difference with PyTorch's CUDA version. + +#### Known Issues + +Some users experience problems building the submodules on Windows (```cl.exe: File not found``` or similar). Please consider the workaround for this problem from the FAQ. + +### Running + +To run the optimizer, simply use + +```shell +python train.py -s +``` + +
+Command Line Arguments for train.py + + #### --source_path / -s + Path to the source directory containing a COLMAP or Synthetic NeRF data set. + #### --model_path / -m + Path where the trained model should be stored (```output/``` by default). + #### --images / -i + Alternative subdirectory for COLMAP images (```images``` by default). + #### --eval + Add this flag to use a MipNeRF360-style training/test split for evaluation. + #### --resolution / -r + Specifies resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. **If not set and input image width exceeds 1.6K pixels, inputs are automatically rescaled to this target.** + #### --data_device + Specifies where to put the source image data, ```cuda``` by default, recommended to use ```cpu``` if training on large/high-resolution dataset, will reduce VRAM consumption, but slightly slow down training. Thanks to [HrsPythonix](https://github.com/HrsPythonix). + #### --white_background / -w + Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. + #### --sh_degree + Order of spherical harmonics to be used (no larger than 3). ```3``` by default. + #### --convert_SHs_python + Flag to make pipeline compute forward and backward of SHs with PyTorch instead of ours. + #### --convert_cov3D_python + Flag to make pipeline compute forward and backward of the 3D covariance with PyTorch instead of ours. + #### --debug + Enables debug mode if you experience erros. If the rasterizer fails, a ```dump``` file is created that you may forward to us in an issue so we can take a look. + #### --debug_from + Debugging is **slow**. You may specify an iteration (starting from 0) after which the above debugging becomes active. + #### --iterations + Number of total iterations to train for, ```30_000``` by default. + #### --ip + IP to start GUI server on, ```127.0.0.1``` by default. + #### --port + Port to use for GUI server, ```6009``` by default. + #### --test_iterations + Space-separated iterations at which the training script computes L1 and PSNR over test set, ```7000 30000``` by default. + #### --save_iterations + Space-separated iterations at which the training script saves the Gaussian model, ```7000 30000 ``` by default. + #### --checkpoint_iterations + Space-separated iterations at which to store a checkpoint for continuing later, saved in the model directory. + #### --start_checkpoint + Path to a saved checkpoint to continue training from. + #### --quiet + Flag to omit any text written to standard out pipe. + #### --feature_lr + Spherical harmonics features learning rate, ```0.0025``` by default. + #### --opacity_lr + Opacity learning rate, ```0.05``` by default. + #### --scaling_lr + Scaling learning rate, ```0.005``` by default. + #### --rotation_lr + Rotation learning rate, ```0.001``` by default. + #### --position_lr_max_steps + Number of steps (from 0) where position learning rate goes from ```initial``` to ```final```. ```30_000``` by default. + #### --position_lr_init + Initial 3D position learning rate, ```0.00016``` by default. + #### --position_lr_final + Final 3D position learning rate, ```0.0000016``` by default. + #### --position_lr_delay_mult + Position learning rate multiplier (cf. Plenoxels), ```0.01``` by default. + #### --densify_from_iter + Iteration where densification starts, ```500``` by default. + #### --densify_until_iter + Iteration where densification stops, ```15_000``` by default. + #### --densify_grad_threshold + Limit that decides if points should be densified based on 2D position gradient, ```0.0002``` by default. + #### --densification_interval + How frequently to densify, ```100``` (every 100 iterations) by default. + #### --opacity_reset_interval + How frequently to reset opacity, ```3_000``` by default. + #### --lambda_dssim + Influence of SSIM on total loss from 0 to 1, ```0.2``` by default. + #### --percent_dense + Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. + +
+
+ +Note that similar to MipNeRF360, we target images at resolutions in the 1-1.6K pixel range. For convenience, arbitrary-size inputs can be passed and will be automatically resized if their width exceeds 1600 pixels. We recommend to keep this behavior, but you may force training to use your higher-resolution images by setting ```-r 1```. + +The MipNeRF360 scenes are hosted by the paper authors [here](https://jonbarron.info/mipnerf360/). You can find our SfM data sets for Tanks&Temples and Deep Blending [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/input/tandt_db.zip). If you do not provide an output model directory (```-m```), trained models are written to folders with randomized unique names inside the ```output``` directory. At this point, the trained models may be viewed with the real-time viewer (see further below). + +### Evaluation +By default, the trained models use all available images in the dataset. To train them while withholding a test set for evaluation, use the ```--eval``` flag. This way, you can render training/test sets and produce error metrics as follows: +```shell +python train.py -s --eval # Train with train/test split +python render.py -m # Generate renderings +python metrics.py -m # Compute error metrics on renderings +``` + +If you want to evaluate our [pre-trained models](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/datasets/pretrained/models.zip), you will have to download the corresponding source data sets and indicate their location to ```render.py``` with an additional ```--source_path/-s``` flag. Note: The pre-trained models were created with the release codebase. This code base has been cleaned up and includes bugfixes, hence the metrics you get from evaluating them will differ from those in the paper. +```shell +python render.py -m -s +python metrics.py -m +``` + +
+Command Line Arguments for render.py + + #### --model_path / -m + Path to the trained model directory you want to create renderings for. + #### --skip_train + Flag to skip rendering the training set. + #### --skip_test + Flag to skip rendering the test set. + #### --quiet + Flag to omit any text written to standard out pipe. + + **The below parameters will be read automatically from the model path, based on what was used for training. However, you may override them by providing them explicitly on the command line.** + + #### --source_path / -s + Path to the source directory containing a COLMAP or Synthetic NeRF data set. + #### --images / -i + Alternative subdirectory for COLMAP images (```images``` by default). + #### --eval + Add this flag to use a MipNeRF360-style training/test split for evaluation. + #### --resolution / -r + Changes the resolution of the loaded images before training. If provided ```1, 2, 4``` or ```8```, uses original, 1/2, 1/4 or 1/8 resolution, respectively. For all other values, rescales the width to the given number while maintaining image aspect. ```1``` by default. + #### --white_background / -w + Add this flag to use white background instead of black (default), e.g., for evaluation of NeRF Synthetic dataset. + #### --convert_SHs_python + Flag to make pipeline render with computed SHs from PyTorch instead of ours. + #### --convert_cov3D_python + Flag to make pipeline render with computed 3D covariance from PyTorch instead of ours. + +
+ +
+Command Line Arguments for metrics.py + + #### --model_paths / -m + Space-separated list of model paths for which metrics should be computed. +
+
+ +We further provide the ```full_eval.py``` script. This script specifies the routine used in our evaluation and demonstrates the use of some additional parameters, e.g., ```--images (-i)``` to define alternative image directories within COLMAP data sets. If you have downloaded and extracted all the training data, you can run it like this: +```shell +python full_eval.py -m360 -tat -db +``` +In the current version, this process takes about 7h on our reference machine containing an A6000. If you want to do the full evaluation on our pre-trained models, you can specify their download location and skip training. +```shell +python full_eval.py -o --skip_training -m360 -tat -db +``` + +If you want to compute the metrics on our paper's [evaluation images](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/evaluation/images.zip), you can also skip rendering. In this case it is not necessary to provide the source datasets. You can compute metrics for multiple image sets at a time. +```shell +python full_eval.py -m /garden ... --skip_training --skip_rendering +``` + +
+Command Line Arguments for full_eval.py + + #### --skip_training + Flag to skip training stage. + #### --skip_rendering + Flag to skip rendering stage. + #### --skip_metrics + Flag to skip metrics calculation stage. + #### --output_path + Directory to put renderings and results in, ```./eval``` by default, set to pre-trained model location if evaluating them. + #### --mipnerf360 / -m360 + Path to MipNeRF360 source datasets, required if training or rendering. + #### --tanksandtemples / -tat + Path to Tanks&Temples source datasets, required if training or rendering. + #### --deepblending / -db + Path to Deep Blending source datasets, required if training or rendering. +
+
+ +## Interactive Viewers +We provide two interactive viewers for our method: remote and real-time. Our viewing solutions are based on the [SIBR](https://sibr.gitlabpages.inria.fr/) framework, developed by the GRAPHDECO group for several novel-view synthesis projects. + +### Hardware Requirements +- OpenGL 4.5-ready GPU and drivers (or latest MESA software) +- 4 GB VRAM recommended +- CUDA-ready GPU with Compute Capability 7.0+ (only for Real-Time Viewer) + +### Software Requirements +- Visual Studio or g++, **not Clang** (we used Visual Studio 2019 for Windows) +- CUDA SDK 11, install *after* Visual Studio (we used 11.8) +- CMake (recent version, we used 3.24) +- 7zip (only on Windows) + +### Pre-built Windows Binaries +We provide pre-built binaries for Windows [here](https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/binaries/viewers.zip). We recommend using them on Windows for an efficient setup, since the building of SIBR involves several external dependencies that must be downloaded and compiled on-the-fly. + +### Installation from Source +If you cloned with submodules (e.g., using ```--recursive```), the source code for the viewers is found in ```SIBR_viewers```. The network viewer runs within the SIBR framework for Image-based Rendering applications. + +#### Windows +CMake should take care of your dependencies. +```shell +cd SIBR_viewers +cmake -Bbuild . +cmake --build build --target install --config RelWithDebInfo +``` +You may specify a different configuration, e.g. ```Debug``` if you need more control during development. + +#### Ubuntu 22.04 +You will need to install a few dependencies before running the project setup. +```shell +# Dependencies +sudo apt install -y libglew-dev libassimp-dev libboost-all-dev libgtk-3-dev libopencv-dev libglfw3-dev libavdevice-dev libavcodec-dev libeigen3-dev libxxf86vm-dev libembree-dev +# Project setup +cd SIBR_viewers +cmake -Bbuild . -DCMAKE_BUILD_TYPE=Release # add -G Ninja to build faster +cmake --build build -j24 --target install +``` + +#### Ubuntu 20.04 +Backwards compatibility with Focal Fossa is not fully tested, but building SIBR with CMake should still work after invoking +```shell +git checkout fossa_compatibility +``` + +### Navigation in SIBR Viewers +The SIBR interface provides several methods of navigating the scene. By default, you will be started with an FPS navigator, which you can control with ```W, A, S, D, Q, E``` for camera translation and ```I, K, J, L, U, O``` for rotation. Alternatively, you may want to use a Trackball-style navigator (select from the floating menu). You can also snap to a camera from the data set with the ```Snap to``` button or find the closest camera with ```Snap to closest```. The floating menues also allow you to change the navigation speed. You can use the ```Scaling Modifier``` to control the size of the displayed Gaussians, or show the initial point cloud. + +### Running the Network Viewer + + + +https://github.com/graphdeco-inria/gaussian-splatting/assets/40643808/90a2e4d3-cf2e-4633-b35f-bfe284e28ff7 + + + +After extracting or installing the viewers, you may run the compiled ```SIBR_remoteGaussian_app[_config]``` app in ```/bin```, e.g.: +```shell +.//bin/SIBR_remoteGaussian_app +``` +The network viewer allows you to connect to a running training process on the same or a different machine. If you are training on the same machine and OS, no command line parameters should be required: the optimizer communicates the location of the training data to the network viewer. By default, optimizer and network viewer will try to establish a connection on **localhost** on port **6009**. You can change this behavior by providing matching ```--ip``` and ```--port``` parameters to both the optimizer and the network viewer. If for some reason the path used by the optimizer to find the training data is not reachable by the network viewer (e.g., due to them running on different (virtual) machines), you may specify an override location to the viewer by using ```-s ```. + +
+Primary Command Line Arguments for Network Viewer + + #### --path / -s + Argument to override model's path to source dataset. + #### --ip + IP to use for connection to a running training script. + #### --port + Port to use for connection to a running training script. + #### --rendering-size + Takes two space separated numbers to define the resolution at which network rendering occurs, ```1200``` width by default. + Note that to enforce an aspect that differs from the input images, you need ```--force-aspect-ratio``` too. + #### --load_images + Flag to load source dataset images to be displayed in the top view for each camera. +
+
+ +### Running the Real-Time Viewer + + + + +https://github.com/graphdeco-inria/gaussian-splatting/assets/40643808/0940547f-1d82-4c2f-a616-44eabbf0f816 + + + + +After extracting or installing the viewers, you may run the compiled ```SIBR_gaussianViewer_app[_config]``` app in ```/bin```, e.g.: +```shell +.//bin/SIBR_gaussianViewer_app -m +``` + +It should suffice to provide the ```-m``` parameter pointing to a trained model directory. Alternatively, you can specify an override location for training input data using ```-s```. To use a specific resolution other than the auto-chosen one, specify ```--rendering-size ```. Combine it with ```--force-aspect-ratio``` if you want the exact resolution and don't mind image distortion. + +**To unlock the full frame rate, please disable V-Sync on your machine and also in the application (Menu → Display). In a multi-GPU system (e.g., laptop) your OpenGL/Display GPU should be the same as your CUDA GPU (e.g., by setting the application's GPU preference on Windows, see below) for maximum performance.** + +![Teaser image](assets/select.png) + +In addition to the initial point cloud and the splats, you also have the option to visualize the Gaussians by rendering them as ellipsoids from the floating menu. +SIBR has many other functionalities, please see the [documentation](https://sibr.gitlabpages.inria.fr/) for more details on the viewer, navigation options etc. There is also a Top View (available from the menu) that shows the placement of the input cameras and the original SfM point cloud; please note that Top View slows rendering when enabled. The real-time viewer also uses slightly more aggressive, fast culling, which can be toggled in the floating menu. If you ever encounter an issue that can be solved by turning fast culling off, please let us know. + +
+Primary Command Line Arguments for Real-Time Viewer + + #### --model-path / -m + Path to trained model. + #### --iteration + Specifies which of state to load if multiple are available. Defaults to latest available iteration. + #### --path / -s + Argument to override model's path to source dataset. + #### --rendering-size + Takes two space separated numbers to define the resolution at which real-time rendering occurs, ```1200``` width by default. Note that to enforce an aspect that differs from the input images, you need ```--force-aspect-ratio``` too. + #### --load_images + Flag to load source dataset images to be displayed in the top view for each camera. + #### --device + Index of CUDA device to use for rasterization if multiple are available, ```0``` by default. + #### --no_interop + Disables CUDA/GL interop forcibly. Use on systems that may not behave according to spec (e.g., WSL2 with MESA GL 4.5 software rendering). +
+
+ +## Processing your own Scenes + +Our COLMAP loaders expect the following dataset structure in the source path location: + +``` + +|---images +| |--- +| |--- +| |---... +|---sparse + |---0 + |---cameras.bin + |---images.bin + |---points3D.bin +``` + +For rasterization, the camera models must be either a SIMPLE_PINHOLE or PINHOLE camera. We provide a converter script ```convert.py```, to extract undistorted images and SfM information from input images. Optionally, you can use ImageMagick to resize the undistorted images. This rescaling is similar to MipNeRF360, i.e., it creates images with 1/2, 1/4 and 1/8 the original resolution in corresponding folders. To use them, please first install a recent version of COLMAP (ideally CUDA-powered) and ImageMagick. Put the images you want to use in a directory ```/input```. +``` + +|---input + |--- + |--- + |---... +``` + If you have COLMAP and ImageMagick on your system path, you can simply run +```shell +python convert.py -s [--resize] #If not resizing, ImageMagick is not needed +``` +Alternatively, you can use the optional parameters ```--colmap_executable``` and ```--magick_executable``` to point to the respective paths. Please note that on Windows, the executable should point to the COLMAP ```.bat``` file that takes care of setting the execution environment. Once done, `````` will contain the expected COLMAP data set structure with undistorted, resized input images, in addition to your original images and some temporary (distorted) data in the directory ```distorted```. + +If you have your own COLMAP dataset without undistortion (e.g., using ```OPENCV``` camera), you can try to just run the last part of the script: Put the images in ```input``` and the COLMAP info in a subdirectory ```distorted```: +``` + +|---input +| |--- +| |--- +| |---... +|---distorted + |---database.db + |---sparse + |---0 + |---... +``` +Then run +```shell +python convert.py -s --skip_matching [--resize] #If not resizing, ImageMagick is not needed +``` + +
+Command Line Arguments for convert.py + + #### --no_gpu + Flag to avoid using GPU in COLMAP. + #### --skip_matching + Flag to indicate that COLMAP info is available for images. + #### --source_path / -s + Location of the inputs. + #### --camera + Which camera model to use for the early matching steps, ```OPENCV``` by default. + #### --resize + Flag for creating resized versions of input images. + #### --colmap_executable + Path to the COLMAP executable (```.bat``` on Windows). + #### --magick_executable + Path to the ImageMagick executable. +
+
+ +## FAQ +- *Where do I get data sets, e.g., those referenced in ```full_eval.py```?* The MipNeRF360 data set is provided by the authors of the original paper on the project site. Note that two of the data sets cannot be openly shared and require you to consult the authors directly. For Tanks&Temples and Deep Blending, please use the download links provided at the top of the page. Alternatively, you may access the cloned data (status: August 2023!) from [HuggingFace](https://huggingface.co/camenduru/gaussian-splatting) + + +- *How can I use this for a much larger dataset, like a city district?* The current method was not designed for these, but given enough memory, it should work out. However, the approach can struggle in multi-scale detail scenes (extreme close-ups, mixed with far-away shots). This is usually the case in, e.g., driving data sets (cars close up, buildings far away). For such scenes, you can lower the ```--position_lr_init```, ```--position_lr_final``` and ```--scaling_lr``` (x0.3, x0.1, ...). The more extensive the scene, the lower these values should be. Below, we use default learning rates (left) and ```--position_lr_init 0.000016 --scaling_lr 0.001"``` (right). + +| ![Default learning rate result](assets/worse.png "title-1") | ![Reduced learning rate result](assets/better.png "title-2") | +| --- | --- | + +- *I'm on Windows and I can't manage to build the submodules, what do I do?* Consider following the steps in the excellent video tutorial [here](https://www.youtube.com/watch?v=UXtuigy_wYc), hopefully they should help. The order in which the steps are done is important! Alternatively, consider using the linked Colab template. + +- *It still doesn't work. It says something about ```cl.exe```. What do I do?* User Henry Pearce found a workaround. You can you try adding the visual studio path to your environment variables (your version number might differ); +```C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\VC\Tools\MSVC\14.29.30133\bin\Hostx64\x64``` +Then make sure you start a new conda prompt and cd to your repo location and try this; +``` +conda activate gaussian_splatting +cd /gaussian-splatting +pip install submodules\diff-gaussian-rasterization +pip install submodules\simple-knn +``` + +- *I'm on macOS/Puppy Linux/Greenhat and I can't manage to build, what do I do?* Sorry, we can't provide support for platforms outside of the ones we list in this README. Consider using the linked Colab template. + +- *I don't have 24 GB of VRAM for training, what do I do?* The VRAM consumption is determined by the number of points that are being optimized, which increases over time. If you only want to train to 7k iterations, you will need significantly less. To do the full training routine and avoid running out of memory, you can increase the ```--densify_grad_threshold```, ```--densification_interval``` or reduce the value of ```--densify_until_iter```. Note however that this will affect the quality of the result. Also try setting ```--test_iterations``` to ```-1``` to avoid memory spikes during testing. If ```--densify_grad_threshold``` is very high, no densification should occur and training should complete if the scene itself loads successfully. + +- *24 GB of VRAM for reference quality training is still a lot! Can't we do it with less?* Yes, most likely. By our calculations it should be possible with **way** less memory (~8GB). If we can find the time we will try to achieve this. If some PyTorch veteran out there wants to tackle this, we look forward to your pull request! + + +- *How can I use the differentiable Gaussian rasterizer for my own project?* Easy, it is included in this repo as a submodule ```diff-gaussian-rasterization```. Feel free to check out and install the package. It's not really documented, but using it from the Python side is very straightforward (cf. ```gaussian_renderer/__init__.py```). + +- *Wait, but `````` isn't optimized and could be much better?* There are several parts we didn't even have time to think about improving (yet). The performance you get with this prototype is probably a rather slow baseline for what is physically possible. + +- *Something is broken, how did this happen?* We tried hard to provide a solid and comprehensible basis to make use of the paper's method. We have refactored the code quite a bit, but we have limited capacity to test all possible usage scenarios. Thus, if part of the website, the code or the performance is lacking, please create an issue. If we find the time, we will do our best to address it. diff --git a/recon/arguments/__init__.py b/recon/arguments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36afc54420cdea24425ff3a2953b826f339160bf --- /dev/null +++ b/recon/arguments/__init__.py @@ -0,0 +1,132 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from argparse import ArgumentParser, Namespace +import sys +import os + + +class GroupParams: + pass + + +class ParamGroup: + def __init__(self, parser: ArgumentParser, name: str, fill_none=False): + group = parser.add_argument_group(name) + for key, value in vars(self).items(): + shorthand = False + if key.startswith("_"): + shorthand = True + key = key[1:] + t = type(value) + value = value if not fill_none else None + if shorthand: + if t == bool: + group.add_argument( + "--" + key, ("-" + key[0:1]), default=value, action="store_true" + ) + else: + group.add_argument( + "--" + key, ("-" + key[0:1]), default=value, type=t + ) + else: + if t == bool: + group.add_argument("--" + key, default=value, action="store_true") + else: + group.add_argument("--" + key, default=value, type=t) + + def extract(self, args): + group = GroupParams() + for arg in vars(args).items(): + if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): + setattr(group, arg[0], arg[1]) + return group + + +class ModelParams(ParamGroup): + def __init__(self, parser, sentinel=False): + self.sh_degree = 3 + self._source_path = "" + self._model_path = "" + # self._images = "images" + self._resolution = -1 + self._white_background = False + self.data_device = "cuda" + self.eval = False + self.num_frames = 18 + self.radius = 2.0 + self.elevation = 0.0 + self.fov = 60.0 + self.reso = 512 + self.images = [] + self.masks = [] + self.num_pts = 100_000 + self.train = True + super().__init__(parser, "Loading Parameters", sentinel) + + def extract(self, args): + g = super().extract(args) + g.source_path = os.path.abspath(g.source_path) + return g + + +class PipelineParams(ParamGroup): + def __init__(self, parser): + self.convert_SHs_python = False + self.compute_cov3D_python = False + self.debug = False + super().__init__(parser, "Pipeline Parameters") + + +class OptimizationParams(ParamGroup): + def __init__(self, parser): + self.iterations = 30_000 + self.position_lr_init = 0.00016 + self.position_lr_final = 0.0000016 + self.position_lr_delay_mult = 0.01 + self.position_lr_max_steps = 30_000 + self.feature_lr = 0.0025 + self.opacity_lr = 0.05 + self.scaling_lr = 0.005 + self.rotation_lr = 0.001 + self.percent_dense = 0.01 + self.lambda_dssim = 0.2 + self.lambda_lpips = 0.2 + self.densification_interval = 100 + self.opacity_reset_interval = 3000 + self.densify_from_iter = 500 + self.densify_until_iter = 15_000 + self.densify_grad_threshold = 0.0002 + self.random_background = False + super().__init__(parser, "Optimization Parameters") + + +def get_combined_args(parser: ArgumentParser): + cmdlne_string = sys.argv[1:] + cfgfile_string = "Namespace()" + args_cmdline = parser.parse_args(cmdlne_string) + + try: + cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") + print("Looking for config file in", cfgfilepath) + with open(cfgfilepath) as cfg_file: + print("Config file found: {}".format(cfgfilepath)) + cfgfile_string = cfg_file.read() + except TypeError: + print("Config file not found at") + pass + args_cfgfile = eval(cfgfile_string) + + merged_dict = vars(args_cfgfile).copy() + for k, v in vars(args_cmdline).items(): + if v != None: + merged_dict[k] = v + return Namespace(**merged_dict) diff --git a/recon/convert.py b/recon/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..78948848f4849a88d686542790cd04f34f34beb0 --- /dev/null +++ b/recon/convert.py @@ -0,0 +1,124 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import logging +from argparse import ArgumentParser +import shutil + +# This Python script is based on the shell converter script provided in the MipNerF 360 repository. +parser = ArgumentParser("Colmap converter") +parser.add_argument("--no_gpu", action='store_true') +parser.add_argument("--skip_matching", action='store_true') +parser.add_argument("--source_path", "-s", required=True, type=str) +parser.add_argument("--camera", default="OPENCV", type=str) +parser.add_argument("--colmap_executable", default="", type=str) +parser.add_argument("--resize", action="store_true") +parser.add_argument("--magick_executable", default="", type=str) +args = parser.parse_args() +colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap" +magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick" +use_gpu = 1 if not args.no_gpu else 0 + +if not args.skip_matching: + os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True) + + ## Feature extraction + feat_extracton_cmd = colmap_command + " feature_extractor "\ + "--database_path " + args.source_path + "/distorted/database.db \ + --image_path " + args.source_path + "/input \ + --ImageReader.single_camera 1 \ + --ImageReader.camera_model " + args.camera + " \ + --SiftExtraction.use_gpu " + str(use_gpu) + exit_code = os.system(feat_extracton_cmd) + if exit_code != 0: + logging.error(f"Feature extraction failed with code {exit_code}. Exiting.") + exit(exit_code) + + ## Feature matching + feat_matching_cmd = colmap_command + " exhaustive_matcher \ + --database_path " + args.source_path + "/distorted/database.db \ + --SiftMatching.use_gpu " + str(use_gpu) + exit_code = os.system(feat_matching_cmd) + if exit_code != 0: + logging.error(f"Feature matching failed with code {exit_code}. Exiting.") + exit(exit_code) + + ### Bundle adjustment + # The default Mapper tolerance is unnecessarily large, + # decreasing it speeds up bundle adjustment steps. + mapper_cmd = (colmap_command + " mapper \ + --database_path " + args.source_path + "/distorted/database.db \ + --image_path " + args.source_path + "/input \ + --output_path " + args.source_path + "/distorted/sparse \ + --Mapper.ba_global_function_tolerance=0.000001") + exit_code = os.system(mapper_cmd) + if exit_code != 0: + logging.error(f"Mapper failed with code {exit_code}. Exiting.") + exit(exit_code) + +### Image undistortion +## We need to undistort our images into ideal pinhole intrinsics. +img_undist_cmd = (colmap_command + " image_undistorter \ + --image_path " + args.source_path + "/input \ + --input_path " + args.source_path + "/distorted/sparse/0 \ + --output_path " + args.source_path + "\ + --output_type COLMAP") +exit_code = os.system(img_undist_cmd) +if exit_code != 0: + logging.error(f"Mapper failed with code {exit_code}. Exiting.") + exit(exit_code) + +files = os.listdir(args.source_path + "/sparse") +os.makedirs(args.source_path + "/sparse/0", exist_ok=True) +# Copy each file from the source directory to the destination directory +for file in files: + if file == '0': + continue + source_file = os.path.join(args.source_path, "sparse", file) + destination_file = os.path.join(args.source_path, "sparse", "0", file) + shutil.move(source_file, destination_file) + +if(args.resize): + print("Copying and resizing...") + + # Resize images. + os.makedirs(args.source_path + "/images_2", exist_ok=True) + os.makedirs(args.source_path + "/images_4", exist_ok=True) + os.makedirs(args.source_path + "/images_8", exist_ok=True) + # Get the list of files in the source directory + files = os.listdir(args.source_path + "/images") + # Copy each file from the source directory to the destination directory + for file in files: + source_file = os.path.join(args.source_path, "images", file) + + destination_file = os.path.join(args.source_path, "images_2", file) + shutil.copy2(source_file, destination_file) + exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file) + if exit_code != 0: + logging.error(f"50% resize failed with code {exit_code}. Exiting.") + exit(exit_code) + + destination_file = os.path.join(args.source_path, "images_4", file) + shutil.copy2(source_file, destination_file) + exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file) + if exit_code != 0: + logging.error(f"25% resize failed with code {exit_code}. Exiting.") + exit(exit_code) + + destination_file = os.path.join(args.source_path, "images_8", file) + shutil.copy2(source_file, destination_file) + exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file) + if exit_code != 0: + logging.error(f"12.5% resize failed with code {exit_code}. Exiting.") + exit(exit_code) + +print("Done.") diff --git a/recon/convert_mesh.py b/recon/convert_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/recon/convert_nerf_mesh.py b/recon/convert_nerf_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5cac9b7311e8832bedfc40a5a7ea8d63036df6 --- /dev/null +++ b/recon/convert_nerf_mesh.py @@ -0,0 +1,539 @@ +import os +import tyro +import tqdm +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lgm.options import AllConfigs, Options +from lgm.gs import GaussianRenderer + +import mcubes +import nerfacc +import nvdiffrast.torch as dr + +import kiui +from kiui.mesh import Mesh +from kiui.mesh_utils import clean_mesh, decimate_mesh +from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency +from kiui.op import uv_padding, safe_normalize, inverse_sigmoid +from kiui.cam import orbit_camera, get_perspective +from kiui.nn import MLP, trunc_exp +from kiui.gridencoder import GridEncoder + + +def get_rays(pose, h, w, fovy, opengl=True): + x, y = torch.meshgrid( + torch.arange(w, device=pose.device), + torch.arange(h, device=pose.device), + indexing="xy", + ) + x = x.flatten() + y = y.flatten() + + cx = w * 0.5 + cy = h * 0.5 + focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) + + camera_dirs = F.pad( + torch.stack( + [ + (x - cx + 0.5) / focal, + (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), + ], + dim=-1, + ), + (0, 1), + value=(-1.0 if opengl else 1.0), + ) # [hw, 3] + + rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) # [hw, 3] + rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3] + + rays_d = safe_normalize(rays_d) + + return rays_o, rays_d + + +# Triple renderer of gaussians, gaussian, and diso mesh. +# gaussian --> nerf --> mesh +class Converter(nn.Module): + def __init__(self, opt: Options): + super().__init__() + + self.opt = opt + self.device = torch.device("cuda") + + # gs renderer + self.tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=self.device) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[2, 3] = 1 + + self.gs_renderer = GaussianRenderer(opt) + + self.gaussians = self.gs_renderer.load_ply(opt.test_path).to(self.device) + + # nerf renderer + if not self.opt.force_cuda_rast: + self.glctx = dr.RasterizeGLContext() + else: + self.glctx = dr.RasterizeCudaContext() + + self.step = 0 + self.render_step_size = 5e-3 + self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device) + self.estimator = nerfacc.OccGridEstimator( + roi_aabb=self.aabb, resolution=64, levels=1 + ) + + self.encoder_density = GridEncoder( + num_levels=12 + ) # VMEncoder(output_dim=16, mode='sum') + self.encoder = GridEncoder(num_levels=12) + self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False) + self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False) + + # mesh renderer + self.proj = ( + torch.from_numpy(get_perspective(self.opt.fovy)).float().to(self.device) + ) + self.v = self.f = None + self.vt = self.ft = None + self.deform = None + self.albedo = None + + @torch.no_grad() + def render_gs(self, pose): + cam_poses = torch.from_numpy(pose).unsqueeze(0).to(self.device) + cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction + + # cameras needed by gaussian rasterizer + cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4] + cam_view_proj = cam_view @ self.proj_matrix # [V, 4, 4] + cam_pos = -cam_poses[:, :3, 3] # [V, 3] + + out = self.gs_renderer.render( + self.gaussians.unsqueeze(0), + cam_view.unsqueeze(0), + cam_view_proj.unsqueeze(0), + cam_pos.unsqueeze(0), + ) + image = out["image"].squeeze(1).squeeze(0) # [C, H, W] + alpha = out["alpha"].squeeze(2).squeeze(1).squeeze(0) # [H, W] + + return image, alpha + + def get_density(self, xs): + # xs: [..., 3] + prefix = xs.shape[:-1] + xs = xs.view(-1, 3) + feats = self.encoder_density(xs) + density = trunc_exp(self.mlp_density(feats)) + density = density.view(*prefix, 1) + return density + + def render_nerf(self, pose): + pose = torch.from_numpy(pose.astype(np.float32)).to(self.device) + + # get rays + resolution = self.opt.output_size + rays_o, rays_d = get_rays(pose, resolution, resolution, self.opt.fovy) + + # update occ grid + if self.training: + + def occ_eval_fn(xs): + sigmas = self.get_density(xs) + return self.render_step_size * sigmas + + self.estimator.update_every_n_steps( + self.step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8 + ) + self.step += 1 + + # render + def sigma_fn(t_starts, t_ends, ray_indices): + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 + sigmas = self.get_density(xs) + return sigmas.squeeze(-1) + + with torch.no_grad(): + ray_indices, t_starts, t_ends = self.estimator.sampling( + rays_o, + rays_d, + sigma_fn=sigma_fn, + near_plane=0.01, + far_plane=100, + render_step_size=self.render_step_size, + stratified=self.training, + cone_angle=0, + ) + + t_origins = rays_o[ray_indices] + t_dirs = rays_d[ray_indices] + xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 + sigmas = self.get_density(xs).squeeze(-1) + rgbs = torch.sigmoid(self.mlp(self.encoder(xs))) + + n_rays = rays_o.shape[0] + weights, trans, alphas = nerfacc.render_weight_from_density( + t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays + ) + color = nerfacc.accumulate_along_rays( + weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays + ) + alpha = nerfacc.accumulate_along_rays( + weights, values=None, ray_indices=ray_indices, n_rays=n_rays + ) + + color = color + 1 * (1.0 - alpha) + + color = ( + color.view(resolution, resolution, 3) + .clamp(0, 1) + .permute(2, 0, 1) + .contiguous() + ) + alpha = alpha.view(resolution, resolution).clamp(0, 1).contiguous() + + return color, alpha + + def fit_nerf(self, iters=512, resolution=128): + self.opt.output_size = resolution + + optimizer = torch.optim.Adam( + [ + {"params": self.encoder_density.parameters(), "lr": 1e-2}, + {"params": self.encoder.parameters(), "lr": 1e-2}, + {"params": self.mlp_density.parameters(), "lr": 1e-3}, + {"params": self.mlp.parameters(), "lr": 1e-3}, + ] + ) + + print(f"[INFO] fitting nerf...") + pbar = tqdm.trange(iters) + for i in pbar: + ver = np.random.randint(-45, 45) + hor = np.random.randint(-180, 180) + rad = np.random.uniform(1.5, 3.0) + + pose = orbit_camera(ver, hor, rad) + + image_gt, alpha_gt = self.render_gs(pose) + image_pred, alpha_pred = self.render_nerf(pose) + + # if i % 200 == 0: + # kiui.vis.plot_image(image_gt, alpha_gt, image_pred, alpha_pred) + + loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss( + alpha_pred, alpha_gt + ) + loss = loss_mse # + 0.1 * self.encoder_density.tv_loss() #+ 0.0001 * self.encoder_density.density_loss() + + loss.backward() + self.encoder_density.grad_total_variation(1e-8) + + optimizer.step() + optimizer.zero_grad() + + pbar.set_description(f"MSE = {loss_mse.item():.6f}") + + print(f"[INFO] finished fitting nerf!") + + def render_mesh(self, pose): + h = w = self.opt.output_size + + v = self.v + self.deform + f = self.f + + pose = torch.from_numpy(pose.astype(np.float32)).to(v.device) + + # get v_clip and render rgb + v_cam = ( + torch.matmul( + F.pad(v, pad=(0, 1), mode="constant", value=1.0), torch.inverse(pose).T + ) + .float() + .unsqueeze(0) + ) + v_clip = v_cam @ self.proj.T + + rast, rast_db = dr.rasterize(self.glctx, v_clip, f, (h, w)) + + alpha = torch.clamp(rast[..., -1:], 0, 1).contiguous() # [1, H, W, 1] + alpha = ( + dr.antialias(alpha, rast, v_clip, f).clamp(0, 1).squeeze(-1).squeeze(0) + ) # [H, W] important to enable gradients! + + if self.albedo is None: + xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f) # [1, H, W, 3] + xyzs = xyzs.view(-1, 3) + mask = (alpha > 0).view(-1) + image = torch.zeros_like(xyzs, dtype=torch.float32) + if mask.any(): + masked_albedo = torch.sigmoid( + self.mlp(self.encoder(xyzs[mask].detach(), bound=1)) + ) + image[mask] = masked_albedo.float() + else: + texc, texc_db = dr.interpolate( + self.vt.unsqueeze(0), rast, self.ft, rast_db=rast_db, diff_attrs="all" + ) + image = torch.sigmoid( + dr.texture(self.albedo.unsqueeze(0), texc, uv_da=texc_db) + ) # [1, H, W, 3] + + image = image.view(1, h, w, 3) + # image = dr.antialias(image, rast, v_clip, f).clamp(0, 1) + image = image.squeeze(0).permute(2, 0, 1).contiguous() # [3, H, W] + image = alpha * image + (1 - alpha) + + return image, alpha + + def fit_mesh(self, iters=2048, resolution=512, decimate_target=5e4): + self.opt.output_size = resolution + + # init mesh from nerf + grid_size = 256 + sigmas = np.zeros([grid_size, grid_size, grid_size], dtype=np.float32) + + S = 128 + density_thresh = 10 + + X = torch.linspace(-1, 1, grid_size).split(S) + Y = torch.linspace(-1, 1, grid_size).split(S) + Z = torch.linspace(-1, 1, grid_size).split(S) + + for xi, xs in enumerate(X): + for yi, ys in enumerate(Y): + for zi, zs in enumerate(Z): + xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing="ij") + pts = torch.cat( + [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], + dim=-1, + ) # [S, 3] + val = self.get_density(pts.to(self.device)) + sigmas[ + xi * S : xi * S + len(xs), + yi * S : yi * S + len(ys), + zi * S : zi * S + len(zs), + ] = ( + val.reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() + ) # [S, 1] --> [x, y, z] + + print( + f"[INFO] marching cubes thresh: {density_thresh} ({sigmas.min()} ~ {sigmas.max()})" + ) + + vertices, triangles = mcubes.marching_cubes(sigmas, density_thresh) + vertices = vertices / (grid_size - 1.0) * 2 - 1 + + # clean + vertices = vertices.astype(np.float32) + triangles = triangles.astype(np.int32) + vertices, triangles = clean_mesh( + vertices, triangles, remesh=True, remesh_size=0.01 + ) + if triangles.shape[0] > decimate_target: + vertices, triangles = decimate_mesh( + vertices, triangles, decimate_target, optimalplacement=False + ) + + self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) + self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) + self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device) + + # fit mesh from gs + lr_factor = 1 + optimizer = torch.optim.Adam( + [ + {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor}, + {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor}, + {"params": self.deform, "lr": 1e-4}, + ] + ) + + print(f"[INFO] fitting mesh...") + pbar = tqdm.trange(iters) + for i in pbar: + ver = np.random.randint(-10, 10) + hor = np.random.randint(-180, 180) + rad = self.opt.cam_radius # np.random.uniform(1, 2) + + pose = orbit_camera(ver, hor, rad) + + image_gt, alpha_gt = self.render_gs(pose) + image_pred, alpha_pred = self.render_mesh(pose) + + loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss( + alpha_pred, alpha_gt + ) + # loss_lap = laplacian_smooth_loss(self.v + self.deform, self.f) + loss_normal = normal_consistency(self.v + self.deform, self.f) + loss_offsets = (self.deform**2).sum(-1).mean() + loss = loss_mse + 0.001 * loss_normal + 0.1 * loss_offsets + + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + # remesh periodically + if i > 0 and i % 512 == 0: + vertices = (self.v + self.deform).detach().cpu().numpy() + triangles = self.f.detach().cpu().numpy() + vertices, triangles = clean_mesh( + vertices, triangles, remesh=True, remesh_size=0.01 + ) + if triangles.shape[0] > decimate_target: + vertices, triangles = decimate_mesh( + vertices, triangles, decimate_target, optimalplacement=False + ) + self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) + self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) + self.deform = nn.Parameter(torch.zeros_like(self.v)).to(self.device) + lr_factor *= 0.5 + optimizer = torch.optim.Adam( + [ + {"params": self.encoder.parameters(), "lr": 1e-3 * lr_factor}, + {"params": self.mlp.parameters(), "lr": 1e-3 * lr_factor}, + {"params": self.deform, "lr": 1e-4}, + ] + ) + + pbar.set_description(f"MSE = {loss_mse.item():.6f}") + + # last clean + vertices = (self.v + self.deform).detach().cpu().numpy() + triangles = self.f.detach().cpu().numpy() + vertices, triangles = clean_mesh(vertices, triangles, remesh=False) + self.v = torch.from_numpy(vertices).contiguous().float().to(self.device) + self.f = torch.from_numpy(triangles).contiguous().int().to(self.device) + self.deform = nn.Parameter(torch.zeros_like(self.v).to(self.device)) + + print(f"[INFO] finished fitting mesh!") + + # uv mesh refine + def fit_mesh_uv( + self, iters=512, resolution=512, texture_resolution=1024, padding=2 + ): + self.opt.output_size = resolution + + # unwrap uv + print(f"[INFO] uv unwrapping...") + mesh = Mesh(v=self.v, f=self.f, albedo=None, device=self.device) + mesh.auto_normal() + mesh.auto_uv() + + self.vt = mesh.vt + self.ft = mesh.ft + + # render uv maps + h = w = texture_resolution + uv = mesh.vt * 2.0 - 1.0 # uvs to range [-1, 1] + uv = torch.cat( + (uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1 + ) # [N, 4] + + rast, _ = dr.rasterize( + self.glctx, uv.unsqueeze(0), mesh.ft, (h, w) + ) # [1, h, w, 4] + xyzs, _ = dr.interpolate(mesh.v.unsqueeze(0), rast, mesh.f) # [1, h, w, 3] + mask, _ = dr.interpolate( + torch.ones_like(mesh.v[:, :1]).unsqueeze(0), rast, mesh.f + ) # [1, h, w, 1] + + # masked query + xyzs = xyzs.view(-1, 3) + mask = (mask > 0).view(-1) + + albedo = torch.zeros(h * w, 3, device=self.device, dtype=torch.float32) + + if mask.any(): + print(f"[INFO] querying texture...") + + xyzs = xyzs[mask] # [M, 3] + + # batched inference to avoid OOM + batch = [] + head = 0 + while head < xyzs.shape[0]: + tail = min(head + 640000, xyzs.shape[0]) + batch.append( + torch.sigmoid(self.mlp(self.encoder(xyzs[head:tail]))).float() + ) + head += 640000 + + albedo[mask] = torch.cat(batch, dim=0) + + albedo = albedo.view(h, w, -1) + mask = mask.view(h, w) + albedo = uv_padding(albedo, mask, padding) + + # optimize texture + self.albedo = nn.Parameter(inverse_sigmoid(albedo)).to(self.device) + + optimizer = torch.optim.Adam( + [ + {"params": self.albedo, "lr": 1e-3}, + ] + ) + + print(f"[INFO] fitting mesh texture...") + pbar = tqdm.trange(iters) + for i in pbar: + # shrink to front view as we care more about it... + ver = np.random.randint(-5, 5) + hor = np.random.randint(-15, 15) + rad = self.opt.cam_radius # np.random.uniform(1, 2) + + pose = orbit_camera(ver, hor, rad) + + image_gt, alpha_gt = self.render_gs(pose) + image_pred, alpha_pred = self.render_mesh(pose) + + loss_mse = F.mse_loss(image_pred, image_gt) + loss = loss_mse + + loss.backward() + + optimizer.step() + optimizer.zero_grad() + + pbar.set_description(f"MSE = {loss_mse.item():.6f}") + + print(f"[INFO] finished fitting mesh texture!") + + @torch.no_grad() + def export_mesh(self, path): + mesh = Mesh( + v=self.v, + f=self.f, + vt=self.vt, + ft=self.ft, + albedo=torch.sigmoid(self.albedo), + device=self.device, + ) + mesh.auto_normal() + mesh.write(path) + + +opt = tyro.cli(AllConfigs) + +# load a saved ply and convert to mesh +assert opt.test_path.endswith( + ".ply" +), "--test_path must be a .ply file saved by infer.py" + +converter = Converter(opt).cuda() +converter.fit_nerf() +converter.fit_mesh() +converter.fit_mesh_uv() +converter.export_mesh(opt.test_path.replace(".ply", ".glb")) diff --git a/recon/convert_to_blender.py b/recon/convert_to_blender.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3d85dc1eff4efbdfa1708f5a57e33c2b85c51b --- /dev/null +++ b/recon/convert_to_blender.py @@ -0,0 +1,102 @@ +import json +import torch +from scene import Scene +from pathlib import Path +from PIL import Image +import numpy as np +import sys +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams +from gaussian_renderer import GaussianModel +from mediapy import write_video +from tqdm import tqdm +from einops import rearrange +from utils.camera_utils import get_uniform_poses +from mediapy import write_image + + +@torch.no_grad() +def render_spiral(dataset, opt, pipe, model_path): + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False) + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + viewpoint_stack = scene.getTrainCameras().copy() + views = [] + alphas = [] + for view_cam in tqdm(viewpoint_stack): + bg = torch.rand((3), device="cuda") if opt.random_background else background + render_pkg = render(view_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + views.append(image) + alphas.append(render_pkg["alpha"]) + views = torch.stack(views) + alphas = torch.stack(alphas) + + png_images = ( + (torch.cat([views, alphas], dim=1).clamp(0.0, 1.0) * 255) + .cpu() + .numpy() + .astype(np.uint8) + ) + png_images = rearrange(png_images, "t c h w -> t h w c") + + poses = get_uniform_poses( + dataset.num_frames, dataset.radius, dataset.elevation, opengl=True + ) + camera_angle_x = np.deg2rad(dataset.fov) + name = Path(dataset.model_path).stem + meta_dir = Path(f"blenders/{name}") + meta_dir.mkdir(exist_ok=True, parents=True) + meta = {} + meta["camera_angle_x"] = camera_angle_x + meta["frames"] = [] + for idx, (pose, image) in enumerate(zip(poses, png_images)): + this_frames = {} + this_frames["file_path"] = f"{idx:06d}" + this_frames["transform_matrix"] = pose.tolist() + meta["frames"].append(this_frames) + write_image(meta_dir / f"{idx:06d}.png", image) + + with open(meta_dir / "transforms_train.json", "w") as f: + json.dump(meta, f, indent=4) + with open(meta_dir / "transforms_val.json", "w") as f: + json.dump(meta, f, indent=4) + with open(meta_dir / "transforms_test.json", "w") as f: + json.dump(meta, f, indent=4) + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--skip_test", action="store_true") + parser.add_argument("--quiet", action="store_true") + args = parser.parse_args(sys.argv[1:]) + print("Rendering " + args.model_path) + lp = lp.extract(args) + fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8)) + lp.images = [fake_image] * args.num_frames + + # Initialize system state (RNG) + render_spiral( + lp, + op.extract(args), + pp.extract(args), + model_path=args.model_path, + ) diff --git a/recon/environment.yml b/recon/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..b17a50f6570ee66e157c8fd168b62c45bfba2fee --- /dev/null +++ b/recon/environment.yml @@ -0,0 +1,17 @@ +name: gaussian_splatting +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - cudatoolkit=11.6 + - plyfile=0.8.1 + - python=3.7.13 + - pip=22.3.1 + - pytorch=1.12.1 + - torchaudio=0.12.1 + - torchvision=0.13.1 + - tqdm + - pip: + - submodules/diff-gaussian-rasterization + - submodules/simple-knn \ No newline at end of file diff --git a/recon/full_eval.py b/recon/full_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..8fbb12247724b25563e215b4409ded9af1cbdd04 --- /dev/null +++ b/recon/full_eval.py @@ -0,0 +1,75 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +from argparse import ArgumentParser + +mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"] +mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"] +tanks_and_temples_scenes = ["truck", "train"] +deep_blending_scenes = ["drjohnson", "playroom"] + +parser = ArgumentParser(description="Full evaluation script parameters") +parser.add_argument("--skip_training", action="store_true") +parser.add_argument("--skip_rendering", action="store_true") +parser.add_argument("--skip_metrics", action="store_true") +parser.add_argument("--output_path", default="./eval") +args, _ = parser.parse_known_args() + +all_scenes = [] +all_scenes.extend(mipnerf360_outdoor_scenes) +all_scenes.extend(mipnerf360_indoor_scenes) +all_scenes.extend(tanks_and_temples_scenes) +all_scenes.extend(deep_blending_scenes) + +if not args.skip_training or not args.skip_rendering: + parser.add_argument('--mipnerf360', "-m360", required=True, type=str) + parser.add_argument("--tanksandtemples", "-tat", required=True, type=str) + parser.add_argument("--deepblending", "-db", required=True, type=str) + args = parser.parse_args() + +if not args.skip_training: + common_args = " --quiet --eval --test_iterations -1 " + for scene in mipnerf360_outdoor_scenes: + source = args.mipnerf360 + "/" + scene + os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args) + for scene in mipnerf360_indoor_scenes: + source = args.mipnerf360 + "/" + scene + os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args) + for scene in tanks_and_temples_scenes: + source = args.tanksandtemples + "/" + scene + os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) + for scene in deep_blending_scenes: + source = args.deepblending + "/" + scene + os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args) + +if not args.skip_rendering: + all_sources = [] + for scene in mipnerf360_outdoor_scenes: + all_sources.append(args.mipnerf360 + "/" + scene) + for scene in mipnerf360_indoor_scenes: + all_sources.append(args.mipnerf360 + "/" + scene) + for scene in tanks_and_temples_scenes: + all_sources.append(args.tanksandtemples + "/" + scene) + for scene in deep_blending_scenes: + all_sources.append(args.deepblending + "/" + scene) + + common_args = " --quiet --eval --skip_train" + for scene, source in zip(all_scenes, all_sources): + os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) + os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args) + +if not args.skip_metrics: + scenes_string = "" + for scene in all_scenes: + scenes_string += "\"" + args.output_path + "/" + scene + "\" " + + os.system("python metrics.py -m " + scenes_string) \ No newline at end of file diff --git a/recon/gaussian_renderer/__init__.py b/recon/gaussian_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..10bb78a587ecde17969800e3f781402d1d9a42f7 --- /dev/null +++ b/recon/gaussian_renderer/__init__.py @@ -0,0 +1,134 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, +) +from scene.gaussian_model import GaussianModel +from utils.sh_utils import eval_sh + + +def render( + viewpoint_camera, + pc: GaussianModel, + pipe, + bg_color: torch.Tensor, + scaling_modifier=1.0, + override_color=None, +): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = ( + torch.zeros_like( + pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda" + ) + + 0 + ) + try: + screenspace_points.retain_grad() + except: + pass + + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=pc.active_sh_degree, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=pipe.debug, + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if override_color is None: + if pipe.convert_SHs_python: + shs_view = pc.get_features.transpose(1, 2).view( + -1, 3, (pc.max_sh_degree + 1) ** 2 + ) + dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat( + pc.get_features.shape[0], 1 + ) + dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True) + sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) + colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) + else: + shs = pc.get_features + else: + colors_precomp = override_color + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, depth, alpha = rasterizer( + means3D=means3D, + means2D=means2D, + shs=shs, + colors_precomp=colors_precomp, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=cov3D_precomp, + ) + # rendered_image, radii = rasterizer( + # means3D = means3D, + # means2D = means2D, + # shs = shs, + # colors_precomp = colors_precomp, + # opacities = opacity, + # scales = scales, + # rotations = rotations, + # cov3D_precomp = cov3D_precomp) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return { + "render": rendered_image, + "viewspace_points": screenspace_points, + "visibility_filter": radii > 0, + "radii": radii, + "depth": depth, + "alpha": alpha, + } diff --git a/recon/gaussian_renderer/network_gui.py b/recon/gaussian_renderer/network_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..df2f9dae782b24527ae5b09f91ca4009361de53f --- /dev/null +++ b/recon/gaussian_renderer/network_gui.py @@ -0,0 +1,86 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import traceback +import socket +import json +from scene.cameras import MiniCam + +host = "127.0.0.1" +port = 6009 + +conn = None +addr = None + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +def init(wish_host, wish_port): + global host, port, listener + host = wish_host + port = wish_port + listener.bind((host, port)) + listener.listen() + listener.settimeout(0) + +def try_connect(): + global conn, addr, listener + try: + conn, addr = listener.accept() + print(f"\nConnected by {addr}") + conn.settimeout(None) + except Exception as inst: + pass + +def read(): + global conn + messageLength = conn.recv(4) + messageLength = int.from_bytes(messageLength, 'little') + message = conn.recv(messageLength) + return json.loads(message.decode("utf-8")) + +def send(message_bytes, verify): + global conn + if message_bytes != None: + conn.sendall(message_bytes) + conn.sendall(len(verify).to_bytes(4, 'little')) + conn.sendall(bytes(verify, 'ascii')) + +def receive(): + message = read() + + width = message["resolution_x"] + height = message["resolution_y"] + + if width != 0 and height != 0: + try: + do_training = bool(message["train"]) + fovy = message["fov_y"] + fovx = message["fov_x"] + znear = message["z_near"] + zfar = message["z_far"] + do_shs_python = bool(message["shs_python"]) + do_rot_scale_python = bool(message["rot_scale_python"]) + keep_alive = bool(message["keep_alive"]) + scaling_modifier = message["scaling_modifier"] + world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() + world_view_transform[:,1] = -world_view_transform[:,1] + world_view_transform[:,2] = -world_view_transform[:,2] + full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() + full_proj_transform[:,1] = -full_proj_transform[:,1] + custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) + except Exception as e: + print("") + traceback.print_exc() + raise e + return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier + else: + return None, None, None, None, None, None \ No newline at end of file diff --git a/recon/lgm/gs.py b/recon/lgm/gs.py new file mode 100644 index 0000000000000000000000000000000000000000..c67469d0c3ed92fa7f6f7575daf609b360bc98a5 --- /dev/null +++ b/recon/lgm/gs.py @@ -0,0 +1,213 @@ +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, +) + +from .options import Options + +import kiui + + +class GaussianRenderer: + def __init__(self, opt: Options): + self.opt = opt + self.bg_color = torch.tensor([1, 1, 1], dtype=torch.float32, device="cuda") + + # intrinsics + self.tan_half_fov = np.tan(0.5 * np.deg2rad(self.opt.fovy)) + self.proj_matrix = torch.zeros(4, 4, dtype=torch.float32) + self.proj_matrix[0, 0] = 1 / self.tan_half_fov + self.proj_matrix[1, 1] = 1 / self.tan_half_fov + self.proj_matrix[2, 2] = (opt.zfar + opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[3, 2] = -(opt.zfar * opt.znear) / (opt.zfar - opt.znear) + self.proj_matrix[2, 3] = 1 + + def render( + self, + gaussians, + cam_view, + cam_view_proj, + cam_pos, + bg_color=None, + scale_modifier=1, + ): + # gaussians: [B, N, 14] + # cam_view, cam_view_proj: [B, V, 4, 4] + # cam_pos: [B, V, 3] + + device = gaussians.device + B, V = cam_view.shape[:2] + + # loop of loop... + images = [] + alphas = [] + for b in range(B): + # pos, opacity, scale, rotation, shs + means3D = gaussians[b, :, 0:3].contiguous().float() + opacity = gaussians[b, :, 3:4].contiguous().float() + scales = gaussians[b, :, 4:7].contiguous().float() + rotations = gaussians[b, :, 7:11].contiguous().float() + rgbs = gaussians[b, :, 11:].contiguous().float() # [N, 3] + + for v in range(V): + # render novel views + view_matrix = cam_view[b, v].float() + view_proj_matrix = cam_view_proj[b, v].float() + campos = cam_pos[b, v].float() + + raster_settings = GaussianRasterizationSettings( + image_height=self.opt.output_size, + image_width=self.opt.output_size, + tanfovx=self.tan_half_fov, + tanfovy=self.tan_half_fov, + bg=self.bg_color if bg_color is None else bg_color, + scale_modifier=scale_modifier, + viewmatrix=view_matrix, + projmatrix=view_proj_matrix, + sh_degree=0, + campos=campos, + prefiltered=False, + debug=False, + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + rendered_image, radii, rendered_depth, rendered_alpha = rasterizer( + means3D=means3D, + means2D=torch.zeros_like( + means3D, dtype=torch.float32, device=device + ), + shs=None, + colors_precomp=rgbs, + opacities=opacity, + scales=scales, + rotations=rotations, + cov3D_precomp=None, + ) + + rendered_image = rendered_image.clamp(0, 1) + + images.append(rendered_image) + alphas.append(rendered_alpha) + + images = torch.stack(images, dim=0).view( + B, V, 3, self.opt.output_size, self.opt.output_size + ) + alphas = torch.stack(alphas, dim=0).view( + B, V, 1, self.opt.output_size, self.opt.output_size + ) + + return { + "image": images, # [B, V, 3, H, W] + "alpha": alphas, # [B, V, 1, H, W] + } + + def save_ply(self, gaussians, path, compatible=True): + # gaussians: [B, N, 14] + # compatible: save pre-activated gaussians as in the original paper + + assert gaussians.shape[0] == 1, "only support batch size 1" + + from plyfile import PlyData, PlyElement + + means3D = gaussians[0, :, 0:3].contiguous().float() + opacity = gaussians[0, :, 3:4].contiguous().float() + scales = gaussians[0, :, 4:7].contiguous().float() + rotations = gaussians[0, :, 7:11].contiguous().float() + shs = gaussians[0, :, 11:].unsqueeze(1).contiguous().float() # [N, 1, 3] + + # prune by opacity + mask = opacity.squeeze(-1) >= 0.005 + means3D = means3D[mask] + opacity = opacity[mask] + scales = scales[mask] + rotations = rotations[mask] + shs = shs[mask] + + # invert activation to make it compatible with the original ply format + if compatible: + opacity = kiui.op.inverse_sigmoid(opacity) + scales = torch.log(scales + 1e-8) + shs = (shs - 0.5) / 0.28209479177387814 + + xyzs = means3D.detach().cpu().numpy() + f_dc = ( + shs.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + ) + opacities = opacity.detach().cpu().numpy() + scales = scales.detach().cpu().numpy() + rotations = rotations.detach().cpu().numpy() + + l = ["x", "y", "z"] + # All channels except the 3 DC + for i in range(f_dc.shape[1]): + l.append("f_dc_{}".format(i)) + l.append("opacity") + for i in range(scales.shape[1]): + l.append("scale_{}".format(i)) + for i in range(rotations.shape[1]): + l.append("rot_{}".format(i)) + + dtype_full = [(attribute, "f4") for attribute in l] + + elements = np.empty(xyzs.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyzs, f_dc, opacities, scales, rotations), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, "vertex") + + PlyData([el]).write(path) + + def load_ply(self, path, compatible=True): + from plyfile import PlyData, PlyElement + + plydata = PlyData.read(path) + + xyz = np.stack( + ( + np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"]), + ), + axis=1, + ) + print("Number of points at loading : ", xyz.shape[0]) + + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + shs = np.zeros((xyz.shape[0], 3)) + shs[:, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + shs[:, 1] = np.asarray(plydata.elements[0]["f_dc_1"]) + shs[:, 2] = np.asarray(plydata.elements[0]["f_dc_2"]) + + scale_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("scale_") + ] + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [ + p.name for p in plydata.elements[0].properties if p.name.startswith("rot_") + ] + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + gaussians = np.concatenate([xyz, opacities, scales, rots, shs], axis=1) + gaussians = torch.from_numpy(gaussians).float() # cpu + + if compatible: + gaussians[..., 3:4] = torch.sigmoid(gaussians[..., 3:4]) + gaussians[..., 4:7] = torch.exp(gaussians[..., 4:7]) + gaussians[..., 11:] = 0.28209479177387814 * gaussians[..., 11:] + 0.5 + + return gaussians diff --git a/recon/lgm/options.py b/recon/lgm/options.py new file mode 100644 index 0000000000000000000000000000000000000000..4cc31944f89ff14a4387204f1828edd785bc3498 --- /dev/null +++ b/recon/lgm/options.py @@ -0,0 +1,120 @@ +import tyro +from dataclasses import dataclass +from typing import Tuple, Literal, Dict, Optional + + +@dataclass +class Options: + ### model + # Unet image input size + input_size: int = 256 + # Unet definition + down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024, 1024) + down_attention: Tuple[bool, ...] = (False, False, False, True, True, True) + mid_attention: bool = True + up_channels: Tuple[int, ...] = (1024, 1024, 512, 256) + up_attention: Tuple[bool, ...] = (True, True, True, False) + # Unet output size, dependent on the input_size and U-Net structure! + splat_size: int = 64 + # gaussian render size + output_size: int = 256 + + ### dataset + # data mode (only support s3 now) + data_mode: Literal['s3'] = 's3' + # fovy of the dataset + fovy: float = 49.1 + # camera near plane + znear: float = 0.5 + # camera far plane + zfar: float = 2.5 + # number of all views (input + output) + num_views: int = 12 + # number of views + num_input_views: int = 4 + # camera radius + cam_radius: float = 1.5 # to better use [-1, 1]^3 space + # num workers + num_workers: int = 8 + + ### training + # workspace + workspace: str = './workspace' + # resume + resume: Optional[str] = None + # batch size (per-GPU) + batch_size: int = 8 + # gradient accumulation + gradient_accumulation_steps: int = 1 + # training epochs + num_epochs: int = 30 + # lpips loss weight + lambda_lpips: float = 1.0 + # gradient clip + gradient_clip: float = 1.0 + # mixed precision + mixed_precision: str = 'bf16' + # learning rate + lr: float = 4e-4 + # augmentation prob for grid distortion + prob_grid_distortion: float = 0.5 + # augmentation prob for camera jitter + prob_cam_jitter: float = 0.5 + + ### testing + # test image path + test_path: Optional[str] = None + + ### misc + # nvdiffrast backend setting + force_cuda_rast: bool = False + # render fancy video with gaussian scaling effect + fancy_video: bool = False + + +# all the default settings +config_defaults: Dict[str, Options] = {} +config_doc: Dict[str, str] = {} + +config_doc['lrm'] = 'the default settings for LGM' +config_defaults['lrm'] = Options() + +config_doc['small'] = 'small model with lower resolution Gaussians' +config_defaults['small'] = Options( + input_size=256, + splat_size=64, + output_size=256, + batch_size=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +config_doc['big'] = 'big model with higher resolution Gaussians' +config_defaults['big'] = Options( + input_size=256, + up_channels=(1024, 1024, 512, 256, 128), # one more decoder + up_attention=(True, True, True, False, False), + splat_size=128, + output_size=512, # render & supervise Gaussians at a higher resolution. + batch_size=8, + num_views=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +config_doc['tiny'] = 'tiny model for ablation' +config_defaults['tiny'] = Options( + input_size=256, + down_channels=(32, 64, 128, 256, 512), + down_attention=(False, False, False, False, True), + up_channels=(512, 256, 128), + up_attention=(True, False, False, False), + splat_size=64, + output_size=256, + batch_size=16, + num_views=8, + gradient_accumulation_steps=1, + mixed_precision='bf16', +) + +AllConfigs = tyro.extras.subcommand_type_from_defaults(config_defaults, config_doc) diff --git a/recon/lpipsPyTorch/__init__.py b/recon/lpipsPyTorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2a6297daa457d1d041c9491dfdf6a75994ffe06e --- /dev/null +++ b/recon/lpipsPyTorch/__init__.py @@ -0,0 +1,21 @@ +import torch + +from .modules.lpips import LPIPS + + +def lpips(x: torch.Tensor, + y: torch.Tensor, + net_type: str = 'alex', + version: str = '0.1'): + r"""Function that measures + Learned Perceptual Image Patch Similarity (LPIPS). + + Arguments: + x, y (torch.Tensor): the input tensors to compare. + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + device = x.device + criterion = LPIPS(net_type, version).to(device) + return criterion(x, y) diff --git a/recon/lpipsPyTorch/modules/lpips.py b/recon/lpipsPyTorch/modules/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..ea2d1da4daa267cc5e6c2ce11e1dddec3a5e9406 --- /dev/null +++ b/recon/lpipsPyTorch/modules/lpips.py @@ -0,0 +1,38 @@ +import torch +import torch.nn as nn + +from .networks import get_network, LinLayers +from .utils import get_state_dict + + +class LPIPS(nn.Module): + r"""Creates a criterion that measures + Learned Perceptual Image Patch Similarity (LPIPS). + + Arguments: + net_type (str): the network type to compare the features: + 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. + version (str): the version of LPIPS. Default: 0.1. + """ + + def __init__(self, net_type: str = "alex", version: str = "0.1"): + + assert version in ["0.1"], "v0.1 is only supported now" + + super(LPIPS, self).__init__() + + # pretrained network + self.net = get_network(net_type) + + # linear layers + self.lin = LinLayers(self.net.n_channels_list) + self.lin.load_state_dict(get_state_dict(net_type, version)) + self.eval() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + feat_x, feat_y = self.net(x), self.net(y) + + diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] + res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] + + return torch.sum(torch.cat(res, 0), 0, True) diff --git a/recon/lpipsPyTorch/modules/networks.py b/recon/lpipsPyTorch/modules/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..d36c6a56163004d49c321da5e26404af9baa4c2a --- /dev/null +++ b/recon/lpipsPyTorch/modules/networks.py @@ -0,0 +1,96 @@ +from typing import Sequence + +from itertools import chain + +import torch +import torch.nn as nn +from torchvision import models + +from .utils import normalize_activation + + +def get_network(net_type: str): + if net_type == 'alex': + return AlexNet() + elif net_type == 'squeeze': + return SqueezeNet() + elif net_type == 'vgg': + return VGG16() + else: + raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') + + +class LinLayers(nn.ModuleList): + def __init__(self, n_channels_list: Sequence[int]): + super(LinLayers, self).__init__([ + nn.Sequential( + nn.Identity(), + nn.Conv2d(nc, 1, 1, 1, 0, bias=False) + ) for nc in n_channels_list + ]) + + for param in self.parameters(): + param.requires_grad = False + + +class BaseNet(nn.Module): + def __init__(self): + super(BaseNet, self).__init__() + + # register buffer + self.register_buffer( + 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer( + 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def set_requires_grad(self, state: bool): + for param in chain(self.parameters(), self.buffers()): + param.requires_grad = state + + def z_score(self, x: torch.Tensor): + return (x - self.mean) / self.std + + def forward(self, x: torch.Tensor): + x = self.z_score(x) + + output = [] + for i, (_, layer) in enumerate(self.layers._modules.items(), 1): + x = layer(x) + if i in self.target_layers: + output.append(normalize_activation(x)) + if len(output) == len(self.target_layers): + break + return output + + +class SqueezeNet(BaseNet): + def __init__(self): + super(SqueezeNet, self).__init__() + + self.layers = models.squeezenet1_1(True).features + self.target_layers = [2, 5, 8, 10, 11, 12, 13] + self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] + + self.set_requires_grad(False) + + +class AlexNet(BaseNet): + def __init__(self): + super(AlexNet, self).__init__() + + self.layers = models.alexnet(True).features + self.target_layers = [2, 5, 8, 10, 12] + self.n_channels_list = [64, 192, 384, 256, 256] + + self.set_requires_grad(False) + + +class VGG16(BaseNet): + def __init__(self): + super(VGG16, self).__init__() + + self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features + self.target_layers = [4, 9, 16, 23, 30] + self.n_channels_list = [64, 128, 256, 512, 512] + + self.set_requires_grad(False) diff --git a/recon/lpipsPyTorch/modules/utils.py b/recon/lpipsPyTorch/modules/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3d15a0983775810ef6239c561c67939b2b9ee3b5 --- /dev/null +++ b/recon/lpipsPyTorch/modules/utils.py @@ -0,0 +1,30 @@ +from collections import OrderedDict + +import torch + + +def normalize_activation(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def get_state_dict(net_type: str = 'alex', version: str = '0.1'): + # build url + url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ + + f'master/lpips/weights/v{version}/{net_type}.pth' + + # download + old_state_dict = torch.hub.load_state_dict_from_url( + url, progress=True, + map_location=None if torch.cuda.is_available() else torch.device('cpu') + ) + + # rename keys + new_state_dict = OrderedDict() + for key, val in old_state_dict.items(): + new_key = key + new_key = new_key.replace('lin', '') + new_key = new_key.replace('model.', '') + new_state_dict[new_key] = val + + return new_state_dict diff --git a/recon/metrics.py b/recon/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..faa3b698a68296a9a6226bc51c78407320c106fe --- /dev/null +++ b/recon/metrics.py @@ -0,0 +1,131 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from pathlib import Path +import os +from PIL import Image +import torch +import torchvision.transforms.functional as tf +from utils.loss_utils import ssim +from lpipsPyTorch import lpips +import json +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser + + +def readImages(renders_dir, gt_dir): + renders = [] + gts = [] + image_names = [] + for fname in os.listdir(renders_dir): + render = Image.open(renders_dir / fname) + gt = Image.open(gt_dir / fname) + renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda()) + gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda()) + image_names.append(fname) + return renders, gts, image_names + + +def evaluate(model_paths): + + full_dict = {} + per_view_dict = {} + full_dict_polytopeonly = {} + per_view_dict_polytopeonly = {} + print("") + + for scene_dir in model_paths: + try: + print("Scene:", scene_dir) + full_dict[scene_dir] = {} + per_view_dict[scene_dir] = {} + full_dict_polytopeonly[scene_dir] = {} + per_view_dict_polytopeonly[scene_dir] = {} + + test_dir = Path(scene_dir) / "test" + + for method in os.listdir(test_dir): + print("Method:", method) + + full_dict[scene_dir][method] = {} + per_view_dict[scene_dir][method] = {} + full_dict_polytopeonly[scene_dir][method] = {} + per_view_dict_polytopeonly[scene_dir][method] = {} + + method_dir = test_dir / method + gt_dir = method_dir / "gt" + renders_dir = method_dir / "renders" + renders, gts, image_names = readImages(renders_dir, gt_dir) + + ssims = [] + psnrs = [] + lpipss = [] + + for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"): + ssims.append(ssim(renders[idx], gts[idx])) + psnrs.append(psnr(renders[idx], gts[idx])) + lpipss.append(lpips(renders[idx], gts[idx], net_type="vgg")) + + print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5")) + print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5")) + print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5")) + print("") + + full_dict[scene_dir][method].update( + { + "SSIM": torch.tensor(ssims).mean().item(), + "PSNR": torch.tensor(psnrs).mean().item(), + "LPIPS": torch.tensor(lpipss).mean().item(), + } + ) + per_view_dict[scene_dir][method].update( + { + "SSIM": { + name: ssim + for ssim, name in zip( + torch.tensor(ssims).tolist(), image_names + ) + }, + "PSNR": { + name: psnr + for psnr, name in zip( + torch.tensor(psnrs).tolist(), image_names + ) + }, + "LPIPS": { + name: lp + for lp, name in zip( + torch.tensor(lpipss).tolist(), image_names + ) + }, + } + ) + + with open(scene_dir + "/results.json", "w") as fp: + json.dump(full_dict[scene_dir], fp, indent=True) + with open(scene_dir + "/per_view.json", "w") as fp: + json.dump(per_view_dict[scene_dir], fp, indent=True) + except: + print("Unable to compute metrics for model", scene_dir) + + +if __name__ == "__main__": + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + parser.add_argument( + "--model_paths", "-m", required=True, nargs="+", type=str, default=[] + ) + args = parser.parse_args() + evaluate(args.model_paths) diff --git a/recon/render.py b/recon/render.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d66379bf3227127ea18ed10b82bfe53ea2726d --- /dev/null +++ b/recon/render.py @@ -0,0 +1,65 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +from scene import Scene +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args +from gaussian_renderer import GaussianModel + +def render_set(model_path, name, iteration, views, gaussians, pipeline, background): + render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders") + gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt") + + makedirs(render_path, exist_ok=True) + makedirs(gts_path, exist_ok=True) + + for idx, view in enumerate(tqdm(views, desc="Rendering progress")): + rendering = render(view, gaussians, pipeline, background)["render"] + gt = view.original_image[0:3, :, :] + torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) + torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) + +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): + with torch.no_grad(): + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) + + bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + if not skip_train: + render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) + + if not skip_test: + render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background) + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Testing script parameters") + model = ModelParams(parser, sentinel=True) + pipeline = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--skip_test", action="store_true") + parser.add_argument("--quiet", action="store_true") + args = get_combined_args(parser) + print("Rendering " + args.model_path) + + # Initialize system state (RNG) + + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file diff --git a/recon/render_depth.py b/recon/render_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..2a09acda32d991f257522f50d0c83c5f718f2cd9 --- /dev/null +++ b/recon/render_depth.py @@ -0,0 +1,79 @@ +import torch +from scene import Scene +from pathlib import Path +from PIL import Image +import numpy as np +import sys +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams +from gaussian_renderer import GaussianModel +from mediapy import write_video +from tqdm import tqdm +from einops import rearrange +from utils.colormaps import apply_depth_colormap + + +@torch.no_grad() +def render_spiral(dataset, opt, pipe, model_path): + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False) + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + viewpoint_stack = scene.getTrainCameras().copy() + views = [] + for view_cam in tqdm(viewpoint_stack): + bg = torch.rand((3), device="cuda") if opt.random_background else background + render_pkg = render(view_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["depth"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + views.append( + rearrange( + apply_depth_colormap( + image[0][..., None], + accumulation=render_pkg["alpha"][0][..., None], + ), + "h w c -> c h w", + ) + ) + views = torch.stack(views) + + write_video( + f"./depth_spirals/{Path(dataset.model_path).stem}.mp4", + rearrange(views.cpu().numpy(), "t c h w -> t h w c"), + fps=3, + ) + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--skip_test", action="store_true") + parser.add_argument("--quiet", action="store_true") + args = parser.parse_args(sys.argv[1:]) + print("Rendering " + args.model_path) + lp = lp.extract(args) + fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8)) + lp.images = [fake_image] * args.num_frames + + # Initialize system state (RNG) + render_spiral( + lp, + op.extract(args), + pp.extract(args), + model_path=args.model_path, + ) diff --git a/recon/render_points.py b/recon/render_points.py new file mode 100644 index 0000000000000000000000000000000000000000..8ad86e10151d899474d9dcecfe3a19afc2e34258 --- /dev/null +++ b/recon/render_points.py @@ -0,0 +1,70 @@ +import torch +from scene import Scene +from pathlib import Path +from PIL import Image +import numpy as np +import sys +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams +from gaussian_renderer import GaussianModel +from mediapy import write_video +from tqdm import tqdm +from einops import rearrange + + +@torch.no_grad() +def render_spiral(dataset, opt, pipe, model_path): + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False) + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + viewpoint_stack = scene.getTrainCameras().copy() + views = [] + for view_cam in tqdm(viewpoint_stack): + bg = torch.rand((3), device="cuda") if opt.random_background else background + render_pkg = render(view_cam, gaussians, pipe, bg, scaling_modifier=0.1) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + views.append(image) + views = torch.stack(views) + + write_video( + f"./paper/specials/{Path(dataset.model_path).stem}.mp4", + rearrange(views.cpu().numpy(), "t c h w -> t h w c"), + fps=30, + ) + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--skip_test", action="store_true") + parser.add_argument("--quiet", action="store_true") + args = parser.parse_args(sys.argv[1:]) + print("Rendering " + args.model_path) + lp = lp.extract(args) + fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8)) + lp.images = [fake_image] * args.num_frames + + # Initialize system state (RNG) + render_spiral( + lp, + op.extract(args), + pp.extract(args), + model_path=args.model_path, + ) diff --git a/recon/render_spiral.py b/recon/render_spiral.py new file mode 100644 index 0000000000000000000000000000000000000000..23aea3afc972638315311e96e30c3aec8d75aee1 --- /dev/null +++ b/recon/render_spiral.py @@ -0,0 +1,75 @@ +import torch +from scene import Scene +from pathlib import Path +from PIL import Image +import numpy as np +import sys +import os +from tqdm import tqdm +from os import makedirs +from gaussian_renderer import render +import torchvision +from utils.general_utils import safe_state +from argparse import ArgumentParser +from arguments import ModelParams, PipelineParams, get_combined_args, OptimizationParams +from gaussian_renderer import GaussianModel +from mediapy import write_video +from tqdm import tqdm +from einops import rearrange + + +@torch.no_grad() +def render_spiral(dataset, opt, pipe, model_path): + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, load_iteration=-1, shuffle=False) + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + viewpoint_stack = scene.getTrainCameras().copy() + views = [] + for view_cam in tqdm(viewpoint_stack): + bg = torch.rand((3), device="cuda") if opt.random_background else background + render_pkg = render(view_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + views.append(image) + views = torch.stack(views) + + write_video( + f"./spirals/{Path(dataset.model_path).stem}.mp4", + rearrange(views.cpu().numpy(), "t c h w -> t h w c"), + fps=30, + ) + write_video( + f"tmp/test_spiral.mp4", + rearrange(views.cpu().numpy(), "t c h w -> t h w c"), + fps=30, + ) + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--iteration", default=-1, type=int) + parser.add_argument("--skip_train", action="store_true") + parser.add_argument("--skip_test", action="store_true") + parser.add_argument("--quiet", action="store_true") + args = parser.parse_args(sys.argv[1:]) + print("Rendering " + args.model_path) + lp = lp.extract(args) + fake_image = Image.fromarray(np.zeros([512, 512, 3], dtype=np.uint8)) + lp.images = [fake_image] * args.num_frames + + # Initialize system state (RNG) + render_spiral( + lp, + op.extract(args), + pp.extract(args), + model_path=args.model_path, + ) diff --git a/recon/restore.py b/recon/restore.py new file mode 100644 index 0000000000000000000000000000000000000000..5c907a210bf62ef2453ccc9d06559c1cfe3b33f2 --- /dev/null +++ b/recon/restore.py @@ -0,0 +1,3 @@ +import torch + +pass diff --git a/recon/scene/__init__.py b/recon/scene/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..faea2365d4aa15be843b7adfdb7efbf22dc60554 --- /dev/null +++ b/recon/scene/__init__.py @@ -0,0 +1,139 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import random +import json +from utils.system_utils import searchForMaxIteration +from scene.dataset_readers import sceneLoadTypeCallbacks +from scene.gaussian_model import GaussianModel +from arguments import ModelParams +from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON + + +class Scene: + gaussians: GaussianModel + + def __init__( + self, + args: ModelParams, + gaussians: GaussianModel, + load_iteration=None, + shuffle=True, + resolution_scales=[1.0], + skip_gaussians=False, + ): + """b + :param path: Path to colmap scene main folder. + """ + self.model_path = args.model_path + self.loaded_iter = None + self.gaussians = gaussians + + if load_iteration: + if load_iteration == -1: + self.loaded_iter = searchForMaxIteration( + os.path.join(self.model_path, "point_cloud") + ) + else: + self.loaded_iter = load_iteration + print("Loading trained model at iteration {}".format(self.loaded_iter)) + + self.train_cameras = {} + self.test_cameras = {} + + if os.path.exists(os.path.join(args.source_path, "sparse")): + scene_info = sceneLoadTypeCallbacks["Colmap"]( + args.source_path, args.images, args.eval + ) + elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")): + print("Found transforms_train.json file, assuming Blender data set!") + scene_info = sceneLoadTypeCallbacks["Blender"]( + args.source_path, args.white_background, args.eval + ) + elif hasattr(args, "num_frames"): + print("using video-nvs target") + scene_info = sceneLoadTypeCallbacks["VideoNVS"]( + args.num_frames, + args.radius, + args.elevation, + args.fov, + args.reso, + args.images, + args.masks, + args.num_pts, + args.train, + ) + else: + assert False, "Could not recognize scene type!" + + if not self.loaded_iter: + with open(scene_info.ply_path, "rb") as src_file, open( + os.path.join(self.model_path, "input.ply"), "wb" + ) as dest_file: + dest_file.write(src_file.read()) + json_cams = [] + camlist = [] + if scene_info.test_cameras: + camlist.extend(scene_info.test_cameras) + if scene_info.train_cameras: + camlist.extend(scene_info.train_cameras) + for id, cam in enumerate(camlist): + json_cams.append(camera_to_JSON(id, cam)) + with open(os.path.join(self.model_path, "cameras.json"), "w") as file: + json.dump(json_cams, file) + + if shuffle: + random.shuffle( + scene_info.train_cameras + ) # Multi-res consistent random shuffling + random.shuffle( + scene_info.test_cameras + ) # Multi-res consistent random shuffling + + self.cameras_extent = scene_info.nerf_normalization["radius"] + + for resolution_scale in resolution_scales: + print("Loading Training Cameras") + self.train_cameras[resolution_scale] = cameraList_from_camInfos( + scene_info.train_cameras, resolution_scale, args + ) + print("Loading Test Cameras") + self.test_cameras[resolution_scale] = cameraList_from_camInfos( + scene_info.test_cameras, resolution_scale, args + ) + + if not skip_gaussians: + if self.loaded_iter: + self.gaussians.load_ply( + os.path.join( + self.model_path, + "point_cloud", + "iteration_" + str(self.loaded_iter), + "point_cloud.ply", + ) + ) + else: + self.gaussians.create_from_pcd( + scene_info.point_cloud, self.cameras_extent + ) + + def save(self, iteration): + point_cloud_path = os.path.join( + self.model_path, "point_cloud/iteration_{}".format(iteration) + ) + self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) + + def getTrainCameras(self, scale=1.0): + return self.train_cameras[scale] + + def getTestCameras(self, scale=1.0): + return self.test_cameras[scale] diff --git a/recon/scene/cameras.py b/recon/scene/cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..abf6e5242bc46ef1915ce24619a8319d0b7591c7 --- /dev/null +++ b/recon/scene/cameras.py @@ -0,0 +1,71 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +from torch import nn +import numpy as np +from utils.graphics_utils import getWorld2View2, getProjectionMatrix + +class Camera(nn.Module): + def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, + image_name, uid, + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + ): + super(Camera, self).__init__() + + self.uid = uid + self.colmap_id = colmap_id + self.R = R + self.T = T + self.FoVx = FoVx + self.FoVy = FoVy + self.image_name = image_name + + try: + self.data_device = torch.device(data_device) + except Exception as e: + print(e) + print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) + self.data_device = torch.device("cuda") + + self.original_image = image.clamp(0.0, 1.0).to(self.data_device) + self.image_width = self.original_image.shape[2] + self.image_height = self.original_image.shape[1] + + if gt_alpha_mask is not None: + self.original_image *= gt_alpha_mask.to(self.data_device) + else: + self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + + self.zfar = 100.0 + self.znear = 0.01 + + self.trans = trans + self.scale = scale + + self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + +class MiniCam: + def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + self.world_view_transform = world_view_transform + self.full_proj_transform = full_proj_transform + view_inv = torch.inverse(self.world_view_transform) + self.camera_center = view_inv[3][:3] + diff --git a/recon/scene/colmap_loader.py b/recon/scene/colmap_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6fba6a9c961f52c88780ecb44d7821b4cb73ee --- /dev/null +++ b/recon/scene/colmap_loader.py @@ -0,0 +1,294 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import numpy as np +import collections +import struct + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"]) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"]) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) +} +CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) + for camera_model in CAMERA_MODELS]) +CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) + for camera_model in CAMERA_MODELS]) + + +def qvec2rotmat(qvec): + return np.array([ + [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], + [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], + [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = np.array([ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + +def read_points3D_text(path): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + xyzs = None + rgbs = None + errors = None + num_points = 0 + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + num_points += 1 + + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + count = 0 + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = np.array(float(elems[7])) + xyzs[count] = xyz + rgbs[count] = rgb + errors[count] = error + count += 1 + + return xyzs, rgbs, errors + +def read_points3D_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + + + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + + xyzs = np.empty((num_points, 3)) + rgbs = np.empty((num_points, 3)) + errors = np.empty((num_points, 1)) + + for p_id in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd") + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q")[0] + track_elems = read_next_bytes( + fid, num_bytes=8*track_length, + format_char_sequence="ii"*track_length) + xyzs[p_id] = xyz + rgbs[p_id] = rgb + errors[p_id] = error + return xyzs, rgbs, errors + +def read_intrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE" + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera(id=camera_id, model=model, + width=width, height=height, + params=params) + return cameras + +def read_extrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi") + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + image_name = "" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + image_name += current_char.decode("utf-8") + current_char = read_next_bytes(fid, 1, "c")[0] + num_points2D = read_next_bytes(fid, num_bytes=8, + format_char_sequence="Q")[0] + x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, + format_char_sequence="ddq"*num_points2D) + xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3]))]) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_intrinsics_binary(path_to_model_file): + """ + see: src/base/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ") + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes(fid, num_bytes=8*num_params, + format_char_sequence="d"*num_params) + cameras[camera_id] = Camera(id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params)) + assert len(cameras) == num_cameras + return cameras + + +def read_extrinsics_text(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack([tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3]))]) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, qvec=qvec, tvec=tvec, + camera_id=camera_id, name=image_name, + xys=xys, point3D_ids=point3D_ids) + return images + + +def read_colmap_bin_array(path): + """ + Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py + + :param path: path to the colmap binary file. + :return: nd array with the floating point values in the value + """ + with open(path, "rb") as fid: + width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1, + usecols=(0, 1, 2), dtype=int) + fid.seek(0) + num_delimiter = 0 + byte = fid.read(1) + while True: + if byte == b"&": + num_delimiter += 1 + if num_delimiter >= 3: + break + byte = fid.read(1) + array = np.fromfile(fid, np.float32) + array = array.reshape((width, height, channels), order="F") + return np.transpose(array, (1, 0, 2)).squeeze() diff --git a/recon/scene/dataset_readers.py b/recon/scene/dataset_readers.py new file mode 100644 index 0000000000000000000000000000000000000000..1e09a82ea5726f3f7a46fd24bacb5f1b93ef5231 --- /dev/null +++ b/recon/scene/dataset_readers.py @@ -0,0 +1,512 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import sys +from PIL import Image +from typing import NamedTuple +from scene.colmap_loader import ( + read_extrinsics_text, + read_intrinsics_text, + qvec2rotmat, + read_extrinsics_binary, + read_intrinsics_binary, + read_points3D_binary, + read_points3D_text, +) +from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal +from utils.camera_utils import get_uniform_poses +import numpy as np +import json +from pathlib import Path +from plyfile import PlyData, PlyElement +from utils.sh_utils import SH2RGB +from scene.gaussian_model import BasicPointCloud +from scene.cameras import Camera +import torch +import rembg +import mcubes +import trimesh + + +class CameraInfo(NamedTuple): + uid: int + R: np.array + T: np.array + FovY: np.array + FovX: np.array + image: np.array + image_path: str + image_name: str + width: int + height: int + + +class SceneInfo(NamedTuple): + point_cloud: BasicPointCloud + train_cameras: list + test_cameras: list + nerf_normalization: dict + ply_path: str + + +def getNerfppNorm(cam_info): + def get_center_and_diag(cam_centers): + cam_centers = np.hstack(cam_centers) + avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) + center = avg_cam_center + dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) + diagonal = np.max(dist) + return center.flatten(), diagonal + + cam_centers = [] + + for cam in cam_info: + W2C = getWorld2View2(cam.R, cam.T) + C2W = np.linalg.inv(W2C) + cam_centers.append(C2W[:3, 3:4]) + + center, diagonal = get_center_and_diag(cam_centers) + radius = diagonal * 1.1 + + translate = -center + + return {"translate": translate, "radius": radius} + + +def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder): + cam_infos = [] + for idx, key in enumerate(cam_extrinsics): + sys.stdout.write("\r") + # the exact output you're looking for: + sys.stdout.write("Reading camera {}/{}".format(idx + 1, len(cam_extrinsics))) + sys.stdout.flush() + + extr = cam_extrinsics[key] + intr = cam_intrinsics[extr.camera_id] + height = intr.height + width = intr.width + + uid = intr.id + R = np.transpose(qvec2rotmat(extr.qvec)) + T = np.array(extr.tvec) + + if intr.model == "SIMPLE_PINHOLE": + focal_length_x = intr.params[0] + FovY = focal2fov(focal_length_x, height) + FovX = focal2fov(focal_length_x, width) + elif intr.model == "PINHOLE": + focal_length_x = intr.params[0] + focal_length_y = intr.params[1] + FovY = focal2fov(focal_length_y, height) + FovX = focal2fov(focal_length_x, width) + else: + assert ( + False + ), "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!" + + image_path = os.path.join(images_folder, os.path.basename(extr.name)) + image_name = os.path.basename(image_path).split(".")[0] + image = Image.open(image_path) + + cam_info = CameraInfo( + uid=uid, + R=R, + T=T, + FovY=FovY, + FovX=FovX, + image=image, + image_path=image_path, + image_name=image_name, + width=width, + height=height, + ) + cam_infos.append(cam_info) + sys.stdout.write("\n") + return cam_infos + + +def fetchPly(path): + plydata = PlyData.read(path) + vertices = plydata["vertex"] + positions = np.vstack([vertices["x"], vertices["y"], vertices["z"]]).T + colors = np.vstack([vertices["red"], vertices["green"], vertices["blue"]]).T / 255.0 + normals = np.vstack([vertices["nx"], vertices["ny"], vertices["nz"]]).T + return BasicPointCloud(points=positions, colors=colors, normals=normals) + + +def storePly(path, xyz, rgb): + # Define the dtype for the structured array + dtype = [ + ("x", "f4"), + ("y", "f4"), + ("z", "f4"), + ("nx", "f4"), + ("ny", "f4"), + ("nz", "f4"), + ("red", "u1"), + ("green", "u1"), + ("blue", "u1"), + ] + + normals = np.zeros_like(xyz) + + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb), axis=1) + elements[:] = list(map(tuple, attributes)) + + # Create the PlyData object and write to file + vertex_element = PlyElement.describe(elements, "vertex") + ply_data = PlyData([vertex_element]) + ply_data.write(path) + + +def readColmapSceneInfo(path, images, eval, llffhold=8): + try: + cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin") + cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin") + cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file) + cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file) + except: + cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt") + cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt") + cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file) + cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file) + + reading_dir = "images" if images == None else images + cam_infos_unsorted = readColmapCameras( + cam_extrinsics=cam_extrinsics, + cam_intrinsics=cam_intrinsics, + images_folder=os.path.join(path, reading_dir), + ) + cam_infos = sorted(cam_infos_unsorted.copy(), key=lambda x: x.image_name) + + if eval: + train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0] + test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0] + else: + train_cam_infos = cam_infos + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + ply_path = os.path.join(path, "sparse/0/points3D.ply") + bin_path = os.path.join(path, "sparse/0/points3D.bin") + txt_path = os.path.join(path, "sparse/0/points3D.txt") + if not os.path.exists(ply_path): + print( + "Converting point3d.bin to .ply, will happen only the first time you open the scene." + ) + try: + xyz, rgb, _ = read_points3D_binary(bin_path) + except: + xyz, rgb, _ = read_points3D_text(txt_path) + storePly(ply_path, xyz, rgb) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = SceneInfo( + point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path, + ) + return scene_info + + +def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"): + cam_infos = [] + + with open(os.path.join(path, transformsfile)) as json_file: + contents = json.load(json_file) + fovx = contents["camera_angle_x"] + + frames = contents["frames"] + for idx, frame in enumerate(frames): + cam_name = os.path.join(path, frame["file_path"] + extension) + + # NeRF 'transform_matrix' is a camera-to-world transform + c2w = np.array(frame["transform_matrix"]) + # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward) + c2w[:3, 1:3] *= -1 + + # get the world-to-camera transform and set R, T + w2c = np.linalg.inv(c2w) + R = np.transpose( + w2c[:3, :3] + ) # R is stored transposed due to 'glm' in CUDA code + T = w2c[:3, 3] + + image_path = os.path.join(path, cam_name) + image_name = Path(cam_name).stem + image = Image.open(image_path) + + im_data = np.array(image.convert("RGBA")) + + bg = np.array([1, 1, 1]) if white_background else np.array([0, 0, 0]) + + norm_data = im_data / 255.0 + if norm_data.shape[-1] != 3: + arr = norm_data[:, :, :3] * norm_data[:, :, 3:4] + bg * ( + 1 - norm_data[:, :, 3:4] + ) + image = Image.fromarray(np.array(arr * 255.0, dtype=np.byte), "RGB") + + fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1]) + FovY = fovy + FovX = fovx + + cam_infos.append( + CameraInfo( + uid=idx, + R=R, + T=T, + FovY=FovY, + FovX=FovX, + image=image, + image_path=image_path, + image_name=image_name, + width=image.size[0], + height=image.size[1], + ) + ) + + return cam_infos + + +def uniform_surface_sampling_from_vertices_and_faces( + vertices, faces, num_points: int +) -> torch.Tensor: + """ + Uniformly sample points from the surface of a mesh. + + Args: + vertices (torch.Tensor): Vertices of the mesh. + faces (torch.Tensor): Faces of the mesh. + num_points (int): Number of points to sample. + + Returns: + torch.Tensor: Points sampled from the surface of the mesh. + """ + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) + n = num_points + points = [] + while n > 0: + p, _ = trimesh.sample.sample_surface_even(mesh, n) + n -= p.shape[0] + if n >= 0: + points.append(p) + else: + points.append(p[:n]) + + if len(points) > 1: + points = np.concatenate(points, axis=0) + else: + points = points[0] + + points = torch.from_numpy(points.astype(np.float32)) + + return points, torch.rand_like(points) + + +def occ_from_sparse_initialize(poses, images, cameras, grid_reso, num_points): + # fov is in degrees + this_session = rembg.new_session() + + imgs = [rembg.remove(im, session=this_session) for im in images] + + reso = grid_reso + occ_grid = torch.ones((reso, reso, reso), dtype=torch.bool, device="cuda") + + c2ws = poses + center = c2ws[..., :3, 3].mean(axis=0) + radius = np.linalg.norm(c2ws[..., :3, 3] - center, axis=-1).mean() + xx, yy, zz = torch.meshgrid( + torch.linspace(-radius, radius, reso, device="cuda"), + torch.linspace(-radius, radius, reso, device="cuda"), + torch.linspace(-radius, radius, reso, device="cuda"), + indexing="ij", + ) + print("radius", radius) + + # xyz_grid = torch.stack((xx.flatten(), yy.flatten(), zz.flatten()), dim=-1) + ww = torch.ones((reso, reso, reso), dtype=torch.float32, device="cuda") + xyzw_grid = torch.stack((xx, yy, zz, ww), dim=-1) + xyzw_grid[..., :3] += torch.from_numpy(center).cuda() + + c2ws = torch.tensor(c2ws, dtype=torch.float32) + + for c2w, camera, img in zip(c2ws, cameras, imgs): + img = np.asarray(img) + alpha = img[..., 3].astype(np.float32) / 255.0 + is_foreground = alpha > 0.05 + is_foreground = torch.from_numpy(is_foreground).cuda() + + full_proj_mtx = Camera( + colmap_id=camera.uid, + R=camera.R, + T=camera.T, + FoVx=camera.FovX, + FoVy=camera.FovY, + image=torch.randn(3, 10, 10), + gt_alpha_mask=None, + image_name="no", + uid=0, + data_device="cuda", + ).full_proj_transform + # check the scale + + ij = xyzw_grid @ full_proj_mtx + ij = (ij + 1) / 2.0 + h, w = img.shape[:2] + ij = ij[..., :2] * torch.tensor([w, h], dtype=torch.float32, device="cuda") + ij = ( + ij.clamp( + min=torch.tensor([0.0, 0.0], device="cuda"), + max=torch.tensor([w - 1, h - 1], dtype=torch.float32, device="cuda"), + ) + .to(torch.long) + .cuda() + ) + + occ_grid = torch.logical_and(occ_grid, is_foreground[ij[..., 1], ij[..., 0]]) + + # To mesh + occ_grid = occ_grid.to(torch.float32).cpu().numpy() + vertices, triangles = mcubes.marching_cubes(occ_grid, 0.5) + + # vertices = (vertices / reso - 0.5) * radius * 2 + center + # vertices = (vertices / (reso - 1.0) - 0.5) * radius * 2 * 2 + center + vertices = vertices / (grid_reso - 1) * 2 - 1 + vertices = vertices * radius + center + # mcubes.export_obj(vertices, triangles, "./tmp/occ_voxel.obj") + + xyz, rgb = uniform_surface_sampling_from_vertices_and_faces( + vertices, triangles, num_points + ) + + return xyz + + +def readNerfSyntheticInfo(path, white_background, eval, extension=".png"): + print("Reading Training Transforms") + train_cam_infos = readCamerasFromTransforms( + path, "transforms_train.json", white_background, extension + ) + print("Reading Test Transforms") + test_cam_infos = readCamerasFromTransforms( + path, "transforms_test.json", white_background, extension + ) + + if not eval: + train_cam_infos.extend(test_cam_infos) + test_cam_infos = [] + + nerf_normalization = getNerfppNorm(train_cam_infos) + + ply_path = os.path.join(path, "points3d.ply") + if not os.path.exists(ply_path): + # Since this data set has no colmap data, we start with random points + num_pts = 100_000 + print(f"Generating random point cloud ({num_pts})...") + + # We create random points inside the bounds of the synthetic Blender scenes + xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 + shs = np.random.random((num_pts, 3)) / 255.0 + pcd = BasicPointCloud( + points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)) + ) + + storePly(ply_path, xyz, SH2RGB(shs) * 255) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = SceneInfo( + point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=test_cam_infos, + nerf_normalization=nerf_normalization, + ply_path=ply_path, + ) + return scene_info + + +def constructVideoNVSInfo( + num_frames, + radius, + elevation, + fov, + reso, + images, + masks, + num_pts=100_000, + train=True, +): + poses = get_uniform_poses(num_frames, radius, elevation) + w2cs = np.linalg.inv(poses) + train_cam_infos = [] + + for idx, pose in enumerate(w2cs): + train_cam_infos.append( + CameraInfo( + uid=idx, + R=np.transpose(pose[:3, :3]), + T=pose[:3, 3], + FovY=np.deg2rad(fov), + FovX=np.deg2rad(fov), + image=images[idx], + image_path=None, + image_name=idx, + width=reso, + height=reso, + ) + ) + + nerf_normalization = getNerfppNorm(train_cam_infos) + # xyz = np.random.random((num_pts, 3)) * radius / 3 - radius / 3 + xyz = np.random.randn(num_pts, 3) * radius / 16 + # if len(poses) <= 24: + # xyz = occ_from_sparse_initialize(poses, images, train_cam_infos, 256, num_pts) + # num_pts = xyz.shape[0] + # else: + # xyz = np.random.randn(num_pts, 3) * radius / 16 + xyz = np.random.randn(num_pts, 3) * radius / 16 + # shs = np.random.random((num_pts, 3)) / 255.0 + shs = np.ones((num_pts, 3)) * 0.2 + pcd = BasicPointCloud( + points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)) + ) + + ply_path = "./tmp/points3d.ply" + storePly(ply_path, xyz, SH2RGB(shs) * 255) + pcd = fetchPly(ply_path) + + scene_info = SceneInfo( + point_cloud=pcd, + train_cameras=train_cam_infos, + test_cameras=[], + nerf_normalization=nerf_normalization, + ply_path="./tmp/points3d.ply", + ) + + return scene_info + + +sceneLoadTypeCallbacks = { + "Colmap": readColmapSceneInfo, + "Blender": readNerfSyntheticInfo, + "VideoNVS": constructVideoNVSInfo, +} diff --git a/recon/scene/gaussian_model.py b/recon/scene/gaussian_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ef893f2183e28a1d6a7cfa4455c446fbaed3ff76 --- /dev/null +++ b/recon/scene/gaussian_model.py @@ -0,0 +1,570 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import numpy as np +from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation +from torch import nn +import os +from utils.system_utils import mkdir_p +from plyfile import PlyData, PlyElement +from utils.sh_utils import RGB2SH +from simple_knn._C import distCUDA2 +from utils.graphics_utils import BasicPointCloud +from utils.general_utils import strip_symmetric, build_scaling_rotation + + +class GaussianModel: + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + self.scaling_activation = torch.exp + self.scaling_inverse_activation = torch.log + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + def __init__(self, sh_degree: int): + self.active_sh_degree = 0 + self.max_sh_degree = sh_degree + self._xyz = torch.empty(0) + self._features_dc = torch.empty(0) + self._features_rest = torch.empty(0) + self._scaling = torch.empty(0) + self._rotation = torch.empty(0) + self._opacity = torch.empty(0) + self.max_radii2D = torch.empty(0) + self.xyz_gradient_accum = torch.empty(0) + self.denom = torch.empty(0) + self.optimizer = None + self.percent_dense = 0 + self.spatial_lr_scale = 0 + self.setup_functions() + + def capture(self): + return ( + self.active_sh_degree, + self._xyz, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + self.xyz_gradient_accum, + self.denom, + self.optimizer.state_dict(), + self.spatial_lr_scale, + ) + + def restore(self, model_args, training_args): + ( + self.active_sh_degree, + self._xyz, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + xyz_gradient_accum, + denom, + opt_dict, + self.spatial_lr_scale, + ) = model_args + self.training_setup(training_args) + self.xyz_gradient_accum = xyz_gradient_accum + self.denom = denom + self.optimizer.load_state_dict(opt_dict) + + @property + def get_scaling(self): + return self.scaling_activation(self._scaling) + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation) + + @property + def get_xyz(self): + return self._xyz + + @property + def get_features(self): + features_dc = self._features_dc + features_rest = self._features_rest + return torch.cat((features_dc, features_rest), dim=1) + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity) + + def get_covariance(self, scaling_modifier=1): + return self.covariance_activation( + self.get_scaling, scaling_modifier, self._rotation + ) + + def oneupSHdegree(self): + if self.active_sh_degree < self.max_sh_degree: + self.active_sh_degree += 1 + + def create_from_pcd(self, pcd: BasicPointCloud, spatial_lr_scale: float): + self.spatial_lr_scale = spatial_lr_scale + fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() + fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) + features = ( + torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)) + .float() + .cuda() + ) + features[:, :3, 0] = fused_color + features[:, 3:, 1:] = 0.0 + + print("Number of points at initialisation : ", fused_point_cloud.shape[0]) + + dist2 = torch.clamp_min( + distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), + 0.0000001, + ) + scales = torch.log(torch.sqrt(dist2))[..., None].repeat(1, 3) + rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") + rots[:, 0] = 1 + + opacities = inverse_sigmoid( + 0.5 + * torch.ones( + (fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda" + ) + ) + + self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + self._features_dc = nn.Parameter( + features[:, :, 0:1].transpose(1, 2).contiguous().requires_grad_(True) + ) + self._features_rest = nn.Parameter( + features[:, :, 1:].transpose(1, 2).contiguous().requires_grad_(True) + ) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) + self._opacity = nn.Parameter(opacities.requires_grad_(True)) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def training_setup(self, training_args): + self.percent_dense = training_args.percent_dense + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + + l = [ + { + "params": [self._xyz], + "lr": training_args.position_lr_init * self.spatial_lr_scale, + "name": "xyz", + }, + { + "params": [self._features_dc], + "lr": training_args.feature_lr, + "name": "f_dc", + }, + { + "params": [self._features_rest], + "lr": training_args.feature_lr / 20.0, + "name": "f_rest", + }, + { + "params": [self._opacity], + "lr": training_args.opacity_lr, + "name": "opacity", + }, + { + "params": [self._scaling], + "lr": training_args.scaling_lr, + "name": "scaling", + }, + { + "params": [self._rotation], + "lr": training_args.rotation_lr, + "name": "rotation", + }, + ] + + self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) + self.xyz_scheduler_args = get_expon_lr_func( + lr_init=training_args.position_lr_init * self.spatial_lr_scale, + lr_final=training_args.position_lr_final * self.spatial_lr_scale, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.position_lr_max_steps, + ) + + def update_learning_rate(self, iteration): + """Learning rate scheduling per step""" + for param_group in self.optimizer.param_groups: + if param_group["name"] == "xyz": + lr = self.xyz_scheduler_args(iteration) + param_group["lr"] = lr + return lr + + def construct_list_of_attributes(self): + l = ["x", "y", "z", "nx", "ny", "nz"] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]): + l.append("f_dc_{}".format(i)) + for i in range(self._features_rest.shape[1] * self._features_rest.shape[2]): + l.append("f_rest_{}".format(i)) + l.append("opacity") + for i in range(self._scaling.shape[1]): + l.append("scale_{}".format(i)) + for i in range(self._rotation.shape[1]): + l.append("rot_{}".format(i)) + return l + + def save_ply(self, path): + mkdir_p(os.path.dirname(path)) + + xyz = self._xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = ( + self._features_dc.detach() + .transpose(1, 2) + .flatten(start_dim=1) + .contiguous() + .cpu() + .numpy() + ) + f_rest = ( + self._features_rest.detach() + .transpose(1, 2) + .flatten(start_dim=1) + .contiguous() + .cpu() + .numpy() + ) + opacities = self._opacity.detach().cpu().numpy() + scale = self._scaling.detach().cpu().numpy() + rotation = self._rotation.detach().cpu().numpy() + + dtype_full = [ + (attribute, "f4") for attribute in self.construct_list_of_attributes() + ] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate( + (xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1 + ) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, "vertex") + PlyData([el]).write(path) + + def reset_opacity(self): + opacities_new = inverse_sigmoid( + torch.min(self.get_opacity, torch.ones_like(self.get_opacity) * 0.01) + ) + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + + def load_ply(self, path): + plydata = PlyData.read(path) + + xyz = np.stack( + ( + np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"]), + ), + axis=1, + ) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + extra_f_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("f_rest_") + ] + extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split("_")[-1])) + assert len(extra_f_names) == 3 * (self.max_sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape( + (features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1) + ) + + scale_names = [ + p.name + for p in plydata.elements[0].properties + if p.name.startswith("scale_") + ] + scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [ + p.name for p in plydata.elements[0].properties if p.name.startswith("rot") + ] + rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + self._xyz = nn.Parameter( + torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True) + ) + self._features_dc = nn.Parameter( + torch.tensor(features_dc, dtype=torch.float, device="cuda") + .transpose(1, 2) + .contiguous() + .requires_grad_(True) + ) + self._features_rest = nn.Parameter( + torch.tensor(features_extra, dtype=torch.float, device="cuda") + .transpose(1, 2) + .contiguous() + .requires_grad_(True) + ) + self._opacity = nn.Parameter( + torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_( + True + ) + ) + self._scaling = nn.Parameter( + torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True) + ) + self._rotation = nn.Parameter( + torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True) + ) + + self.active_sh_degree = self.max_sh_degree + + def replace_tensor_to_optimizer(self, tensor, name): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if group["name"] == name: + stored_state = self.optimizer.state.get(group["params"][0], None) + stored_state["exp_avg"] = torch.zeros_like(tensor) + stored_state["exp_avg_sq"] = torch.zeros_like(tensor) + + del self.optimizer.state[group["params"][0]] + group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) + self.optimizer.state[group["params"][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def _prune_optimizer(self, mask): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + stored_state = self.optimizer.state.get(group["params"][0], None) + if stored_state is not None: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + + del self.optimizer.state[group["params"][0]] + group["params"][0] = nn.Parameter( + (group["params"][0][mask].requires_grad_(True)) + ) + self.optimizer.state[group["params"][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter( + group["params"][0][mask].requires_grad_(True) + ) + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def prune_points(self, mask): + valid_points_mask = ~mask + optimizable_tensors = self._prune_optimizer(valid_points_mask) + + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] + + self.denom = self.denom[valid_points_mask] + self.max_radii2D = self.max_radii2D[valid_points_mask] + + def cat_tensors_to_optimizer(self, tensors_dict): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + assert len(group["params"]) == 1 + extension_tensor = tensors_dict[group["name"]] + stored_state = self.optimizer.state.get(group["params"][0], None) + if stored_state is not None: + stored_state["exp_avg"] = torch.cat( + (stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0 + ) + stored_state["exp_avg_sq"] = torch.cat( + (stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), + dim=0, + ) + + del self.optimizer.state[group["params"][0]] + group["params"][0] = nn.Parameter( + torch.cat( + (group["params"][0], extension_tensor), dim=0 + ).requires_grad_(True) + ) + self.optimizer.state[group["params"][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter( + torch.cat( + (group["params"][0], extension_tensor), dim=0 + ).requires_grad_(True) + ) + optimizable_tensors[group["name"]] = group["params"][0] + + return optimizable_tensors + + def densification_postfix( + self, + new_xyz, + new_features_dc, + new_features_rest, + new_opacities, + new_scaling, + new_rotation, + ): + d = { + "xyz": new_xyz, + "f_dc": new_features_dc, + "f_rest": new_features_rest, + "opacity": new_opacities, + "scaling": new_scaling, + "rotation": new_rotation, + } + + optimizable_tensors = self.cat_tensors_to_optimizer(d) + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): + n_init_points = self.get_xyz.shape[0] + # Extract points that satisfy the gradient condition + padded_grad = torch.zeros((n_init_points), device="cuda") + padded_grad[: grads.shape[0]] = grads.squeeze() + selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and( + selected_pts_mask, + torch.max(self.get_scaling, dim=1).values + > self.percent_dense * scene_extent, + ) + + stds = self.get_scaling[selected_pts_mask].repeat(N, 1) + means = torch.zeros((stds.size(0), 3), device="cuda") + samples = torch.normal(mean=means, std=stds) + rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N, 1, 1) + new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[ + selected_pts_mask + ].repeat(N, 1) + new_scaling = self.scaling_inverse_activation( + self.get_scaling[selected_pts_mask].repeat(N, 1) / (0.8 * N) + ) + new_rotation = self._rotation[selected_pts_mask].repeat(N, 1) + new_features_dc = self._features_dc[selected_pts_mask].repeat(N, 1, 1) + new_features_rest = self._features_rest[selected_pts_mask].repeat(N, 1, 1) + new_opacity = self._opacity[selected_pts_mask].repeat(N, 1) + + self.densification_postfix( + new_xyz, + new_features_dc, + new_features_rest, + new_opacity, + new_scaling, + new_rotation, + ) + + prune_filter = torch.cat( + ( + selected_pts_mask, + torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool), + ) + ) + self.prune_points(prune_filter) + + def densify_and_clone(self, grads, grad_threshold, scene_extent): + # Extract points that satisfy the gradient condition + selected_pts_mask = torch.where( + torch.norm(grads, dim=-1) >= grad_threshold, True, False + ) + selected_pts_mask = torch.logical_and( + selected_pts_mask, + torch.max(self.get_scaling, dim=1).values + <= self.percent_dense * scene_extent, + ) + + new_xyz = self._xyz[selected_pts_mask] + new_features_dc = self._features_dc[selected_pts_mask] + new_features_rest = self._features_rest[selected_pts_mask] + new_opacities = self._opacity[selected_pts_mask] + new_scaling = self._scaling[selected_pts_mask] + new_rotation = self._rotation[selected_pts_mask] + + self.densification_postfix( + new_xyz, + new_features_dc, + new_features_rest, + new_opacities, + new_scaling, + new_rotation, + ) + + def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): + grads = self.xyz_gradient_accum / self.denom + grads[grads.isnan()] = 0.0 + + self.densify_and_clone(grads, max_grad, extent) + self.densify_and_split(grads, max_grad, extent) + + prune_mask = (self.get_opacity < min_opacity).squeeze() + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or( + torch.logical_or(prune_mask, big_points_vs), big_points_ws + ) + self.prune_points(prune_mask) + + torch.cuda.empty_cache() + + def add_densification_stats(self, viewspace_point_tensor, update_filter): + self.xyz_gradient_accum[update_filter] += torch.norm( + viewspace_point_tensor.grad[update_filter, :2], dim=-1, keepdim=True + ) + self.denom[update_filter] += 1 diff --git a/recon/sparse_pcd.py b/recon/sparse_pcd.py new file mode 100644 index 0000000000000000000000000000000000000000..523a3d5ebd6d758b500d27f83b6405ff3ada76af --- /dev/null +++ b/recon/sparse_pcd.py @@ -0,0 +1,141 @@ +import argparse +import subprocess +from pathlib import Path +import os +import numpy as np + +from skimage.io import imread, imsave +from transforms3d.quaternions import mat2quat + +from colmap.database import COLMAPDatabase +from colmap.read_write_model import CAMERA_MODEL_NAMES +import open3d as o3d + +# from ldm.base_utils import read_pickle + +K, _, _, _, POSES = read_pickle(f'meta_info/camera-16.pkl') +H, W, NUM_IMAGES = 256, 256, 16 + +def extract_and_match_sift(colmap_path, database_path, image_dir): + cmd = [ + str(colmap_path), 'feature_extractor', + '--database_path', str(database_path), + '--image_path', str(image_dir), + ] + print(' '.join(cmd)) + subprocess.run(cmd, check=True) + cmd = [ + str(colmap_path), 'exhaustive_matcher', + '--database_path', str(database_path), + ] + print(' '.join(cmd)) + subprocess.run(cmd, check=True) + +def run_triangulation(colmap_path, model_path, in_sparse_model, database_path, image_dir): + print('Running the triangulation...') + model_path.mkdir(exist_ok=True, parents=True) + cmd = [ + str(colmap_path), 'point_triangulator', + '--database_path', str(database_path), + '--image_path', str(image_dir), + '--input_path', str(in_sparse_model), + '--output_path', str(model_path), + '--Mapper.ba_refine_focal_length', '0', + '--Mapper.ba_refine_principal_point', '0', + '--Mapper.ba_refine_extra_params', '0'] + print(' '.join(cmd)) + subprocess.run(cmd, check=True) + +def run_patch_match(colmap_path, sparse_model: Path, image_dir: Path, dense_model: Path): + print('Running patch match...') + assert sparse_model.exists() + dense_model.mkdir(parents=True, exist_ok=True) + cmd = [str(colmap_path), 'image_undistorter', '--input_path', str(sparse_model), '--image_path', str(image_dir), '--output_path', str(dense_model),] + print(' '.join(cmd)) + subprocess.run(cmd, check=True) + cmd = [str(colmap_path), 'patch_match_stereo','--workspace_path', str(dense_model),] + print(' '.join(cmd)) + subprocess.run(cmd, check=True) + +def dump_images(in_image_dir, image_dir): + for index in range(NUM_IMAGES): + img = imread(f'{in_image_dir}/{index:03}.png') + imsave(f'{str(image_dir)}/{index:03}.png', img) + +def build_db_known_poses_fixed(db_path, in_sparse_path): + db = COLMAPDatabase.connect(db_path) + db.create_tables() + + # insert intrinsics + with open(f'{str(in_sparse_path)}/cameras.txt', 'w') as f: + for index in range(NUM_IMAGES): + fx, fy = K[0,0], K[1,1] + cx, cy = K[0,2], K[1,2] + model, width, height, params = CAMERA_MODEL_NAMES['PINHOLE'].model_id, W, H, np.array((fx, fy, cx, cy),np.float32) + db.add_camera(model, width, height, params, prior_focal_length=(fx+fy)/2, camera_id=index+1) + f.write(f'{index+1} PINHOLE {W} {H} {fx:.3f} {fy:.3f} {cx:.3f} {cy:.3f}\n') + + with open(f'{str(in_sparse_path)}/images.txt','w') as f: + for index in range(NUM_IMAGES): + pose = POSES[index] + q = mat2quat(pose[:,:3]) + t = pose[:,3] + img_id = db.add_image(f"{index:03}.png", camera_id=index+1, prior_q=q, prior_t=t) + f.write(f'{img_id} {q[0]:.5f} {q[1]:.5f} {q[2]:.5f} {q[3]:.5f} {t[0]:.5f} {t[1]:.5f} {t[2]:.5f} {index+1} {index:03}.png\n\n') + + db.commit() + db.close() + + with open(f'{in_sparse_path}/points3D.txt','w') as f: + f.write('\n') + + +def patch_match_with_known_poses(in_image_dir, project_dir, colmap_path='colmap'): + Path(project_dir).mkdir(exist_ok=True, parents=True) + if os.path.exists(f'{str(project_dir)}/dense/stereo/depth_maps'): return + + # output poses + db_path = f'{str(project_dir)}/database.db' + image_dir = Path(f'{str(project_dir)}/images') + sparse_dir = Path(f'{str(project_dir)}/sparse') + in_sparse_dir = Path(f'{str(project_dir)}/sparse_in') + dense_dir = Path(f'{str(project_dir)}/dense') + + image_dir.mkdir(exist_ok=True,parents=True) + sparse_dir.mkdir(exist_ok=True,parents=True) + in_sparse_dir.mkdir(exist_ok=True,parents=True) + dense_dir.mkdir(exist_ok=True,parents=True) + + dump_images(in_image_dir, image_dir) + build_db_known_poses_fixed(db_path, in_sparse_dir) + extract_and_match_sift(colmap_path, db_path, image_dir) + run_triangulation(colmap_path,sparse_dir, in_sparse_dir, db_path, image_dir) + run_patch_match(colmap_path, sparse_dir, image_dir, dense_dir) + + # fuse + cmd = [str(colmap_path), 'stereo_fusion', + '--workspace_path', f'{project_dir}/dense', + '--workspace_format', 'COLMAP', + '--input_type', 'geometric', + '--output_path', f'{project_dir}/points.ply',] + print(' '.join(cmd)) + subprocess.run(cmd, check=True) + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--dir',type=str) + parser.add_argument('--project',type=str) + parser.add_argument('--name',type=str) + parser.add_argument('--colmap',type=str, default='colmap') + args = parser.parse_args() + + if not os.path.exists(f'{args.project}/points.ply'): + patch_match_with_known_poses(args.dir, args.project, colmap_path=args.colmap) + + mesh = o3d.io.read_triangle_mesh(f'{args.project}/points.ply',) + vn = len(mesh.vertices) + with open('colmap-results.log', 'a') as f: + f.write(f'{args.name}\t{vn}\n') + +if __name__=="__main__": + main() \ No newline at end of file diff --git a/recon/sync_submodules.sh b/recon/sync_submodules.sh new file mode 100755 index 0000000000000000000000000000000000000000..2e7d433381b376ead37297aab5cde26205b1f7c6 --- /dev/null +++ b/recon/sync_submodules.sh @@ -0,0 +1,14 @@ +#!/bin/sh + +set -e + +git config -f .gitmodules --get-regexp '^submodule\..*\.path$' | + while read path_key path + do + name=$(echo $path_key | sed 's/\submodule\.\(.*\)\.path/\1/') + url_key=$(echo $path_key | sed 's/\.path/.url/') + branch_key=$(echo $path_key | sed 's/\.path/.branch/') + url=$(git config -f .gitmodules --get "$url_key") + branch=$(git config -f .gitmodules --get "$branch_key" || echo "master") + git submodule add -b $branch --name $name $url $path || continue + done diff --git a/recon/train.py b/recon/train.py new file mode 100644 index 0000000000000000000000000000000000000000..2a89ca9785071178a1b8505ebe512a328118f41d --- /dev/null +++ b/recon/train.py @@ -0,0 +1,369 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +from random import randint +from utils.loss_utils import l1_loss, ssim, lpips +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams + +from scripts.sampling.simple_mv_sample import sample_one + +try: + from torch.utils.tensorboard import SummaryWriter + + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + + +def training( + dataset, + opt, + pipe, + testing_iterations, + saving_iterations, + checkpoint_iterations, + checkpoint, + debug_from, +): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing=True) + iter_end = torch.cuda.Event(enable_timing=True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + if network_gui.conn == None: + network_gui.try_connect() + while network_gui.conn != None: + try: + net_image_bytes = None + ( + custom_cam, + do_training, + pipe.convert_SHs_python, + pipe.compute_cov3D_python, + keep_alive, + scaling_modifer, + ) = network_gui.receive() + if custom_cam != None: + net_image = render( + custom_cam, gaussians, pipe, background, scaling_modifer + )["render"] + net_image_bytes = memoryview( + (torch.clamp(net_image, min=0, max=1.0) * 255) + .byte() + .permute(1, 2, 0) + .contiguous() + .cpu() + .numpy() + ) + network_gui.send(net_image_bytes, dataset.source_path) + if do_training and ( + (iteration < int(opt.iterations)) or not keep_alive + ): + break + except Exception as e: + network_gui.conn = None + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + bg = torch.rand((3), device="cuda") if opt.random_background else background + + render_pkg = render(viewpoint_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( + 1.0 - ssim(image, gt_image) + ) + if opt.lambda_lpips > 0: + loss += opt.lambda_lpips * lpips(image, gt_image) + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + iter_start.elapsed_time(iter_end), + testing_iterations, + scene, + render, + (pipe, background), + ) + if iteration in saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max( + gaussians.max_radii2D[visibility_filter], radii[visibility_filter] + ) + gaussians.add_densification_stats( + viewspace_point_tensor, visibility_filter + ) + + if ( + iteration > opt.densify_from_iter + and iteration % opt.densification_interval == 0 + ): + size_threshold = ( + 20 if iteration > opt.opacity_reset_interval else None + ) + gaussians.densify_and_prune( + opt.densify_grad_threshold, + 0.005, + scene.cameras_extent, + size_threshold, + ) + + if iteration % opt.opacity_reset_interval == 0 or ( + dataset.white_background and iteration == opt.densify_from_iter + ): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none=True) + + if iteration in checkpoint_iterations: + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save( + (gaussians.capture(), iteration), + scene.model_path + "/chkpnt" + str(iteration) + ".pth", + ) + + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv("OAR_JOB_ID"): + unique_str = os.getenv("OAR_JOB_ID") + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok=True) + with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + + +def training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + elapsed, + testing_iterations, + scene: Scene, + renderFunc, + renderArgs, +): + if tb_writer: + tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration) + tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration) + tb_writer.add_scalar("iter_time", elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ( + {"name": "test", "cameras": scene.getTestCameras()}, + { + "name": "train", + "cameras": [ + scene.getTrainCameras()[idx % len(scene.getTrainCameras())] + for idx in range(5, 30, 5) + ], + }, + ) + + for config in validation_configs: + if config["cameras"] and len(config["cameras"]) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config["cameras"]): + image = torch.clamp( + renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], + 0.0, + 1.0, + ) + gt_image = torch.clamp( + viewpoint.original_image.to("cuda"), 0.0, 1.0 + ) + if tb_writer and (idx < 5): + tb_writer.add_images( + config["name"] + + "_view_{}/render".format(viewpoint.image_name), + image[None], + global_step=iteration, + ) + if iteration == testing_iterations[0]: + tb_writer.add_images( + config["name"] + + "_view_{}/ground_truth".format(viewpoint.image_name), + gt_image[None], + global_step=iteration, + ) + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config["cameras"]) + l1_test /= len(config["cameras"]) + print( + "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format( + iteration, config["name"], l1_test, psnr_test + ) + ) + if tb_writer: + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration + ) + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration + ) + + if tb_writer: + tb_writer.add_histogram( + "scene/opacity_histogram", scene.gaussians.get_opacity, iteration + ) + tb_writer.add_scalar( + "total_points", scene.gaussians.get_xyz.shape[0], iteration + ) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--image", type=str, default="assets/images/ceramic.png") + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--ip", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=6009) + parser.add_argument("--debug_from", type=int, default=-1) + parser.add_argument("--detect_anomaly", action="store_true", default=False) + parser.add_argument( + "--test_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument( + "--save_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default=None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + + print("=====Start generating MV Images=====") + + images, _ = sample_one(args.image, args.ckpt_path, seed=args.seed) + + print("=====Finish generating MV Images=====") + + lp = lp.extract(args) + lp.images = images + + training( + lp, + op.extract(args), + pp.extract(args), + args.test_iterations, + args.save_iterations, + args.checkpoint_iterations, + args.start_checkpoint, + args.debug_from, + ) + + # All done + print("\nTraining complete.") diff --git a/recon/train_512.py b/recon/train_512.py new file mode 100644 index 0000000000000000000000000000000000000000..d7e70fe62e842d7ac16f937e6ae7a6c878ce5def --- /dev/null +++ b/recon/train_512.py @@ -0,0 +1,381 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +from random import randint +from utils.loss_utils import l1_loss, ssim, lpips +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams + +from scripts.sampling.simple_mv_latent_sample import sample_one + +try: + from torch.utils.tensorboard import SummaryWriter + + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + + +def training( + dataset, + opt, + pipe, + testing_iterations, + saving_iterations, + checkpoint_iterations, + checkpoint, + debug_from, +): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing=True) + iter_end = torch.cuda.Event(enable_timing=True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + if network_gui.conn == None: + network_gui.try_connect() + while network_gui.conn != None: + try: + net_image_bytes = None + ( + custom_cam, + do_training, + pipe.convert_SHs_python, + pipe.compute_cov3D_python, + keep_alive, + scaling_modifer, + ) = network_gui.receive() + if custom_cam != None: + net_image = render( + custom_cam, gaussians, pipe, background, scaling_modifer + )["render"] + net_image_bytes = memoryview( + (torch.clamp(net_image, min=0, max=1.0) * 255) + .byte() + .permute(1, 2, 0) + .contiguous() + .cpu() + .numpy() + ) + network_gui.send(net_image_bytes, dataset.source_path) + if do_training and ( + (iteration < int(opt.iterations)) or not keep_alive + ): + break + except Exception as e: + network_gui.conn = None + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + bg = torch.rand((3), device="cuda") if opt.random_background else background + + render_pkg = render(viewpoint_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( + 1.0 - ssim(image, gt_image) + ) + if opt.lambda_lpips > 0: + loss += opt.lambda_lpips * lpips(image, gt_image) + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + iter_start.elapsed_time(iter_end), + testing_iterations, + scene, + render, + (pipe, background), + ) + if iteration in saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max( + gaussians.max_radii2D[visibility_filter], radii[visibility_filter] + ) + gaussians.add_densification_stats( + viewspace_point_tensor, visibility_filter + ) + + if ( + iteration > opt.densify_from_iter + and iteration % opt.densification_interval == 0 + ): + size_threshold = ( + 20 if iteration > opt.opacity_reset_interval else None + ) + gaussians.densify_and_prune( + opt.densify_grad_threshold, + 0.005, + scene.cameras_extent, + size_threshold, + ) + + if iteration % opt.opacity_reset_interval == 0 or ( + dataset.white_background and iteration == opt.densify_from_iter + ): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none=True) + + if iteration in checkpoint_iterations: + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save( + (gaussians.capture(), iteration), + scene.model_path + "/chkpnt" + str(iteration) + ".pth", + ) + + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv("OAR_JOB_ID"): + unique_str = os.getenv("OAR_JOB_ID") + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok=True) + with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + + +def training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + elapsed, + testing_iterations, + scene: Scene, + renderFunc, + renderArgs, +): + if tb_writer: + tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration) + tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration) + tb_writer.add_scalar("iter_time", elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ( + {"name": "test", "cameras": scene.getTestCameras()}, + { + "name": "train", + "cameras": [ + scene.getTrainCameras()[idx % len(scene.getTrainCameras())] + for idx in range(5, 30, 5) + ], + }, + ) + + for config in validation_configs: + if config["cameras"] and len(config["cameras"]) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config["cameras"]): + image = torch.clamp( + renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], + 0.0, + 1.0, + ) + gt_image = torch.clamp( + viewpoint.original_image.to("cuda"), 0.0, 1.0 + ) + if tb_writer and (idx < 5): + tb_writer.add_images( + config["name"] + + "_view_{}/render".format(viewpoint.image_name), + image[None], + global_step=iteration, + ) + if iteration == testing_iterations[0]: + tb_writer.add_images( + config["name"] + + "_view_{}/ground_truth".format(viewpoint.image_name), + gt_image[None], + global_step=iteration, + ) + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config["cameras"]) + l1_test /= len(config["cameras"]) + print( + "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format( + iteration, config["name"], l1_test, psnr_test + ) + ) + if tb_writer: + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration + ) + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration + ) + + if tb_writer: + tb_writer.add_histogram( + "scene/opacity_histogram", scene.gaussians.get_opacity, iteration + ) + tb_writer.add_scalar( + "total_points", scene.gaussians.get_xyz.shape[0], iteration + ) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--image", type=str, default="assets/images/ceramic.png") + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--ip", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=6009) + parser.add_argument("--debug_from", type=int, default=-1) + parser.add_argument("--detect_anomaly", action="store_true", default=False) + parser.add_argument( + "--test_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument( + "--save_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default=None) + parser.add_argument("--border_ratio", type=float, default=0.3) + parser.add_argument("--min_guidance_scale", type=float, default=1.0) + parser.add_argument("--max_guidance_scale", type=float, default=2.5) + parser.add_argument("--sigma_max", type=float, default=None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + + print("=====Start generating MV Images=====") + + images, _ = sample_one( + args.image, + args.ckpt_path, + seed=args.seed, + border_ratio=args.border_ratio, + min_guidance_scale=args.min_guidance_scale, + max_guidance_scale=args.max_guidance_scale, + sigma_max=args.sigma_max, + ) + + print("=====Finish generating MV Images=====") + + lp = lp.extract(args) + lp.images = images + + training( + lp, + op.extract(args), + pp.extract(args), + args.test_iterations, + args.save_iterations, + args.checkpoint_iterations, + args.start_checkpoint, + args.debug_from, + ) + + # All done + print("\nTraining complete.") diff --git a/recon/train_autoaggressive.py b/recon/train_autoaggressive.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/recon/train_from_vid.py b/recon/train_from_vid.py new file mode 100644 index 0000000000000000000000000000000000000000..88b5fd8cb8144e0d81dafbebd89ebae46b9ff9de --- /dev/null +++ b/recon/train_from_vid.py @@ -0,0 +1,389 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +from random import randint +from PIL import Image +from mediapy import read_video +from utils.loss_utils import l1_loss, ssim, lpips +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams + +from scripts.sampling.simple_mv_latent_sample import sample_one + +try: + from torch.utils.tensorboard import SummaryWriter + + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + + +def training( + dataset, + opt, + pipe, + testing_iterations, + saving_iterations, + checkpoint_iterations, + checkpoint, + debug_from, +): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing=True) + iter_end = torch.cuda.Event(enable_timing=True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + if network_gui.conn == None: + network_gui.try_connect() + while network_gui.conn != None: + try: + net_image_bytes = None + ( + custom_cam, + do_training, + pipe.convert_SHs_python, + pipe.compute_cov3D_python, + keep_alive, + scaling_modifer, + ) = network_gui.receive() + if custom_cam != None: + net_image = render( + custom_cam, gaussians, pipe, background, scaling_modifer + )["render"] + net_image_bytes = memoryview( + (torch.clamp(net_image, min=0, max=1.0) * 255) + .byte() + .permute(1, 2, 0) + .contiguous() + .cpu() + .numpy() + ) + network_gui.send(net_image_bytes, dataset.source_path) + if do_training and ( + (iteration < int(opt.iterations)) or not keep_alive + ): + break + except Exception as e: + network_gui.conn = None + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + bg = torch.rand((3), device="cuda") if opt.random_background else background + + render_pkg = render(viewpoint_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( + 1.0 - ssim(image, gt_image) + ) + if opt.lambda_lpips > 0: + loss += opt.lambda_lpips * lpips(image, gt_image) + + loss += torch.mean(gaussians.get_opacity) * 0.1 + + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + iter_start.elapsed_time(iter_end), + testing_iterations, + scene, + render, + (pipe, background), + ) + if iteration in saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max( + gaussians.max_radii2D[visibility_filter], radii[visibility_filter] + ) + gaussians.add_densification_stats( + viewspace_point_tensor, visibility_filter + ) + + if ( + iteration > opt.densify_from_iter + and iteration % opt.densification_interval == 0 + ): + size_threshold = ( + 20 if iteration > opt.opacity_reset_interval else None + ) + gaussians.densify_and_prune( + opt.densify_grad_threshold, + 0.005, + scene.cameras_extent, + size_threshold, + ) + + if iteration % opt.opacity_reset_interval == 0 or ( + dataset.white_background and iteration == opt.densify_from_iter + ): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none=True) + + if iteration in checkpoint_iterations: + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save( + (gaussians.capture(), iteration), + scene.model_path + "/chkpnt" + str(iteration) + ".pth", + ) + + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv("OAR_JOB_ID"): + unique_str = os.getenv("OAR_JOB_ID") + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok=True) + with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + + +def training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + elapsed, + testing_iterations, + scene: Scene, + renderFunc, + renderArgs, +): + if tb_writer: + tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration) + tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration) + tb_writer.add_scalar("iter_time", elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ( + {"name": "test", "cameras": scene.getTestCameras()}, + { + "name": "train", + "cameras": [ + scene.getTrainCameras()[idx % len(scene.getTrainCameras())] + for idx in range(5, 30, 5) + ], + }, + ) + + for config in validation_configs: + if config["cameras"] and len(config["cameras"]) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config["cameras"]): + image = torch.clamp( + renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], + 0.0, + 1.0, + ) + gt_image = torch.clamp( + viewpoint.original_image.to("cuda"), 0.0, 1.0 + ) + if tb_writer and (idx < 5): + tb_writer.add_images( + config["name"] + + "_view_{}/render".format(viewpoint.image_name), + image[None], + global_step=iteration, + ) + if iteration == testing_iterations[0]: + tb_writer.add_images( + config["name"] + + "_view_{}/ground_truth".format(viewpoint.image_name), + gt_image[None], + global_step=iteration, + ) + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config["cameras"]) + l1_test /= len(config["cameras"]) + print( + "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format( + iteration, config["name"], l1_test, psnr_test + ) + ) + if tb_writer: + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration + ) + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration + ) + + if tb_writer: + tb_writer.add_histogram( + "scene/opacity_histogram", scene.gaussians.get_opacity, iteration + ) + tb_writer.add_scalar( + "total_points", scene.gaussians.get_xyz.shape[0], iteration + ) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--video", type=str, default="") + parser.add_argument("--ip", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=6009) + parser.add_argument("--debug_from", type=int, default=-1) + parser.add_argument("--detect_anomaly", action="store_true", default=False) + parser.add_argument( + "--test_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument( + "--save_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default=None) + parser.add_argument("--border_ratio", type=float, default=0.3) + parser.add_argument("--min_guidance_scale", type=float, default=1.0) + parser.add_argument("--max_guidance_scale", type=float, default=2.5) + parser.add_argument("--sigma_max", type=float, default=None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + + print("=====Start generating MV Images=====") + + # images, _ = sample_one( + # args.image, + # args.ckpt_path, + # seed=args.seed, + # border_ratio=args.border_ratio, + # min_guidance_scale=args.min_guidance_scale, + # max_guidance_scale=args.max_guidance_scale, + # sigma_max=args.sigma_max, + # ) + images = [] + frames = read_video(args.video) + for frame in frames: + images.append(Image.fromarray(frame)) + + print("=====Finish generating MV Images=====") + + lp = lp.extract(args) + lp.images = images + + training( + lp, + op.extract(args), + pp.extract(args), + args.test_iterations, + args.save_iterations, + args.checkpoint_iterations, + args.start_checkpoint, + args.debug_from, + ) + + # All done + print("\nTraining complete.") diff --git a/recon/train_iterative.py b/recon/train_iterative.py new file mode 100644 index 0000000000000000000000000000000000000000..91d2a79d7307367f864718fbded50863a6cc7d11 --- /dev/null +++ b/recon/train_iterative.py @@ -0,0 +1,400 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +import numpy as np +from torchvision.transforms.functional import pil_to_tensor, to_tensor +from torchvision.utils import make_grid, save_image +from random import randint +from utils.loss_utils import l1_loss, ssim, lpips +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams + +from scripts.sampling.simple_mv_sample import sample_one + +try: + from torch.utils.tensorboard import SummaryWriter + + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + + +def training( + dataset, + opt, + pipe, + testing_iterations, + saving_iterations, + checkpoint_iterations, + checkpoint, + debug_from, + resample_period=500, + resample_sigma=0.1, + resample_start=1000, + model=None, +): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians, shuffle=False) + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing=True) + iter_end = torch.cuda.Event(enable_timing=True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + with torch.no_grad(): + if iteration % resample_period == 0 and iteration > resample_start: + # if iteration % resample_period: + views = [] + viewpoint_stack = scene.getTrainCameras().copy() + for view_cam in viewpoint_stack: + bg = ( + torch.rand((3), device="cuda") + if opt.random_background + else background + ) + render_pkg = render(view_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + views.append(image) + views = torch.stack(views) + save_image(views, f"tmp/views_{iteration}.png") + views = views * 2.0 - 1.0 + views = model.encode_first_stage(views) + noisy_views = views + torch.randn_like(views) * resample_sigma + noisy_views = ( + np.sqrt(1 - resample_sigma**2) * views + + torch.randn_like(views) * resample_sigma + ) + resampled_images = sample_one( + args.image, + args.ckpt_path, + noise=noisy_views, + cached_model=model, + )[0] + dataset.images = resampled_images + scene = Scene( + dataset, + gaussians, + shuffle=False, + skip_gaussians=True, + ) + resampled_images_grid = [] + for img in resampled_images: + resampled_images_grid.append(to_tensor(img)) + resampled_images_grid = torch.stack(resampled_images_grid) + resampled_images_grid = make_grid(resampled_images_grid, nrow=3) + save_image( + resampled_images_grid, f"tmp/resampled_images_{iteration}.png" + ) + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + bg = torch.rand((3), device="cuda") if opt.random_background else background + + render_pkg = render(viewpoint_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( + 1.0 - ssim(image, gt_image) + ) + if opt.lambda_lpips > 0: + loss += opt.lambda_lpips * lpips(image, gt_image) + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + iter_start.elapsed_time(iter_end), + testing_iterations, + scene, + render, + (pipe, background), + ) + if iteration in saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max( + gaussians.max_radii2D[visibility_filter], radii[visibility_filter] + ) + gaussians.add_densification_stats( + viewspace_point_tensor, visibility_filter + ) + + if ( + iteration > opt.densify_from_iter + and iteration % opt.densification_interval == 0 + ): + size_threshold = ( + 20 if iteration > opt.opacity_reset_interval else None + ) + gaussians.densify_and_prune( + opt.densify_grad_threshold, + 0.005, + scene.cameras_extent, + size_threshold, + ) + + if iteration % opt.opacity_reset_interval == 0 or ( + dataset.white_background and iteration == opt.densify_from_iter + ): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none=True) + + if iteration in checkpoint_iterations: + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save( + (gaussians.capture(), iteration), + scene.model_path + "/chkpnt" + str(iteration) + ".pth", + ) + + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv("OAR_JOB_ID"): + unique_str = os.getenv("OAR_JOB_ID") + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok=True) + with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + + +def training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + elapsed, + testing_iterations, + scene: Scene, + renderFunc, + renderArgs, +): + if tb_writer: + tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration) + tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration) + tb_writer.add_scalar("iter_time", elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ( + {"name": "test", "cameras": scene.getTestCameras()}, + { + "name": "train", + "cameras": [ + scene.getTrainCameras()[idx % len(scene.getTrainCameras())] + for idx in range(5, 30, 5) + ], + }, + ) + + for config in validation_configs: + if config["cameras"] and len(config["cameras"]) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config["cameras"]): + image = torch.clamp( + renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], + 0.0, + 1.0, + ) + gt_image = torch.clamp( + viewpoint.original_image.to("cuda"), 0.0, 1.0 + ) + if tb_writer and (idx < 5): + tb_writer.add_images( + config["name"] + + "_view_{}/render".format(viewpoint.image_name), + image[None], + global_step=iteration, + ) + if iteration == testing_iterations[0]: + tb_writer.add_images( + config["name"] + + "_view_{}/ground_truth".format(viewpoint.image_name), + gt_image[None], + global_step=iteration, + ) + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config["cameras"]) + l1_test /= len(config["cameras"]) + print( + "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format( + iteration, config["name"], l1_test, psnr_test + ) + ) + if tb_writer: + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration + ) + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration + ) + + if tb_writer: + tb_writer.add_histogram( + "scene/opacity_histogram", scene.gaussians.get_opacity, iteration + ) + tb_writer.add_scalar( + "total_points", scene.gaussians.get_xyz.shape[0], iteration + ) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--image", type=str, default="assets/images/ceramic.png") + parser.add_argument("--ckpt_path", type=str, required=True) + parser.add_argument("--ip", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=6009) + parser.add_argument("--debug_from", type=int, default=-1) + parser.add_argument("--detect_anomaly", action="store_true", default=False) + parser.add_argument( + "--test_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument( + "--save_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default=None) + parser.add_argument("--resample_period", type=int, default=500) + parser.add_argument("--resample_sigma", type=float, default=0.1) + parser.add_argument("--resample_start", type=int, default=500) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + + print("=====Start generating MV Images=====") + + images, model = sample_one(args.image, args.ckpt_path, seed=args.seed) + + print("=====Finish generating MV Images=====") + + lp = lp.extract(args) + lp.images = images + + training( + lp, + op.extract(args), + pp.extract(args), + args.test_iterations, + args.save_iterations, + args.checkpoint_iterations, + args.start_checkpoint, + args.debug_from, + args.resample_period, + args.resample_sigma, + args.resample_start, + model, + ) + + # All done + print("\nTraining complete.") diff --git a/recon/train_scene.py b/recon/train_scene.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e373b9b9589a17e3bc188b29baeb4db4ab6fd2 --- /dev/null +++ b/recon/train_scene.py @@ -0,0 +1,352 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import torch +from random import randint +from utils.loss_utils import l1_loss, ssim +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams + +try: + from torch.utils.tensorboard import SummaryWriter + + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + + +def training( + dataset, + opt, + pipe, + testing_iterations, + saving_iterations, + checkpoint_iterations, + checkpoint, + debug_from, +): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gaussians) + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + + iter_start = torch.cuda.Event(enable_timing=True) + iter_end = torch.cuda.Event(enable_timing=True) + + viewpoint_stack = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + for iteration in range(first_iter, opt.iterations + 1): + if network_gui.conn == None: + network_gui.try_connect() + while network_gui.conn != None: + try: + net_image_bytes = None + ( + custom_cam, + do_training, + pipe.convert_SHs_python, + pipe.compute_cov3D_python, + keep_alive, + scaling_modifer, + ) = network_gui.receive() + if custom_cam != None: + net_image = render( + custom_cam, gaussians, pipe, background, scaling_modifer + )["render"] + net_image_bytes = memoryview( + (torch.clamp(net_image, min=0, max=1.0) * 255) + .byte() + .permute(1, 2, 0) + .contiguous() + .cpu() + .numpy() + ) + network_gui.send(net_image_bytes, dataset.source_path) + if do_training and ( + (iteration < int(opt.iterations)) or not keep_alive + ): + break + except Exception as e: + network_gui.conn = None + + iter_start.record() + + gaussians.update_learning_rate(iteration) + + # Every 1000 its we increase the levels of SH up to a maximum degree + if iteration % 1000 == 0: + gaussians.oneupSHdegree() + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack) - 1)) + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + + bg = torch.rand((3), device="cuda") if opt.random_background else background + + render_pkg = render(viewpoint_cam, gaussians, pipe, bg) + image, viewspace_point_tensor, visibility_filter, radii = ( + render_pkg["render"], + render_pkg["viewspace_points"], + render_pkg["visibility_filter"], + render_pkg["radii"], + ) + + # Loss + gt_image = viewpoint_cam.original_image.cuda() + Ll1 = l1_loss(image, gt_image) + loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * ( + 1.0 - ssim(image, gt_image) + ) + loss.backward() + + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + iter_start.elapsed_time(iter_end), + testing_iterations, + scene, + render, + (pipe, background), + ) + if iteration in saving_iterations: + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max( + gaussians.max_radii2D[visibility_filter], radii[visibility_filter] + ) + gaussians.add_densification_stats( + viewspace_point_tensor, visibility_filter + ) + + if ( + iteration > opt.densify_from_iter + and iteration % opt.densification_interval == 0 + ): + size_threshold = ( + 20 if iteration > opt.opacity_reset_interval else None + ) + gaussians.densify_and_prune( + opt.densify_grad_threshold, + 0.005, + scene.cameras_extent, + size_threshold, + ) + + if iteration % opt.opacity_reset_interval == 0 or ( + dataset.white_background and iteration == opt.densify_from_iter + ): + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none=True) + + if iteration in checkpoint_iterations: + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save( + (gaussians.capture(), iteration), + scene.model_path + "/chkpnt" + str(iteration) + ".pth", + ) + + +def prepare_output_and_logger(args): + if not args.model_path: + if os.getenv("OAR_JOB_ID"): + unique_str = os.getenv("OAR_JOB_ID") + else: + unique_str = str(uuid.uuid4()) + args.model_path = os.path.join("./output/", unique_str[0:10]) + + # Set up output folder + print("Output folder: {}".format(args.model_path)) + os.makedirs(args.model_path, exist_ok=True) + with open(os.path.join(args.model_path, "cfg_args"), "w") as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args.model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + + +def training_report( + tb_writer, + iteration, + Ll1, + loss, + l1_loss, + elapsed, + testing_iterations, + scene: Scene, + renderFunc, + renderArgs, +): + if tb_writer: + tb_writer.add_scalar("train_loss_patches/l1_loss", Ll1.item(), iteration) + tb_writer.add_scalar("train_loss_patches/total_loss", loss.item(), iteration) + tb_writer.add_scalar("iter_time", elapsed, iteration) + + # Report test and samples of training set + if iteration in testing_iterations: + torch.cuda.empty_cache() + validation_configs = ( + {"name": "test", "cameras": scene.getTestCameras()}, + { + "name": "train", + "cameras": [ + scene.getTrainCameras()[idx % len(scene.getTrainCameras())] + for idx in range(5, 30, 5) + ], + }, + ) + + for config in validation_configs: + if config["cameras"] and len(config["cameras"]) > 0: + l1_test = 0.0 + psnr_test = 0.0 + for idx, viewpoint in enumerate(config["cameras"]): + image = torch.clamp( + renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], + 0.0, + 1.0, + ) + gt_image = torch.clamp( + viewpoint.original_image.to("cuda"), 0.0, 1.0 + ) + if tb_writer and (idx < 5): + tb_writer.add_images( + config["name"] + + "_view_{}/render".format(viewpoint.image_name), + image[None], + global_step=iteration, + ) + if iteration == testing_iterations[0]: + tb_writer.add_images( + config["name"] + + "_view_{}/ground_truth".format(viewpoint.image_name), + gt_image[None], + global_step=iteration, + ) + l1_test += l1_loss(image, gt_image).mean().double() + psnr_test += psnr(image, gt_image).mean().double() + psnr_test /= len(config["cameras"]) + l1_test /= len(config["cameras"]) + print( + "\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format( + iteration, config["name"], l1_test, psnr_test + ) + ) + if tb_writer: + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - l1_loss", l1_test, iteration + ) + tb_writer.add_scalar( + config["name"] + "/loss_viewpoint - psnr", psnr_test, iteration + ) + + if tb_writer: + tb_writer.add_histogram( + "scene/opacity_histogram", scene.gaussians.get_opacity, iteration + ) + tb_writer.add_scalar( + "total_points", scene.gaussians.get_xyz.shape[0], iteration + ) + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + parser.add_argument("--ip", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=6009) + parser.add_argument("--debug_from", type=int, default=-1) + parser.add_argument("--detect_anomaly", action="store_true", default=False) + parser.add_argument( + "--test_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument( + "--save_iterations", nargs="+", type=int, default=[7_000, 30_000] + ) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default=None) + args = parser.parse_args(sys.argv[1:]) + args.save_iterations.append(args.iterations) + + print("Optimizing " + args.model_path) + + # Initialize system state (RNG) + safe_state(args.quiet) + + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + training( + lp.extract(args), + op.extract(args), + pp.extract(args), + args.test_iterations, + args.save_iterations, + args.checkpoint_iterations, + args.start_checkpoint, + args.debug_from, + ) + + # All done + print("\nTraining complete.") diff --git a/recon/utils/camera_utils.py b/recon/utils/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ace2474dd723975515438ae6f5d8a64e0c819317 --- /dev/null +++ b/recon/utils/camera_utils.py @@ -0,0 +1,151 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from pathlib import Path +from mediapy import read_video, write_video +from scene.cameras import Camera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal + +WARNED = False + + +def loadCam(args, id, cam_info, resolution_scale): + orig_w, orig_h = cam_info.image.size + + if args.resolution in [1, 2, 4, 8]: + resolution = round(orig_w / (resolution_scale * args.resolution)), round( + orig_h / (resolution_scale * args.resolution) + ) + else: # should be a type that converts to float + if args.resolution == -1: + if orig_w > 1600: + global WARNED + if not WARNED: + print( + "[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " + "If this is not desired, please explicitly specify '--resolution/-r' as 1" + ) + WARNED = True + global_down = orig_w / 1600 + else: + global_down = 1 + else: + global_down = orig_w / args.resolution + + scale = float(global_down) * float(resolution_scale) + resolution = (int(orig_w / scale), int(orig_h / scale)) + + resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + gt_image = resized_image_rgb[:3, ...] + loaded_mask = None + + if resized_image_rgb.shape[1] == 4: + loaded_mask = resized_image_rgb[3:4, ...] + + return Camera( + colmap_id=cam_info.uid, + R=cam_info.R, + T=cam_info.T, + FoVx=cam_info.FovX, + FoVy=cam_info.FovY, + image=gt_image, + gt_alpha_mask=loaded_mask, + image_name=cam_info.image_name, + uid=id, + data_device=args.data_device, + ) + + +def cameraList_from_camInfos(cam_infos, resolution_scale, args): + camera_list = [] + + for id, c in enumerate(cam_infos): + camera_list.append(loadCam(args, id, c, resolution_scale)) + + return camera_list + + +def camera_to_JSON(id, camera: Camera): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = camera.R.transpose() + Rt[:3, 3] = camera.T + Rt[3, 3] = 1.0 + + W2C = np.linalg.inv(Rt) + pos = W2C[:3, 3] + rot = W2C[:3, :3] + serializable_array_2d = [x.tolist() for x in rot] + camera_entry = { + "id": id, + "img_name": camera.image_name, + "width": camera.width, + "height": camera.height, + "position": pos.tolist(), + "rotation": serializable_array_2d, + "fy": fov2focal(camera.FovY, camera.height), + "fx": fov2focal(camera.FovX, camera.width), + } + return camera_entry + + +def get_c2w_from_up_and_look_at( + up, + look_at, + pos, + opengl=False, +): + up = up / np.linalg.norm(up) + z = look_at - pos + z = z / np.linalg.norm(z) + y = -up + x = np.cross(y, z) + x /= np.linalg.norm(x) + y = np.cross(z, x) + + c2w = np.zeros([4, 4], dtype=np.float32) + c2w[:3, 0] = x + c2w[:3, 1] = y + c2w[:3, 2] = z + c2w[:3, 3] = pos + c2w[3, 3] = 1.0 + + # opencv to opengl + if opengl: + c2w[..., 1:3] *= -1 + + return c2w + + +def get_uniform_poses(num_frames, radius, elevation, opengl=False): + T = num_frames + azimuths = np.deg2rad(np.linspace(0, 360, T + 1)[:T]) + elevations = np.full_like(azimuths, np.deg2rad(elevation)) + cam_dists = np.full_like(azimuths, radius) + + campos = np.stack( + [ + cam_dists * np.cos(elevations) * np.cos(azimuths), + cam_dists * np.cos(elevations) * np.sin(azimuths), + cam_dists * np.sin(elevations), + ], + axis=-1, + ) + + center = np.array([0, 0, 0], dtype=np.float32) + up = np.array([0, 0, 1], dtype=np.float32) + poses = [] + for t in range(T): + poses.append(get_c2w_from_up_and_look_at(up, center, campos[t], opengl=opengl)) + + return np.stack(poses, axis=0) diff --git a/recon/utils/colormaps.py b/recon/utils/colormaps.py new file mode 100644 index 0000000000000000000000000000000000000000..3ee85b4ff33b5aeb84e6779befb3601e167d744c --- /dev/null +++ b/recon/utils/colormaps.py @@ -0,0 +1,220 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Helper functions for visualizing outputs """ + +from dataclasses import dataclass + +# from utils.typing import * +from typing import * + +import matplotlib +import torch +from jaxtyping import Bool, Float +from torch import Tensor + +from utils import colors + +Colormaps = Literal[ + "default", "turbo", "viridis", "magma", "inferno", "cividis", "gray", "pca" +] + + +@dataclass(frozen=True) +class ColormapOptions: + """Options for colormap""" + + colormap: Colormaps = "default" + """ The colormap to use """ + normalize: bool = False + """ Whether to normalize the input tensor image """ + colormap_min: float = 0 + """ Minimum value for the output colormap """ + colormap_max: float = 1 + """ Maximum value for the output colormap """ + invert: bool = False + """ Whether to invert the output colormap """ + + +def apply_colormap( + image: Float[Tensor, "*bs channels"], + colormap_options: ColormapOptions = ColormapOptions(), + eps: float = 1e-9, +) -> Float[Tensor, "*bs rgb"]: + """ + Applies a colormap to a tensor image. + If single channel, applies a colormap to the image. + If 3 channel, treats the channels as RGB. + If more than 3 channel, applies a PCA reduction on the dimensions to 3 channels + + Args: + image: Input tensor image. + eps: Epsilon value for numerical stability. + + Returns: + Tensor with the colormap applied. + """ + + # default for rgb images + if image.shape[-1] == 3: + return image + + # rendering depth outputs + if image.shape[-1] == 1 and torch.is_floating_point(image): + output = image + if colormap_options.normalize: + output = output - torch.min(output) + output = output / (torch.max(output) + eps) + output = ( + output * (colormap_options.colormap_max - colormap_options.colormap_min) + + colormap_options.colormap_min + ) + output = torch.clip(output, 0, 1) + if colormap_options.invert: + output = 1 - output + return apply_float_colormap(output, colormap=colormap_options.colormap) + + # rendering boolean outputs + if image.dtype == torch.bool: + return apply_boolean_colormap(image) + + if image.shape[-1] > 3: + return apply_pca_colormap(image) + + raise NotImplementedError + + +def apply_float_colormap( + image: Float[Tensor, "*bs 1"], colormap: Colormaps = "viridis" +) -> Float[Tensor, "*bs rgb"]: + """Convert single channel to a color image. + + Args: + image: Single channel image. + colormap: Colormap for image. + + Returns: + Tensor: Colored image with colors in [0, 1] + """ + if colormap == "default": + colormap = "turbo" + + image = torch.nan_to_num(image, 0) + if colormap == "gray": + return image.repeat(1, 1, 3) + image = image.clamp(0, 1) + image_long = (image * 255).long() + image_long_min = torch.min(image_long) + image_long_max = torch.max(image_long) + assert image_long_min >= 0, f"the min value is {image_long_min}" + assert image_long_max <= 255, f"the max value is {image_long_max}" + return torch.tensor(matplotlib.colormaps[colormap].colors, device=image.device)[ + image_long[..., 0] + ] + + +def apply_depth_colormap( + depth: Float[Tensor, "*bs 1"], + accumulation: Optional[Float[Tensor, "*bs 1"]] = None, + near_plane: Optional[float] = None, + far_plane: Optional[float] = None, + colormap_options: ColormapOptions = ColormapOptions(), +) -> Float[Tensor, "*bs rgb"]: + """Converts a depth image to color for easier analysis. + + Args: + depth: Depth image. + accumulation: Ray accumulation used for masking vis. + near_plane: Closest depth to consider. If None, use min image value. + far_plane: Furthest depth to consider. If None, use max image value. + colormap: Colormap to apply. + + Returns: + Colored depth image with colors in [0, 1] + """ + + near_plane = near_plane or float(torch.min(depth)) + far_plane = far_plane or float(torch.max(depth)) + + depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) + depth = torch.clip(depth, 0, 1) + # depth = torch.nan_to_num(depth, nan=0.0) # TODO(ethan): remove this + + colored_image = apply_colormap(depth, colormap_options=colormap_options) + + if accumulation is not None: + colored_image = colored_image * accumulation + (1 - accumulation) + + return colored_image + + +def apply_boolean_colormap( + image: Bool[Tensor, "*bs 1"], + true_color: Float[Tensor, "*bs rgb"] = colors.WHITE, + false_color: Float[Tensor, "*bs rgb"] = colors.BLACK, +) -> Float[Tensor, "*bs rgb"]: + """Converts a depth image to color for easier analysis. + + Args: + image: Boolean image. + true_color: Color to use for True. + false_color: Color to use for False. + + Returns: + Colored boolean image + """ + + colored_image = torch.ones(image.shape[:-1] + (3,)) + colored_image[image[..., 0], :] = true_color + colored_image[~image[..., 0], :] = false_color + return colored_image + + +def apply_pca_colormap(image: Float[Tensor, "*bs dim"]) -> Float[Tensor, "*bs rgb"]: + """Convert feature image to 3-channel RGB via PCA. The first three principle + components are used for the color channels, with outlier rejection per-channel + + Args: + image: image of arbitrary vectors + + Returns: + Tensor: Colored image + """ + original_shape = image.shape + image = image.view(-1, image.shape[-1]) + _, _, v = torch.pca_lowrank(image) + image = torch.matmul(image, v[..., :3]) + d = torch.abs(image - torch.median(image, dim=0).values) + mdev = torch.median(d, dim=0).values + s = d / mdev + m = 3.0 # this is a hyperparam controlling how many std dev outside for outliers + rins = image[s[:, 0] < m, 0] + gins = image[s[:, 1] < m, 1] + bins = image[s[:, 2] < m, 2] + + image[:, 0] -= rins.min() + image[:, 1] -= gins.min() + image[:, 2] -= bins.min() + + image[:, 0] /= rins.max() - rins.min() + image[:, 1] /= gins.max() - gins.min() + image[:, 2] /= bins.max() - bins.min() + + image = torch.clamp(image, 0, 1) + image_long = (image * 255).long() + image_long_min = torch.min(image_long) + image_long_max = torch.max(image_long) + assert image_long_min >= 0, f"the min value is {image_long_min}" + assert image_long_max <= 255, f"the max value is {image_long_max}" + return image.view(*original_shape[:-1], 3) diff --git a/recon/utils/colors.py b/recon/utils/colors.py new file mode 100644 index 0000000000000000000000000000000000000000..66ac8d24357d0c6f5c0db9f560f13dff459a3c83 --- /dev/null +++ b/recon/utils/colors.py @@ -0,0 +1,55 @@ +# Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common Colors""" +from typing import Union + +import torch +from jaxtyping import Float +from torch import Tensor + +WHITE = torch.tensor([1.0, 1.0, 1.0]) +BLACK = torch.tensor([0.0, 0.0, 0.0]) +RED = torch.tensor([1.0, 0.0, 0.0]) +GREEN = torch.tensor([0.0, 1.0, 0.0]) +BLUE = torch.tensor([0.0, 0.0, 1.0]) + +COLORS_DICT = { + "white": WHITE, + "black": BLACK, + "red": RED, + "green": GREEN, + "blue": BLUE, +} + + +def get_color(color: Union[str, list]) -> Float[Tensor, "3"]: + """ + Args: + Color as a string or a rgb list + + Returns: + Parsed color + """ + if isinstance(color, str): + color = color.lower() + if color not in COLORS_DICT: + raise ValueError(f"{color} is not a valid preset color") + return COLORS_DICT[color] + if isinstance(color, list): + if len(color) != 3: + raise ValueError(f"Color should be 3 values (RGB) instead got {color}") + return torch.tensor(color) + + raise ValueError(f"Color should be an RGB list or string, instead got {type(color)}") diff --git a/recon/utils/diffusion.py b/recon/utils/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..0217efb13b1342dcb4b7f22b8528c53c58a11627 --- /dev/null +++ b/recon/utils/diffusion.py @@ -0,0 +1,42 @@ +import torch +from PIL import Image +from pathlib import Path +from omegaconf import OmegaConf + +from scripts.demo.streamlit_helpers import ( + load_model_from_config, + get_sampler, + get_batch, + do_sample, +) + + +def load_config_and_model(ckpt: Path): + if (ckpt.parent.parent / "configs").exists(): + config_path = list((ckpt.parent.parent / "configs").glob("*-project.yaml"))[0] + else: + config_path = list( + (ckpt.parent.parent.parent / "configs").glob("*-project.yaml") + )[0] + + config = OmegaConf.load(config_path) + + model, msg = load_model_from_config(config, ckpt) + + return config, model + + +def load_sampler(sampler_cfg): + return get_sampler(**sampler_cfg) + + +def load_batch(): + pass + + +class DiffusionEngine: + def __init__(self, cfg) -> None: + self.cfg = cfg + + def sample(self): + pass diff --git a/recon/utils/general_utils.py b/recon/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..541c0825229a2d86e84460b765879f86f724a59d --- /dev/null +++ b/recon/utils/general_utils.py @@ -0,0 +1,133 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.set_device(torch.device("cuda:0")) diff --git a/recon/utils/graphics_utils.py b/recon/utils/graphics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b4627d837c74fcdffc898fa0c3071cb7b316802b --- /dev/null +++ b/recon/utils/graphics_utils.py @@ -0,0 +1,77 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +import numpy as np +from typing import NamedTuple + +class BasicPointCloud(NamedTuple): + points : np.array + colors : np.array + normals : np.array + +def geom_transform_points(points, transf_matrix): + P, _ = points.shape + ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) + points_hom = torch.cat([points, ones], dim=1) + points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) + + denom = points_out[..., 3:] + 0.0000001 + return (points_out[..., :3] / denom).squeeze(dim=0) + +def getWorld2View(R, t): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + return np.float32(Rt) + +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + +def fov2focal(fov, pixels): + return pixels / (2 * math.tan(fov / 2)) + +def focal2fov(focal, pixels): + return 2*math.atan(pixels/(2*focal)) \ No newline at end of file diff --git a/recon/utils/image_utils.py b/recon/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeaa1b6d250e549181ab165070f82ccd31b3eb9 --- /dev/null +++ b/recon/utils/image_utils.py @@ -0,0 +1,19 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch + +def mse(img1, img2): + return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + +def psnr(img1, img2): + mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + return 20 * torch.log10(1.0 / torch.sqrt(mse)) diff --git a/recon/utils/loss_utils.py b/recon/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d1824708789f207d58f86c0e8350bc70e4b4037a --- /dev/null +++ b/recon/utils/loss_utils.py @@ -0,0 +1,96 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp +from lpipsPyTorch import lpips as lpips_fn +from lpipsPyTorch.modules.lpips import LPIPS + +_lpips = None + + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + + +def gaussian(window_size, sigma): + gauss = torch.Tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable( + _2D_window.expand(channel, 1, window_size, window_size).contiguous() + ) + return window + + +def ssim(img1, img2, window_size=11, size_average=True): + channel = img1.size(-3) + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + ) + sigma2_sq = ( + F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + ) + sigma12 = ( + F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) + - mu1_mu2 + ) + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + + +def lpips(img1, img2): + global _lpips + if _lpips is None: + _lpips = LPIPS("vgg", "0.1").to("cuda") + return _lpips(img1, img2).mean() diff --git a/recon/utils/sh_utils.py b/recon/utils/sh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785 --- /dev/null +++ b/recon/utils/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/recon/utils/system_utils.py b/recon/utils/system_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90ca6d7f77610c967affe313398777cd86920e8e --- /dev/null +++ b/recon/utils/system_utils.py @@ -0,0 +1,28 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from errno import EEXIST +from os import makedirs, path +import os + +def mkdir_p(folder_path): + # Creates a directory. equivalent to using mkdir -p on the command line + try: + makedirs(folder_path) + except OSError as exc: # Python >2.5 + if exc.errno == EEXIST and path.isdir(folder_path): + pass + else: + raise + +def searchForMaxIteration(folder): + saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] + return max(saved_iters) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..16898baf6dfc3f0f4ad7b3b63accac8b1834921a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,43 @@ +black==23.7.0 +chardet==5.1.0 +clip @ git+https://github.com/openai/CLIP.git +einops>=0.6.1 +fairscale>=0.4.13 +fire>=0.5.0 +fsspec>=2023.6.0 +invisible-watermark>=0.2.0 +kornia==0.6.9 +matplotlib>=3.7.2 +natsort>=8.4.0 +ninja>=1.11.1 +numpy>=1.24.4 +omegaconf>=2.3.0 +open-clip-torch>=2.20.0 +opencv-python==4.6.0.66 +pandas>=2.0.3 +pillow>=9.5.0 +pudb>=2022.1.3 +pytorch-lightning==2.0.1 +pyyaml>=6.0.1 +scipy>=1.10.1 +streamlit>=0.73.1 +tensorboardx==2.6 +timm>=0.9.2 +tokenizers==0.12.1 +torch>=2.0.1 +torchaudio>=2.0.2 +torchdata==0.6.1 +torchmetrics>=1.0.1 +torchvision>=0.15.2 +tqdm>=4.65.0 +transformers==4.19.1 +triton==2.0.0 +urllib3<1.27,>=1.25.4 +wandb>=0.15.6 +webdataset>=0.2.33 +wheel>=0.41.0 +xformers>=0.0.20 +streamlit-keyup==0.2.0 +mediapy +tyro +wget diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/pub/V3D_512.py b/scripts/pub/V3D_512.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1fd579348d477d395958e13f7e7002bb9be1f2 --- /dev/null +++ b/scripts/pub/V3D_512.py @@ -0,0 +1,317 @@ +import math +import os +from glob import glob +from pathlib import Path +from typing import Optional + +import cv2 +import numpy as np +import torch +from einops import rearrange, repeat +from fire import Fire +import tyro +from omegaconf import OmegaConf +from PIL import Image +from torchvision.transforms import ToTensor +from mediapy import write_video +import rembg +from kiui.op import recenter +from safetensors.torch import load_file as load_safetensors +from typing import Any + +from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering +from sgm.inference.helpers import embed_watermark +from sgm.util import default, instantiate_from_config + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list(set([x.input_key for x in conditioner.embedders])) + + +def get_batch(keys, value_dict, N, T, device): + batch = {} + batch_uc = {} + + for key in keys: + if key == "fps_id": + batch[key] = ( + torch.tensor([value_dict["fps_id"]]) + .to(device) + .repeat(int(math.prod(N))) + ) + elif key == "motion_bucket_id": + batch[key] = ( + torch.tensor([value_dict["motion_bucket_id"]]) + .to(device) + .repeat(int(math.prod(N))) + ) + elif key == "cond_aug": + batch[key] = repeat( + torch.tensor([value_dict["cond_aug"]]).to(device), + "1 -> b", + b=math.prod(N), + ) + elif key == "cond_frames": + batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0]) + elif key == "cond_frames_without_noise": + batch[key] = repeat( + value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0] + ) + else: + batch[key] = value_dict[key] + + if T is not None: + batch["num_video_frames"] = T + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def load_model( + config: str, + device: str, + num_frames: int, + num_steps: int, + ckpt_path: Optional[str] = None, + min_cfg: Optional[float] = None, + max_cfg: Optional[float] = None, + sigma_max: Optional[float] = None, +): + config = OmegaConf.load(config) + + config.model.params.sampler_config.params.num_steps = num_steps + config.model.params.sampler_config.params.guider_config.params.num_frames = ( + num_frames + ) + if max_cfg is not None: + config.model.params.sampler_config.params.guider_config.params.max_scale = ( + max_cfg + ) + if min_cfg is not None: + config.model.params.sampler_config.params.guider_config.params.min_scale = ( + min_cfg + ) + if sigma_max is not None: + print("Overriding sigma_max to ", sigma_max) + config.model.params.sampler_config.params.discretization_config.params.sigma_max = ( + sigma_max + ) + + config.model.params.from_scratch = False + + if ckpt_path is not None: + config.model.params.ckpt_path = str(ckpt_path) + if device == "cuda": + with torch.device(device): + model = instantiate_from_config(config.model).to(device).eval() + else: + model = instantiate_from_config(config.model).to(device).eval() + + return model, None + + +def sample_one( + input_path: str = "assets/test_image.png", # Can either be image file or folder with image files + checkpoint_path: Optional[str] = None, + num_frames: Optional[int] = None, + num_steps: Optional[int] = None, + fps_id: int = 1, + motion_bucket_id: int = 300, + cond_aug: float = 0.02, + seed: int = 23, + decoding_t: int = 24, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary. + device: str = "cuda", + output_folder: Optional[str] = None, + noise: torch.Tensor = None, + save: bool = False, + cached_model: Any = None, + border_ratio: float = 0.3, + min_guidance_scale: float = 3.5, + max_guidance_scale: float = 3.5, + sigma_max: float = None, + ignore_alpha: bool = False, +): + model_config = "scripts/pub/configs/V3D_512.yaml" + num_frames = OmegaConf.load( + model_config + ).model.params.sampler_config.params.guider_config.params.num_frames + print("Detected num_frames:", num_frames) + num_steps = default(num_steps, 25) + output_folder = default(output_folder, f"outputs/V3D_512") + decoding_t = min(decoding_t, num_frames) + + sd = load_safetensors("./ckpts/svd_xt.safetensors") + clip_model_config = OmegaConf.load("configs/embedder/clip_image.yaml") + clip_model = instantiate_from_config(clip_model_config).eval() + clip_sd = dict() + for k, v in sd.items(): + if "conditioner.embedders.0" in k: + clip_sd[k.replace("conditioner.embedders.0.", "")] = v + clip_model.load_state_dict(clip_sd) + clip_model = clip_model.to(device) + + ae_model_config = OmegaConf.load("configs/ae/video.yaml") + ae_model = instantiate_from_config(ae_model_config).eval() + encoder_sd = dict() + for k, v in sd.items(): + if "first_stage_model" in k: + encoder_sd[k.replace("first_stage_model.", "")] = v + ae_model.load_state_dict(encoder_sd) + ae_model = ae_model.to(device) + + if cached_model is None: + model, filter = load_model( + model_config, + device, + num_frames, + num_steps, + ckpt_path=checkpoint_path, + min_cfg=min_guidance_scale, + max_cfg=max_guidance_scale, + sigma_max=sigma_max, + ) + else: + model = cached_model + torch.manual_seed(seed) + + need_return = True + path = Path(input_path) + if path.is_file(): + if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]): + all_img_paths = [input_path] + else: + raise ValueError("Path is not valid image file.") + elif path.is_dir(): + all_img_paths = sorted( + [ + f + for f in path.iterdir() + if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"] + ] + ) + need_return = False + if len(all_img_paths) == 0: + raise ValueError("Folder does not contain any images.") + else: + raise ValueError + + for input_path in all_img_paths: + with Image.open(input_path) as image: + # if image.mode == "RGBA": + # image = image.convert("RGB") + w, h = image.size + + if border_ratio > 0: + if image.mode != "RGBA" or ignore_alpha: + image = image.convert("RGB") + image = np.asarray(image) + carved_image = rembg.remove(image) # [H, W, 4] + else: + image = np.asarray(image) + carved_image = image + mask = carved_image[..., -1] > 0 + image = recenter(carved_image, mask, border_ratio=border_ratio) + image = image.astype(np.float32) / 255.0 + if image.shape[-1] == 4: + image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) + image = Image.fromarray((image * 255).astype(np.uint8)) + else: + print("Ignore border ratio") + image = image.resize((512, 512)) + + image = ToTensor()(image) + image = image * 2.0 - 1.0 + + image = image.unsqueeze(0).to(device) + H, W = image.shape[2:] + assert image.shape[1] == 3 + F = 8 + C = 4 + shape = (num_frames, C, H // F, W // F) + + value_dict = {} + value_dict["motion_bucket_id"] = motion_bucket_id + value_dict["fps_id"] = fps_id + value_dict["cond_aug"] = cond_aug + value_dict["cond_frames_without_noise"] = clip_model(image) + value_dict["cond_frames"] = ae_model.encode(image) + value_dict["cond_frames"] += cond_aug * torch.randn_like( + value_dict["cond_frames"] + ) + value_dict["cond_aug"] = cond_aug + + with torch.no_grad(): + with torch.autocast(device): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [1, num_frames], + T=num_frames, + device=device, + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=[ + "cond_frames", + "cond_frames_without_noise", + ], + ) + + for k in ["crossattn", "concat"]: + uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames) + uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames) + c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames) + c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames) + + randn = torch.randn(shape, device=device) if noise is None else noise + randn = randn.to(device) + + additional_model_inputs = {} + additional_model_inputs["image_only_indicator"] = torch.zeros( + 2, num_frames + ).to(device) + additional_model_inputs["num_video_frames"] = batch["num_video_frames"] + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = model.sampler(denoiser, randn, cond=c, uc=uc) + model.en_and_decode_n_samples_a_time = decoding_t + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + os.makedirs(output_folder, exist_ok=True) + base_count = len(glob(os.path.join(output_folder, "*.mp4"))) + video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") + # writer = cv2.VideoWriter( + # video_path, + # cv2.VideoWriter_fourcc(*"MP4V"), + # fps_id + 1, + # (samples.shape[-1], samples.shape[-2]), + # ) + + frames = ( + (rearrange(samples, "t c h w -> t h w c") * 255) + .cpu() + .numpy() + .astype(np.uint8) + ) + + if save: + write_video(video_path, frames, fps=3) + + images = [] + for frame in frames: + images.append(Image.fromarray(frame)) + + if need_return: + return images, model + + +if __name__ == "__main__": + tyro.cli(sample_one) diff --git a/scripts/pub/configs/V3D_512.yaml b/scripts/pub/configs/V3D_512.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aee4108e741a50a75336d277e72a72d9b1df8ade --- /dev/null +++ b/scripts/pub/configs/V3D_512.yaml @@ -0,0 +1,161 @@ +model: + base_learning_rate: 1.0e-04 + target: sgm.models.video_diffusion.DiffusionEngine + params: + ckpt_path: ckpts/V3D_512.ckpt + scale_factor: 0.18215 + disable_first_stage_autocast: true + input_key: latents + log_keys: [] + scheduler_config: + target: sgm.lr_scheduler.LambdaLinearScheduler + params: + warm_up_steps: + - 1 + cycle_lengths: + - 10000000000000 + f_start: + - 1.0e-06 + f_max: + - 1.0 + f_min: + - 1.0 + denoiser_config: + target: sgm.modules.diffusionmodules.denoiser.Denoiser + params: + scaling_config: + target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise + network_config: + target: sgm.modules.diffusionmodules.video_model.VideoUNet + params: + adm_in_channels: 768 + num_classes: sequential + use_checkpoint: true + in_channels: 8 + out_channels: 4 + model_channels: 320 + attention_resolutions: + - 4 + - 2 + - 1 + num_res_blocks: 2 + channel_mult: + - 1 + - 2 + - 4 + - 4 + num_head_channels: 64 + use_linear_in_transformer: true + transformer_depth: 1 + context_dim: 1024 + spatial_transformer_attn_type: softmax-xformers + extra_ff_mix_layer: true + use_spatial_context: true + merge_strategy: learned_with_images + video_kernel_size: + - 3 + - 1 + - 1 + conditioner_config: + target: sgm.modules.GeneralConditioner + params: + emb_models: + - is_trainable: false + ucg_rate: 0.2 + input_key: cond_frames_without_noise + target: sgm.modules.encoders.modules.IdentityEncoder + - input_key: fps_id + is_trainable: true + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + - input_key: motion_bucket_id + is_trainable: true + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + - input_key: cond_frames + is_trainable: false + ucg_rate: 0.2 + target: sgm.modules.encoders.modules.IdentityEncoder + - input_key: cond_aug + is_trainable: true + target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND + params: + outdim: 256 + first_stage_config: + target: sgm.models.autoencoder.AutoencodingEngine + params: + loss_config: + target: torch.nn.Identity + regularizer_config: + target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer + encoder_config: + target: sgm.modules.diffusionmodules.model.Encoder + params: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + decoder_config: + target: sgm.modules.autoencoding.temporal_ae.VideoDecoder + params: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + video_kernel_size: + - 3 + - 1 + - 1 + sampler_config: + target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler + params: + num_steps: 30 + discretization_config: + target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization + params: + sigma_max: 700.0 + guider_config: + target: sgm.modules.diffusionmodules.guiders.LinearPredictionGuider + params: + max_scale: 3.5 + min_scale: 3.5 + num_frames: 18 + loss_fn_config: + target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss + params: + batch2model_keys: + - num_video_frames + - image_only_indicator + loss_weighting_config: + target: sgm.modules.diffusionmodules.loss_weighting.EDMWeighting + params: + sigma_data: 1.0 + sigma_sampler_config: + target: sgm.modules.diffusionmodules.sigma_sampling.EDMSampling + params: + p_mean: 1.5 + p_std: 2.0 \ No newline at end of file diff --git a/scripts/tests/attention.py b/scripts/tests/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d7c3f7c8da27c577a7ce0ea3a01ab7f9e9c1baa2 --- /dev/null +++ b/scripts/tests/attention.py @@ -0,0 +1,319 @@ +import einops +import torch +import torch.nn.functional as F +import torch.utils.benchmark as benchmark +from torch.backends.cuda import SDPBackend + +from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer + + +def benchmark_attn(): + # Lets define a helpful benchmarking function: + # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html + device = "cuda" if torch.cuda.is_available() else "cpu" + + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + return t0.blocked_autorange().mean * 1e6 + + # Lets define the hyper-parameters of our input + batch_size = 32 + max_sequence_len = 1024 + num_heads = 32 + embed_dimension = 32 + + dtype = torch.float16 + + query = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + key = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + value = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + + print(f"q/k/v shape:", query.shape, key.shape, value.shape) + + # Lets explore the speed of each of the 3 implementations + from torch.backends.cuda import SDPBackend, sdp_kernel + + # Helpful arguments mapper + backend_map = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + } + + from torch.profiler import ProfilerActivity, profile, record_function + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + print( + f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("Default detailed stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + print( + f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + with sdp_kernel(**backend_map[SDPBackend.MATH]): + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("Math implmentation stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): + try: + print( + f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + except RuntimeError: + print("FlashAttention is not supported. See warnings for reasons.") + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("FlashAttention stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): + try: + print( + f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + except RuntimeError: + print("EfficientAttention is not supported. See warnings for reasons.") + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("EfficientAttention stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def run_model(model, x, context): + return model(x, context) + + +def benchmark_transformer_blocks(): + device = "cuda" if torch.cuda.is_available() else "cpu" + import torch.utils.benchmark as benchmark + + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + return t0.blocked_autorange().mean * 1e6 + + checkpoint = True + compile = False + + batch_size = 32 + h, w = 64, 64 + context_len = 77 + embed_dimension = 1024 + context_dim = 1024 + d_head = 64 + + transformer_depth = 4 + + n_heads = embed_dimension // d_head + + dtype = torch.float16 + + model_native = SpatialTransformer( + embed_dimension, + n_heads, + d_head, + context_dim=context_dim, + use_linear=True, + use_checkpoint=checkpoint, + attn_type="softmax", + depth=transformer_depth, + sdp_backend=SDPBackend.FLASH_ATTENTION, + ).to(device) + model_efficient_attn = SpatialTransformer( + embed_dimension, + n_heads, + d_head, + context_dim=context_dim, + use_linear=True, + depth=transformer_depth, + use_checkpoint=checkpoint, + attn_type="softmax-xformers", + ).to(device) + if not checkpoint and compile: + print("compiling models") + model_native = torch.compile(model_native) + model_efficient_attn = torch.compile(model_efficient_attn) + + x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) + c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) + + from torch.profiler import ProfilerActivity, profile, record_function + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + with torch.autocast("cuda"): + print( + f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" + ) + print( + f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" + ) + + print(75 * "+") + print("NATIVE") + print(75 * "+") + torch.cuda.reset_peak_memory_stats() + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("NativeAttention stats"): + for _ in range(25): + model_native(x, c) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") + + print(75 * "+") + print("Xformers") + print(75 * "+") + torch.cuda.reset_peak_memory_stats() + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("xformers stats"): + for _ in range(25): + model_efficient_attn(x, c) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") + + +def test01(): + # conv1x1 vs linear + from sgm.util import count_params + + conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() + print(count_params(conv)) + linear = torch.nn.Linear(3, 32).cuda() + print(count_params(linear)) + + print(conv.weight.shape) + + # use same initialization + linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) + linear.bias = torch.nn.Parameter(conv.bias) + + print(linear.weight.shape) + + x = torch.randn(11, 3, 64, 64).cuda() + + xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() + print(xr.shape) + out_linear = linear(xr) + print(out_linear.mean(), out_linear.shape) + + out_conv = conv(x) + print(out_conv.mean(), out_conv.shape) + print("done with test01.\n") + + +def test02(): + # try cosine flash attention + import time + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + print("testing cosine flash attention...") + DIM = 1024 + SEQLEN = 4096 + BS = 16 + + print(" softmax (vanilla) first...") + model = BasicTransformerBlock( + dim=DIM, + n_heads=16, + d_head=64, + dropout=0.0, + context_dim=None, + attn_mode="softmax", + ).cuda() + try: + x = torch.randn(BS, SEQLEN, DIM).cuda() + tic = time.time() + y = model(x) + toc = time.time() + print(y.shape, toc - tic) + except RuntimeError as e: + # likely oom + print(str(e)) + + print("\n now flash-cosine...") + model = BasicTransformerBlock( + dim=DIM, + n_heads=16, + d_head=64, + dropout=0.0, + context_dim=None, + attn_mode="flash-cosine", + ).cuda() + x = torch.randn(BS, SEQLEN, DIM).cuda() + tic = time.time() + y = model(x) + toc = time.time() + print(y.shape, toc - tic) + print("done with test02.\n") + + +if __name__ == "__main__": + # test01() + # test02() + # test03() + + # benchmark_attn() + benchmark_transformer_blocks() + + print("done.") diff --git a/scripts/util/__init__.py b/scripts/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/util/detection/__init__.py b/scripts/util/detection/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/util/detection/nsfw_and_watermark_dectection.py b/scripts/util/detection/nsfw_and_watermark_dectection.py new file mode 100644 index 0000000000000000000000000000000000000000..1096b8177d8e3dbcf8e913f924e98d5ce58cb120 --- /dev/null +++ b/scripts/util/detection/nsfw_and_watermark_dectection.py @@ -0,0 +1,110 @@ +import os + +import clip +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image + +RESOURCES_ROOT = "scripts/util/detection/" + + +def predict_proba(X, weights, biases): + logits = X @ weights.T + biases + proba = np.where( + logits >= 0, 1 / (1 + np.exp(-logits)), np.exp(logits) / (1 + np.exp(logits)) + ) + return proba.T + + +def load_model_weights(path: str): + model_weights = np.load(path) + return model_weights["weights"], model_weights["biases"] + + +def clip_process_images(images: torch.Tensor) -> torch.Tensor: + min_size = min(images.shape[-2:]) + return T.Compose( + [ + T.CenterCrop(min_size), # TODO: this might affect the watermark, check this + T.Resize(224, interpolation=T.InterpolationMode.BICUBIC, antialias=True), + T.Normalize( + (0.48145466, 0.4578275, 0.40821073), + (0.26862954, 0.26130258, 0.27577711), + ), + ] + )(images) + + +class DeepFloydDataFiltering(object): + def __init__( + self, verbose: bool = False, device: torch.device = torch.device("cpu") + ): + super().__init__() + self.verbose = verbose + self._device = None + self.clip_model, _ = clip.load("ViT-L/14", device=device) + self.clip_model.eval() + + self.cpu_w_weights, self.cpu_w_biases = load_model_weights( + os.path.join(RESOURCES_ROOT, "w_head_v1.npz") + ) + self.cpu_p_weights, self.cpu_p_biases = load_model_weights( + os.path.join(RESOURCES_ROOT, "p_head_v1.npz") + ) + self.w_threshold, self.p_threshold = 0.5, 0.5 + + @torch.inference_mode() + def __call__(self, images: torch.Tensor) -> torch.Tensor: + imgs = clip_process_images(images) + if self._device is None: + self._device = next(p for p in self.clip_model.parameters()).device + image_features = self.clip_model.encode_image(imgs.to(self._device)) + image_features = image_features.detach().cpu().numpy().astype(np.float16) + p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases) + w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases) + print(f"p_pred = {p_pred}, w_pred = {w_pred}") if self.verbose else None + query = p_pred > self.p_threshold + if query.sum() > 0: + print(f"Hit for p_threshold: {p_pred}") if self.verbose else None + images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) + query = w_pred > self.w_threshold + if query.sum() > 0: + print(f"Hit for w_threshold: {w_pred}") if self.verbose else None + images[query] = T.GaussianBlur(99, sigma=(100.0, 100.0))(images[query]) + return images + + +def load_img(path: str) -> torch.Tensor: + image = Image.open(path) + if not image.mode == "RGB": + image = image.convert("RGB") + image_transforms = T.Compose( + [ + T.ToTensor(), + ] + ) + return image_transforms(image)[None, ...] + + +def test(root): + from einops import rearrange + + filter = DeepFloydDataFiltering(verbose=True) + for p in os.listdir((root)): + print(f"running on {p}...") + img = load_img(os.path.join(root, p)) + filtered_img = filter(img) + filtered_img = rearrange( + 255.0 * (filtered_img.numpy())[0], "c h w -> h w c" + ).astype(np.uint8) + Image.fromarray(filtered_img).save( + os.path.join(root, f"{os.path.splitext(p)[0]}-filtered.jpg") + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(test) + print("done.") diff --git a/sgm/__init__.py b/sgm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24bc84af8b1041de34b9816e0507cb1ac207bd13 --- /dev/null +++ b/sgm/__init__.py @@ -0,0 +1,4 @@ +from .models import AutoencodingEngine, DiffusionEngine +from .util import get_configs_path, instantiate_from_config + +__version__ = "0.1.0" diff --git a/sgm/data/__init__.py b/sgm/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7664a25c655c376bd1a7b0ccbaca7b983a2bf9ad --- /dev/null +++ b/sgm/data/__init__.py @@ -0,0 +1 @@ +from .dataset import StableDataModuleFromConfig diff --git a/sgm/data/cam_utils.py b/sgm/data/cam_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d44b38721dafc771c092887d93726b38e1ec0a6 --- /dev/null +++ b/sgm/data/cam_utils.py @@ -0,0 +1,1253 @@ +''' +Common camera utilities +''' + +import math +import numpy as np +import torch +import torch.nn as nn +from pytorch3d.renderer import PerspectiveCameras +from pytorch3d.renderer.cameras import look_at_view_transform +from pytorch3d.renderer.implicit.raysampling import _xy_to_ray_bundle + +class RelativeCameraLoader(nn.Module): + def __init__(self, + query_batch_size=1, + rand_query=True, + relative=True, + center_at_origin=False, + ): + super().__init__() + + self.query_batch_size = query_batch_size + self.rand_query = rand_query + self.relative = relative + self.center_at_origin = center_at_origin + + def plot_cameras(self, cameras_1, cameras_2): + ''' + Helper function to plot cameras + + Args: + cameras_1 (PyTorch3D camera): cameras object to plot + cameras_2 (PyTorch3D camera): cameras object to plot + ''' + from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene + import plotly.graph_objects as go + plotlyplot = plot_scene( + { + 'scene_batch': { + 'cameras': cameras_1.to('cpu'), + 'rel_cameras': cameras_2.to('cpu'), + } + }, + camera_scale=.5,#0.05, + pointcloud_max_points=10000, + pointcloud_marker_size=1.0, + raybundle_max_rays=100 + ) + plotlyplot.show() + + def concat_cameras(self, camera_list): + ''' + Returns a concatenation of a list of cameras + + Args: + camera_list (List[PyTorch3D camera]): a list of PyTorch3D cameras + ''' + R_list, T_list, f_list, c_list, size_list = [], [], [], [], [] + for cameras in camera_list: + R_list.append(cameras.R) + T_list.append(cameras.T) + f_list.append(cameras.focal_length) + c_list.append(cameras.principal_point) + size_list.append(cameras.image_size) + + camera_slice = PerspectiveCameras( + R = torch.cat(R_list), + T = torch.cat(T_list), + focal_length = torch.cat(f_list), + principal_point = torch.cat(c_list), + image_size = torch.cat(size_list), + device = camera_list[0].device, + ) + return camera_slice + + def get_camera_slice(self, scene_cameras, indices): + ''' + Return a subset of cameras from a super set given indices + + Args: + scene_cameras (PyTorch3D Camera): cameras object + indices (tensor or List): a flat list or tensor of indices + + Returns: + camera_slice (PyTorch3D Camera) - cameras subset + ''' + camera_slice = PerspectiveCameras( + R = scene_cameras.R[indices], + T = scene_cameras.T[indices], + focal_length = scene_cameras.focal_length[indices], + principal_point = scene_cameras.principal_point[indices], + image_size = scene_cameras.image_size[indices], + device = scene_cameras.device, + ) + return camera_slice + + + def get_relative_camera(self, scene_cameras:PerspectiveCameras, query_idx, center_at_origin=False): + """ + Transform context cameras relative to a base query camera + + Args: + scene_cameras (PyTorch3D Camera): cameras object + query_idx (tensor or List): a length 1 list defining query idx + + Returns: + cams_relative (PyTorch3D Camera): cameras object relative to query camera + """ + + query_camera = self.get_camera_slice(scene_cameras, query_idx) + query_world2view = query_camera.get_world_to_view_transform() + all_world2view = scene_cameras.get_world_to_view_transform() + + if center_at_origin: + identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=query_camera.T) + else: + T = torch.zeros((1, 3)) + identity_cam = PerspectiveCameras(device=scene_cameras.device, R=query_camera.R, T=T) + + identity_world2view = identity_cam.get_world_to_view_transform() + + # compose the relative transformation as g_i^{-1} g_j + relative_world2view = identity_world2view.inverse().compose(all_world2view) + + # generate a camera from the relative transform + relative_matrix = relative_world2view.get_matrix() + cams_relative = PerspectiveCameras( + R = relative_matrix[:, :3, :3], + T = relative_matrix[:, 3, :3], + focal_length = scene_cameras.focal_length, + principal_point = scene_cameras.principal_point, + image_size = scene_cameras.image_size, + device = scene_cameras.device, + ) + return cams_relative + + def forward(self, scene_cameras, scene_rgb=None, scene_masks=None, query_idx=None, context_size=3, context_idx=None, return_context=False): + ''' + Return a sampled batch of query and context cameras (used in training) + + Args: + scene_cameras (PyTorch3D Camera): a batch of PyTorch3D cameras + scene_rgb (Tensor): a batch of rgb + scene_masks (Tensor): a batch of masks (optional) + query_idx (List or Tensor): desired query idx (optional) + context_size (int): number of views for context + + Returns: + query_cameras, query_rgb, query_masks: random query view + context_cameras, context_rgb, context_masks: context views + ''' + + if query_idx is None: + query_idx = [0] + if self.rand_query: + rand = torch.randperm(len(scene_cameras)) + query_idx = rand[:1] + + if context_idx is None: + rand = torch.randperm(len(scene_cameras)) + context_idx = rand[:context_size] + + + if self.relative: + rel_cameras = self.get_relative_camera(scene_cameras, query_idx, center_at_origin=self.center_at_origin) + else: + rel_cameras = scene_cameras + + query_cameras = self.get_camera_slice(rel_cameras, query_idx) + query_rgb = None + if scene_rgb is not None: + query_rgb = scene_rgb[query_idx] + query_masks = None + if scene_masks is not None: + query_masks = scene_masks[query_idx] + + context_cameras = self.get_camera_slice(rel_cameras, context_idx) + context_rgb = None + if scene_rgb is not None: + context_rgb = scene_rgb[context_idx] + context_masks = None + if scene_masks is not None: + context_masks = scene_masks[context_idx] + + if return_context: + return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks, context_idx + return query_cameras, query_rgb, query_masks, context_cameras, context_rgb, context_masks + + +def get_interpolated_path(cameras: PerspectiveCameras, n=50, method='circle', theta_offset_max=0.0): + ''' + Given a camera object containing a set of cameras, fit a circle and get + interpolated cameras + + Args: + cameras (PyTorch3D Camera): input camera object + n (int): length of cameras in new path + method (str): 'circle' + theta_offset_max (int): max camera jitter in radians + + Returns: + path_cameras (PyTorch3D Camera): interpolated cameras + ''' + device = cameras.device + cameras = cameras.cpu() + + if method == 'circle': + + #@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ + #@ Fit plane + P = cameras.get_camera_center().cpu() + P_mean = P.mean(axis=0) + P_centered = P - P_mean + U,s,V = torch.linalg.svd(P_centered) + normal = V[2,:] + if (normal*2 - P_mean).norm() < (normal - P_mean).norm(): + normal = - normal + d = -torch.dot(P_mean, normal) # d = - + + #@ Project pts to plane + P_xy = rodrigues_rot(P_centered, normal, torch.tensor([0.0,0.0,1.0])) + + #@ Fit circle in 2D + xc, yc, r = fit_circle_2d(P_xy[:,0], P_xy[:,1]) + t = torch.linspace(0, 2*math.pi, 100) + xx = xc + r*torch.cos(t) + yy = yc + r*torch.sin(t) + + #@ Project circle to 3D + C = rodrigues_rot(torch.tensor([xc,yc,0.0]), torch.tensor([0.0,0.0,1.0]), normal) + P_mean + C = C.flatten() + + #@ Get pts n 3D + t = torch.linspace(0, 2*math.pi, n) + u = P[0] - C + new_camera_centers = generate_circle_by_vectors(t, C, r, normal, u) + + #@ OPTIONAL THETA OFFSET + if theta_offset_max > 0.0: + aug_theta = (torch.rand((new_camera_centers.shape[0])) * (2*theta_offset_max)) - theta_offset_max + new_camera_centers = rodrigues_rot2(new_camera_centers, normal, aug_theta) + + #@ Get camera look at + new_camera_look_at = get_nearest_centroid(cameras) + + #@ Get R T + up_vec = -normal + R, T = look_at_view_transform(eye=new_camera_centers, at=new_camera_look_at.unsqueeze(0), up=up_vec.unsqueeze(0), device=cameras.device) + else: + raise NotImplementedError + + c = (cameras.principal_point).mean(dim=0, keepdim=True).expand(R.shape[0],-1) + f = (cameras.focal_length).mean(dim=0, keepdim=True).expand(R.shape[0],-1) + image_size = cameras.image_size[:1].expand(R.shape[0],-1) + + + path_cameras = PerspectiveCameras(R=R,T=T,focal_length=f,principal_point=c,image_size=image_size, device=device) + cameras = cameras.to(device) + return path_cameras + +def np_normalize(vec, axis=-1): + vec = vec / (np.linalg.norm(vec, axis=axis, keepdims=True) + 1e-9) + return vec + + +#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ +#------------------------------------------------------------------------------- +# Generate points on circle +# P(t) = r*cos(t)*u + r*sin(t)*(n x u) + C +#------------------------------------------------------------------------------- +def generate_circle_by_vectors(t, C, r, n, u): + n = n/torch.linalg.norm(n) + u = u/torch.linalg.norm(u) + P_circle = r*torch.cos(t)[:,None]*u + r*torch.sin(t)[:,None]*torch.cross(n,u) + C + return P_circle + +#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ +#------------------------------------------------------------------------------- +# FIT CIRCLE 2D +# - Find center [xc, yc] and radius r of circle fitting to set of 2D points +# - Optionally specify weights for points +# +# - Implicit circle function: +# (x-xc)^2 + (y-yc)^2 = r^2 +# (2*xc)*x + (2*yc)*y + (r^2-xc^2-yc^2) = x^2+y^2 +# c[0]*x + c[1]*y + c[2] = x^2+y^2 +# +# - Solution by method of least squares: +# A*c = b, c' = argmin(||A*c - b||^2) +# A = [x y 1], b = [x^2+y^2] +#------------------------------------------------------------------------------- +def fit_circle_2d(x, y, w=[]): + + A = torch.stack([x, y, torch.ones(len(x))]).T + b = x**2 + y**2 + + # Modify A,b for weighted least squares + if len(w) == len(x): + W = torch.diag(w) + A = torch.dot(W,A) + b = torch.dot(W,b) + + # Solve by method of least squares + c = torch.linalg.lstsq(A,b,rcond=None)[0] + + # Get circle parameters from solution c + xc = c[0]/2 + yc = c[1]/2 + r = torch.sqrt(c[2] + xc**2 + yc**2) + return xc, yc, r + +#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ +#------------------------------------------------------------------------------- +# RODRIGUES ROTATION +# - Rotate given points based on a starting and ending vector +# - Axis k and angle of rotation theta given by vectors n0,n1 +# P_rot = P*cos(theta) + (k x P)*sin(theta) + k**(1-cos(theta)) +#------------------------------------------------------------------------------- +def rodrigues_rot(P, n0, n1): + + # If P is only 1d array (coords of single point), fix it to be matrix + if P.ndim == 1: + P = P[None,...] + + # Get vector of rotation k and angle theta + n0 = n0/torch.linalg.norm(n0) + n1 = n1/torch.linalg.norm(n1) + k = torch.cross(n0,n1) + k = k/torch.linalg.norm(k) + theta = torch.arccos(torch.dot(n0,n1)) + + # Compute rotated points + P_rot = torch.zeros((len(P),3)) + for i in range(len(P)): + P_rot[i] = P[i]*torch.cos(theta) + torch.cross(k,P[i])*torch.sin(theta) + k*torch.dot(k,P[i])*(1-torch.cos(theta)) + + return P_rot + +def rodrigues_rot2(P, n1, theta): + ''' + Rotate points P wrt axis k by theta radians + ''' + + # If P is only 1d array (coords of single point), fix it to be matrix + if P.ndim == 1: + P = P[None,...] + + k = torch.cross(P, n1.unsqueeze(0)) + k = k/torch.linalg.norm(k) + + # Compute rotated points + P_rot = torch.zeros((len(P),3)) + for i in range(len(P)): + P_rot[i] = P[i]*torch.cos(theta[i]) + torch.cross(k[i],P[i])*torch.sin(theta[i]) + k[i]*torch.dot(k[i],P[i])*(1-torch.cos(theta[i])) + + return P_rot + +#@ https://meshlogic.github.io/posts/jupyter/curve-fitting/fitting-a-circle-to-cluster-of-3d-points/ +#------------------------------------------------------------------------------- +# ANGLE BETWEEN +# - Get angle between vectors u,v with sign based on plane with unit normal n +#------------------------------------------------------------------------------- +def angle_between(u, v, n=None): + if n is None: + return torch.arctan2(torch.linalg.norm(torch.cross(u,v)), torch.dot(u,v)) + else: + return torch.arctan2(torch.dot(n,torch.cross(u,v)), torch.dot(u,v)) + +#@ https://www.crewes.org/Documents/ResearchReports/2010/CRR201032.pdf +def get_nearest_centroid(cameras: PerspectiveCameras): + ''' + Given PyTorch3D cameras, find the nearest point along their principal ray + ''' + + #@ GET CAMERA CENTERS AND DIRECTIONS + camera_centers = cameras.get_camera_center() + + c_mean = (cameras.principal_point).mean(dim=0) + xy_grid = c_mean.unsqueeze(0).unsqueeze(0) + ray_vis = _xy_to_ray_bundle(cameras, xy_grid.expand(len(cameras),-1,-1), 1.0, 15.0, 20, True) + camera_directions = ray_vis.directions + + #@ CONSTRUCT MATRICIES + A = torch.zeros((3*len(cameras)), len(cameras)+3) + b = torch.zeros((3*len(cameras), 1)) + A[:,:3] = torch.eye(3).repeat(len(cameras),1) + for ci in range(len(camera_directions)): + A[3*ci:3*ci+3, ci+3] = -camera_directions[ci] + b[3*ci:3*ci+3, 0] = camera_centers[ci] + #' A (3*N, 3*N+3) b (3*N, 1) + + #@ SVD + U, s, VT = torch.linalg.svd(A) + Sinv = torch.diag(1/s) + if len(s) < 3*len(cameras): + Sinv = torch.cat((Sinv, torch.zeros((Sinv.shape[0], 3*len(cameras) - Sinv.shape[1]), device=Sinv.device)), dim=1) + x = torch.matmul(VT.T, torch.matmul(Sinv,torch.matmul(U.T, b))) + + centroid = x[:3,0] + return centroid + + +def get_angles(target_camera: PerspectiveCameras, context_cameras: PerspectiveCameras, centroid=None): + ''' + Get angles between cameras wrt a centroid + + Args: + target_camera (Pytorch3D Camera): a camera object with a single camera + context_cameras (PyTorch3D Camera): a camera object + + Returns: + theta_deg (Tensor): a tensor containing angles in degrees + ''' + a1 = target_camera.get_camera_center() + b1 = context_cameras.get_camera_center() + + a = a1 - centroid.unsqueeze(0) + a = a.expand(len(context_cameras), -1) + b = b1 - centroid.unsqueeze(0) + + ab_dot = (a*b).sum(dim=-1) + theta = torch.acos((ab_dot)/(torch.linalg.norm(a, dim=-1) * torch.linalg.norm(b, dim=-1))) + theta_deg = theta * 180 / math.pi + + return theta_deg + + +import math +from typing import List, Literal, Optional, Tuple + +import numpy as np +import torch +from jaxtyping import Float +from numpy.typing import NDArray +from torch import Tensor + +_EPS = np.finfo(float).eps * 4.0 + + +def unit_vector(data: NDArray, axis: Optional[int] = None) -> np.ndarray: + """Return ndarray normalized by length, i.e. Euclidean norm, along axis. + + Args: + axis: the axis along which to normalize into unit vector + out: where to write out the data to. If None, returns a new np ndarray + """ + data = np.array(data, dtype=np.float64, copy=True) + if data.ndim == 1: + data /= math.sqrt(np.dot(data, data)) + return data + length = np.atleast_1d(np.sum(data * data, axis)) + np.sqrt(length, length) + if axis is not None: + length = np.expand_dims(length, axis) + data /= length + return data + + +def quaternion_from_matrix(matrix: NDArray, isprecise: bool = False) -> np.ndarray: + """Return quaternion from rotation matrix. + + Args: + matrix: rotation matrix to obtain quaternion + isprecise: if True, input matrix is assumed to be precise rotation matrix and a faster algorithm is used. + """ + M = np.array(matrix, dtype=np.float64, copy=False)[:4, :4] + if isprecise: + q = np.empty((4,)) + t = np.trace(M) + if t > M[3, 3]: + q[0] = t + q[3] = M[1, 0] - M[0, 1] + q[2] = M[0, 2] - M[2, 0] + q[1] = M[2, 1] - M[1, 2] + else: + i, j, k = 1, 2, 3 + if M[1, 1] > M[0, 0]: + i, j, k = 2, 3, 1 + if M[2, 2] > M[i, i]: + i, j, k = 3, 1, 2 + t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3] + q[i] = t + q[j] = M[i, j] + M[j, i] + q[k] = M[k, i] + M[i, k] + q[3] = M[k, j] - M[j, k] + q *= 0.5 / math.sqrt(t * M[3, 3]) + else: + m00 = M[0, 0] + m01 = M[0, 1] + m02 = M[0, 2] + m10 = M[1, 0] + m11 = M[1, 1] + m12 = M[1, 2] + m20 = M[2, 0] + m21 = M[2, 1] + m22 = M[2, 2] + # symmetric matrix K + K = [ + [m00 - m11 - m22, 0.0, 0.0, 0.0], + [m01 + m10, m11 - m00 - m22, 0.0, 0.0], + [m02 + m20, m12 + m21, m22 - m00 - m11, 0.0], + [m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22], + ] + K = np.array(K) + K /= 3.0 + # quaternion is eigenvector of K that corresponds to largest eigenvalue + w, V = np.linalg.eigh(K) + q = V[np.array([3, 0, 1, 2]), np.argmax(w)] + if q[0] < 0.0: + np.negative(q, q) + return q + + +def quaternion_slerp( + quat0: NDArray, quat1: NDArray, fraction: float, spin: int = 0, shortestpath: bool = True +) -> np.ndarray: + """Return spherical linear interpolation between two quaternions. + Args: + quat0: first quaternion + quat1: second quaternion + fraction: how much to interpolate between quat0 vs quat1 (if 0, closer to quat0; if 1, closer to quat1) + spin: how much of an additional spin to place on the interpolation + shortestpath: whether to return the short or long path to rotation + """ + q0 = unit_vector(quat0[:4]) + q1 = unit_vector(quat1[:4]) + if q0 is None or q1 is None: + raise ValueError("Input quaternions invalid.") + if fraction == 0.0: + return q0 + if fraction == 1.0: + return q1 + d = np.dot(q0, q1) + if abs(abs(d) - 1.0) < _EPS: + return q0 + if shortestpath and d < 0.0: + # invert rotation + d = -d + np.negative(q1, q1) + angle = math.acos(d) + spin * math.pi + if abs(angle) < _EPS: + return q0 + isin = 1.0 / math.sin(angle) + q0 *= math.sin((1.0 - fraction) * angle) * isin + q1 *= math.sin(fraction * angle) * isin + q0 += q1 + return q0 + + +def quaternion_matrix(quaternion: NDArray) -> np.ndarray: + """Return homogeneous rotation matrix from quaternion. + + Args: + quaternion: value to convert to matrix + """ + q = np.array(quaternion, dtype=np.float64, copy=True) + n = np.dot(q, q) + if n < _EPS: + return np.identity(4) + q *= math.sqrt(2.0 / n) + q = np.outer(q, q) + return np.array( + [ + [1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0], 0.0], + [q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0], 0.0], + [q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2], 0.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + + +def get_interpolated_poses(pose_a: NDArray, pose_b: NDArray, steps: int = 10) -> List[float]: + """Return interpolation of poses with specified number of steps. + Args: + pose_a: first pose + pose_b: second pose + steps: number of steps the interpolated pose path should contain + """ + + quat_a = quaternion_from_matrix(pose_a[:3, :3]) + quat_b = quaternion_from_matrix(pose_b[:3, :3]) + + ts = np.linspace(0, 1, steps) + quats = [quaternion_slerp(quat_a, quat_b, t) for t in ts] + trans = [(1 - t) * pose_a[:3, 3] + t * pose_b[:3, 3] for t in ts] + + poses_ab = [] + for quat, tran in zip(quats, trans): + pose = np.identity(4) + pose[:3, :3] = quaternion_matrix(quat)[:3, :3] + pose[:3, 3] = tran + poses_ab.append(pose[:3]) + return poses_ab + + +def get_interpolated_k( + k_a: Float[Tensor, "3 3"], k_b: Float[Tensor, "3 3"], steps: int = 10 +) -> List[Float[Tensor, "3 4"]]: + """ + Returns interpolated path between two camera poses with specified number of steps. + + Args: + k_a: camera matrix 1 + k_b: camera matrix 2 + steps: number of steps the interpolated pose path should contain + + Returns: + List of interpolated camera poses + """ + Ks: List[Float[Tensor, "3 3"]] = [] + ts = np.linspace(0, 1, steps) + for t in ts: + new_k = k_a * (1.0 - t) + k_b * t + Ks.append(new_k) + return Ks + + +def get_ordered_poses_and_k( + poses: Float[Tensor, "num_poses 3 4"], + Ks: Float[Tensor, "num_poses 3 3"], +) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]: + """ + Returns ordered poses and intrinsics by euclidian distance between poses. + + Args: + poses: list of camera poses + Ks: list of camera intrinsics + + Returns: + tuple of ordered poses and intrinsics + + """ + + poses_num = len(poses) + + ordered_poses = torch.unsqueeze(poses[0], 0) + ordered_ks = torch.unsqueeze(Ks[0], 0) + + # remove the first pose from poses + poses = poses[1:] + Ks = Ks[1:] + + for _ in range(poses_num - 1): + distances = torch.norm(ordered_poses[-1][:, 3] - poses[:, :, 3], dim=1) + idx = torch.argmin(distances) + ordered_poses = torch.cat((ordered_poses, torch.unsqueeze(poses[idx], 0)), dim=0) + ordered_ks = torch.cat((ordered_ks, torch.unsqueeze(Ks[idx], 0)), dim=0) + poses = torch.cat((poses[0:idx], poses[idx + 1 :]), dim=0) + Ks = torch.cat((Ks[0:idx], Ks[idx + 1 :]), dim=0) + + return ordered_poses, ordered_ks + + +def get_interpolated_poses_many( + poses: Float[Tensor, "num_poses 3 4"], + Ks: Float[Tensor, "num_poses 3 3"], + steps_per_transition: int = 10, + order_poses: bool = False, +) -> Tuple[Float[Tensor, "num_poses 3 4"], Float[Tensor, "num_poses 3 3"]]: + """Return interpolated poses for many camera poses. + + Args: + poses: list of camera poses + Ks: list of camera intrinsics + steps_per_transition: number of steps per transition + order_poses: whether to order poses by euclidian distance + + Returns: + tuple of new poses and intrinsics + """ + traj = [] + k_interp = [] + + if order_poses: + poses, Ks = get_ordered_poses_and_k(poses, Ks) + + for idx in range(poses.shape[0] - 1): + pose_a = poses[idx].cpu().numpy() + pose_b = poses[idx + 1].cpu().numpy() + poses_ab = get_interpolated_poses(pose_a, pose_b, steps=steps_per_transition) + traj += poses_ab + k_interp += get_interpolated_k(Ks[idx], Ks[idx + 1], steps=steps_per_transition) + + traj = np.stack(traj, axis=0) + k_interp = torch.stack(k_interp, dim=0) + + return torch.tensor(traj, dtype=torch.float32), torch.tensor(k_interp, dtype=torch.float32) + + +def normalize(x: torch.Tensor) -> Float[Tensor, "*batch"]: + """Returns a normalized vector.""" + return x / torch.linalg.norm(x) + + +def normalize_with_norm(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Normalize tensor along axis and return normalized value with norms. + + Args: + x: tensor to normalize. + dim: axis along which to normalize. + + Returns: + Tuple of normalized tensor and corresponding norm. + """ + + norm = torch.maximum(torch.linalg.vector_norm(x, dim=dim, keepdims=True), torch.tensor([_EPS]).to(x)) + return x / norm, norm + + +def viewmatrix(lookat: torch.Tensor, up: torch.Tensor, pos: torch.Tensor) -> Float[Tensor, "*batch"]: + """Returns a camera transformation matrix. + + Args: + lookat: The direction the camera is looking. + up: The upward direction of the camera. + pos: The position of the camera. + + Returns: + A camera transformation matrix. + """ + vec2 = normalize(lookat) + vec1_avg = normalize(up) + vec0 = normalize(torch.cross(vec1_avg, vec2)) + vec1 = normalize(torch.cross(vec2, vec0)) + m = torch.stack([vec0, vec1, vec2, pos], 1) + return m + + +def get_distortion_params( + k1: float = 0.0, + k2: float = 0.0, + k3: float = 0.0, + k4: float = 0.0, + p1: float = 0.0, + p2: float = 0.0, +) -> Float[Tensor, "*batch"]: + """Returns a distortion parameters matrix. + + Args: + k1: The first radial distortion parameter. + k2: The second radial distortion parameter. + k3: The third radial distortion parameter. + k4: The fourth radial distortion parameter. + p1: The first tangential distortion parameter. + p2: The second tangential distortion parameter. + Returns: + torch.Tensor: A distortion parameters matrix. + """ + return torch.Tensor([k1, k2, k3, k4, p1, p2]) + + +def _compute_residual_and_jacobian( + x: torch.Tensor, + y: torch.Tensor, + xd: torch.Tensor, + yd: torch.Tensor, + distortion_params: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Auxiliary function of radial_and_tangential_undistort() that computes residuals and jacobians. + Adapted from MultiNeRF: + https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L427-L474 + + Args: + x: The updated x coordinates. + y: The updated y coordinates. + xd: The distorted x coordinates. + yd: The distorted y coordinates. + distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2]. + + Returns: + The residuals (fx, fy) and jacobians (fx_x, fx_y, fy_x, fy_y). + """ + + k1 = distortion_params[..., 0] + k2 = distortion_params[..., 1] + k3 = distortion_params[..., 2] + k4 = distortion_params[..., 3] + p1 = distortion_params[..., 4] + p2 = distortion_params[..., 5] + + # let r(x, y) = x^2 + y^2; + # d(x, y) = 1 + k1 * r(x, y) + k2 * r(x, y) ^2 + k3 * r(x, y)^3 + + # k4 * r(x, y)^4; + r = x * x + y * y + d = 1.0 + r * (k1 + r * (k2 + r * (k3 + r * k4))) + + # The perfect projection is: + # xd = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2); + # yd = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2); + # + # Let's define + # + # fx(x, y) = x * d(x, y) + 2 * p1 * x * y + p2 * (r(x, y) + 2 * x^2) - xd; + # fy(x, y) = y * d(x, y) + 2 * p2 * x * y + p1 * (r(x, y) + 2 * y^2) - yd; + # + # We are looking for a solution that satisfies + # fx(x, y) = fy(x, y) = 0; + fx = d * x + 2 * p1 * x * y + p2 * (r + 2 * x * x) - xd + fy = d * y + 2 * p2 * x * y + p1 * (r + 2 * y * y) - yd + + # Compute derivative of d over [x, y] + d_r = k1 + r * (2.0 * k2 + r * (3.0 * k3 + r * 4.0 * k4)) + d_x = 2.0 * x * d_r + d_y = 2.0 * y * d_r + + # Compute derivative of fx over x and y. + fx_x = d + d_x * x + 2.0 * p1 * y + 6.0 * p2 * x + fx_y = d_y * x + 2.0 * p1 * x + 2.0 * p2 * y + + # Compute derivative of fy over x and y. + fy_x = d_x * y + 2.0 * p2 * y + 2.0 * p1 * x + fy_y = d + d_y * y + 2.0 * p2 * x + 6.0 * p1 * y + + return fx, fy, fx_x, fx_y, fy_x, fy_y + + +# @torch_compile(dynamic=True, mode="reduce-overhead", backend="eager") +def radial_and_tangential_undistort( + coords: torch.Tensor, + distortion_params: torch.Tensor, + eps: float = 1e-3, + max_iterations: int = 10, +) -> torch.Tensor: + """Computes undistorted coords given opencv distortion parameters. + Adapted from MultiNeRF + https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/camera_utils.py#L477-L509 + + Args: + coords: The distorted coordinates. + distortion_params: The distortion parameters [k1, k2, k3, k4, p1, p2]. + eps: The epsilon for the convergence. + max_iterations: The maximum number of iterations to perform. + + Returns: + The undistorted coordinates. + """ + + # Initialize from the distorted point. + x = coords[..., 0] + y = coords[..., 1] + + for _ in range(max_iterations): + fx, fy, fx_x, fx_y, fy_x, fy_y = _compute_residual_and_jacobian( + x=x, y=y, xd=coords[..., 0], yd=coords[..., 1], distortion_params=distortion_params + ) + denominator = fy_x * fx_y - fx_x * fy_y + x_numerator = fx * fy_y - fy * fx_y + y_numerator = fy * fx_x - fx * fy_x + step_x = torch.where(torch.abs(denominator) > eps, x_numerator / denominator, torch.zeros_like(denominator)) + step_y = torch.where(torch.abs(denominator) > eps, y_numerator / denominator, torch.zeros_like(denominator)) + + x = x + step_x + y = y + step_y + + return torch.stack([x, y], dim=-1) + + +def rotation_matrix(a: Float[Tensor, "3"], b: Float[Tensor, "3"]) -> Float[Tensor, "3 3"]: + """Compute the rotation matrix that rotates vector a to vector b. + + Args: + a: The vector to rotate. + b: The vector to rotate to. + Returns: + The rotation matrix. + """ + a = a / torch.linalg.norm(a) + b = b / torch.linalg.norm(b) + v = torch.cross(a, b) + c = torch.dot(a, b) + # If vectors are exactly opposite, we add a little noise to one of them + if c < -1 + 1e-8: + eps = (torch.rand(3) - 0.5) * 0.01 + return rotation_matrix(a + eps, b) + s = torch.linalg.norm(v) + skew_sym_mat = torch.Tensor( + [ + [0, -v[2], v[1]], + [v[2], 0, -v[0]], + [-v[1], v[0], 0], + ] + ) + return torch.eye(3) + skew_sym_mat + skew_sym_mat @ skew_sym_mat * ((1 - c) / (s**2 + 1e-8)) + + +def focus_of_attention(poses: Float[Tensor, "*num_poses 4 4"], initial_focus: Float[Tensor, "3"]) -> Float[Tensor, "3"]: + """Compute the focus of attention of a set of cameras. Only cameras + that have the focus of attention in front of them are considered. + + Args: + poses: The poses to orient. + initial_focus: The 3D point views to decide which cameras are initially activated. + + Returns: + The 3D position of the focus of attention. + """ + # References to the same method in third-party code: + # https://github.com/google-research/multinerf/blob/1c8b1c552133cdb2de1c1f3c871b2813f6662265/internal/camera_utils.py#L145 + # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/load_llff.py#L197 + active_directions = -poses[:, :3, 2:3] + active_origins = poses[:, :3, 3:4] + # initial value for testing if the focus_pt is in front or behind + focus_pt = initial_focus + # Prune cameras which have the current have the focus_pt behind them. + active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0 + done = False + # We need at least two active cameras, else fallback on the previous solution. + # This may be the "poses" solution if no cameras are active on first iteration, e.g. + # they are in an outward-looking configuration. + while torch.sum(active.int()) > 1 and not done: + active_directions = active_directions[active] + active_origins = active_origins[active] + # https://en.wikipedia.org/wiki/Line–line_intersection#In_more_than_two_dimensions + m = torch.eye(3) - active_directions * torch.transpose(active_directions, -2, -1) + mt_m = torch.transpose(m, -2, -1) @ m + focus_pt = torch.linalg.inv(mt_m.mean(0)) @ (mt_m @ active_origins).mean(0)[:, 0] + active = torch.sum(active_directions.squeeze(-1) * (focus_pt - active_origins.squeeze(-1)), dim=-1) > 0 + if active.all(): + # the set of active cameras did not change, so we're done. + done = True + return focus_pt + + +def auto_orient_and_center_poses( + poses: Float[Tensor, "*num_poses 4 4"], + method: Literal["pca", "up", "vertical", "none"] = "up", + center_method: Literal["poses", "focus", "none"] = "poses", +) -> Tuple[Float[Tensor, "*num_poses 3 4"], Float[Tensor, "3 4"]]: + """Orients and centers the poses. + + We provide three methods for orientation: + + - pca: Orient the poses so that the principal directions of the camera centers are aligned + with the axes, Z corresponding to the smallest principal component. + This method works well when all of the cameras are in the same plane, for example when + images are taken using a mobile robot. + - up: Orient the poses so that the average up vector is aligned with the z axis. + This method works well when images are not at arbitrary angles. + - vertical: Orient the poses so that the Z 3D direction projects close to the + y axis in images. This method works better if cameras are not all + looking in the same 3D direction, which may happen in camera arrays or in LLFF. + + There are two centering methods: + + - poses: The poses are centered around the origin. + - focus: The origin is set to the focus of attention of all cameras (the + closest point to cameras optical axes). Recommended for inward-looking + camera configurations. + + Args: + poses: The poses to orient. + method: The method to use for orientation. + center_method: The method to use to center the poses. + + Returns: + Tuple of the oriented poses and the transform matrix. + """ + + origins = poses[..., :3, 3] + + mean_origin = torch.mean(origins, dim=0) + translation_diff = origins - mean_origin + + if center_method == "poses": + translation = mean_origin + elif center_method == "focus": + translation = focus_of_attention(poses, mean_origin) + elif center_method == "none": + translation = torch.zeros_like(mean_origin) + else: + raise ValueError(f"Unknown value for center_method: {center_method}") + + if method == "pca": + _, eigvec = torch.linalg.eigh(translation_diff.T @ translation_diff) + eigvec = torch.flip(eigvec, dims=(-1,)) + + if torch.linalg.det(eigvec) < 0: + eigvec[:, 2] = -eigvec[:, 2] + + transform = torch.cat([eigvec, eigvec @ -translation[..., None]], dim=-1) + oriented_poses = transform @ poses + + if oriented_poses.mean(dim=0)[2, 1] < 0: + oriented_poses[:, 1:3] = -1 * oriented_poses[:, 1:3] + elif method in ("up", "vertical"): + up = torch.mean(poses[:, :3, 1], dim=0) + up = up / torch.linalg.norm(up) + if method == "vertical": + # If cameras are not all parallel (e.g. not in an LLFF configuration), + # we can find the 3D direction that most projects vertically in all + # cameras by minimizing ||Xu|| s.t. ||u||=1. This total least squares + # problem is solved by SVD. + x_axis_matrix = poses[:, :3, 0] + _, S, Vh = torch.linalg.svd(x_axis_matrix, full_matrices=False) + # Singular values are S_i=||Xv_i|| for each right singular vector v_i. + # ||S|| = sqrt(n) because lines of X are all unit vectors and the v_i + # are an orthonormal basis. + # ||Xv_i|| = sqrt(sum(dot(x_axis_j,v_i)^2)), thus S_i/sqrt(n) is the + # RMS of cosines between x axes and v_i. If the second smallest singular + # value corresponds to an angle error less than 10° (cos(80°)=0.17), + # this is probably a degenerate camera configuration (typical values + # are around 5° average error for the true vertical). In this case, + # rather than taking the vector corresponding to the smallest singular + # value, we project the "up" vector on the plane spanned by the two + # best singular vectors. We could also just fallback to the "up" + # solution. + if S[1] > 0.17 * math.sqrt(poses.shape[0]): + # regular non-degenerate configuration + up_vertical = Vh[2, :] + # It may be pointing up or down. Use "up" to disambiguate the sign. + up = up_vertical if torch.dot(up_vertical, up) > 0 else -up_vertical + else: + # Degenerate configuration: project "up" on the plane spanned by + # the last two right singular vectors (which are orthogonal to the + # first). v_0 is a unit vector, no need to divide by its norm when + # projecting. + up = up - Vh[0, :] * torch.dot(up, Vh[0, :]) + # re-normalize + up = up / torch.linalg.norm(up) + + rotation = rotation_matrix(up, torch.Tensor([0, 0, 1])) + transform = torch.cat([rotation, rotation @ -translation[..., None]], dim=-1) + oriented_poses = transform @ poses + elif method == "none": + transform = torch.eye(4) + transform[:3, 3] = -translation + transform = transform[:3, :] + oriented_poses = transform @ poses + else: + raise ValueError(f"Unknown value for method: {method}") + + return oriented_poses, transform + + +@torch.jit.script +def fisheye624_project(xyz, params): + """ + Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera + model project() function. + Inputs: + xyz: BxNx3 tensor of 3D points to be projected + params: Bx16 tensor of Fisheye624 parameters formatted like this: + [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] + or Bx15 tensor of Fisheye624 parameters formatted like this: + [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] + Outputs: + uv: BxNx2 tensor of 2D projections of xyz in image plane + Model for fisheye cameras with radial, tangential, and thin-prism distortion. + This model allows fu != fv. + Specifically, the model is: + uvDistorted = [x_r] + tangentialDistortion + thinPrismDistortion + [y_r] + proj = diag(fu,fv) * uvDistorted + [cu;cv]; + where: + a = x/z, b = y/z, r = (a^2+b^2)^(1/2) + th = atan(r) + cosPhi = a/r, sinPhi = b/r + [x_r] = (th+ k0 * th^3 + k1* th^5 + ...) [cosPhi] + [y_r] [sinPhi] + the number of terms in the series is determined by the template parameter numK. + tangentialDistortion = [(2 x_r^2 + rd^2)*p_0 + 2*x_r*y_r*p_1] + [(2 y_r^2 + rd^2)*p_1 + 2*x_r*y_r*p_0] + where rd^2 = x_r^2 + y_r^2 + thinPrismDistortion = [s0 * rd^2 + s1 rd^4] + [s2 * rd^2 + s3 rd^4] + Author: Daniel DeTone (ddetone@meta.com) + """ + + assert xyz.ndim == 3 + assert params.ndim == 2 + assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy" + eps = 1e-9 + B, N = xyz.shape[0], xyz.shape[1] + + # Radial correction. + z = xyz[:, :, 2].reshape(B, N, 1) + z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) + ab = xyz[:, :, :2] / z + r = torch.norm(ab, dim=-1, p=2, keepdim=True) + th = torch.atan(r) + th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r) + th_k = th.reshape(B, N, 1).clone() + for i in range(6): + th_k = th_k + params[:, -12 + i].reshape(B, 1, 1) * torch.pow(th, 3 + i * 2) + xr_yr = th_k * th_divr + uv_dist = xr_yr + + # Tangential correction. + p0 = params[:, -6].reshape(B, 1) + p1 = params[:, -5].reshape(B, 1) + xr = xr_yr[:, :, 0].reshape(B, N) + yr = xr_yr[:, :, 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_tu = uv_dist[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1) + uv_dist_tv = uv_dist[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0) + uv_dist = torch.stack([uv_dist_tu, uv_dist_tv], dim=-1) # Avoids in-place complaint. + + # Thin Prism correction. + s0 = params[:, -4].reshape(B, 1) + s1 = params[:, -3].reshape(B, 1) + s2 = params[:, -2].reshape(B, 1) + s3 = params[:, -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + + # Finally, apply standard terms: focal length and camera centers. + if params.shape[-1] == 15: + fx_fy = params[:, 0].reshape(B, 1, 1) + cx_cy = params[:, 1:3].reshape(B, 1, 2) + else: + fx_fy = params[:, 0:2].reshape(B, 1, 2) + cx_cy = params[:, 2:4].reshape(B, 1, 2) + result = uv_dist * fx_fy + cx_cy + + return result + + +# Core implementation of fisheye 624 unprojection. More details are documented here: +# https://facebookresearch.github.io/projectaria_tools/docs/tech_insights/camera_intrinsic_models#the-fisheye62-model +@torch.jit.script +def fisheye624_unproject_helper(uv, params, max_iters: int = 5): + """ + Batched implementation of the FisheyeRadTanThinPrism (aka Fisheye624) camera + model. There is no analytical solution for the inverse of the project() + function so this solves an optimization problem using Newton's method to get + the inverse. + Inputs: + uv: BxNx2 tensor of 2D pixels to be unprojected + params: Bx16 tensor of Fisheye624 parameters formatted like this: + [f_u f_v c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] + or Bx15 tensor of Fisheye624 parameters formatted like this: + [f c_u c_v {k_0 ... k_5} {p_0 p_1} {s_0 s_1 s_2 s_3}] + Outputs: + xyz: BxNx3 tensor of 3D rays of uv points with z = 1. + Model for fisheye cameras with radial, tangential, and thin-prism distortion. + This model assumes fu=fv. This unproject function holds that: + X = unproject(project(X)) [for X=(x,y,z) in R^3, z>0] + and + x = project(unproject(s*x)) [for s!=0 and x=(u,v) in R^2] + Author: Daniel DeTone (ddetone@meta.com) + """ + + assert uv.ndim == 3, "Expected batched input shaped BxNx3" + assert params.ndim == 2 + assert params.shape[-1] == 16 or params.shape[-1] == 15, "This model allows fx != fy" + eps = 1e-6 + B, N = uv.shape[0], uv.shape[1] + + if params.shape[-1] == 15: + fx_fy = params[:, 0].reshape(B, 1, 1) + cx_cy = params[:, 1:3].reshape(B, 1, 2) + else: + fx_fy = params[:, 0:2].reshape(B, 1, 2) + cx_cy = params[:, 2:4].reshape(B, 1, 2) + + uv_dist = (uv - cx_cy) / fx_fy + + # Compute xr_yr using Newton's method. + xr_yr = uv_dist.clone() # Initial guess. + for _ in range(max_iters): + uv_dist_est = xr_yr.clone() + # Tangential terms. + p0 = params[:, -6].reshape(B, 1) + p1 = params[:, -5].reshape(B, 1) + xr = xr_yr[:, :, 0].reshape(B, N) + yr = xr_yr[:, :, 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + ((2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1) + uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + ((2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0) + # Thin Prism terms. + s0 = params[:, -4].reshape(B, 1) + s1 = params[:, -3].reshape(B, 1) + s2 = params[:, -2].reshape(B, 1) + s3 = params[:, -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + # Compute the derivative of uv_dist w.r.t. xr_yr. + duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) + duv_dist_dxr_yr[:, :, 0, 0] = 1.0 + 6.0 * xr_yr[:, :, 0] * p0 + 2.0 * xr_yr[:, :, 1] * p1 + offdiag = 2.0 * (xr_yr[:, :, 0] * p1 + xr_yr[:, :, 1] * p0) + duv_dist_dxr_yr[:, :, 0, 1] = offdiag + duv_dist_dxr_yr[:, :, 1, 0] = offdiag + duv_dist_dxr_yr[:, :, 1, 1] = 1.0 + 6.0 * xr_yr[:, :, 1] * p1 + 2.0 * xr_yr[:, :, 0] * p0 + xr_yr_sq_norm = xr_yr_sq[:, :, 0] + xr_yr_sq[:, :, 1] + temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) + duv_dist_dxr_yr[:, :, 0, 0] = duv_dist_dxr_yr[:, :, 0, 0] + (xr_yr[:, :, 0] * temp1) + duv_dist_dxr_yr[:, :, 0, 1] = duv_dist_dxr_yr[:, :, 0, 1] + (xr_yr[:, :, 1] * temp1) + temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) + duv_dist_dxr_yr[:, :, 1, 0] = duv_dist_dxr_yr[:, :, 1, 0] + (xr_yr[:, :, 0] * temp2) + duv_dist_dxr_yr[:, :, 1, 1] = duv_dist_dxr_yr[:, :, 1, 1] + (xr_yr[:, :, 1] * temp2) + # Compute 2x2 inverse manually here since torch.inverse() is very slow. + # Because this is slow: inv = duv_dist_dxr_yr.inverse() + # About a 10x reduction in speed with above line. + mat = duv_dist_dxr_yr.reshape(-1, 2, 2) + a = mat[:, 0, 0].reshape(-1, 1, 1) + b = mat[:, 0, 1].reshape(-1, 1, 1) + c = mat[:, 1, 0].reshape(-1, 1, 1) + d = mat[:, 1, 1].reshape(-1, 1, 1) + det = 1.0 / ((a * d) - (b * c)) + top = torch.cat([d, -b], dim=2) + bot = torch.cat([-c, a], dim=2) + inv = det * torch.cat([top, bot], dim=1) + inv = inv.reshape(B, N, 2, 2) + # Manually compute 2x2 @ 2x1 matrix multiply. + # Because this is slow: step = (inv @ (uv_dist - uv_dist_est)[..., None])[..., 0] + diff = uv_dist - uv_dist_est + a = inv[:, :, 0, 0] + b = inv[:, :, 0, 1] + c = inv[:, :, 1, 0] + d = inv[:, :, 1, 1] + e = diff[:, :, 0] + f = diff[:, :, 1] + step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) + # Newton step. + xr_yr = xr_yr + step + + # Compute theta using Newton's method. + xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) + th = xr_yr_norm.clone() + for _ in range(max_iters): + th_radial = uv.new_ones(B, N, 1) + dthd_th = uv.new_ones(B, N, 1) + for k in range(6): + r_k = params[:, -12 + k].reshape(B, 1, 1) + th_radial = th_radial + (r_k * torch.pow(th, 2 + k * 2)) + dthd_th = dthd_th + ((3.0 + 2.0 * k) * r_k * torch.pow(th, 2 + k * 2)) + th_radial = th_radial * th + step = (xr_yr_norm - th_radial) / dthd_th + # handle dthd_th close to 0. + step = torch.where(dthd_th.abs() > eps, step, torch.sign(step) * eps * 10.0) + th = th + step + # Compute the ray direction using theta and xr_yr. + close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) + ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr) + ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) + return ray + + +# unproject 2D point to 3D with fisheye624 model +def fisheye624_unproject(coords: torch.Tensor, distortion_params: torch.Tensor) -> torch.Tensor: + dirs = fisheye624_unproject_helper(coords.unsqueeze(0), distortion_params[0].unsqueeze(0)) + # correct for camera space differences: + dirs[..., 1] = -dirs[..., 1] + dirs[..., 2] = -dirs[..., 2] + return dirs diff --git a/sgm/data/cifar10.py b/sgm/data/cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..6083646f136bad308a0485843b89234cf7a9d6cd --- /dev/null +++ b/sgm/data/cifar10.py @@ -0,0 +1,67 @@ +import pytorch_lightning as pl +import torchvision +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class CIFAR10DataDictWrapper(Dataset): + def __init__(self, dset): + super().__init__() + self.dset = dset + + def __getitem__(self, i): + x, y = self.dset[i] + return {"jpg": x, "cls": y} + + def __len__(self): + return len(self.dset) + + +class CIFAR10Loader(pl.LightningDataModule): + def __init__(self, batch_size, num_workers=0, shuffle=True): + super().__init__() + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] + ) + + self.batch_size = batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.train_dataset = CIFAR10DataDictWrapper( + torchvision.datasets.CIFAR10( + root=".data/", train=True, download=True, transform=transform + ) + ) + self.test_dataset = CIFAR10DataDictWrapper( + torchvision.datasets.CIFAR10( + root=".data/", train=False, download=True, transform=transform + ) + ) + + def prepare_data(self): + pass + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + ) diff --git a/sgm/data/co3d.py b/sgm/data/co3d.py new file mode 100644 index 0000000000000000000000000000000000000000..ba95cfbb540e4664b0fdb313f67bb5013bdea6bf --- /dev/null +++ b/sgm/data/co3d.py @@ -0,0 +1,1367 @@ +""" +adopted from SparseFusion +Wrapper for the full CO3Dv2 dataset +#@ Modified from https://github.com/facebookresearch/pytorch3d +""" + +import json +import logging +import math +import os +import random +import time +import warnings +from collections import defaultdict +from itertools import islice +from typing import ( + Any, + ClassVar, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypedDict, + Union, +) +from einops import rearrange, repeat + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from pytorch3d.utils import opencv_from_cameras_projection +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase +from sgm.data.json_index_dataset import ( + FrameAnnotsEntry, + _bbox_xywh_to_xyxy, + _bbox_xyxy_to_xywh, + _clamp_box_to_image_bounds_and_round, + _crop_around_box, + _get_1d_bounds, + _get_bbox_from_mask, + _get_clamp_bbox, + _load_1bit_png_mask, + _load_16big_png_depth, + _load_depth, + _load_depth_mask, + _load_image, + _load_mask, + _load_pointcloud, + _rescale_bbox, + _safe_as_tensor, + _seq_name_to_seed, +) +from sgm.data.objaverse import video_collate_fn +from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( + get_available_subset_names, +) +from pytorch3d.renderer.cameras import PerspectiveCameras + +logger = logging.getLogger(__name__) + + +from dataclasses import dataclass, field, fields + +from pytorch3d.renderer.camera_utils import join_cameras_as_batch +from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch +from pytorch_lightning import LightningDataModule +from torch.utils.data import DataLoader + +CO3D_ALL_CATEGORIES = list( + reversed( + [ + "baseballbat", + "banana", + "bicycle", + "microwave", + "tv", + "cellphone", + "toilet", + "hairdryer", + "couch", + "kite", + "pizza", + "umbrella", + "wineglass", + "laptop", + "hotdog", + "stopsign", + "frisbee", + "baseballglove", + "cup", + "parkingmeter", + "backpack", + "toyplane", + "toybus", + "handbag", + "chair", + "keyboard", + "car", + "motorcycle", + "carrot", + "bottle", + "sandwich", + "remote", + "bowl", + "skateboard", + "toaster", + "mouse", + "toytrain", + "book", + "toytruck", + "orange", + "broccoli", + "plant", + "teddybear", + "suitcase", + "bench", + "ball", + "cake", + "vase", + "hydrant", + "apple", + "donut", + ] + ) +) + +CO3D_ALL_TEN = [ + "donut", + "apple", + "hydrant", + "vase", + "cake", + "ball", + "bench", + "suitcase", + "teddybear", + "plant", +] + + +# @ FROM https://github.com/facebookresearch/pytorch3d +@dataclass +class FrameData(Mapping[str, Any]): + """ + A type of the elements returned by indexing the dataset object. + It can represent both individual frames and batches of thereof; + in this documentation, the sizes of tensors refer to single frames; + add the first batch dimension for the collation result. + Args: + frame_number: The number of the frame within its sequence. + 0-based continuous integers. + sequence_name: The unique name of the frame's sequence. + sequence_category: The object category of the sequence. + frame_timestamp: The time elapsed since the start of a sequence in sec. + image_size_hw: The size of the image in pixels; (height, width) tensor + of shape (2,). + image_path: The qualified path to the loaded image (with dataset_root). + image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image + of the frame; elements are floats in [0, 1]. + mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image + regions. Regions can be invalid (mask_crop[i,j]=0) in case they + are a result of zero-padding of the image after cropping around + the object bounding box; elements are floats in {0.0, 1.0}. + depth_path: The qualified path to the frame's depth map. + depth_map: A float Tensor of shape `(1, H, W)` holding the depth map + of the frame; values correspond to distances from the camera; + use `depth_mask` and `mask_crop` to filter for valid pixels. + depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the + depth map that are valid for evaluation, they have been checked for + consistency across views; elements are floats in {0.0, 1.0}. + mask_path: A qualified path to the foreground probability mask. + fg_probability: A Tensor of `(1, H, W)` denoting the probability of the + pixels belonging to the captured object; elements are floats + in [0, 1]. + bbox_xywh: The bounding box tightly enclosing the foreground object in the + format (x0, y0, width, height). The convention assumes that + `x0+width` and `y0+height` includes the boundary of the box. + I.e., to slice out the corresponding crop from an image tensor `I` + we execute `crop = I[..., y0:y0+height, x0:x0+width]` + crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb` + in the original image coordinates in the format (x0, y0, width, height). + The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs + from `bbox_xywh` due to padding (which can happen e.g. due to + setting `JsonIndexDataset.box_crop_context > 0`) + camera: A PyTorch3D camera object corresponding the frame's viewpoint, + corrected for cropping if it happened. + camera_quality_score: The score proportional to the confidence of the + frame's camera estimation (the higher the more accurate). + point_cloud_quality_score: The score proportional to the accuracy of the + frame's sequence point cloud (the higher the more accurate). + sequence_point_cloud_path: The path to the sequence's point cloud. + sequence_point_cloud: A PyTorch3D Pointclouds object holding the + point cloud corresponding to the frame's sequence. When the object + represents a batch of frames, point clouds may be deduplicated; + see `sequence_point_cloud_idx`. + sequence_point_cloud_idx: Integer indices mapping frame indices to the + corresponding point clouds in `sequence_point_cloud`; to get the + corresponding point cloud to `image_rgb[i]`, use + `sequence_point_cloud[sequence_point_cloud_idx[i]]`. + frame_type: The type of the loaded frame specified in + `subset_lists_file`, if provided. + meta: A dict for storing additional frame information. + """ + + frame_number: Optional[torch.LongTensor] + sequence_name: Union[str, List[str]] + sequence_category: Union[str, List[str]] + frame_timestamp: Optional[torch.Tensor] = None + image_size_hw: Optional[torch.Tensor] = None + image_path: Union[str, List[str], None] = None + image_rgb: Optional[torch.Tensor] = None + # masks out padding added due to cropping the square bit + mask_crop: Optional[torch.Tensor] = None + depth_path: Union[str, List[str], None] = "" + depth_map: Optional[torch.Tensor] = torch.zeros(1) + depth_mask: Optional[torch.Tensor] = torch.zeros(1) + mask_path: Union[str, List[str], None] = None + fg_probability: Optional[torch.Tensor] = None + bbox_xywh: Optional[torch.Tensor] = None + crop_bbox_xywh: Optional[torch.Tensor] = None + camera: Optional[PerspectiveCameras] = None + camera_quality_score: Optional[torch.Tensor] = None + point_cloud_quality_score: Optional[torch.Tensor] = None + sequence_point_cloud_path: Union[str, List[str], None] = "" + sequence_point_cloud: Optional[Pointclouds] = torch.zeros(1) + sequence_point_cloud_idx: Optional[torch.Tensor] = torch.zeros(1) + frame_type: Union[str, List[str], None] = "" # known | unseen + meta: dict = field(default_factory=lambda: {}) + valid_region: Optional[torch.Tensor] = None + category_one_hot: Optional[torch.Tensor] = None + + def to(self, *args, **kwargs): + new_params = {} + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): + new_params[f.name] = value.to(*args, **kwargs) + else: + new_params[f.name] = value + return type(self)(**new_params) + + def cpu(self): + return self.to(device=torch.device("cpu")) + + def cuda(self): + return self.to(device=torch.device("cuda")) + + # the following functions make sure **frame_data can be passed to functions + def __iter__(self): + for f in fields(self): + yield f.name + + def __getitem__(self, key): + return getattr(self, key) + + def __len__(self): + return len(fields(self)) + + @classmethod + def collate(cls, batch): + """ + Given a list objects `batch` of class `cls`, collates them into a batched + representation suitable for processing with deep networks. + """ + + elem = batch[0] + + if isinstance(elem, cls): + pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] + id_to_idx = defaultdict(list) + for i, pc_id in enumerate(pointcloud_ids): + id_to_idx[pc_id].append(i) + + sequence_point_cloud = [] + sequence_point_cloud_idx = -np.ones((len(batch),)) + for i, ind in enumerate(id_to_idx.values()): + sequence_point_cloud_idx[ind] = i + sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) + assert (sequence_point_cloud_idx >= 0).all() + + override_fields = { + "sequence_point_cloud": sequence_point_cloud, + "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), + } + # note that the pre-collate value of sequence_point_cloud_idx is unused + + collated = {} + for f in fields(elem): + list_values = override_fields.get( + f.name, [getattr(d, f.name) for d in batch] + ) + collated[f.name] = ( + cls.collate(list_values) + if all(list_value is not None for list_value in list_values) + else None + ) + return cls(**collated) + + elif isinstance(elem, Pointclouds): + return join_pointclouds_as_batch(batch) + + elif isinstance(elem, CamerasBase): + # TODO: don't store K; enforce working in NDC space + return join_cameras_as_batch(batch) + else: + return torch.utils.data._utils.collate.default_collate(batch) + + +# @ MODIFIED FROM https://github.com/facebookresearch/pytorch3d +class CO3Dv2Wrapper(torch.utils.data.Dataset): + def __init__( + self, + root_dir="/drive/datasets/co3d/", + category="hydrant", + subset="fewview_train", + stage="train", + sample_batch_size=20, + image_size=256, + masked=False, + deprecated_val_region=False, + return_frame_data_list=False, + reso: int = 256, + mask_type: str = "random", + cond_aug_mean=-3.0, + cond_aug_std=0.5, + condition_on_elevation=False, + fps_id=0.0, + motion_bucket_id=300.0, + num_frames: int = 20, + use_mask: bool = True, + load_pixelnerf: bool = True, + scale_pose: bool = True, + max_n_cond: int = 5, + min_n_cond: int = 2, + cond_on_multi: bool = False, + ): + root = root_dir + from typing import List + + from co3d.dataset.data_types import ( + FrameAnnotation, + SequenceAnnotation, + load_dataclass_jgzip, + ) + + self.dataset_root = root + self.path_manager = None + self.subset = subset + self.stage = stage + self.subset_lists_file: List[str] = [ + f"{self.dataset_root}/{category}/set_lists/set_lists_{subset}.json" + ] + self.subsets: Optional[List[str]] = [subset] + self.sample_batch_size = sample_batch_size + self.limit_to: int = 0 + self.limit_sequences_to: int = 0 + self.pick_sequence: Tuple[str, ...] = () + self.exclude_sequence: Tuple[str, ...] = () + self.limit_category_to: Tuple[int, ...] = () + self.load_images: bool = True + self.load_depths: bool = False + self.load_depth_masks: bool = False + self.load_masks: bool = True + self.load_point_clouds: bool = False + self.max_points: int = 0 + self.mask_images: bool = False + self.mask_depths: bool = False + self.image_height: Optional[int] = image_size + self.image_width: Optional[int] = image_size + self.box_crop: bool = True + self.box_crop_mask_thr: float = 0.4 + self.box_crop_context: float = 0.3 + self.remove_empty_masks: bool = True + self.n_frames_per_sequence: int = -1 + self.seed: int = 0 + self.sort_frames: bool = False + self.eval_batches: Any = None + + self.img_h = self.image_height + self.img_w = self.image_width + self.masked = masked + self.deprecated_val_region = deprecated_val_region + self.return_frame_data_list = return_frame_data_list + + self.reso = reso + self.num_frames = num_frames + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + self.condition_on_elevation = condition_on_elevation + self.fps_id = fps_id + self.motion_bucket_id = motion_bucket_id + self.mask_type = mask_type + self.use_mask = use_mask + self.load_pixelnerf = load_pixelnerf + self.scale_pose = scale_pose + self.max_n_cond = max_n_cond + self.min_n_cond = min_n_cond + self.cond_on_multi = cond_on_multi + + if self.cond_on_multi: + assert self.min_n_cond == self.max_n_cond + + start_time = time.time() + if "all_" in category or category == "all": + self.category_frame_annotations = [] + self.category_sequence_annotations = [] + self.subset_lists_file = [] + + if category == "all": + cats = CO3D_ALL_CATEGORIES + elif category == "all_four": + cats = ["hydrant", "teddybear", "motorcycle", "bench"] + elif category == "all_ten": + cats = [ + "donut", + "apple", + "hydrant", + "vase", + "cake", + "ball", + "bench", + "suitcase", + "teddybear", + "plant", + ] + elif category == "all_15": + cats = [ + "hydrant", + "teddybear", + "motorcycle", + "bench", + "hotdog", + "remote", + "suitcase", + "donut", + "plant", + "toaster", + "keyboard", + "handbag", + "toyplane", + "tv", + "orange", + ] + else: + print("UNSPECIFIED CATEGORY SUBSET") + cats = ["hydrant", "teddybear"] + print("loading", cats) + for cat in cats: + self.category_frame_annotations.extend( + load_dataclass_jgzip( + f"{self.dataset_root}/{cat}/frame_annotations.jgz", + List[FrameAnnotation], + ) + ) + self.category_sequence_annotations.extend( + load_dataclass_jgzip( + f"{self.dataset_root}/{cat}/sequence_annotations.jgz", + List[SequenceAnnotation], + ) + ) + self.subset_lists_file.append( + f"{self.dataset_root}/{cat}/set_lists/set_lists_{subset}.json" + ) + + else: + self.category_frame_annotations = load_dataclass_jgzip( + f"{self.dataset_root}/{category}/frame_annotations.jgz", + List[FrameAnnotation], + ) + self.category_sequence_annotations = load_dataclass_jgzip( + f"{self.dataset_root}/{category}/sequence_annotations.jgz", + List[SequenceAnnotation], + ) + + self.subset_to_image_path = None + self._load_frames() + self._load_sequences() + self._sort_frames() + self._load_subset_lists() + self._filter_db() # also computes sequence indices + # self._extract_and_set_eval_batches() + # print(self.eval_batches) + logger.info(str(self)) + + self.seq_to_frames = {} + for fi, item in enumerate(self.frame_annots): + if item["frame_annotation"].sequence_name in self.seq_to_frames: + self.seq_to_frames[item["frame_annotation"].sequence_name].append(fi) + else: + self.seq_to_frames[item["frame_annotation"].sequence_name] = [fi] + + if self.stage != "test" or self.subset != "fewview_test": + count = 0 + new_seq_to_frames = {} + for item in self.seq_to_frames: + if len(self.seq_to_frames[item]) > 10: + count += 1 + new_seq_to_frames[item] = self.seq_to_frames[item] + self.seq_to_frames = new_seq_to_frames + + self.seq_list = list(self.seq_to_frames.keys()) + + # @ REMOVE A FEW TRAINING SEQ THAT CAUSES BUG + remove_list = ["411_55952_107659", "376_42884_85882"] + for remove_idx in remove_list: + if remove_idx in self.seq_to_frames: + self.seq_list.remove(remove_idx) + print("removing", remove_idx) + + print("total training seq", len(self.seq_to_frames)) + print("data loading took", time.time() - start_time, "seconds") + + self.all_category_list = list(CO3D_ALL_CATEGORIES) + self.all_category_list.sort() + self.cat_to_idx = {} + for ci, cname in enumerate(self.all_category_list): + self.cat_to_idx[cname] = ci + + def __len__(self): + return len(self.seq_list) + + def __getitem__(self, index): + seq_index = self.seq_list[index] + + if self.subset == "fewview_test" and self.stage == "test": + batch_idx = torch.arange(len(self.seq_to_frames[seq_index])) + + elif self.stage == "test": + batch_idx = ( + torch.linspace( + 0, len(self.seq_to_frames[seq_index]) - 1, self.sample_batch_size + ) + .long() + .tolist() + ) + else: + rand = torch.randperm(len(self.seq_to_frames[seq_index])) + batch_idx = rand[: min(len(rand), self.sample_batch_size)] + + frame_data_list = [] + idx_list = [] + timestamp_list = [] + for idx in batch_idx: + idx_list.append(self.seq_to_frames[seq_index][idx]) + timestamp_list.append( + self.frame_annots[self.seq_to_frames[seq_index][idx]][ + "frame_annotation" + ].frame_timestamp + ) + frame_data_list.append( + self._get_frame(int(self.seq_to_frames[seq_index][idx])) + ) + + time_order = torch.argsort(torch.tensor(timestamp_list)) + frame_data_list = [frame_data_list[i] for i in time_order] + + frame_data = FrameData.collate(frame_data_list) + image_size = torch.Tensor([self.image_height]).repeat( + frame_data.camera.R.shape[0], 2 + ) + frame_dict = { + "R": frame_data.camera.R, + "T": frame_data.camera.T, + "f": frame_data.camera.focal_length, + "c": frame_data.camera.principal_point, + "images": frame_data.image_rgb * frame_data.fg_probability + + (1 - frame_data.fg_probability), + "valid_region": frame_data.mask_crop, + "bbox": frame_data.valid_region, + "image_size": image_size, + "frame_type": frame_data.frame_type, + "idx": seq_index, + "category": frame_data.category_one_hot, + } + if not self.masked: + frame_dict["images_full"] = frame_data.image_rgb + frame_dict["masks"] = frame_data.fg_probability + frame_dict["mask_crop"] = frame_data.mask_crop + + cond_aug = np.exp( + np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean + ) + + def _pad(input): + return torch.cat([input, torch.flip(input, dims=[0])], dim=0)[ + : self.num_frames + ] + + if len(frame_dict["images"]) < self.num_frames: + for k in frame_dict: + if isinstance(frame_dict[k], torch.Tensor): + frame_dict[k] = _pad(frame_dict[k]) + + data = dict() + if "images_full" in frame_dict: + frames = frame_dict["images_full"] * 2 - 1 + else: + frames = frame_dict["images"] * 2 - 1 + data["frames"] = frames + cond = frames[0] + data["cond_frames_without_noise"] = cond + data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames) + data["cond_frames"] = cond + cond_aug * torch.randn_like(cond) + data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames) + data["motion_bucket_id"] = torch.as_tensor( + [self.motion_bucket_id] * self.num_frames + ) + data["num_video_frames"] = self.num_frames + data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames) + + if self.load_pixelnerf: + data["pixelnerf_input"] = dict() + # Rs = frame_dict["R"].transpose(-1, -2) + # Ts = frame_dict["T"] + # Rs[:, :, 2] *= -1 + # Rs[:, :, 0] *= -1 + # Ts[:, 2] *= -1 + # Ts[:, 0] *= -1 + # c2ws = torch.zeros(Rs.shape[0], 4, 4) + # c2ws[:, :3, :3] = Rs + # c2ws[:, :3, 3] = Ts + # c2ws[:, 3, 3] = 1 + # c2ws = c2ws.inverse() + # # c2ws[..., 0] *= -1 + # # c2ws[..., 2] *= -1 + # cx = frame_dict["c"][:, 0] + # cy = frame_dict["c"][:, 1] + # fx = frame_dict["f"][:, 0] + # fy = frame_dict["f"][:, 1] + # intrinsics = torch.zeros(cx.shape[0], 3, 3) + # intrinsics[:, 2, 2] = 1 + # intrinsics[:, 0, 0] = fx + # intrinsics[:, 1, 1] = fy + # intrinsics[:, 0, 2] = cx + # intrinsics[:, 1, 2] = cy + + scene_cameras = PerspectiveCameras( + R=frame_dict["R"], + T=frame_dict["T"], + focal_length=frame_dict["f"], + principal_point=frame_dict["c"], + image_size=frame_dict["image_size"], + ) + R, T, intrinsics = opencv_from_cameras_projection( + scene_cameras, frame_dict["image_size"] + ) + c2ws = torch.zeros(R.shape[0], 4, 4) + c2ws[:, :3, :3] = R + c2ws[:, :3, 3] = T + c2ws[:, 3, 3] = 1.0 + c2ws = c2ws.inverse() + c2ws[..., 1:3] *= -1 + intrinsics[:, :2] /= 256 + + cameras = torch.zeros(c2ws.shape[0], 25) + cameras[..., :16] = c2ws.reshape(-1, 16) + cameras[..., 16:] = intrinsics.reshape(-1, 9) + if self.scale_pose: + c2ws = cameras[..., :16].reshape(-1, 4, 4) + center = c2ws[:, :3, 3].mean(0) + radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() + scale = 1.5 / radius + c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale + cameras[..., :16] = c2ws.reshape(-1, 16) + + data["pixelnerf_input"]["frames"] = frames + data["pixelnerf_input"]["cameras"] = cameras + data["pixelnerf_input"]["rgb"] = ( + F.interpolate( + frames, + (self.image_width // 8, self.image_height // 8), + mode="bilinear", + align_corners=False, + ) + + 1 + ) * 0.5 + + return data + # if self.return_frame_data_list: + # return (frame_dict, frame_data_list) + # return frame_dict + + def collate_fn(self, batch): + # a hack to add source index and keep consistent within a batch + if self.max_n_cond > 1: + # TODO implement this + n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1) + # debug + # source_index = [0] + if n_cond > 1: + for b in batch: + source_index = [0] + np.random.choice( + np.arange(1, self.num_frames), + self.max_n_cond - 1, + replace=False, + ).tolist() + b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) + b["pixelnerf_input"]["n_cond"] = n_cond + b["pixelnerf_input"]["source_images"] = b["frames"][source_index] + b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ + "cameras" + ][source_index] + + if self.cond_on_multi: + b["cond_frames_without_noise"] = b["frames"][source_index] + + ret = video_collate_fn(batch) + + if self.cond_on_multi: + ret["cond_frames_without_noise"] = rearrange( + ret["cond_frames_without_noise"], "b t ... -> (b t) ..." + ) + + return ret + + def _get_frame(self, index): + # if index >= len(self.frame_annots): + # raise IndexError(f"index {index} out of range {len(self.frame_annots)}") + + entry = self.frame_annots[index]["frame_annotation"] + # pyre-ignore[16] + point_cloud = self.seq_annots[entry.sequence_name].point_cloud + frame_data = FrameData( + frame_number=_safe_as_tensor(entry.frame_number, torch.long), + frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), + sequence_name=entry.sequence_name, + sequence_category=self.seq_annots[entry.sequence_name].category, + camera_quality_score=_safe_as_tensor( + self.seq_annots[entry.sequence_name].viewpoint_quality_score, + torch.float, + ), + point_cloud_quality_score=_safe_as_tensor( + point_cloud.quality_score, torch.float + ) + if point_cloud is not None + else None, + ) + + # The rest of the fields are optional + frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) + + ( + frame_data.fg_probability, + frame_data.mask_path, + frame_data.bbox_xywh, + clamp_bbox_xyxy, + frame_data.crop_bbox_xywh, + ) = self._load_crop_fg_probability(entry) + + scale = 1.0 + if self.load_images and entry.image is not None: + # original image size + frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) + + ( + frame_data.image_rgb, + frame_data.image_path, + frame_data.mask_crop, + scale, + ) = self._load_crop_images( + entry, frame_data.fg_probability, clamp_bbox_xyxy + ) + # print(frame_data.fg_probability.sum()) + # print('scale', scale) + + #! INSERT + if self.deprecated_val_region: + # print(frame_data.crop_bbox_xywh) + valid_bbox = _bbox_xywh_to_xyxy(frame_data.crop_bbox_xywh).float() + # print(valid_bbox, frame_data.image_size_hw) + valid_bbox[0] = torch.clip( + ( + valid_bbox[0] + - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor") + ) + / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"), + -1.0, + 1.0, + ) + valid_bbox[1] = torch.clip( + ( + valid_bbox[1] + - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor") + ) + / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"), + -1.0, + 1.0, + ) + valid_bbox[2] = torch.clip( + ( + valid_bbox[2] + - torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor") + ) + / torch.div(frame_data.image_size_hw[1], 2, rounding_mode="floor"), + -1.0, + 1.0, + ) + valid_bbox[3] = torch.clip( + ( + valid_bbox[3] + - torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor") + ) + / torch.div(frame_data.image_size_hw[0], 2, rounding_mode="floor"), + -1.0, + 1.0, + ) + # print(valid_bbox) + frame_data.valid_region = valid_bbox + else: + #! UPDATED VALID BBOX + if self.stage == "train": + assert self.image_height == 256 and self.image_width == 256 + valid = torch.nonzero(frame_data.mask_crop[0]) + min_y = valid[:, 0].min() + min_x = valid[:, 1].min() + max_y = valid[:, 0].max() + max_x = valid[:, 1].max() + valid_bbox = torch.tensor( + [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device + ).unsqueeze(0) + valid_bbox = torch.clip( + (valid_bbox - (256 // 2)) / (256 // 2), -1.0, 1.0 + ) + frame_data.valid_region = valid_bbox[0] + else: + valid = torch.nonzero(frame_data.mask_crop[0]) + min_y = valid[:, 0].min() + min_x = valid[:, 1].min() + max_y = valid[:, 0].max() + max_x = valid[:, 1].max() + valid_bbox = torch.tensor( + [min_y, min_x, max_y, max_x], device=frame_data.image_rgb.device + ).unsqueeze(0) + valid_bbox = torch.clip( + (valid_bbox - (self.image_height // 2)) / (self.image_height // 2), + -1.0, + 1.0, + ) + frame_data.valid_region = valid_bbox[0] + + #! SET CLASS ONEHOT + frame_data.category_one_hot = torch.zeros( + (len(self.all_category_list)), device=frame_data.image_rgb.device + ) + frame_data.category_one_hot[self.cat_to_idx[frame_data.sequence_category]] = 1 + + if self.load_depths and entry.depth is not None: + ( + frame_data.depth_map, + frame_data.depth_path, + frame_data.depth_mask, + ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) + + if entry.viewpoint is not None: + frame_data.camera = self._get_pytorch3d_camera( + entry, + scale, + clamp_bbox_xyxy, + ) + + if self.load_point_clouds and point_cloud is not None: + frame_data.sequence_point_cloud_path = pcl_path = os.path.join( + self.dataset_root, point_cloud.path + ) + frame_data.sequence_point_cloud = _load_pointcloud( + self._local_path(pcl_path), max_points=self.max_points + ) + + # for key in frame_data: + # if frame_data[key] == None: + # print(key) + return frame_data + + def _extract_and_set_eval_batches(self): + """ + Sets eval_batches based on input eval_batch_index. + """ + if self.eval_batch_index is not None: + if self.eval_batches is not None: + raise ValueError( + "Cannot define both eval_batch_index and eval_batches." + ) + self.eval_batches = self.seq_frame_index_to_dataset_index( + self.eval_batch_index + ) + + def _load_crop_fg_probability( + self, entry: types.FrameAnnotation + ) -> Tuple[ + Optional[torch.Tensor], + Optional[str], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + fg_probability = None + full_path = None + bbox_xywh = None + clamp_bbox_xyxy = None + crop_box_xywh = None + + if (self.load_masks or self.box_crop) and entry.mask is not None: + full_path = os.path.join(self.dataset_root, entry.mask.path) + mask = _load_mask(self._local_path(full_path)) + + if mask.shape[-2:] != entry.image.size: + raise ValueError( + f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" + ) + + bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) + + if self.box_crop: + clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( + _get_clamp_bbox( + bbox_xywh, + image_path=entry.image.path, + box_crop_context=self.box_crop_context, + ), + image_size_hw=tuple(mask.shape[-2:]), + ) + crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) + + mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) + + fg_probability, _, _ = self._resize_image(mask, mode="nearest") + + return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh + + def _load_crop_images( + self, + entry: types.FrameAnnotation, + fg_probability: Optional[torch.Tensor], + clamp_bbox_xyxy: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: + assert self.dataset_root is not None and entry.image is not None + path = os.path.join(self.dataset_root, entry.image.path) + image_rgb = _load_image(self._local_path(path)) + + if image_rgb.shape[-2:] != entry.image.size: + raise ValueError( + f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" + ) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) + + image_rgb, scale, mask_crop = self._resize_image(image_rgb) + + if self.mask_images: + assert fg_probability is not None + image_rgb *= fg_probability + + return image_rgb, path, mask_crop, scale + + def _load_mask_depth( + self, + entry: types.FrameAnnotation, + clamp_bbox_xyxy: Optional[torch.Tensor], + fg_probability: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str, torch.Tensor]: + entry_depth = entry.depth + assert entry_depth is not None + path = os.path.join(self.dataset_root, entry_depth.path) + depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + depth_bbox_xyxy = _rescale_bbox( + clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] + ) + depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) + + depth_map, _, _ = self._resize_image(depth_map, mode="nearest") + + if self.mask_depths: + assert fg_probability is not None + depth_map *= fg_probability + + if self.load_depth_masks: + assert entry_depth.mask_path is not None + mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) + depth_mask = _load_depth_mask(self._local_path(mask_path)) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + depth_mask_bbox_xyxy = _rescale_bbox( + clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] + ) + depth_mask = _crop_around_box( + depth_mask, depth_mask_bbox_xyxy, mask_path + ) + + depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") + else: + depth_mask = torch.ones_like(depth_map) + + return depth_map, path, depth_mask + + def _get_pytorch3d_camera( + self, + entry: types.FrameAnnotation, + scale: float, + clamp_bbox_xyxy: Optional[torch.Tensor], + ) -> PerspectiveCameras: + entry_viewpoint = entry.viewpoint + assert entry_viewpoint is not None + # principal point and focal length + principal_point = torch.tensor( + entry_viewpoint.principal_point, dtype=torch.float + ) + focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) + + half_image_size_wh_orig = ( + torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 + ) + + # first, we convert from the dataset's NDC convention to pixels + format = entry_viewpoint.intrinsics_format + if format.lower() == "ndc_norm_image_bounds": + # this is e.g. currently used in CO3D for storing intrinsics + rescale = half_image_size_wh_orig + elif format.lower() == "ndc_isotropic": + rescale = half_image_size_wh_orig.min() + else: + raise ValueError(f"Unknown intrinsics format: {format}") + + # principal point and focal length in pixels + principal_point_px = half_image_size_wh_orig - principal_point * rescale + focal_length_px = focal_length * rescale + if self.box_crop: + assert clamp_bbox_xyxy is not None + principal_point_px -= clamp_bbox_xyxy[:2] + + # now, convert from pixels to PyTorch3D v0.5+ NDC convention + if self.image_height is None or self.image_width is None: + out_size = list(reversed(entry.image.size)) + else: + out_size = [self.image_width, self.image_height] + + half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 + half_min_image_size_output = half_image_size_output.min() + + # rescaled principal point and focal length in ndc + principal_point = ( + half_image_size_output - principal_point_px * scale + ) / half_min_image_size_output + focal_length = focal_length_px * scale / half_min_image_size_output + + return PerspectiveCameras( + focal_length=focal_length[None], + principal_point=principal_point[None], + R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], + T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], + ) + + def _load_frames(self) -> None: + self.frame_annots = [ + FrameAnnotsEntry(frame_annotation=a, subset=None) + for a in self.category_frame_annotations + ] + + def _load_sequences(self) -> None: + self.seq_annots = { + entry.sequence_name: entry for entry in self.category_sequence_annotations + } + + def _load_subset_lists(self) -> None: + logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.") + if not self.subset_lists_file: + return + + frame_path_to_subset = {} + + for subset_list_file in self.subset_lists_file: + with open(self._local_path(subset_list_file), "r") as f: + subset_to_seq_frame = json.load(f) + + #! PRINT SUBSET_LIST STATS + # if len(self.subset_lists_file) == 1: + # print('train frames', len(subset_to_seq_frame['train'])) + # print('val frames', len(subset_to_seq_frame['val'])) + # print('test frames', len(subset_to_seq_frame['test'])) + + for set_ in subset_to_seq_frame: + for _, _, path in subset_to_seq_frame[set_]: + if path in frame_path_to_subset: + frame_path_to_subset[path].add(set_) + else: + frame_path_to_subset[path] = {set_} + + # pyre-ignore[16] + for frame in self.frame_annots: + frame["subset"] = frame_path_to_subset.get( + frame["frame_annotation"].image.path, None + ) + + if frame["subset"] is None: + continue + warnings.warn( + "Subset lists are given but don't include " + + frame["frame_annotation"].image.path + ) + + def _sort_frames(self) -> None: + # Sort frames to have them grouped by sequence, ordered by timestamp + # pyre-ignore[16] + self.frame_annots = sorted( + self.frame_annots, + key=lambda f: ( + f["frame_annotation"].sequence_name, + f["frame_annotation"].frame_timestamp or 0, + ), + ) + + def _filter_db(self) -> None: + if self.remove_empty_masks: + logger.info("Removing images with empty masks.") + # pyre-ignore[16] + old_len = len(self.frame_annots) + + msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." + + def positive_mass(frame_annot: types.FrameAnnotation) -> bool: + mask = frame_annot.mask + if mask is None: + return False + if mask.mass is None: + raise ValueError(msg) + return mask.mass > 1 + + self.frame_annots = [ + frame + for frame in self.frame_annots + if positive_mass(frame["frame_annotation"]) + ] + logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots))) + + # this has to be called after joining with categories!! + subsets = self.subsets + if subsets: + if not self.subset_lists_file: + raise ValueError( + "Subset filter is on but subset_lists_file was not given" + ) + + logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.") + + # truncate the list of subsets to the valid one + self.frame_annots = [ + entry + for entry in self.frame_annots + if (entry["subset"] is not None and self.stage in entry["subset"]) + ] + + if len(self.frame_annots) == 0: + raise ValueError(f"There are no frames in the '{subsets}' subsets!") + + self._invalidate_indexes(filter_seq_annots=True) + + if len(self.limit_category_to) > 0: + logger.info(f"Limiting dataset to categories: {self.limit_category_to}") + # pyre-ignore[16] + self.seq_annots = { + name: entry + for name, entry in self.seq_annots.items() + if entry.category in self.limit_category_to + } + + # sequence filters + for prefix in ("pick", "exclude"): + orig_len = len(self.seq_annots) + attr = f"{prefix}_sequence" + arr = getattr(self, attr) + if len(arr) > 0: + logger.info(f"{attr}: {str(arr)}") + self.seq_annots = { + name: entry + for name, entry in self.seq_annots.items() + if (name in arr) == (prefix == "pick") + } + logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) + + if self.limit_sequences_to > 0: + self.seq_annots = dict( + islice(self.seq_annots.items(), self.limit_sequences_to) + ) + + # retain only frames from retained sequences + self.frame_annots = [ + f + for f in self.frame_annots + if f["frame_annotation"].sequence_name in self.seq_annots + ] + + self._invalidate_indexes() + + if self.n_frames_per_sequence > 0: + logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") + keep_idx = [] + # pyre-ignore[16] + for seq, seq_indices in self._seq_to_idx.items(): + # infer the seed from the sequence name, this is reproducible + # and makes the selection differ for different sequences + seed = _seq_name_to_seed(seq) + self.seed + seq_idx_shuffled = random.Random(seed).sample( + sorted(seq_indices), len(seq_indices) + ) + keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) + + logger.info( + "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)) + ) + self.frame_annots = [self.frame_annots[i] for i in keep_idx] + self._invalidate_indexes(filter_seq_annots=False) + # sequences are not decimated, so self.seq_annots is valid + + if self.limit_to > 0 and self.limit_to < len(self.frame_annots): + logger.info( + "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) + ) + self.frame_annots = self.frame_annots[: self.limit_to] + self._invalidate_indexes(filter_seq_annots=True) + + def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: + # update _seq_to_idx and filter seq_meta according to frame_annots change + # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx + self._invalidate_seq_to_idx() + + if filter_seq_annots: + # pyre-ignore[16] + self.seq_annots = { + k: v + for k, v in self.seq_annots.items() + # pyre-ignore[16] + if k in self._seq_to_idx + } + + def _invalidate_seq_to_idx(self) -> None: + seq_to_idx = defaultdict(list) + # pyre-ignore[16] + for idx, entry in enumerate(self.frame_annots): + seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) + # pyre-ignore[16] + self._seq_to_idx = seq_to_idx + + def _resize_image( + self, image, mode="bilinear" + ) -> Tuple[torch.Tensor, float, torch.Tensor]: + image_height, image_width = self.image_height, self.image_width + if image_height is None or image_width is None: + # skip the resizing + imre_ = torch.from_numpy(image) + return imre_, 1.0, torch.ones_like(imre_[:1]) + # takes numpy array, returns pytorch tensor + minscale = min( + image_height / image.shape[-2], + image_width / image.shape[-1], + ) + imre = torch.nn.functional.interpolate( + torch.from_numpy(image)[None], + scale_factor=minscale, + mode=mode, + align_corners=False if mode == "bilinear" else None, + recompute_scale_factor=True, + )[0] + # pyre-fixme[19]: Expected 1 positional argument. + imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) + imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre + # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. + # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. + mask = torch.zeros(1, self.image_height, self.image_width) + mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 + return imre_, minscale, mask + + def _local_path(self, path: str) -> str: + if self.path_manager is None: + return path + return self.path_manager.get_local_path(path) + + def get_frame_numbers_and_timestamps( + self, idxs: Sequence[int] + ) -> List[Tuple[int, float]]: + out: List[Tuple[int, float]] = [] + for idx in idxs: + # pyre-ignore[16] + frame_annotation = self.frame_annots[idx]["frame_annotation"] + out.append( + (frame_annotation.frame_number, frame_annotation.frame_timestamp) + ) + return out + + def get_eval_batches(self) -> Optional[List[List[int]]]: + return self.eval_batches + + def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: + return entry["frame_annotation"].meta["frame_type"] + + +class CO3DDataset(LightningDataModule): + def __init__( + self, + root_dir, + batch_size=2, + shuffle=True, + num_workers=10, + prefetch_factor=2, + category="hydrant", + **kwargs, + ): + super().__init__() + + self.batch_size = batch_size + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.shuffle = shuffle + + self.train_dataset = CO3Dv2Wrapper( + root_dir=root_dir, + stage="train", + category=category, + **kwargs, + ) + + self.test_dataset = CO3Dv2Wrapper( + root_dir=root_dir, + stage="test", + subset="fewview_dev", + category=category, + **kwargs, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=self.train_dataset.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=self.test_dataset.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=video_collate_fn, + ) diff --git a/sgm/data/colmap.py b/sgm/data/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..b739f2e9637c0c96b80c42fce05dfeab6c5e1228 --- /dev/null +++ b/sgm/data/colmap.py @@ -0,0 +1,605 @@ +# Copyright (c) 2023, ETH Zurich and UNC Chapel Hill. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of +# its contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + + +import os +import collections +import numpy as np +import struct +import argparse + + +CameraModel = collections.namedtuple( + "CameraModel", ["model_id", "model_name", "num_params"] +) +Camera = collections.namedtuple( + "Camera", ["id", "model", "width", "height", "params"] +) +BaseImage = collections.namedtuple( + "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"] +) +Point3D = collections.namedtuple( + "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"] +) + + +class Image(BaseImage): + def qvec2rotmat(self): + return qvec2rotmat(self.qvec) + + +CAMERA_MODELS = { + CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), + CameraModel(model_id=1, model_name="PINHOLE", num_params=4), + CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), + CameraModel(model_id=3, model_name="RADIAL", num_params=5), + CameraModel(model_id=4, model_name="OPENCV", num_params=8), + CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), + CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), + CameraModel(model_id=7, model_name="FOV", num_params=5), + CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), + CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), + CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12), +} +CAMERA_MODEL_IDS = dict( + [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS] +) +CAMERA_MODEL_NAMES = dict( + [(camera_model.model_name, camera_model) for camera_model in CAMERA_MODELS] +) + + +def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): + """Read and unpack the next bytes from a binary file. + :param fid: + :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + :param endian_character: Any of {@, =, <, >, !} + :return: Tuple of read and unpacked values. + """ + data = fid.read(num_bytes) + return struct.unpack(endian_character + format_char_sequence, data) + + +def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): + """pack and write to a binary file. + :param fid: + :param data: data to send, if multiple elements are sent at the same time, + they should be encapsuled either in a list or a tuple + :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. + should be the same length as the data list or tuple + :param endian_character: Any of {@, =, <, >, !} + """ + if isinstance(data, (list, tuple)): + bytes = struct.pack(endian_character + format_char_sequence, *data) + else: + bytes = struct.pack(endian_character + format_char_sequence, data) + fid.write(bytes) + + +def read_cameras_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + cameras = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + camera_id = int(elems[0]) + model = elems[1] + width = int(elems[2]) + height = int(elems[3]) + params = np.array(tuple(map(float, elems[4:]))) + cameras[camera_id] = Camera( + id=camera_id, + model=model, + width=width, + height=height, + params=params, + ) + return cameras + + +def read_cameras_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + cameras = {} + with open(path_to_model_file, "rb") as fid: + num_cameras = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_cameras): + camera_properties = read_next_bytes( + fid, num_bytes=24, format_char_sequence="iiQQ" + ) + camera_id = camera_properties[0] + model_id = camera_properties[1] + model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name + width = camera_properties[2] + height = camera_properties[3] + num_params = CAMERA_MODEL_IDS[model_id].num_params + params = read_next_bytes( + fid, + num_bytes=8 * num_params, + format_char_sequence="d" * num_params, + ) + cameras[camera_id] = Camera( + id=camera_id, + model=model_name, + width=width, + height=height, + params=np.array(params), + ) + assert len(cameras) == num_cameras + return cameras + + +def write_cameras_text(cameras, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasText(const std::string& path) + void Reconstruction::ReadCamerasText(const std::string& path) + """ + HEADER = ( + "# Camera list with one line of data per camera:\n" + + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + + "# Number of cameras: {}\n".format(len(cameras)) + ) + with open(path, "w") as fid: + fid.write(HEADER) + for _, cam in cameras.items(): + to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] + line = " ".join([str(elem) for elem in to_write]) + fid.write(line + "\n") + + +def write_cameras_binary(cameras, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::WriteCamerasBinary(const std::string& path) + void Reconstruction::ReadCamerasBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(cameras), "Q") + for _, cam in cameras.items(): + model_id = CAMERA_MODEL_NAMES[cam.model].model_id + camera_properties = [cam.id, model_id, cam.width, cam.height] + write_next_bytes(fid, camera_properties, "iiQQ") + for p in cam.params: + write_next_bytes(fid, float(p), "d") + return cameras + + +def read_images_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + images = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + image_id = int(elems[0]) + qvec = np.array(tuple(map(float, elems[1:5]))) + tvec = np.array(tuple(map(float, elems[5:8]))) + camera_id = int(elems[8]) + image_name = elems[9] + elems = fid.readline().split() + xys = np.column_stack( + [ + tuple(map(float, elems[0::3])), + tuple(map(float, elems[1::3])), + ] + ) + point3D_ids = np.array(tuple(map(int, elems[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def read_images_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + images = {} + with open(path_to_model_file, "rb") as fid: + num_reg_images = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_reg_images): + binary_image_properties = read_next_bytes( + fid, num_bytes=64, format_char_sequence="idddddddi" + ) + image_id = binary_image_properties[0] + qvec = np.array(binary_image_properties[1:5]) + tvec = np.array(binary_image_properties[5:8]) + camera_id = binary_image_properties[8] + binary_image_name = b"" + current_char = read_next_bytes(fid, 1, "c")[0] + while current_char != b"\x00": # look for the ASCII 0 entry + binary_image_name += current_char + current_char = read_next_bytes(fid, 1, "c")[0] + image_name = binary_image_name.decode("utf-8") + num_points2D = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q" + )[0] + x_y_id_s = read_next_bytes( + fid, + num_bytes=24 * num_points2D, + format_char_sequence="ddq" * num_points2D, + ) + xys = np.column_stack( + [ + tuple(map(float, x_y_id_s[0::3])), + tuple(map(float, x_y_id_s[1::3])), + ] + ) + point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) + images[image_id] = Image( + id=image_id, + qvec=qvec, + tvec=tvec, + camera_id=camera_id, + name=image_name, + xys=xys, + point3D_ids=point3D_ids, + ) + return images + + +def write_images_text(images, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesText(const std::string& path) + void Reconstruction::WriteImagesText(const std::string& path) + """ + if len(images) == 0: + mean_observations = 0 + else: + mean_observations = sum( + (len(img.point3D_ids) for _, img in images.items()) + ) / len(images) + HEADER = ( + "# Image list with two lines of data per image:\n" + + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + + "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + + "# Number of images: {}, mean observations per image: {}\n".format( + len(images), mean_observations + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, img in images.items(): + image_header = [ + img.id, + *img.qvec, + *img.tvec, + img.camera_id, + img.name, + ] + first_line = " ".join(map(str, image_header)) + fid.write(first_line + "\n") + + points_strings = [] + for xy, point3D_id in zip(img.xys, img.point3D_ids): + points_strings.append(" ".join(map(str, [*xy, point3D_id]))) + fid.write(" ".join(points_strings) + "\n") + + +def write_images_binary(images, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadImagesBinary(const std::string& path) + void Reconstruction::WriteImagesBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(images), "Q") + for _, img in images.items(): + write_next_bytes(fid, img.id, "i") + write_next_bytes(fid, img.qvec.tolist(), "dddd") + write_next_bytes(fid, img.tvec.tolist(), "ddd") + write_next_bytes(fid, img.camera_id, "i") + for char in img.name: + write_next_bytes(fid, char.encode("utf-8"), "c") + write_next_bytes(fid, b"\x00", "c") + write_next_bytes(fid, len(img.point3D_ids), "Q") + for xy, p3d_id in zip(img.xys, img.point3D_ids): + write_next_bytes(fid, [*xy, p3d_id], "ddq") + + +def read_points3D_text(path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + points3D = {} + with open(path, "r") as fid: + while True: + line = fid.readline() + if not line: + break + line = line.strip() + if len(line) > 0 and line[0] != "#": + elems = line.split() + point3D_id = int(elems[0]) + xyz = np.array(tuple(map(float, elems[1:4]))) + rgb = np.array(tuple(map(int, elems[4:7]))) + error = float(elems[7]) + image_ids = np.array(tuple(map(int, elems[8::2]))) + point2D_idxs = np.array(tuple(map(int, elems[9::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def read_points3D_binary(path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + points3D = {} + with open(path_to_model_file, "rb") as fid: + num_points = read_next_bytes(fid, 8, "Q")[0] + for _ in range(num_points): + binary_point_line_properties = read_next_bytes( + fid, num_bytes=43, format_char_sequence="QdddBBBd" + ) + point3D_id = binary_point_line_properties[0] + xyz = np.array(binary_point_line_properties[1:4]) + rgb = np.array(binary_point_line_properties[4:7]) + error = np.array(binary_point_line_properties[7]) + track_length = read_next_bytes( + fid, num_bytes=8, format_char_sequence="Q" + )[0] + track_elems = read_next_bytes( + fid, + num_bytes=8 * track_length, + format_char_sequence="ii" * track_length, + ) + image_ids = np.array(tuple(map(int, track_elems[0::2]))) + point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) + points3D[point3D_id] = Point3D( + id=point3D_id, + xyz=xyz, + rgb=rgb, + error=error, + image_ids=image_ids, + point2D_idxs=point2D_idxs, + ) + return points3D + + +def write_points3D_text(points3D, path): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DText(const std::string& path) + void Reconstruction::WritePoints3DText(const std::string& path) + """ + if len(points3D) == 0: + mean_track_length = 0 + else: + mean_track_length = sum( + (len(pt.image_ids) for _, pt in points3D.items()) + ) / len(points3D) + HEADER = ( + "# 3D point list with one line of data per point:\n" + + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + + "# Number of points: {}, mean track length: {}\n".format( + len(points3D), mean_track_length + ) + ) + + with open(path, "w") as fid: + fid.write(HEADER) + for _, pt in points3D.items(): + point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] + fid.write(" ".join(map(str, point_header)) + " ") + track_strings = [] + for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): + track_strings.append(" ".join(map(str, [image_id, point2D]))) + fid.write(" ".join(track_strings) + "\n") + + +def write_points3D_binary(points3D, path_to_model_file): + """ + see: src/colmap/scene/reconstruction.cc + void Reconstruction::ReadPoints3DBinary(const std::string& path) + void Reconstruction::WritePoints3DBinary(const std::string& path) + """ + with open(path_to_model_file, "wb") as fid: + write_next_bytes(fid, len(points3D), "Q") + for _, pt in points3D.items(): + write_next_bytes(fid, pt.id, "Q") + write_next_bytes(fid, pt.xyz.tolist(), "ddd") + write_next_bytes(fid, pt.rgb.tolist(), "BBB") + write_next_bytes(fid, pt.error, "d") + track_length = pt.image_ids.shape[0] + write_next_bytes(fid, track_length, "Q") + for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): + write_next_bytes(fid, [image_id, point2D_id], "ii") + + +def detect_model_format(path, ext): + if ( + os.path.isfile(os.path.join(path, "cameras" + ext)) + and os.path.isfile(os.path.join(path, "images" + ext)) + and os.path.isfile(os.path.join(path, "points3D" + ext)) + ): + print("Detected model format: '" + ext + "'") + return True + + return False + + +def read_model(path, ext=""): + # try to detect the extension automatically + if ext == "": + if detect_model_format(path, ".bin"): + ext = ".bin" + elif detect_model_format(path, ".txt"): + ext = ".txt" + else: + print("Provide model format: '.bin' or '.txt'") + return + + if ext == ".txt": + cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) + images = read_images_text(os.path.join(path, "images" + ext)) + points3D = read_points3D_text(os.path.join(path, "points3D") + ext) + else: + cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) + images = read_images_binary(os.path.join(path, "images" + ext)) + points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def write_model(cameras, images, points3D, path, ext=".bin"): + if ext == ".txt": + write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) + write_images_text(images, os.path.join(path, "images" + ext)) + write_points3D_text(points3D, os.path.join(path, "points3D") + ext) + else: + write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) + write_images_binary(images, os.path.join(path, "images" + ext)) + write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) + return cameras, images, points3D + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def rotmat2qvec(R): + Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat + K = ( + np.array( + [ + [Rxx - Ryy - Rzz, 0, 0, 0], + [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], + [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], + [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz], + ] + ) + / 3.0 + ) + eigvals, eigvecs = np.linalg.eigh(K) + qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] + if qvec[0] < 0: + qvec *= -1 + return qvec + + +def main(): + parser = argparse.ArgumentParser( + description="Read and write COLMAP binary and text models" + ) + parser.add_argument("--input_model", help="path to input model folder") + parser.add_argument( + "--input_format", + choices=[".bin", ".txt"], + help="input model format", + default="", + ) + parser.add_argument("--output_model", help="path to output model folder") + parser.add_argument( + "--output_format", + choices=[".bin", ".txt"], + help="outut model format", + default=".txt", + ) + args = parser.parse_args() + + cameras, images, points3D = read_model( + path=args.input_model, ext=args.input_format + ) + + print("num_cameras:", len(cameras)) + print("num_images:", len(images)) + print("num_points3D:", len(points3D)) + + if args.output_model is not None: + write_model( + cameras, + images, + points3D, + path=args.output_model, + ext=args.output_format, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sgm/data/dataset.py b/sgm/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b726149996591c6c3db69230e1bb68c07d2faa12 --- /dev/null +++ b/sgm/data/dataset.py @@ -0,0 +1,80 @@ +from typing import Optional + +import torchdata.datapipes.iter +import webdataset as wds +from omegaconf import DictConfig +from pytorch_lightning import LightningDataModule + +try: + from sdata import create_dataset, create_dummy_dataset, create_loader +except ImportError as e: + print("#" * 100) + print("Datasets not yet available") + print("to enable, we need to add stable-datasets as a submodule") + print("please use ``git submodule update --init --recursive``") + print("and do ``pip install -e stable-datasets/`` from the root of this repo") + print("#" * 100) + exit(1) + + +class StableDataModuleFromConfig(LightningDataModule): + def __init__( + self, + train: DictConfig, + validation: Optional[DictConfig] = None, + test: Optional[DictConfig] = None, + skip_val_loader: bool = False, + dummy: bool = False, + ): + super().__init__() + self.train_config = train + assert ( + "datapipeline" in self.train_config and "loader" in self.train_config + ), "train config requires the fields `datapipeline` and `loader`" + + self.val_config = validation + if not skip_val_loader: + if self.val_config is not None: + assert ( + "datapipeline" in self.val_config and "loader" in self.val_config + ), "validation config requires the fields `datapipeline` and `loader`" + else: + print( + "Warning: No Validation datapipeline defined, using that one from training" + ) + self.val_config = train + + self.test_config = test + if self.test_config is not None: + assert ( + "datapipeline" in self.test_config and "loader" in self.test_config + ), "test config requires the fields `datapipeline` and `loader`" + + self.dummy = dummy + if self.dummy: + print("#" * 100) + print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") + print("#" * 100) + + def setup(self, stage: str) -> None: + print("Preparing datasets") + if self.dummy: + data_fn = create_dummy_dataset + else: + data_fn = create_dataset + + self.train_datapipeline = data_fn(**self.train_config.datapipeline) + if self.val_config: + self.val_datapipeline = data_fn(**self.val_config.datapipeline) + if self.test_config: + self.test_datapipeline = data_fn(**self.test_config.datapipeline) + + def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe: + loader = create_loader(self.train_datapipeline, **self.train_config.loader) + return loader + + def val_dataloader(self) -> wds.DataPipeline: + return create_loader(self.val_datapipeline, **self.val_config.loader) + + def test_dataloader(self) -> wds.DataPipeline: + return create_loader(self.test_datapipeline, **self.test_config.loader) diff --git a/sgm/data/joint3d.py b/sgm/data/joint3d.py new file mode 100644 index 0000000000000000000000000000000000000000..0569210466a2391bdbb3be358c5cd8f8477aeba1 --- /dev/null +++ b/sgm/data/joint3d.py @@ -0,0 +1,10 @@ +import torch +from torch.utils.data import Dataset + +default_sub_data_config = {} + + +class Joint3D(Dataset): + def __init__(self, sub_data_config: dict) -> None: + super().__init__() + self.sub_data_config = sub_data_config diff --git a/sgm/data/json_index_dataset.py b/sgm/data/json_index_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..16f1dbf3bbae4fb6861f45703d1493914ffaf791 --- /dev/null +++ b/sgm/data/json_index_dataset.py @@ -0,0 +1,1080 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import functools +import gzip +import hashlib +import json +import logging +import os +import random +import warnings +from collections import defaultdict +from itertools import islice +from pathlib import Path +from typing import ( + Any, + ClassVar, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + TYPE_CHECKING, + Union, +) + +import numpy as np +import torch +from PIL import Image +from pytorch3d.implicitron.tools.config import registry, ReplaceableBase +from pytorch3d.io import IO +from pytorch3d.renderer.camera_utils import join_cameras_as_batch +from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from pytorch3d.structures.pointclouds import Pointclouds +from tqdm import tqdm + +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData +from pytorch3d.implicitron.dataset.utils import is_known_frame_scalar + + +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + from typing import TypedDict + + class FrameAnnotsEntry(TypedDict): + subset: Optional[str] + frame_annotation: types.FrameAnnotation + +else: + FrameAnnotsEntry = dict + + +@registry.register +class JsonIndexDataset(DatasetBase, ReplaceableBase): + """ + A dataset with annotations in json files like the Common Objects in 3D + (CO3D) dataset. + + Args: + frame_annotations_file: A zipped json file containing metadata of the + frames in the dataset, serialized List[types.FrameAnnotation]. + sequence_annotations_file: A zipped json file containing metadata of the + sequences in the dataset, serialized List[types.SequenceAnnotation]. + subset_lists_file: A json file containing the lists of frames corresponding + corresponding to different subsets (e.g. train/val/test) of the dataset; + format: {subset: (sequence_name, frame_id, file_path)}. + subsets: Restrict frames/sequences only to the given list of subsets + as defined in subset_lists_file (see above). + limit_to: Limit the dataset to the first #limit_to frames (after other + filters have been applied). + limit_sequences_to: Limit the dataset to the first + #limit_sequences_to sequences (after other sequence filters have been + applied but before frame-based filters). + pick_sequence: A list of sequence names to restrict the dataset to. + exclude_sequence: A list of the names of the sequences to exclude. + limit_category_to: Restrict the dataset to the given list of categories. + dataset_root: The root folder of the dataset; all the paths in jsons are + specified relative to this root (but not json paths themselves). + load_images: Enable loading the frame RGB data. + load_depths: Enable loading the frame depth maps. + load_depth_masks: Enable loading the frame depth map masks denoting the + depth values used for evaluation (the points consistent across views). + load_masks: Enable loading frame foreground masks. + load_point_clouds: Enable loading sequence-level point clouds. + max_points: Cap on the number of loaded points in the point cloud; + if reached, they are randomly sampled without replacement. + mask_images: Whether to mask the images with the loaded foreground masks; + 0 value is used for background. + mask_depths: Whether to mask the depth maps with the loaded foreground + masks; 0 value is used for background. + image_height: The height of the returned images, masks, and depth maps; + aspect ratio is preserved during cropping/resizing. + image_width: The width of the returned images, masks, and depth maps; + aspect ratio is preserved during cropping/resizing. + box_crop: Enable cropping of the image around the bounding box inferred + from the foreground region of the loaded segmentation mask; masks + and depth maps are cropped accordingly; cameras are corrected. + box_crop_mask_thr: The threshold used to separate pixels into foreground + and background based on the foreground_probability mask; if no value + is greater than this threshold, the loader lowers it and repeats. + box_crop_context: The amount of additional padding added to each + dimension of the cropping bounding box, relative to box size. + remove_empty_masks: Removes the frames with no active foreground pixels + in the segmentation mask after thresholding (see box_crop_mask_thr). + n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence + frames in each sequences uniformly without replacement if it has + more frames than that; applied before other frame-level filters. + seed: The seed of the random generator sampling #n_frames_per_sequence + random frames per sequence. + sort_frames: Enable frame annotations sorting to group frames from the + same sequences together and order them by timestamps + eval_batches: A list of batches that form the evaluation set; + list of batch-sized lists of indices corresponding to __getitem__ + of this class, thus it can be used directly as a batch sampler. + eval_batch_index: + ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] ) + A list of batches of frames described as (sequence_name, frame_idx) + that can form the evaluation set, `eval_batches` will be set from this. + + """ + + frame_annotations_type: ClassVar[ + Type[types.FrameAnnotation] + ] = types.FrameAnnotation + + path_manager: Any = None + frame_annotations_file: str = "" + sequence_annotations_file: str = "" + subset_lists_file: str = "" + subsets: Optional[List[str]] = None + limit_to: int = 0 + limit_sequences_to: int = 0 + pick_sequence: Tuple[str, ...] = () + exclude_sequence: Tuple[str, ...] = () + limit_category_to: Tuple[int, ...] = () + dataset_root: str = "" + load_images: bool = True + load_depths: bool = True + load_depth_masks: bool = True + load_masks: bool = True + load_point_clouds: bool = False + max_points: int = 0 + mask_images: bool = False + mask_depths: bool = False + image_height: Optional[int] = 800 + image_width: Optional[int] = 800 + box_crop: bool = True + box_crop_mask_thr: float = 0.4 + box_crop_context: float = 0.3 + remove_empty_masks: bool = True + n_frames_per_sequence: int = -1 + seed: int = 0 + sort_frames: bool = False + eval_batches: Any = None + eval_batch_index: Any = None + # frame_annots: List[FrameAnnotsEntry] = field(init=False) + # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) + + def __post_init__(self) -> None: + # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`. + self.subset_to_image_path = None + self._load_frames() + self._load_sequences() + if self.sort_frames: + self._sort_frames() + self._load_subset_lists() + self._filter_db() # also computes sequence indices + self._extract_and_set_eval_batches() + logger.info(str(self)) + + def _extract_and_set_eval_batches(self): + """ + Sets eval_batches based on input eval_batch_index. + """ + if self.eval_batch_index is not None: + if self.eval_batches is not None: + raise ValueError( + "Cannot define both eval_batch_index and eval_batches." + ) + self.eval_batches = self.seq_frame_index_to_dataset_index( + self.eval_batch_index + ) + + def join(self, other_datasets: Iterable[DatasetBase]) -> None: + """ + Join the dataset with other JsonIndexDataset objects. + + Args: + other_datasets: A list of JsonIndexDataset objects to be joined + into the current dataset. + """ + if not all(isinstance(d, JsonIndexDataset) for d in other_datasets): + raise ValueError("This function can only join a list of JsonIndexDataset") + # pyre-ignore[16] + self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots]) + # pyre-ignore[16] + self.seq_annots.update( + # https://gist.github.com/treyhunner/f35292e676efa0be1728 + functools.reduce( + lambda a, b: {**a, **b}, + [d.seq_annots for d in other_datasets], # pyre-ignore[16] + ) + ) + all_eval_batches = [ + self.eval_batches, + # pyre-ignore + *[d.eval_batches for d in other_datasets], + ] + if not ( + all(ba is None for ba in all_eval_batches) + or all(ba is not None for ba in all_eval_batches) + ): + raise ValueError( + "When joining datasets, either all joined datasets have to have their" + " eval_batches defined, or all should have their eval batches undefined." + ) + if self.eval_batches is not None: + self.eval_batches = sum(all_eval_batches, []) + self._invalidate_indexes(filter_seq_annots=True) + + def is_filtered(self) -> bool: + """ + Returns `True` in case the dataset has been filtered and thus some frame annotations + stored on the disk might be missing in the dataset object. + + Returns: + is_filtered: `True` if the dataset has been filtered, else `False`. + """ + return ( + self.remove_empty_masks + or self.limit_to > 0 + or self.limit_sequences_to > 0 + or len(self.pick_sequence) > 0 + or len(self.exclude_sequence) > 0 + or len(self.limit_category_to) > 0 + or self.n_frames_per_sequence > 0 + ) + + def seq_frame_index_to_dataset_index( + self, + seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], + allow_missing_indices: bool = False, + remove_missing_indices: bool = False, + suppress_missing_index_warning: bool = True, + ) -> List[List[Union[Optional[int], int]]]: + """ + Obtain indices into the dataset object given a list of frame ids. + + Args: + seq_frame_index: The list of frame ids specified as + `List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally, + Image paths relative to the dataset_root can be stored specified as well: + `List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]` + allow_missing_indices: If `False`, throws an IndexError upon reaching the first + entry from `seq_frame_index` which is missing in the dataset. + Otherwise, depending on `remove_missing_indices`, either returns `None` + in place of missing entries or removes the indices of missing entries. + remove_missing_indices: Active when `allow_missing_indices=True`. + If `False`, returns `None` in place of `seq_frame_index` entries that + are not present in the dataset. + If `True` removes missing indices from the returned indices. + suppress_missing_index_warning: + Active if `allow_missing_indices==True`. Suppressess a warning message + in case an entry from `seq_frame_index` is missing in the dataset + (expected in certain cases - e.g. when setting + `self.remove_empty_masks=True`). + + Returns: + dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`. + """ + _dataset_seq_frame_n_index = { + seq: { + # pyre-ignore[16] + self.frame_annots[idx]["frame_annotation"].frame_number: idx + for idx in seq_idx + } + # pyre-ignore[16] + for seq, seq_idx in self._seq_to_idx.items() + } + + def _get_dataset_idx( + seq_name: str, frame_no: int, path: Optional[str] = None + ) -> Optional[int]: + idx_seq = _dataset_seq_frame_n_index.get(seq_name, None) + idx = idx_seq.get(frame_no, None) if idx_seq is not None else None + if idx is None: + msg = ( + f"sequence_name={seq_name} / frame_number={frame_no}" + " not in the dataset!" + ) + if not allow_missing_indices: + raise IndexError(msg) + if not suppress_missing_index_warning: + warnings.warn(msg) + return idx + if path is not None: + # Check that the loaded frame path is consistent + # with the one stored in self.frame_annots. + assert os.path.normpath( + # pyre-ignore[16] + self.frame_annots[idx]["frame_annotation"].image.path + ) == os.path.normpath( + path + ), f"Inconsistent frame indices {seq_name, frame_no, path}." + return idx + + dataset_idx = [ + [_get_dataset_idx(*b) for b in batch] # pyre-ignore [6] + for batch in seq_frame_index + ] + + if allow_missing_indices and remove_missing_indices: + # remove all None indices, and also batches with only None entries + valid_dataset_idx = [ + [b for b in batch if b is not None] for batch in dataset_idx + ] + return [ # pyre-ignore[7] + batch for batch in valid_dataset_idx if len(batch) > 0 + ] + + return dataset_idx + + def subset_from_frame_index( + self, + frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]], + allow_missing_indices: bool = True, + ) -> "JsonIndexDataset": + """ + Generate a dataset subset given the list of frames specified in `frame_index`. + + Args: + frame_index: The list of frame indentifiers (as stored in the metadata) + specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally, + Image paths relative to the dataset_root can be stored specified as well: + `List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`, + in the latter case, if imaga_path do not match the stored paths, an error + is raised. + allow_missing_indices: If `False`, throws an IndexError upon reaching the first + entry from `frame_index` which is missing in the dataset. + Otherwise, generates a subset consisting of frames entries that actually + exist in the dataset. + """ + # Get the indices into the frame annots. + dataset_indices = self.seq_frame_index_to_dataset_index( + [frame_index], + allow_missing_indices=self.is_filtered() and allow_missing_indices, + )[0] + valid_dataset_indices = [i for i in dataset_indices if i is not None] + + # Deep copy the whole dataset except frame_annots, which are large so we + # deep copy only the requested subset of frame_annots. + memo = {id(self.frame_annots): None} # pyre-ignore[16] + dataset_new = copy.deepcopy(self, memo) + dataset_new.frame_annots = copy.deepcopy( + [self.frame_annots[i] for i in valid_dataset_indices] + ) + + # This will kill all unneeded sequence annotations. + dataset_new._invalidate_indexes(filter_seq_annots=True) + + # Finally annotate the frame annotations with the name of the subset + # stored in meta. + for frame_annot in dataset_new.frame_annots: + frame_annotation = frame_annot["frame_annotation"] + if frame_annotation.meta is not None: + frame_annot["subset"] = frame_annotation.meta.get("frame_type", None) + + # A sanity check - this will crash in case some entries from frame_index are missing + # in dataset_new. + valid_frame_index = [ + fi for fi, di in zip(frame_index, dataset_indices) if di is not None + ] + dataset_new.seq_frame_index_to_dataset_index( + [valid_frame_index], allow_missing_indices=False + ) + + return dataset_new + + def __str__(self) -> str: + # pyre-ignore[16] + return f"JsonIndexDataset #frames={len(self.frame_annots)}" + + def __len__(self) -> int: + # pyre-ignore[16] + return len(self.frame_annots) + + def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: + return entry["subset"] + + def get_all_train_cameras(self) -> CamerasBase: + """ + Returns the cameras corresponding to all the known frames. + """ + logger.info("Loading all train cameras.") + cameras = [] + # pyre-ignore[16] + for frame_idx, frame_annot in enumerate(tqdm(self.frame_annots)): + frame_type = self._get_frame_type(frame_annot) + if frame_type is None: + raise ValueError("subsets not loaded") + if is_known_frame_scalar(frame_type): + cameras.append(self[frame_idx].camera) + return join_cameras_as_batch(cameras) + + def __getitem__(self, index) -> FrameData: + # pyre-ignore[16] + if index >= len(self.frame_annots): + raise IndexError(f"index {index} out of range {len(self.frame_annots)}") + + entry = self.frame_annots[index]["frame_annotation"] + # pyre-ignore[16] + point_cloud = self.seq_annots[entry.sequence_name].point_cloud + frame_data = FrameData( + frame_number=_safe_as_tensor(entry.frame_number, torch.long), + frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), + sequence_name=entry.sequence_name, + sequence_category=self.seq_annots[entry.sequence_name].category, + camera_quality_score=_safe_as_tensor( + self.seq_annots[entry.sequence_name].viewpoint_quality_score, + torch.float, + ), + point_cloud_quality_score=_safe_as_tensor( + point_cloud.quality_score, torch.float + ) + if point_cloud is not None + else None, + ) + + # The rest of the fields are optional + frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) + + ( + frame_data.fg_probability, + frame_data.mask_path, + frame_data.bbox_xywh, + clamp_bbox_xyxy, + frame_data.crop_bbox_xywh, + ) = self._load_crop_fg_probability(entry) + + scale = 1.0 + if self.load_images and entry.image is not None: + # original image size + frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) + + ( + frame_data.image_rgb, + frame_data.image_path, + frame_data.mask_crop, + scale, + ) = self._load_crop_images( + entry, frame_data.fg_probability, clamp_bbox_xyxy + ) + + if self.load_depths and entry.depth is not None: + ( + frame_data.depth_map, + frame_data.depth_path, + frame_data.depth_mask, + ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) + + if entry.viewpoint is not None: + frame_data.camera = self._get_pytorch3d_camera( + entry, + scale, + clamp_bbox_xyxy, + ) + + if self.load_point_clouds and point_cloud is not None: + pcl_path = self._fix_point_cloud_path(point_cloud.path) + frame_data.sequence_point_cloud = _load_pointcloud( + self._local_path(pcl_path), max_points=self.max_points + ) + frame_data.sequence_point_cloud_path = pcl_path + + return frame_data + + def _fix_point_cloud_path(self, path: str) -> str: + """ + Fix up a point cloud path from the dataset. + Some files in Co3Dv2 have an accidental absolute path stored. + """ + unwanted_prefix = ( + "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/" + ) + if path.startswith(unwanted_prefix): + path = path[len(unwanted_prefix) :] + return os.path.join(self.dataset_root, path) + + def _load_crop_fg_probability( + self, entry: types.FrameAnnotation + ) -> Tuple[ + Optional[torch.Tensor], + Optional[str], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + fg_probability = None + full_path = None + bbox_xywh = None + clamp_bbox_xyxy = None + crop_box_xywh = None + + if (self.load_masks or self.box_crop) and entry.mask is not None: + full_path = os.path.join(self.dataset_root, entry.mask.path) + mask = _load_mask(self._local_path(full_path)) + + if mask.shape[-2:] != entry.image.size: + raise ValueError( + f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" + ) + + bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) + + if self.box_crop: + clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( + _get_clamp_bbox( + bbox_xywh, + image_path=entry.image.path, + box_crop_context=self.box_crop_context, + ), + image_size_hw=tuple(mask.shape[-2:]), + ) + crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) + + mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) + + fg_probability, _, _ = self._resize_image(mask, mode="nearest") + + return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh + + def _load_crop_images( + self, + entry: types.FrameAnnotation, + fg_probability: Optional[torch.Tensor], + clamp_bbox_xyxy: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: + assert self.dataset_root is not None and entry.image is not None + path = os.path.join(self.dataset_root, entry.image.path) + image_rgb = _load_image(self._local_path(path)) + + if image_rgb.shape[-2:] != entry.image.size: + raise ValueError( + f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" + ) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) + + image_rgb, scale, mask_crop = self._resize_image(image_rgb) + + if self.mask_images: + assert fg_probability is not None + image_rgb *= fg_probability + + return image_rgb, path, mask_crop, scale + + def _load_mask_depth( + self, + entry: types.FrameAnnotation, + clamp_bbox_xyxy: Optional[torch.Tensor], + fg_probability: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str, torch.Tensor]: + entry_depth = entry.depth + assert entry_depth is not None + path = os.path.join(self.dataset_root, entry_depth.path) + depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + depth_bbox_xyxy = _rescale_bbox( + clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] + ) + depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) + + depth_map, _, _ = self._resize_image(depth_map, mode="nearest") + + if self.mask_depths: + assert fg_probability is not None + depth_map *= fg_probability + + if self.load_depth_masks: + assert entry_depth.mask_path is not None + mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) + depth_mask = _load_depth_mask(self._local_path(mask_path)) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + depth_mask_bbox_xyxy = _rescale_bbox( + clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] + ) + depth_mask = _crop_around_box( + depth_mask, depth_mask_bbox_xyxy, mask_path + ) + + depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") + else: + depth_mask = torch.ones_like(depth_map) + + return depth_map, path, depth_mask + + def _get_pytorch3d_camera( + self, + entry: types.FrameAnnotation, + scale: float, + clamp_bbox_xyxy: Optional[torch.Tensor], + ) -> PerspectiveCameras: + entry_viewpoint = entry.viewpoint + assert entry_viewpoint is not None + # principal point and focal length + principal_point = torch.tensor( + entry_viewpoint.principal_point, dtype=torch.float + ) + focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) + + half_image_size_wh_orig = ( + torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 + ) + + # first, we convert from the dataset's NDC convention to pixels + format = entry_viewpoint.intrinsics_format + if format.lower() == "ndc_norm_image_bounds": + # this is e.g. currently used in CO3D for storing intrinsics + rescale = half_image_size_wh_orig + elif format.lower() == "ndc_isotropic": + rescale = half_image_size_wh_orig.min() + else: + raise ValueError(f"Unknown intrinsics format: {format}") + + # principal point and focal length in pixels + principal_point_px = half_image_size_wh_orig - principal_point * rescale + focal_length_px = focal_length * rescale + if self.box_crop: + assert clamp_bbox_xyxy is not None + principal_point_px -= clamp_bbox_xyxy[:2] + + # now, convert from pixels to PyTorch3D v0.5+ NDC convention + if self.image_height is None or self.image_width is None: + out_size = list(reversed(entry.image.size)) + else: + out_size = [self.image_width, self.image_height] + + half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 + half_min_image_size_output = half_image_size_output.min() + + # rescaled principal point and focal length in ndc + principal_point = ( + half_image_size_output - principal_point_px * scale + ) / half_min_image_size_output + focal_length = focal_length_px * scale / half_min_image_size_output + + return PerspectiveCameras( + focal_length=focal_length[None], + principal_point=principal_point[None], + R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], + T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], + ) + + def _load_frames(self) -> None: + logger.info(f"Loading Co3D frames from {self.frame_annotations_file}.") + local_file = self._local_path(self.frame_annotations_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + frame_annots_list = types.load_dataclass( + zipfile, List[self.frame_annotations_type] + ) + if not frame_annots_list: + raise ValueError("Empty dataset!") + # pyre-ignore[16] + self.frame_annots = [ + FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list + ] + + def _load_sequences(self) -> None: + logger.info(f"Loading Co3D sequences from {self.sequence_annotations_file}.") + local_file = self._local_path(self.sequence_annotations_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) + if not seq_annots: + raise ValueError("Empty sequences file!") + # pyre-ignore[16] + self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} + + def _load_subset_lists(self) -> None: + logger.info(f"Loading Co3D subset lists from {self.subset_lists_file}.") + if not self.subset_lists_file: + return + + with open(self._local_path(self.subset_lists_file), "r") as f: + subset_to_seq_frame = json.load(f) + + frame_path_to_subset = { + path: subset + for subset, frames in subset_to_seq_frame.items() + for _, _, path in frames + } + # pyre-ignore[16] + for frame in self.frame_annots: + frame["subset"] = frame_path_to_subset.get( + frame["frame_annotation"].image.path, None + ) + if frame["subset"] is None: + warnings.warn( + "Subset lists are given but don't include " + + frame["frame_annotation"].image.path + ) + + def _sort_frames(self) -> None: + # Sort frames to have them grouped by sequence, ordered by timestamp + # pyre-ignore[16] + self.frame_annots = sorted( + self.frame_annots, + key=lambda f: ( + f["frame_annotation"].sequence_name, + f["frame_annotation"].frame_timestamp or 0, + ), + ) + + def _filter_db(self) -> None: + if self.remove_empty_masks: + logger.info("Removing images with empty masks.") + # pyre-ignore[16] + old_len = len(self.frame_annots) + + msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." + + def positive_mass(frame_annot: types.FrameAnnotation) -> bool: + mask = frame_annot.mask + if mask is None: + return False + if mask.mass is None: + raise ValueError(msg) + return mask.mass > 1 + + self.frame_annots = [ + frame + for frame in self.frame_annots + if positive_mass(frame["frame_annotation"]) + ] + logger.info("... filtered %d -> %d" % (old_len, len(self.frame_annots))) + + # this has to be called after joining with categories!! + subsets = self.subsets + if subsets: + if not self.subset_lists_file: + raise ValueError( + "Subset filter is on but subset_lists_file was not given" + ) + + logger.info(f"Limiting Co3D dataset to the '{subsets}' subsets.") + + # truncate the list of subsets to the valid one + self.frame_annots = [ + entry for entry in self.frame_annots if entry["subset"] in subsets + ] + if len(self.frame_annots) == 0: + raise ValueError(f"There are no frames in the '{subsets}' subsets!") + + self._invalidate_indexes(filter_seq_annots=True) + + if len(self.limit_category_to) > 0: + logger.info(f"Limiting dataset to categories: {self.limit_category_to}") + # pyre-ignore[16] + self.seq_annots = { + name: entry + for name, entry in self.seq_annots.items() + if entry.category in self.limit_category_to + } + + # sequence filters + for prefix in ("pick", "exclude"): + orig_len = len(self.seq_annots) + attr = f"{prefix}_sequence" + arr = getattr(self, attr) + if len(arr) > 0: + logger.info(f"{attr}: {str(arr)}") + self.seq_annots = { + name: entry + for name, entry in self.seq_annots.items() + if (name in arr) == (prefix == "pick") + } + logger.info("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) + + if self.limit_sequences_to > 0: + self.seq_annots = dict( + islice(self.seq_annots.items(), self.limit_sequences_to) + ) + + # retain only frames from retained sequences + self.frame_annots = [ + f + for f in self.frame_annots + if f["frame_annotation"].sequence_name in self.seq_annots + ] + + self._invalidate_indexes() + + if self.n_frames_per_sequence > 0: + logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") + keep_idx = [] + # pyre-ignore[16] + for seq, seq_indices in self._seq_to_idx.items(): + # infer the seed from the sequence name, this is reproducible + # and makes the selection differ for different sequences + seed = _seq_name_to_seed(seq) + self.seed + seq_idx_shuffled = random.Random(seed).sample( + sorted(seq_indices), len(seq_indices) + ) + keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) + + logger.info( + "... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)) + ) + self.frame_annots = [self.frame_annots[i] for i in keep_idx] + self._invalidate_indexes(filter_seq_annots=False) + # sequences are not decimated, so self.seq_annots is valid + + if self.limit_to > 0 and self.limit_to < len(self.frame_annots): + logger.info( + "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) + ) + self.frame_annots = self.frame_annots[: self.limit_to] + self._invalidate_indexes(filter_seq_annots=True) + + def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: + # update _seq_to_idx and filter seq_meta according to frame_annots change + # if filter_seq_annots, also uldates seq_annots based on the changed _seq_to_idx + self._invalidate_seq_to_idx() + + if filter_seq_annots: + # pyre-ignore[16] + self.seq_annots = { + k: v + for k, v in self.seq_annots.items() + # pyre-ignore[16] + if k in self._seq_to_idx + } + + def _invalidate_seq_to_idx(self) -> None: + seq_to_idx = defaultdict(list) + # pyre-ignore[16] + for idx, entry in enumerate(self.frame_annots): + seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) + # pyre-ignore[16] + self._seq_to_idx = seq_to_idx + + def _resize_image( + self, image, mode="bilinear" + ) -> Tuple[torch.Tensor, float, torch.Tensor]: + image_height, image_width = self.image_height, self.image_width + if image_height is None or image_width is None: + # skip the resizing + imre_ = torch.from_numpy(image) + return imre_, 1.0, torch.ones_like(imre_[:1]) + # takes numpy array, returns pytorch tensor + minscale = min( + image_height / image.shape[-2], + image_width / image.shape[-1], + ) + imre = torch.nn.functional.interpolate( + torch.from_numpy(image)[None], + scale_factor=minscale, + mode=mode, + align_corners=False if mode == "bilinear" else None, + recompute_scale_factor=True, + )[0] + # pyre-fixme[19]: Expected 1 positional argument. + imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) + imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre + # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. + # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. + mask = torch.zeros(1, self.image_height, self.image_width) + mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 + return imre_, minscale, mask + + def _local_path(self, path: str) -> str: + if self.path_manager is None: + return path + return self.path_manager.get_local_path(path) + + def get_frame_numbers_and_timestamps( + self, idxs: Sequence[int] + ) -> List[Tuple[int, float]]: + out: List[Tuple[int, float]] = [] + for idx in idxs: + # pyre-ignore[16] + frame_annotation = self.frame_annots[idx]["frame_annotation"] + out.append( + (frame_annotation.frame_number, frame_annotation.frame_timestamp) + ) + return out + + def category_to_sequence_names(self) -> Dict[str, List[str]]: + c2seq = defaultdict(list) + # pyre-ignore + for sequence_name, sa in self.seq_annots.items(): + c2seq[sa.category].append(sequence_name) + return dict(c2seq) + + def get_eval_batches(self) -> Optional[List[List[int]]]: + return self.eval_batches + + +def _seq_name_to_seed(seq_name) -> int: + return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16) + + +def _load_image(path) -> np.ndarray: + with Image.open(path) as pil_im: + im = np.array(pil_im.convert("RGB")) + im = im.transpose((2, 0, 1)) + im = im.astype(np.float32) / 255.0 + return im + + +def _load_16big_png_depth(depth_png) -> np.ndarray: + with Image.open(depth_png) as depth_pil: + # the image is stored with 16-bit depth but PIL reads it as I (32 bit). + # we cast it to uint16, then reinterpret as float16, then cast to float32 + depth = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0])) + ) + return depth + + +def _load_1bit_png_mask(file: str) -> np.ndarray: + with Image.open(file) as pil_im: + mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) + return mask + + +def _load_depth_mask(path: str) -> np.ndarray: + if not path.lower().endswith(".png"): + raise ValueError('unsupported depth mask file name "%s"' % path) + m = _load_1bit_png_mask(path) + return m[None] # fake feature channel + + +def _load_depth(path, scale_adjustment) -> np.ndarray: + if not path.lower().endswith(".png"): + raise ValueError('unsupported depth file name "%s"' % path) + + d = _load_16big_png_depth(path) * scale_adjustment + d[~np.isfinite(d)] = 0.0 + return d[None] # fake feature channel + + +def _load_mask(path) -> np.ndarray: + with Image.open(path) as pil_im: + mask = np.array(pil_im) + mask = mask.astype(np.float32) / 255.0 + return mask[None] # fake feature channel + + +def _get_1d_bounds(arr) -> Tuple[int, int]: + nz = np.flatnonzero(arr) + return nz[0], nz[-1] + 1 + + +def _get_bbox_from_mask( + mask, thr, decrease_quant: float = 0.05 +) -> Tuple[int, int, int, int]: + # bbox in xywh + masks_for_box = np.zeros_like(mask) + while masks_for_box.sum() <= 1.0: + masks_for_box = (mask > thr).astype(np.float32) + thr -= decrease_quant + if thr <= 0.0: + warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.") + + x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) + y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) + + return x0, y0, x1 - x0, y1 - y0 + + +def _get_clamp_bbox( + bbox: torch.Tensor, + box_crop_context: float = 0.0, + image_path: str = "", +) -> torch.Tensor: + # box_crop_context: rate of expansion for bbox + # returns possibly expanded bbox xyxy as float + + bbox = bbox.clone() # do not edit bbox in place + + # increase box size + if box_crop_context > 0.0: + c = box_crop_context + bbox = bbox.float() + bbox[0] -= bbox[2] * c / 2 + bbox[1] -= bbox[3] * c / 2 + bbox[2] += bbox[2] * c + bbox[3] += bbox[3] * c + + if (bbox[2:] <= 1.0).any(): + raise ValueError( + f"squashed image {image_path}!! The bounding box contains no pixels." + ) + + bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes + bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2) + + return bbox_xyxy + + +def _crop_around_box(tensor, bbox, impath: str = ""): + # bbox is xyxy, where the upper bound is corrected with +1 + bbox = _clamp_box_to_image_bounds_and_round( + bbox, + image_size_hw=tensor.shape[-2:], + ) + tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] + assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" + return tensor + + +def _clamp_box_to_image_bounds_and_round( + bbox_xyxy: torch.Tensor, + image_size_hw: Tuple[int, int], +) -> torch.LongTensor: + bbox_xyxy = bbox_xyxy.clone() + bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) + bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) + if not isinstance(bbox_xyxy, torch.LongTensor): + bbox_xyxy = bbox_xyxy.round().long() + return bbox_xyxy # pyre-ignore [7] + + +def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: + assert bbox is not None + assert np.prod(orig_res) > 1e-8 + # average ratio of dimensions + rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 + return bbox * rel_size + + +def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor: + wh = xyxy[2:] - xyxy[:2] + xywh = torch.cat([xyxy[:2], wh]) + return xywh + + +def _bbox_xywh_to_xyxy( + xywh: torch.Tensor, clamp_size: Optional[int] = None +) -> torch.Tensor: + xyxy = xywh.clone() + if clamp_size is not None: + xyxy[2:] = torch.clamp(xyxy[2:], clamp_size) + xyxy[2:] += xyxy[:2] + return xyxy + + +def _safe_as_tensor(data, dtype): + if data is None: + return None + return torch.tensor(data, dtype=dtype) + + +# NOTE this cache is per-worker; they are implemented as processes. +# each batch is loaded and collated by a single worker; +# since sequences tend to co-occur within batches, this is useful. +@functools.lru_cache(maxsize=256) +def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds: + pcl = IO().load_pointcloud(pcl_path) + if max_points > 0: + pcl = pcl.subsample(max_points) + + return pcl \ No newline at end of file diff --git a/sgm/data/latent_objaverse.py b/sgm/data/latent_objaverse.py new file mode 100644 index 0000000000000000000000000000000000000000..8819c1e7529efb1fcf44a6f95f92df3d73869517 --- /dev/null +++ b/sgm/data/latent_objaverse.py @@ -0,0 +1,52 @@ +import numpy as np +from pathlib import Path +from PIL import Image +import json +import torch +from torch.utils.data import Dataset, DataLoader, default_collate +from torchvision.transforms import ToTensor, Normalize, Compose, Resize +from pytorch_lightning import LightningDataModule +from einops import rearrange + + +class LatentObjaverseSpiral(Dataset): + def __init__( + self, + root_dir, + split="train", + transform=None, + random_front=False, + max_item=None, + cond_aug_mean=-3.0, + cond_aug_std=0.5, + condition_on_elevation=False, + **unused_kwargs, + ): + print("Using LVIS subset with precomputed Latents") + self.root_dir = Path(root_dir) + self.split = split + self.random_front = random_front + self.transform = transform + + self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") + + self.ids = json.load(open("./assets/lvis_uids.json", "r")) + self.n_views = 18 + valid_ids = [] + for idx in self.ids: + if (self.root_dir / idx).exists(): + valid_ids.append(idx) + self.ids = valid_ids + print("=" * 30) + print("Number of valid ids: ", len(self.ids)) + print("=" * 30) + + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + self.condition_on_elevation = condition_on_elevation + + if max_item is not None: + self.ids = self.ids[:max_item] + + ## debug + self.ids = self.ids * 10000 diff --git a/sgm/data/mnist.py b/sgm/data/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..dea4d7e670666bec80ecb22aa89603345e173d09 --- /dev/null +++ b/sgm/data/mnist.py @@ -0,0 +1,85 @@ +import pytorch_lightning as pl +import torchvision +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms + + +class MNISTDataDictWrapper(Dataset): + def __init__(self, dset): + super().__init__() + self.dset = dset + + def __getitem__(self, i): + x, y = self.dset[i] + return {"jpg": x, "cls": y} + + def __len__(self): + return len(self.dset) + + +class MNISTLoader(pl.LightningDataModule): + def __init__(self, batch_size, num_workers=0, prefetch_factor=2, shuffle=True): + super().__init__() + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] + ) + + self.batch_size = batch_size + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor if num_workers > 0 else 0 + self.shuffle = shuffle + self.train_dataset = MNISTDataDictWrapper( + torchvision.datasets.MNIST( + root=".data/", train=True, download=True, transform=transform + ) + ) + self.test_dataset = MNISTDataDictWrapper( + torchvision.datasets.MNIST( + root=".data/", train=False, download=True, transform=transform + ) + ) + + def prepare_data(self): + pass + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + def val_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + ) + + +if __name__ == "__main__": + dset = MNISTDataDictWrapper( + torchvision.datasets.MNIST( + root=".data/", + train=False, + download=True, + transform=transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0)] + ), + ) + ) + ex = dset[0] diff --git a/sgm/data/mvimagenet.py b/sgm/data/mvimagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..b20c398c08dd976c8bef1455845022f181cfcb73 --- /dev/null +++ b/sgm/data/mvimagenet.py @@ -0,0 +1,408 @@ +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader, default_collate +from pathlib import Path +from PIL import Image +from scipy.spatial.transform import Rotation +import rembg +from rembg import remove, new_session +from einops import rearrange + +from torchvision.transforms import ToTensor, Normalize, Compose, Resize +from torchvision.transforms.functional import to_tensor +from pytorch_lightning import LightningDataModule + +from sgm.data.colmap import read_cameras_binary, read_images_binary +from sgm.data.objaverse import video_collate_fn, FLATTEN_FIELDS, flatten_for_video + + +def qvec2rotmat(qvec): + return np.array( + [ + [ + 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], + 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2], + ], + [ + 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2, + 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1], + ], + [ + 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], + 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], + 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2, + ], + ] + ) + + +def qt2c2w(q, t): + # NOTE: remember to convert to opengl coordinate system + # rot = Rotation.from_quat(q).as_matrix() + rot = qvec2rotmat(q) + c2w = np.eye(4) + c2w[:3, :3] = np.transpose(rot) + c2w[:3, 3] = -np.transpose(rot) @ t + c2w[..., 1:3] *= -1 + return c2w + + +def random_crop(): + pass + + +class MVImageNet(Dataset): + def __init__( + self, + root_dir, + split, + transform, + reso: int = 256, + mask_type: str = "random", + cond_aug_mean=-3.0, + cond_aug_std=0.5, + condition_on_elevation=False, + fps_id=0.0, + motion_bucket_id=300.0, + num_frames: int = 24, + use_mask: bool = True, + load_pixelnerf: bool = False, + scale_pose: bool = False, + max_n_cond: int = 1, + min_n_cond: int = 1, + cond_on_multi: bool = False, + ) -> None: + super().__init__() + + self.root_dir = Path(root_dir) + self.split = split + + avails = self.root_dir.glob("*/*") + self.ids = list( + map( + lambda x: str(x.relative_to(self.root_dir)), + filter(lambda x: x.is_dir(), avails), + ) + ) + + self.transform = transform + self.reso = reso + self.num_frames = num_frames + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + self.condition_on_elevation = condition_on_elevation + self.fps_id = fps_id + self.motion_bucket_id = motion_bucket_id + self.mask_type = mask_type + self.use_mask = use_mask + self.load_pixelnerf = load_pixelnerf + self.scale_pose = scale_pose + self.max_n_cond = max_n_cond + self.min_n_cond = min_n_cond + self.cond_on_multi = cond_on_multi + + if self.cond_on_multi: + assert self.min_n_cond == self.max_n_cond + self.session = new_session() + + def __getitem__(self, index: int): + # mvimgnet starts with idx==1 + idx_list = np.arange(0, self.num_frames) + this_image_dir = self.root_dir / self.ids[index] / "images" + this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" + + # while not this_camera_dir.exists(): + # index = (index + 1) % len(self.ids) + # this_image_dir = self.root_dir / self.ids[index] / "images" + # this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" + if not this_camera_dir.exists(): + index = 0 + this_image_dir = self.root_dir / self.ids[index] / "images" + this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" + + this_images = read_images_binary(this_camera_dir / "images.bin") + # filenames = list(map(lambda x: f"{x:03d}", this_images.keys())) + filenames = list(this_images.keys()) + + if len(filenames) == 0: + index = 0 + this_image_dir = self.root_dir / self.ids[index] / "images" + this_camera_dir = self.root_dir / self.ids[index] / "sparse/0" + this_images = read_images_binary(this_camera_dir / "images.bin") + # filenames = list(map(lambda x: f"{x:03d}", this_images.keys())) + filenames = list(this_images.keys()) + + filenames = list( + filter(lambda x: (this_image_dir / this_images[x].name).exists(), filenames) + ) + + filenames = sorted(filenames, key=lambda x: this_images[x].name) + + # # debug + # names = [] + # for v in filenames: + # names.append(this_images[v].name) + # breakpoint() + + while len(filenames) < self.num_frames: + num_surpass = self.num_frames - len(filenames) + filenames += list(reversed(filenames[-num_surpass:])) + + if len(filenames) < self.num_frames: + print(f"\n\n{self.ids[index]}\n\n") + + frames = [] + cameras = [] + downsampled_rgb = [] + for view_idx in idx_list: + this_id = filenames[view_idx] + frame = Image.open(this_image_dir / this_images[this_id].name) + w, h = frame.size + + if self.mask_type == "random": + image_size = min(h, w) + left = np.random.randint(0, w - image_size + 1) + right = left + image_size + top = np.random.randint(0, h - image_size + 1) + bottom = top + image_size + ## need to assign left, right, top, bottom, image_size + elif self.mask_type == "object": + pass + elif self.mask_type == "rembg": + image_size = min(h, w) + if ( + cached := this_image_dir + / f"{this_images[this_id].name[:-4]}_rembg.png" + ).exists(): + try: + mask = np.asarray(Image.open(cached, formats=["png"]))[..., 3] + except: + mask = remove(frame, session=self.session) + mask.save(cached) + mask = np.asarray(mask)[..., 3] + else: + mask = remove(frame, session=self.session) + mask.save(cached) + mask = np.asarray(mask)[..., 3] + # in h,w order + y, x = np.array(mask.nonzero()) + bbox_cx = x.mean() + bbox_cy = y.mean() + + if bbox_cy - image_size / 2 < 0: + top = 0 + elif bbox_cy + image_size / 2 > h: + top = h - image_size + else: + top = int(bbox_cy - image_size / 2) + + if bbox_cx - image_size / 2 < 0: + left = 0 + elif bbox_cx + image_size / 2 > w: + left = w - image_size + else: + left = int(bbox_cx - image_size / 2) + + # top = max(int(bbox_cy - image_size / 2), 0) + # left = max(int(bbox_cx - image_size / 2), 0) + bottom = top + image_size + right = left + image_size + else: + raise ValueError(f"Unknown mask type: {self.mask_type}") + + frame = frame.crop((left, top, right, bottom)) + frame = frame.resize((self.reso, self.reso)) + frames.append(self.transform(frame)) + + if self.load_pixelnerf: + # extrinsics + extrinsics = this_images[this_id] + c2w = qt2c2w(extrinsics.qvec, extrinsics.tvec) + # intrinsics + intrinsics = read_cameras_binary(this_camera_dir / "cameras.bin") + assert len(intrinsics) == 1 + intrinsics = intrinsics[1] + f, cx, cy, _ = intrinsics.params + f *= 1 / image_size + cx -= left + cy -= top + cx *= 1 / image_size + cy *= 1 / image_size # all are relative values + intrinsics = np.array([[f, 0, cx], [0, f, cy], [0, 0, 1]]) + + this_camera = np.zeros(25) + this_camera[:16] = c2w.reshape(-1) + this_camera[16:] = intrinsics.reshape(-1) + + cameras.append(this_camera) + downsampled = frame.resize((self.reso // 8, self.reso // 8)) + downsampled_rgb.append((self.transform(downsampled) + 1.0) * 0.5) + + data = dict() + + cond_aug = np.exp( + np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean + ) + frames = torch.stack(frames) + cond = frames[0] + # setting all things in data + data["frames"] = frames + data["cond_frames_without_noise"] = cond + data["cond_aug"] = torch.as_tensor([cond_aug] * self.num_frames) + data["cond_frames"] = cond + cond_aug * torch.randn_like(cond) + data["fps_id"] = torch.as_tensor([self.fps_id] * self.num_frames) + data["motion_bucket_id"] = torch.as_tensor( + [self.motion_bucket_id] * self.num_frames + ) + data["num_video_frames"] = self.num_frames + data["image_only_indicator"] = torch.as_tensor([0.0] * self.num_frames) + + if self.load_pixelnerf: + # TODO: normalize camera poses + data["pixelnerf_input"] = dict() + data["pixelnerf_input"]["frames"] = frames + data["pixelnerf_input"]["rgb"] = torch.stack(downsampled_rgb) + + cameras = torch.from_numpy(np.stack(cameras)).float() + if self.scale_pose: + c2ws = cameras[..., :16].reshape(-1, 4, 4) + center = c2ws[:, :3, 3].mean(0) + radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() + scale = 1.5 / radius + c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale + cameras[..., :16] = c2ws.reshape(-1, 16) + + # if self.max_n_cond > 1: + # # TODO implement this + # n_cond = np.random.randint(1, self.max_n_cond + 1) + # # debug + # source_index = [0] + # if n_cond > 1: + # source_index += np.random.choice( + # np.arange(1, self.num_frames), + # self.max_n_cond - 1, + # replace=False, + # ).tolist() + # data["pixelnerf_input"]["source_index"] = torch.as_tensor( + # source_index + # ) + # data["pixelnerf_input"]["n_cond"] = n_cond + # data["pixelnerf_input"]["source_images"] = frames[source_index] + # data["pixelnerf_input"]["source_cameras"] = cameras[source_index] + + data["pixelnerf_input"]["cameras"] = cameras + + return data + + def __len__(self): + return len(self.ids) + + def collate_fn(self, batch): + # a hack to add source index and keep consistent within a batch + if self.max_n_cond > 1: + # TODO implement this + n_cond = np.random.randint(self.min_n_cond, self.max_n_cond + 1) + # debug + # source_index = [0] + if n_cond > 1: + for b in batch: + source_index = [0] + np.random.choice( + np.arange(1, self.num_frames), + self.max_n_cond - 1, + replace=False, + ).tolist() + b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) + b["pixelnerf_input"]["n_cond"] = n_cond + b["pixelnerf_input"]["source_images"] = b["frames"][source_index] + b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ + "cameras" + ][source_index] + + if self.cond_on_multi: + b["cond_frames_without_noise"] = b["frames"][source_index] + + ret = video_collate_fn(batch) + + if self.cond_on_multi: + ret["cond_frames_without_noise"] = rearrange(ret["cond_frames_without_noise"], "b t ... -> (b t) ...") + + return ret + + +class MVImageNetFixedCond(MVImageNet): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + +class MVImageNetDataset(LightningDataModule): + def __init__( + self, + root_dir, + batch_size=2, + shuffle=True, + num_workers=10, + prefetch_factor=2, + **kwargs, + ): + super().__init__() + + self.batch_size = batch_size + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.shuffle = shuffle + + self.transform = Compose( + [ + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + self.train_dataset = MVImageNet( + root_dir=root_dir, + split="train", + transform=self.transform, + **kwargs, + ) + + self.test_dataset = MVImageNet( + root_dir=root_dir, + split="test", + transform=self.transform, + **kwargs, + ) + + def train_dataloader(self): + def worker_init_fn(worker_id): + np.random.seed(np.random.get_state()[1][0]) + + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=self.train_dataset.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=self.test_dataset.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=video_collate_fn, + ) diff --git a/sgm/data/objaverse.py b/sgm/data/objaverse.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ae0730ab09dc4e5ad87e3212b3f2ae22581934 --- /dev/null +++ b/sgm/data/objaverse.py @@ -0,0 +1,882 @@ +import numpy as np +from pathlib import Path +from PIL import Image +import json +import torch +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader, default_collate +from torchvision.transforms import ToTensor, Normalize, Compose, Resize +from torchvision.transforms.functional import to_tensor +from pytorch_lightning import LightningDataModule +from einops import rearrange + + +def read_camera_matrix_single(json_file): + # for gobjaverse + with open(json_file, "r", encoding="utf8") as reader: + json_content = json.load(reader) + + # negative sign for opencv to opengl + camera_matrix = torch.zeros(3, 4) + camera_matrix[:3, 0] = torch.tensor(json_content["x"]) + camera_matrix[:3, 1] = -torch.tensor(json_content["y"]) + camera_matrix[:3, 2] = -torch.tensor(json_content["z"]) + camera_matrix[:3, 3] = torch.tensor(json_content["origin"]) + """ + camera_matrix = np.eye(4) + camera_matrix[:3, 0] = np.array(json_content['x']) + camera_matrix[:3, 1] = np.array(json_content['y']) + camera_matrix[:3, 2] = np.array(json_content['z']) + camera_matrix[:3, 3] = np.array(json_content['origin']) + # print(camera_matrix) + """ + + return camera_matrix + + +def read_camera_instrinsics_single(json_file, h: int, w: int, scale: float = 1.0): + with open(json_file, "r", encoding="utf8") as reader: + json_content = json.load(reader) + + h = int(h * scale) + w = int(w * scale) + + y_fov = json_content["y_fov"] + x_fov = json_content["x_fov"] + + fy = h / 2 / np.tan(y_fov / 2) + fx = w / 2 / np.tan(x_fov / 2) + + cx = w // 2 + cy = h // 2 + + intrinsics = torch.tensor( + [ + [fx, fy], + [cx, cy], + [w, h], + ], + dtype=torch.float32, + ) + return intrinsics + + +def compose_extrinsic_RT(RT: torch.Tensor): + """ + Compose the standard form extrinsic matrix from RT. + Batched I/O. + """ + return torch.cat( + [ + RT, + torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat( + RT.shape[0], 1, 1 + ), + ], + dim=1, + ) + + +def get_normalized_camera_intrinsics(intrinsics: torch.Tensor): + """ + intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] + Return batched fx, fy, cx, cy + """ + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1] + cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1] + width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1] + fx, fy = fx / width, fy / height + cx, cy = cx / width, cy / height + return fx, fy, cx, cy + + +def build_camera_standard(RT: torch.Tensor, intrinsics: torch.Tensor): + """ + RT: (N, 3, 4) + intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]] + """ + E = compose_extrinsic_RT(RT) + fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics) + I = torch.stack( + [ + torch.stack([fx, torch.zeros_like(fx), cx], dim=-1), + torch.stack([torch.zeros_like(fy), fy, cy], dim=-1), + torch.tensor([[0, 0, 1]], dtype=torch.float32).repeat(RT.shape[0], 1), + ], + dim=1, + ) + return torch.cat( + [ + E.reshape(-1, 16), + I.reshape(-1, 9), + ], + dim=-1, + ) + + +def calc_elevation(c2w): + ## works for single or batched c2w + ## assume world up is (0, 0, 1) + pos = c2w[..., :3, 3] + + return np.arcsin(pos[..., 2] / np.linalg.norm(pos, axis=-1, keepdims=False)) + + +def read_camera_matrix_single(json_file): + with open(json_file, "r", encoding="utf8") as reader: + json_content = json.load(reader) + + # negative sign for opencv to opengl + # camera_matrix = np.zeros([3, 4]) + # camera_matrix[:3, 0] = np.array(json_content["x"]) + # camera_matrix[:3, 1] = -np.array(json_content["y"]) + # camera_matrix[:3, 2] = -np.array(json_content["z"]) + # camera_matrix[:3, 3] = np.array(json_content["origin"]) + camera_matrix = torch.zeros([3, 4]) + camera_matrix[:3, 0] = torch.tensor(json_content["x"]) + camera_matrix[:3, 1] = -torch.tensor(json_content["y"]) + camera_matrix[:3, 2] = -torch.tensor(json_content["z"]) + camera_matrix[:3, 3] = torch.tensor(json_content["origin"]) + """ + camera_matrix = np.eye(4) + camera_matrix[:3, 0] = np.array(json_content['x']) + camera_matrix[:3, 1] = np.array(json_content['y']) + camera_matrix[:3, 2] = np.array(json_content['z']) + camera_matrix[:3, 3] = np.array(json_content['origin']) + # print(camera_matrix) + """ + + return camera_matrix + + +def blend_white_bg(image): + new_image = Image.new("RGB", image.size, (255, 255, 255)) + new_image.paste(image, mask=image.split()[3]) + + return new_image + + +def flatten_for_video(input): + return input.flatten() + + +FLATTEN_FIELDS = ["fps_id", "motion_bucket_id", "cond_aug", "elevation"] + + +def video_collate_fn(batch: list[dict], *args, **kwargs): + out = {} + for key in batch[0].keys(): + if key in FLATTEN_FIELDS: + out[key] = default_collate([item[key] for item in batch]) + out[key] = flatten_for_video(out[key]) + elif key == "num_video_frames": + out[key] = batch[0][key] + elif key in ["frames", "latents", "rgb"]: + out[key] = default_collate([item[key] for item in batch]) + out[key] = rearrange(out[key], "b t c h w -> (b t) c h w") + else: + out[key] = default_collate([item[key] for item in batch]) + + if "pixelnerf_input" in out: + out["pixelnerf_input"]["rgb"] = rearrange( + out["pixelnerf_input"]["rgb"], "b t c h w -> (b t) c h w" + ) + + return out + + +class GObjaverse(Dataset): + def __init__( + self, + root_dir, + split="train", + transform=None, + random_front=False, + max_item=None, + cond_aug_mean=-3.0, + cond_aug_std=0.5, + condition_on_elevation=False, + fps_id=0.0, + motion_bucket_id=300.0, + use_latents=False, + load_caps=False, + front_view_selection="random", + load_pixelnerf=False, + debug_base_idx=None, + scale_pose: bool = False, + max_n_cond: int = 1, + **unused_kwargs, + ): + self.root_dir = Path(root_dir) + self.split = split + self.random_front = random_front + self.transform = transform + self.use_latents = use_latents + + self.ids = json.load(open(self.root_dir / "valid_uids.json", "r")) + self.n_views = 24 + + self.load_caps = load_caps + if self.load_caps: + self.caps = json.load(open(self.root_dir / "text_captions_cap3d.json", "r")) + + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + self.condition_on_elevation = condition_on_elevation + self.fps_id = fps_id + self.motion_bucket_id = motion_bucket_id + self.load_pixelnerf = load_pixelnerf + self.scale_pose = scale_pose + self.max_n_cond = max_n_cond + + if self.use_latents: + self.latents_dir = self.root_dir / "latents256" + self.clip_dir = self.root_dir / "clip_emb256" + + self.front_view_selection = front_view_selection + if self.front_view_selection == "random": + pass + elif self.front_view_selection == "fixed": + pass + elif self.front_view_selection.startswith("clip_score"): + self.clip_scores = torch.load(self.root_dir / "clip_score_per_view.pt") + self.ids = list(self.clip_scores.keys()) + else: + raise ValueError( + f"Unknown front view selection method {self.front_view_selection}" + ) + + if max_item is not None: + self.ids = self.ids[:max_item] + ## debug + self.ids = self.ids * 10000 + + if debug_base_idx is not None: + print(f"debug mode with base idx: {debug_base_idx}") + self.debug_base_idx = debug_base_idx + + def __getitem__(self, idx: int): + if hasattr(self, "debug_base_idx"): + idx = (idx + self.debug_base_idx) % len(self.ids) + data = {} + idx_list = np.arange(self.n_views) + # if self.random_front: + # roll_idx = np.random.randint(self.n_views) + # idx_list = np.roll(idx_list, roll_idx) + if self.front_view_selection == "random": + roll_idx = np.random.randint(self.n_views) + idx_list = np.roll(idx_list, roll_idx) + elif self.front_view_selection == "fixed": + pass + elif self.front_view_selection == "clip_score_softmax": + this_clip_score = ( + F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy() + ) + roll_idx = np.random.choice(idx_list, p=this_clip_score) + idx_list = np.roll(idx_list, roll_idx) + elif self.front_view_selection == "clip_score_max": + this_clip_score = ( + F.softmax(self.clip_scores[self.ids[idx]], dim=-1).cpu().numpy() + ) + roll_idx = np.argmax(this_clip_score) + idx_list = np.roll(idx_list, roll_idx) + frames = [] + if not self.use_latents: + try: + for view_idx in idx_list: + frame = Image.open( + self.root_dir + / "gobjaverse" + / self.ids[idx] + / f"{view_idx:05d}/{view_idx:05d}.png" + ) + frames.append(self.transform(frame)) + except: + idx = 0 + frames = [] + for view_idx in idx_list: + frame = Image.open( + self.root_dir + / "gobjaverse" + / self.ids[idx] + / f"{view_idx:05d}/{view_idx:05d}.png" + ) + frames.append(self.transform(frame)) + # a workaround for some bugs in gobjaverse + # use idx=0 and the repeat will be resolved when gathering results, valid number of items can be checked by the len of results + frames = torch.stack(frames, dim=0) + cond = frames[0] + + cond_aug = np.exp( + np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean + ) + + data.update( + { + "frames": frames, + "cond_frames_without_noise": cond, + "cond_aug": torch.as_tensor([cond_aug] * self.n_views), + "cond_frames": cond + cond_aug * torch.randn_like(cond), + "fps_id": torch.as_tensor([self.fps_id] * self.n_views), + "motion_bucket_id": torch.as_tensor( + [self.motion_bucket_id] * self.n_views + ), + "num_video_frames": 24, + "image_only_indicator": torch.as_tensor([0.0] * self.n_views), + } + ) + else: + latents = torch.load(self.latents_dir / f"{self.ids[idx]}.pt")[idx_list] + clip_emb = torch.load(self.clip_dir / f"{self.ids[idx]}.pt")[idx_list][0] + + cond = latents[0] + + cond_aug = np.exp( + np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean + ) + + data.update( + { + "latents": latents, + "cond_frames_without_noise": clip_emb, + "cond_aug": torch.as_tensor([cond_aug] * self.n_views), + "cond_frames": cond + cond_aug * torch.randn_like(cond), + "fps_id": torch.as_tensor([self.fps_id] * self.n_views), + "motion_bucket_id": torch.as_tensor( + [self.motion_bucket_id] * self.n_views + ), + "num_video_frames": 24, + "image_only_indicator": torch.as_tensor([0.0] * self.n_views), + } + ) + + if self.condition_on_elevation: + sample_c2w = read_camera_matrix_single( + self.root_dir / self.ids[idx] / f"00000/00000.json" + ) + elevation = calc_elevation(sample_c2w) + data["elevation"] = torch.as_tensor([elevation] * self.n_views) + + if self.load_pixelnerf: + assert "frames" in data, f"pixelnerf cannot work with latents only mode" + data["pixelnerf_input"] = {} + RTs = [] + intrinsics = [] + for view_idx in idx_list: + meta = ( + self.root_dir + / "gobjaverse" + / self.ids[idx] + / f"{view_idx:05d}/{view_idx:05d}.json" + ) + RTs.append(read_camera_matrix_single(meta)[:3]) + intrinsics.append(read_camera_instrinsics_single(meta, 256, 256)) + RTs = torch.stack(RTs, dim=0) + intrinsics = torch.stack(intrinsics, dim=0) + cameras = build_camera_standard(RTs, intrinsics) + data["pixelnerf_input"]["cameras"] = cameras + + downsampled = [] + for view_idx in idx_list: + frame = Image.open( + self.root_dir + / "gobjaverse" + / self.ids[idx] + / f"{view_idx:05d}/{view_idx:05d}.png" + ).resize((32, 32)) + downsampled.append(to_tensor(blend_white_bg(frame))) + data["pixelnerf_input"]["rgb"] = torch.stack(downsampled, dim=0) + data["pixelnerf_input"]["frames"] = data["frames"] + if self.scale_pose: + c2ws = cameras[..., :16].reshape(-1, 4, 4) + center = c2ws[:, :3, 3].mean(0) + radius = (c2ws[:, :3, 3] - center).norm(dim=-1).max() + scale = 1.5 / radius + c2ws[..., :3, 3] = (c2ws[..., :3, 3] - center) * scale + cameras[..., :16] = c2ws.reshape(-1, 16) + + if self.load_caps: + data["caption"] = self.caps[self.ids[idx]] + data["ids"] = self.ids[idx] + + return data + + def __len__(self): + return len(self.ids) + + def collate_fn(self, batch): + if self.max_n_cond > 1: + n_cond = np.random.randint(1, self.max_n_cond + 1) + if n_cond > 1: + for b in batch: + source_index = [0] + np.random.choice( + np.arange(1, self.n_views), + self.max_n_cond - 1, + replace=False, + ).tolist() + b["pixelnerf_input"]["source_index"] = torch.as_tensor(source_index) + b["pixelnerf_input"]["n_cond"] = n_cond + b["pixelnerf_input"]["source_images"] = b["frames"][source_index] + b["pixelnerf_input"]["source_cameras"] = b["pixelnerf_input"][ + "cameras" + ][source_index] + + return video_collate_fn(batch) + + +class ObjaverseSpiral(Dataset): + def __init__( + self, + root_dir, + split="train", + transform=None, + random_front=False, + max_item=None, + cond_aug_mean=-3.0, + cond_aug_std=0.5, + condition_on_elevation=False, + **unused_kwargs, + ): + self.root_dir = Path(root_dir) + self.split = split + self.random_front = random_front + self.transform = transform + + self.ids = json.load(open(self.root_dir / f"{split}_ids.json", "r")) + self.n_views = 24 + valid_ids = [] + for idx in self.ids: + if (self.root_dir / idx).exists(): + valid_ids.append(idx) + self.ids = valid_ids + + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + self.condition_on_elevation = condition_on_elevation + + if max_item is not None: + self.ids = self.ids[:max_item] + + ## debug + self.ids = self.ids * 10000 + + def __getitem__(self, idx: int): + frames = [] + idx_list = np.arange(self.n_views) + if self.random_front: + roll_idx = np.random.randint(self.n_views) + idx_list = np.roll(idx_list, roll_idx) + for view_idx in idx_list: + frame = Image.open( + self.root_dir / self.ids[idx] / f"{view_idx:05d}/{view_idx:05d}.png" + ) + frames.append(self.transform(frame)) + + # data = {"jpg": torch.stack(frames, dim=0)} # [T, C, H, W] + frames = torch.stack(frames, dim=0) + cond = frames[0] + + cond_aug = np.exp( + np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean + ) + + data = { + "frames": frames, + "cond_frames_without_noise": cond, + "cond_aug": torch.as_tensor([cond_aug] * self.n_views), + "cond_frames": cond + cond_aug * torch.randn_like(cond), + "fps_id": torch.as_tensor([1.0] * self.n_views), + "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), + "num_video_frames": 24, + "image_only_indicator": torch.as_tensor([0.0] * self.n_views), + } + + if self.condition_on_elevation: + sample_c2w = read_camera_matrix_single( + self.root_dir / self.ids[idx] / f"00000/00000.json" + ) + elevation = calc_elevation(sample_c2w) + data["elevation"] = torch.as_tensor([elevation] * self.n_views) + + return data + + def __len__(self): + return len(self.ids) + + +class ObjaverseLVISSpiral(Dataset): + def __init__( + self, + root_dir, + split="train", + transform=None, + random_front=False, + max_item=None, + cond_aug_mean=-3.0, + cond_aug_std=0.5, + condition_on_elevation=False, + use_precomputed_latents=False, + **unused_kwargs, + ): + print("Using LVIS subset") + self.root_dir = Path(root_dir) + self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") + self.split = split + self.random_front = random_front + self.transform = transform + self.use_precomputed_latents = use_precomputed_latents + + self.ids = json.load(open("./assets/lvis_uids.json", "r")) + self.n_views = 18 + valid_ids = [] + for idx in self.ids: + if (self.root_dir / idx).exists(): + valid_ids.append(idx) + self.ids = valid_ids + print("=" * 30) + print("Number of valid ids: ", len(self.ids)) + print("=" * 30) + + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + self.condition_on_elevation = condition_on_elevation + + if max_item is not None: + self.ids = self.ids[:max_item] + + ## debug + self.ids = self.ids * 10000 + + def __getitem__(self, idx: int): + frames = [] + idx_list = np.arange(self.n_views) + if self.random_front: + roll_idx = np.random.randint(self.n_views) + idx_list = np.roll(idx_list, roll_idx) + for view_idx in idx_list: + frame = Image.open( + self.root_dir + / self.ids[idx] + / "elevations_0" + / f"colors_{view_idx * 2}.png" + ) + frames.append(self.transform(frame)) + + frames = torch.stack(frames, dim=0) + cond = frames[0] + + cond_aug = np.exp( + np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean + ) + + data = { + "frames": frames, + "cond_frames_without_noise": cond, + "cond_aug": torch.as_tensor([cond_aug] * self.n_views), + "cond_frames": cond + cond_aug * torch.randn_like(cond), + "fps_id": torch.as_tensor([0.0] * self.n_views), + "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), + "num_video_frames": self.n_views, + "image_only_indicator": torch.as_tensor([0.0] * self.n_views), + } + + if self.use_precomputed_latents: + data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt") + + if self.condition_on_elevation: + # sample_c2w = read_camera_matrix_single( + # self.root_dir / self.ids[idx] / f"00000/00000.json" + # ) + # elevation = calc_elevation(sample_c2w) + # data["elevation"] = torch.as_tensor([elevation] * self.n_views) + assert False, "currently assumes elevation 0" + + return data + + def __len__(self): + return len(self.ids) + + +class ObjaverseALLSpiral(ObjaverseLVISSpiral): + def __init__( + self, + root_dir, + split="train", + transform=None, + random_front=False, + max_item=None, + cond_aug_mean=-3.0, + cond_aug_std=0.5, + condition_on_elevation=False, + use_precomputed_latents=False, + **unused_kwargs, + ): + print("Using ALL objects in Objaverse") + self.root_dir = Path(root_dir) + self.split = split + self.random_front = random_front + self.transform = transform + self.use_precomputed_latents = use_precomputed_latents + self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") + + self.ids = json.load(open("./assets/all_ids.json", "r")) + self.n_views = 18 + valid_ids = [] + for idx in self.ids: + if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir(): + valid_ids.append(idx) + self.ids = valid_ids + print("=" * 30) + print("Number of valid ids: ", len(self.ids)) + print("=" * 30) + + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + self.condition_on_elevation = condition_on_elevation + + if max_item is not None: + self.ids = self.ids[:max_item] + + ## debug + self.ids = self.ids * 10000 + + +class ObjaverseWithPose(Dataset): + def __init__( + self, + root_dir, + split="train", + transform=None, + random_front=False, + max_item=None, + cond_aug_mean=-3.0, + cond_aug_std=0.5, + condition_on_elevation=False, + use_precomputed_latents=False, + **unused_kwargs, + ): + print("Using Objaverse with poses") + self.root_dir = Path(root_dir) + self.split = split + self.random_front = random_front + self.transform = transform + self.use_precomputed_latents = use_precomputed_latents + self.latent_dir = Path("/mnt/vepfs/3Ddataset/render_results/latents512") + + self.ids = json.load(open("./assets/all_ids.json", "r")) + self.n_views = 18 + valid_ids = [] + for idx in self.ids: + if (self.root_dir / idx).exists() and (self.root_dir / idx).is_dir(): + valid_ids.append(idx) + self.ids = valid_ids + print("=" * 30) + print("Number of valid ids: ", len(self.ids)) + print("=" * 30) + + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + self.condition_on_elevation = condition_on_elevation + + def __getitem__(self, idx: int): + frames = [] + idx_list = np.arange(self.n_views) + if self.random_front: + roll_idx = np.random.randint(self.n_views) + idx_list = np.roll(idx_list, roll_idx) + for view_idx in idx_list: + frame = Image.open( + self.root_dir + / self.ids[idx] + / "elevations_0" + / f"colors_{view_idx * 2}.png" + ) + frames.append(self.transform(frame)) + + frames = torch.stack(frames, dim=0) + cond = frames[0] + + cond_aug = np.exp( + np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean + ) + + data = { + "frames": frames, + "cond_frames_without_noise": cond, + "cond_aug": torch.as_tensor([cond_aug] * self.n_views), + "cond_frames": cond + cond_aug * torch.randn_like(cond), + "fps_id": torch.as_tensor([0.0] * self.n_views), + "motion_bucket_id": torch.as_tensor([300.0] * self.n_views), + "num_video_frames": self.n_views, + "image_only_indicator": torch.as_tensor([0.0] * self.n_views), + } + + if self.use_precomputed_latents: + data["latents"] = torch.load(self.latent_dir / f"{self.ids[idx]}.pt") + + if self.condition_on_elevation: + assert False, "currently assumes elevation 0" + + return data + + +class LatentObjaverse(Dataset): + def __init__( + self, + root_dir, + split="train", + random_front=False, + subset="lvis", + fps_id=1.0, + motion_bucket_id=300.0, + cond_aug_mean=-3.0, + cond_aug_std=0.5, + **unused_kwargs, + ): + self.root_dir = Path(root_dir) + self.split = split + self.random_front = random_front + self.ids = json.load(open(Path("./assets") / f"{subset}_ids.json", "r")) + self.clip_emb_dir = self.root_dir / ".." / "clip_emb512" + self.n_views = 18 + self.fps_id = fps_id + self.motion_bucket_id = motion_bucket_id + self.cond_aug_mean = cond_aug_mean + self.cond_aug_std = cond_aug_std + if self.random_front: + print("Using a random view as front view") + + valid_ids = [] + for idx in self.ids: + if (self.root_dir / f"{idx}.pt").exists() and ( + self.clip_emb_dir / f"{idx}.pt" + ).exists(): + valid_ids.append(idx) + self.ids = valid_ids + print("=" * 30) + print("Number of valid ids: ", len(self.ids)) + print("=" * 30) + + def __getitem__(self, idx: int): + uid = self.ids[idx] + idx_list = torch.arange(self.n_views) + latents = torch.load(self.root_dir / f"{uid}.pt") + clip_emb = torch.load(self.clip_emb_dir / f"{uid}.pt") + if self.random_front: + idx_list = torch.roll(idx_list, np.random.randint(self.n_views)) + latents = latents[idx_list] + clip_emb = clip_emb[idx_list][0] + + cond_aug = np.exp( + np.random.randn(1)[0] * self.cond_aug_std + self.cond_aug_mean + ) + cond = latents[0] + + data = { + "latents": latents, + "cond_frames_without_noise": clip_emb, + "cond_frames": cond + cond_aug * torch.randn_like(cond), + "fps_id": torch.as_tensor([self.fps_id] * self.n_views), + "motion_bucket_id": torch.as_tensor([self.motion_bucket_id] * self.n_views), + "cond_aug": torch.as_tensor([cond_aug] * self.n_views), + "num_video_frames": self.n_views, + "image_only_indicator": torch.as_tensor([0.0] * self.n_views), + } + + return data + + def __len__(self): + return len(self.ids) + + +class ObjaverseSpiralDataset(LightningDataModule): + def __init__( + self, + root_dir, + random_front=False, + batch_size=2, + num_workers=10, + prefetch_factor=2, + shuffle=True, + max_item=None, + dataset_cls="richdreamer", + reso: int = 256, + **kwargs, + ) -> None: + super().__init__() + + self.batch_size = batch_size + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.shuffle = shuffle + self.max_item = max_item + + self.transform = Compose( + [ + blend_white_bg, + Resize((reso, reso)), + ToTensor(), + Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + + data_cls = { + "richdreamer": ObjaverseSpiral, + "lvis": ObjaverseLVISSpiral, + "shengshu_all": ObjaverseALLSpiral, + "latent": LatentObjaverse, + "gobjaverse": GObjaverse, + }[dataset_cls] + + self.train_dataset = data_cls( + root_dir=root_dir, + split="train", + random_front=random_front, + transform=self.transform, + max_item=self.max_item, + **kwargs, + ) + self.test_dataset = data_cls( + root_dir=root_dir, + split="val", + random_front=random_front, + transform=self.transform, + max_item=self.max_item, + **kwargs, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=video_collate_fn + if not hasattr(self.train_dataset, "collate_fn") + else self.train_dataset.collate_fn, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=video_collate_fn + if not hasattr(self.test_dataset, "collate_fn") + else self.train_dataset.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + prefetch_factor=self.prefetch_factor, + collate_fn=video_collate_fn + if not hasattr(self.test_dataset, "collate_fn") + else self.train_dataset.collate_fn, + ) diff --git a/sgm/inference/api.py b/sgm/inference/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a359a67bcd9740acc9e320d2f26dc6a3befb36e0 --- /dev/null +++ b/sgm/inference/api.py @@ -0,0 +1,385 @@ +import pathlib +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Optional + +from omegaconf import OmegaConf + +from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img, + do_sample) +from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler, + DPMPP2SAncestralSampler, + EulerAncestralSampler, + EulerEDMSampler, + HeunEDMSampler, + LinearMultistepSampler) +from sgm.util import load_model_from_config + + +class ModelArchitecture(str, Enum): + SD_2_1 = "stable-diffusion-v2-1" + SD_2_1_768 = "stable-diffusion-v2-1-768" + SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base" + SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner" + SDXL_V1_BASE = "stable-diffusion-xl-v1-base" + SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner" + + +class Sampler(str, Enum): + EULER_EDM = "EulerEDMSampler" + HEUN_EDM = "HeunEDMSampler" + EULER_ANCESTRAL = "EulerAncestralSampler" + DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler" + DPMPP2M = "DPMPP2MSampler" + LINEAR_MULTISTEP = "LinearMultistepSampler" + + +class Discretization(str, Enum): + LEGACY_DDPM = "LegacyDDPMDiscretization" + EDM = "EDMDiscretization" + + +class Guider(str, Enum): + VANILLA = "VanillaCFG" + IDENTITY = "IdentityGuider" + + +class Thresholder(str, Enum): + NONE = "None" + + +@dataclass +class SamplingParams: + width: int = 1024 + height: int = 1024 + steps: int = 50 + sampler: Sampler = Sampler.DPMPP2M + discretization: Discretization = Discretization.LEGACY_DDPM + guider: Guider = Guider.VANILLA + thresholder: Thresholder = Thresholder.NONE + scale: float = 6.0 + aesthetic_score: float = 5.0 + negative_aesthetic_score: float = 5.0 + img2img_strength: float = 1.0 + orig_width: int = 1024 + orig_height: int = 1024 + crop_coords_top: int = 0 + crop_coords_left: int = 0 + sigma_min: float = 0.0292 + sigma_max: float = 14.6146 + rho: float = 3.0 + s_churn: float = 0.0 + s_tmin: float = 0.0 + s_tmax: float = 999.0 + s_noise: float = 1.0 + eta: float = 1.0 + order: int = 4 + + +@dataclass +class SamplingSpec: + width: int + height: int + channels: int + factor: int + is_legacy: bool + config: str + ckpt: str + is_guided: bool + + +model_specs = { + ModelArchitecture.SD_2_1: SamplingSpec( + height=512, + width=512, + channels=4, + factor=8, + is_legacy=True, + config="sd_2_1.yaml", + ckpt="v2-1_512-ema-pruned.safetensors", + is_guided=True, + ), + ModelArchitecture.SD_2_1_768: SamplingSpec( + height=768, + width=768, + channels=4, + factor=8, + is_legacy=True, + config="sd_2_1_768.yaml", + ckpt="v2-1_768-ema-pruned.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=False, + config="sd_xl_base.yaml", + ckpt="sd_xl_base_0.9.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=True, + config="sd_xl_refiner.yaml", + ckpt="sd_xl_refiner_0.9.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V1_BASE: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=False, + config="sd_xl_base.yaml", + ckpt="sd_xl_base_1.0.safetensors", + is_guided=True, + ), + ModelArchitecture.SDXL_V1_REFINER: SamplingSpec( + height=1024, + width=1024, + channels=4, + factor=8, + is_legacy=True, + config="sd_xl_refiner.yaml", + ckpt="sd_xl_refiner_1.0.safetensors", + is_guided=True, + ), +} + + +class SamplingPipeline: + def __init__( + self, + model_id: ModelArchitecture, + model_path="checkpoints", + config_path="configs/inference", + device="cuda", + use_fp16=True, + ) -> None: + if model_id not in model_specs: + raise ValueError(f"Model {model_id} not supported") + self.model_id = model_id + self.specs = model_specs[self.model_id] + self.config = str(pathlib.Path(config_path, self.specs.config)) + self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) + self.device = device + self.model = self._load_model(device=device, use_fp16=use_fp16) + + def _load_model(self, device="cuda", use_fp16=True): + config = OmegaConf.load(self.config) + model = load_model_from_config(config, self.ckpt) + if model is None: + raise ValueError(f"Model {self.model_id} could not be loaded") + model.to(device) + if use_fp16: + model.conditioner.half() + model.model.half() + return model + + def text_to_image( + self, + params: SamplingParams, + prompt: str, + negative_prompt: str = "", + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + value_dict = asdict(params) + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + value_dict["target_width"] = params.width + value_dict["target_height"] = params.height + return do_sample( + self.model, + sampler, + value_dict, + samples, + params.height, + params.width, + self.specs.channels, + self.specs.factor, + force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], + return_latents=return_latents, + filter=None, + ) + + def image_to_image( + self, + params: SamplingParams, + image, + prompt: str, + negative_prompt: str = "", + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + + if params.img2img_strength < 1.0: + sampler.discretization = Img2ImgDiscretizationWrapper( + sampler.discretization, + strength=params.img2img_strength, + ) + height, width = image.shape[2], image.shape[3] + value_dict = asdict(params) + value_dict["prompt"] = prompt + value_dict["negative_prompt"] = negative_prompt + value_dict["target_width"] = width + value_dict["target_height"] = height + return do_img2img( + image, + self.model, + sampler, + value_dict, + samples, + force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [], + return_latents=return_latents, + filter=None, + ) + + def refiner( + self, + params: SamplingParams, + image, + prompt: str, + negative_prompt: Optional[str] = None, + samples: int = 1, + return_latents: bool = False, + ): + sampler = get_sampler_config(params) + value_dict = { + "orig_width": image.shape[3] * 8, + "orig_height": image.shape[2] * 8, + "target_width": image.shape[3] * 8, + "target_height": image.shape[2] * 8, + "prompt": prompt, + "negative_prompt": negative_prompt, + "crop_coords_top": 0, + "crop_coords_left": 0, + "aesthetic_score": 6.0, + "negative_aesthetic_score": 2.5, + } + + return do_img2img( + image, + self.model, + sampler, + value_dict, + samples, + skip_encode=True, + return_latents=return_latents, + filter=None, + ) + + +def get_guider_config(params: SamplingParams): + if params.guider == Guider.IDENTITY: + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } + elif params.guider == Guider.VANILLA: + scale = params.scale + + thresholder = params.thresholder + + if thresholder == Thresholder.NONE: + dyn_thresh_config = { + "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" + } + else: + raise NotImplementedError + + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", + "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, + } + else: + raise NotImplementedError + return guider_config + + +def get_discretization_config(params: SamplingParams): + if params.discretization == Discretization.LEGACY_DDPM: + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", + } + elif params.discretization == Discretization.EDM: + discretization_config = { + "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", + "params": { + "sigma_min": params.sigma_min, + "sigma_max": params.sigma_max, + "rho": params.rho, + }, + } + else: + raise ValueError(f"unknown discretization {params.discretization}") + return discretization_config + + +def get_sampler_config(params: SamplingParams): + discretization_config = get_discretization_config(params) + guider_config = get_guider_config(params) + sampler = None + if params.sampler == Sampler.EULER_EDM: + return EulerEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.HEUN_EDM: + return HeunEDMSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + s_churn=params.s_churn, + s_tmin=params.s_tmin, + s_tmax=params.s_tmax, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.EULER_ANCESTRAL: + return EulerAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.DPMPP2S_ANCESTRAL: + return DPMPP2SAncestralSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + eta=params.eta, + s_noise=params.s_noise, + verbose=True, + ) + if params.sampler == Sampler.DPMPP2M: + return DPMPP2MSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + verbose=True, + ) + if params.sampler == Sampler.LINEAR_MULTISTEP: + return LinearMultistepSampler( + num_steps=params.steps, + discretization_config=discretization_config, + guider_config=guider_config, + order=params.order, + verbose=True, + ) + + raise ValueError(f"unknown sampler {params.sampler}!") diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..31b0ec3dc414bf522261e35f73805810cd35582d --- /dev/null +++ b/sgm/inference/helpers.py @@ -0,0 +1,305 @@ +import math +import os +from typing import List, Optional, Union + +import numpy as np +import torch +from einops import rearrange +from imwatermark import WatermarkEncoder +from omegaconf import ListConfig +from PIL import Image +from torch import autocast + +from sgm.util import append_dims + + +class WatermarkEmbedder: + def __init__(self, watermark): + self.watermark = watermark + self.num_bits = len(WATERMARK_BITS) + self.encoder = WatermarkEncoder() + self.encoder.set_watermark("bits", self.watermark) + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + """ + Adds a predefined watermark to the input image + + Args: + image: ([N,] B, RGB, H, W) in range [0, 1] + + Returns: + same as input but watermarked + """ + squeeze = len(image.shape) == 4 + if squeeze: + image = image[None, ...] + n = image.shape[0] + image_np = rearrange( + (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" + ).numpy()[:, :, :, ::-1] + # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] + # watermarking libary expects input as cv2 BGR format + for k in range(image_np.shape[0]): + image_np[k] = self.encoder.encode(image_np[k], "dwtDct") + image = torch.from_numpy( + rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) + ).to(image.device) + image = torch.clamp(image / 255, min=0.0, max=1.0) + if squeeze: + image = image[0] + return image + + +# A fixed 48-bit message that was choosen at random +# WATERMARK_MESSAGE = 0xB3EC907BB19E +WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110 +# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 +WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] +embed_watermark = WatermarkEmbedder(WATERMARK_BITS) + + +def get_unique_embedder_keys_from_conditioner(conditioner): + return list({x.input_key for x in conditioner.embedders}) + + +def perform_save_locally(save_path, samples): + os.makedirs(os.path.join(save_path), exist_ok=True) + base_count = len(os.listdir(os.path.join(save_path))) + samples = embed_watermark(samples) + for sample in samples: + sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") + Image.fromarray(sample.astype(np.uint8)).save( + os.path.join(save_path, f"{base_count:09}.png") + ) + base_count += 1 + + +class Img2ImgDiscretizationWrapper: + """ + wraps a discretizer, and prunes the sigmas + params: + strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned) + """ + + def __init__(self, discretization, strength: float = 1.0): + self.discretization = discretization + self.strength = strength + assert 0.0 <= self.strength <= 1.0 + + def __call__(self, *args, **kwargs): + # sigmas start large first, and decrease then + sigmas = self.discretization(*args, **kwargs) + print(f"sigmas after discretization, before pruning img2img: ", sigmas) + sigmas = torch.flip(sigmas, (0,)) + sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)] + print("prune index:", max(int(self.strength * len(sigmas)), 1)) + sigmas = torch.flip(sigmas, (0,)) + print(f"sigmas after pruning: ", sigmas) + return sigmas + + +def do_sample( + model, + sampler, + value_dict, + num_samples, + H, + W, + C, + F, + force_uc_zero_embeddings: Optional[List] = None, + batch2model_input: Optional[List] = None, + return_latents=False, + filter=None, + device="cuda", +): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + if batch2model_input is None: + batch2model_input = [] + + with torch.no_grad(): + with autocast(device) as precision_scope: + with model.ema_scope(): + num_samples = [num_samples] + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + num_samples, + ) + for key in batch: + if isinstance(batch[key], torch.Tensor): + print(key, batch[key].shape) + elif isinstance(batch[key], list): + print(key, [len(l) for l in batch[key]]) + else: + print(key, batch[key]) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + if not k == "crossattn": + c[k], uc[k] = map( + lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc) + ) + + additional_model_inputs = {} + for k in batch2model_input: + additional_model_inputs[k] = batch[k] + + shape = (math.prod(num_samples), C, H // F, W // F) + randn = torch.randn(shape).to(device) + + def denoiser(input, sigma, c): + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) + + samples_z = sampler(denoiser, randn, cond=c, uc=uc) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + return samples + + +def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): + # Hardcoded demo setups; might undergo some changes in the future + + batch = {} + batch_uc = {} + + for key in keys: + if key == "txt": + batch["txt"] = ( + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + batch_uc["txt"] = ( + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() + ) + elif key == "original_size_as_tuple": + batch["original_size_as_tuple"] = ( + torch.tensor([value_dict["orig_height"], value_dict["orig_width"]]) + .to(device) + .repeat(*N, 1) + ) + elif key == "crop_coords_top_left": + batch["crop_coords_top_left"] = ( + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) + .to(device) + .repeat(*N, 1) + ) + elif key == "aesthetic_score": + batch["aesthetic_score"] = ( + torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) + ) + batch_uc["aesthetic_score"] = ( + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) + ) + + elif key == "target_size_as_tuple": + batch["target_size_as_tuple"] = ( + torch.tensor([value_dict["target_height"], value_dict["target_width"]]) + .to(device) + .repeat(*N, 1) + ) + else: + batch[key] = value_dict[key] + + for key in batch.keys(): + if key not in batch_uc and isinstance(batch[key], torch.Tensor): + batch_uc[key] = torch.clone(batch[key]) + return batch, batch_uc + + +def get_input_image_tensor(image: Image.Image, device="cuda"): + w, h = image.size + print(f"loaded input image of size ({w}, {h})") + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 + image = image.resize((width, height)) + image_array = np.array(image.convert("RGB")) + image_array = image_array[None].transpose(0, 3, 1, 2) + image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0 + return image_tensor.to(device) + + +def do_img2img( + img, + model, + sampler, + value_dict, + num_samples, + force_uc_zero_embeddings=[], + additional_kwargs={}, + offset_noise_level: float = 0.0, + return_latents=False, + skip_encode=False, + filter=None, + device="cuda", +): + with torch.no_grad(): + with autocast(device) as precision_scope: + with model.ema_scope(): + batch, batch_uc = get_batch( + get_unique_embedder_keys_from_conditioner(model.conditioner), + value_dict, + [num_samples], + ) + c, uc = model.conditioner.get_unconditional_conditioning( + batch, + batch_uc=batch_uc, + force_uc_zero_embeddings=force_uc_zero_embeddings, + ) + + for k in c: + c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc)) + + for k in additional_kwargs: + c[k] = uc[k] = additional_kwargs[k] + if skip_encode: + z = img + else: + z = model.encode_first_stage(img) + noise = torch.randn_like(z) + sigmas = sampler.discretization(sampler.num_steps) + sigma = sigmas[0].to(z.device) + + if offset_noise_level > 0.0: + noise = noise + offset_noise_level * append_dims( + torch.randn(z.shape[0], device=z.device), z.ndim + ) + noised_z = z + noise * append_dims(sigma, z.ndim) + noised_z = noised_z / torch.sqrt( + 1.0 + sigmas[0] ** 2.0 + ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + + def denoiser(x, sigma, c): + return model.denoiser(model.model, x, sigma, c) + + samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) + samples_x = model.decode_first_stage(samples_z) + samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) + + if filter is not None: + samples = filter(samples) + + if return_latents: + return samples, samples_z + return samples diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f4d384c1fcaff0df13e0564450d3fa972ace42 --- /dev/null +++ b/sgm/lr_scheduler.py @@ -0,0 +1,135 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + + def __init__( + self, + warm_up_steps, + lr_min, + lr_max, + lr_start, + max_decay_steps, + verbosity_interval=0, + ): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0.0 + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = ( + self.lr_max - self.lr_start + ) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / ( + self.lr_max_decay_steps - self.lr_warm_up_steps + ) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi) + ) + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + + def __init__( + self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0 + ): + assert ( + len(warm_up_steps) + == len(f_min) + == len(f_max) + == len(f_start) + == len(cycle_lengths) + ) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / ( + self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle] + ) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( + 1 + np.cos(t * np.pi) + ) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ + cycle + ] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * ( + self.cycle_lengths[cycle] - n + ) / (self.cycle_lengths[cycle]) + self.last_f = f + return f diff --git a/sgm/models/__init__.py b/sgm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c410b3747afc208e4204c8f140170e0a7808eace --- /dev/null +++ b/sgm/models/__init__.py @@ -0,0 +1,2 @@ +from .autoencoder import AutoencodingEngine +from .diffusion import DiffusionEngine diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2949b91011a2be7a6b8ca17ce260812f20ce8b75 --- /dev/null +++ b/sgm/models/autoencoder.py @@ -0,0 +1,615 @@ +import logging +import math +import re +from abc import abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytorch_lightning as pl +import torch +import torch.nn as nn +from einops import rearrange +from packaging import version + +from ..modules.autoencoding.regularizers import AbstractRegularizer +from ..modules.ema import LitEma +from ..util import (default, get_nested_attribute, get_obj_from_str, + instantiate_from_config) + +logpy = logging.getLogger(__name__) + + +class AbstractAutoencoder(pl.LightningModule): + """ + This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, + unCLIP models, etc. Hence, it is fairly general, and specific features + (e.g. discriminator training, encoding, decoding) must be implemented in subclasses. + """ + + def __init__( + self, + ema_decay: Union[None, float] = None, + monitor: Union[None, str] = None, + input_key: str = "jpg", + ): + super().__init__() + + self.input_key = input_key + self.use_ema = ema_decay is not None + if monitor is not None: + self.monitor = monitor + + if self.use_ema: + self.model_ema = LitEma(self, decay=ema_decay) + logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if version.parse(torch.__version__) >= version.parse("2.0.0"): + self.automatic_optimization = False + + def apply_ckpt(self, ckpt: Union[None, str, dict]): + if ckpt is None: + return + if isinstance(ckpt, str): + ckpt = { + "target": "sgm.modules.checkpoint.CheckpointEngine", + "params": {"ckpt_path": ckpt}, + } + engine = instantiate_from_config(ckpt) + engine(self) + + @abstractmethod + def get_input(self, batch) -> Any: + raise NotImplementedError() + + def on_train_batch_end(self, *args, **kwargs): + # for EMA computation + if self.use_ema: + self.model_ema(self) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + logpy.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + logpy.info(f"{context}: Restored training weights") + + @abstractmethod + def encode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("encode()-method of abstract base class called") + + @abstractmethod + def decode(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError("decode()-method of abstract base class called") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) + + def configure_optimizers(self) -> Any: + raise NotImplementedError() + + +class AutoencodingEngine(AbstractAutoencoder): + """ + Base class for all image autoencoders that we train, like VQGAN or AutoencoderKL + (we also restore them explicitly as special cases for legacy reasons). + Regularizations such as KL or VQ are moved to the regularizer class. + """ + + def __init__( + self, + *args, + encoder_config: Dict, + decoder_config: Dict, + loss_config: Dict, + regularizer_config: Dict, + optimizer_config: Union[Dict, None] = None, + lr_g_factor: float = 1.0, + trainable_ae_params: Optional[List[List[str]]] = None, + ae_optimizer_args: Optional[List[dict]] = None, + trainable_disc_params: Optional[List[List[str]]] = None, + disc_optimizer_args: Optional[List[dict]] = None, + disc_start_iter: int = 0, + diff_boost_factor: float = 3.0, + ckpt_engine: Union[None, str, dict] = None, + ckpt_path: Optional[str] = None, + additional_decode_keys: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.automatic_optimization = False # pytorch lightning + + self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) + self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) + self.loss: torch.nn.Module = instantiate_from_config(loss_config) + self.regularization: AbstractRegularizer = instantiate_from_config( + regularizer_config + ) + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.Adam"} + ) + self.diff_boost_factor = diff_boost_factor + self.disc_start_iter = disc_start_iter + self.lr_g_factor = lr_g_factor + self.trainable_ae_params = trainable_ae_params + if self.trainable_ae_params is not None: + self.ae_optimizer_args = default( + ae_optimizer_args, + [{} for _ in range(len(self.trainable_ae_params))], + ) + assert len(self.ae_optimizer_args) == len(self.trainable_ae_params) + else: + self.ae_optimizer_args = [{}] # makes type consitent + + self.trainable_disc_params = trainable_disc_params + if self.trainable_disc_params is not None: + self.disc_optimizer_args = default( + disc_optimizer_args, + [{} for _ in range(len(self.trainable_disc_params))], + ) + assert len(self.disc_optimizer_args) == len(self.trainable_disc_params) + else: + self.disc_optimizer_args = [{}] # makes type consitent + + if ckpt_path is not None: + assert ckpt_engine is None, "Can't set ckpt_engine and ckpt_path" + logpy.warn("Checkpoint path is deprecated, use `checkpoint_egnine` instead") + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + self.additional_decode_keys = set(default(additional_decode_keys, [])) + + def get_input(self, batch: Dict) -> torch.Tensor: + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in channels-first + # format (e.g., bchw instead if bhwc) + return batch[self.input_key] + + def get_autoencoder_params(self) -> list: + params = [] + if hasattr(self.loss, "get_trainable_autoencoder_parameters"): + params += list(self.loss.get_trainable_autoencoder_parameters()) + if hasattr(self.regularization, "get_trainable_parameters"): + params += list(self.regularization.get_trainable_parameters()) + params = params + list(self.encoder.parameters()) + params = params + list(self.decoder.parameters()) + return params + + def get_discriminator_params(self) -> list: + if hasattr(self.loss, "get_trainable_parameters"): + params = list(self.loss.get_trainable_parameters()) # e.g., discriminator + else: + params = [] + return params + + def get_last_layer(self): + return self.decoder.get_last_layer() + + def encode( + self, + x: torch.Tensor, + return_reg_log: bool = False, + unregularized: bool = False, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + z = self.encoder(x) + if unregularized: + return z, dict() + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.decoder(z, **kwargs) + return x + + def forward( + self, x: torch.Tensor, **additional_decode_kwargs + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + z, reg_log = self.encode(x, return_reg_log=True) + dec = self.decode(z, **additional_decode_kwargs) + return z, dec, reg_log + + def inner_training_step( + self, batch: dict, batch_idx: int, optimizer_idx: int = 0 + ) -> torch.Tensor: + x = self.get_input(batch) + additional_decode_kwargs = { + key: batch[key] for key in self.additional_decode_keys.intersection(batch) + } + z, xrec, regularization_log = self(x, **additional_decode_kwargs) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": optimizer_idx, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "train", + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + + if optimizer_idx == 0: + # autoencode + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {"train/loss/rec": aeloss.detach()} + + self.log_dict( + log_dict_ae, + prog_bar=False, + logger=True, + on_step=True, + on_epoch=True, + sync_dist=False, + ) + self.log( + "loss", + aeloss.mean().detach(), + prog_bar=True, + logger=False, + on_epoch=False, + on_step=True, + ) + return aeloss + elif optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + # -> discriminator always needs to return a tuple + self.log_dict( + log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True + ) + return discloss + else: + raise NotImplementedError(f"Unknown optimizer {optimizer_idx}") + + def training_step(self, batch: dict, batch_idx: int): + opts = self.optimizers() + if not isinstance(opts, list): + # Non-adversarial case + opts = [opts] + optimizer_idx = batch_idx % len(opts) + if self.global_step < self.disc_start_iter: + optimizer_idx = 0 + opt = opts[optimizer_idx] + opt.zero_grad() + with opt.toggle_model(): + loss = self.inner_training_step( + batch, batch_idx, optimizer_idx=optimizer_idx + ) + self.manual_backward(loss) + opt.step() + + def validation_step(self, batch: dict, batch_idx: int) -> Dict: + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + log_dict.update(log_dict_ema) + return log_dict + + def _validation_step(self, batch: dict, batch_idx: int, postfix: str = "") -> Dict: + x = self.get_input(batch) + + z, xrec, regularization_log = self(x) + if hasattr(self.loss, "forward_keys"): + extra_info = { + "z": z, + "optimizer_idx": 0, + "global_step": self.global_step, + "last_layer": self.get_last_layer(), + "split": "val" + postfix, + "regularization_log": regularization_log, + "autoencoder": self, + } + extra_info = {k: extra_info[k] for k in self.loss.forward_keys} + else: + extra_info = dict() + out_loss = self.loss(x, xrec, **extra_info) + if isinstance(out_loss, tuple): + aeloss, log_dict_ae = out_loss + else: + # simple loss function + aeloss = out_loss + log_dict_ae = {f"val{postfix}/loss/rec": aeloss.detach()} + full_log_dict = log_dict_ae + + if "optimizer_idx" in extra_info: + extra_info["optimizer_idx"] = 1 + discloss, log_dict_disc = self.loss(x, xrec, **extra_info) + full_log_dict.update(log_dict_disc) + self.log( + f"val{postfix}/loss/rec", + log_dict_ae[f"val{postfix}/loss/rec"], + sync_dist=True, + ) + self.log_dict(full_log_dict, sync_dist=True) + return full_log_dict + + def get_param_groups( + self, parameter_names: List[List[str]], optimizer_args: List[dict] + ) -> Tuple[List[Dict[str, Any]], int]: + groups = [] + num_params = 0 + for names, args in zip(parameter_names, optimizer_args): + params = [] + for pattern_ in names: + pattern_params = [] + pattern = re.compile(pattern_) + for p_name, param in self.named_parameters(): + if re.match(pattern, p_name): + pattern_params.append(param) + num_params += param.numel() + if len(pattern_params) == 0: + logpy.warn(f"Did not find parameters for pattern {pattern_}") + params.extend(pattern_params) + groups.append({"params": params, **args}) + return groups, num_params + + def configure_optimizers(self) -> List[torch.optim.Optimizer]: + if self.trainable_ae_params is None: + ae_params = self.get_autoencoder_params() + else: + ae_params, num_ae_params = self.get_param_groups( + self.trainable_ae_params, self.ae_optimizer_args + ) + logpy.info(f"Number of trainable autoencoder parameters: {num_ae_params:,}") + if self.trainable_disc_params is None: + disc_params = self.get_discriminator_params() + else: + disc_params, num_disc_params = self.get_param_groups( + self.trainable_disc_params, self.disc_optimizer_args + ) + logpy.info( + f"Number of trainable discriminator parameters: {num_disc_params:,}" + ) + opt_ae = self.instantiate_optimizer_from_config( + ae_params, + default(self.lr_g_factor, 1.0) * self.learning_rate, + self.optimizer_config, + ) + opts = [opt_ae] + if len(disc_params) > 0: + opt_disc = self.instantiate_optimizer_from_config( + disc_params, self.learning_rate, self.optimizer_config + ) + opts.append(opt_disc) + + return opts + + @torch.no_grad() + def log_images( + self, batch: dict, additional_log_kwargs: Optional[Dict] = None, **kwargs + ) -> dict: + log = dict() + additional_decode_kwargs = {} + x = self.get_input(batch) + additional_decode_kwargs.update( + {key: batch[key] for key in self.additional_decode_keys.intersection(batch)} + ) + + _, xrec, _ = self(x, **additional_decode_kwargs) + log["inputs"] = x + log["reconstructions"] = xrec + diff = 0.5 * torch.abs(torch.clamp(xrec, -1.0, 1.0) - x) + diff.clamp_(0, 1.0) + log["diff"] = 2.0 * diff - 1.0 + # diff_boost shows location of small errors, by boosting their + # brightness. + log["diff_boost"] = ( + 2.0 * torch.clamp(self.diff_boost_factor * diff, 0.0, 1.0) - 1 + ) + if hasattr(self.loss, "log_images"): + log.update(self.loss.log_images(x, xrec)) + with self.ema_scope(): + _, xrec_ema, _ = self(x, **additional_decode_kwargs) + log["reconstructions_ema"] = xrec_ema + diff_ema = 0.5 * torch.abs(torch.clamp(xrec_ema, -1.0, 1.0) - x) + diff_ema.clamp_(0, 1.0) + log["diff_ema"] = 2.0 * diff_ema - 1.0 + log["diff_boost_ema"] = ( + 2.0 * torch.clamp(self.diff_boost_factor * diff_ema, 0.0, 1.0) - 1 + ) + if additional_log_kwargs: + additional_decode_kwargs.update(additional_log_kwargs) + _, xrec_add, _ = self(x, **additional_decode_kwargs) + log_str = "reconstructions-" + "-".join( + [f"{key}={additional_log_kwargs[key]}" for key in additional_log_kwargs] + ) + log[log_str] = xrec_add + return log + + +class AutoencodingEngineLegacy(AutoencodingEngine): + def __init__(self, embed_dim: int, **kwargs): + self.max_batch_size = kwargs.pop("max_batch_size", None) + ddconfig = kwargs.pop("ddconfig") + ckpt_path = kwargs.pop("ckpt_path", None) + ckpt_engine = kwargs.pop("ckpt_engine", None) + super().__init__( + encoder_config={ + "target": "sgm.modules.diffusionmodules.model.Encoder", + "params": ddconfig, + }, + decoder_config={ + "target": "sgm.modules.diffusionmodules.model.Decoder", + "params": ddconfig, + }, + **kwargs, + ) + self.quant_conv = torch.nn.Conv2d( + (1 + ddconfig["double_z"]) * ddconfig["z_channels"], + (1 + ddconfig["double_z"]) * embed_dim, + 1, + ) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + + self.apply_ckpt(default(ckpt_path, ckpt_engine)) + + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params + + def encode( + self, x: torch.Tensor, return_reg_log: bool = False + ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: + if self.max_batch_size is None: + z = self.encoder(x) + z = self.quant_conv(z) + else: + N = x.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + z = list() + for i_batch in range(n_batches): + z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs]) + z_batch = self.quant_conv(z_batch) + z.append(z_batch) + z = torch.cat(z, 0) + + z, reg_log = self.regularization(z) + if return_reg_log: + return z, reg_log + return z + + def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.max_batch_size is None: + dec = self.post_quant_conv(z) + dec = self.decoder(dec, **decoder_kwargs) + else: + N = z.shape[0] + bs = self.max_batch_size + n_batches = int(math.ceil(N / bs)) + dec = list() + for i_batch in range(n_batches): + dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs]) + dec_batch = self.decoder(dec_batch, **decoder_kwargs) + dec.append(dec_batch) + dec = torch.cat(dec, 0) + + return dec + + +class AutoencoderKL(AutoencodingEngineLegacy): + def __init__(self, **kwargs): + if "lossconfig" in kwargs: + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "sgm.modules.autoencoding.regularizers" + ".DiagonalGaussianRegularizer" + ) + }, + **kwargs, + ) + + +class AutoencoderLegacyVQ(AutoencodingEngineLegacy): + def __init__( + self, + embed_dim: int, + n_embed: int, + sane_index_shape: bool = False, + **kwargs, + ): + if "lossconfig" in kwargs: + logpy.warn(f"Parameter `lossconfig` is deprecated, use `loss_config`.") + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "sgm.modules.autoencoding.regularizers.quantize" ".VectorQuantizer" + ), + "params": { + "n_e": n_embed, + "e_dim": embed_dim, + "sane_index_shape": sane_index_shape, + }, + }, + **kwargs, + ) + + +class IdentityFirstStage(AbstractAutoencoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def get_input(self, x: Any) -> Any: + return x + + def encode(self, x: Any, *args, **kwargs) -> Any: + return x + + def decode(self, x: Any, *args, **kwargs) -> Any: + return x + + +class AEIntegerWrapper(nn.Module): + def __init__( + self, + model: nn.Module, + shape: Union[None, Tuple[int, int], List[int]] = (16, 16), + regularization_key: str = "regularization", + encoder_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__() + self.model = model + assert hasattr(model, "encode") and hasattr( + model, "decode" + ), "Need AE interface" + self.regularization = get_nested_attribute(model, regularization_key) + self.shape = shape + self.encoder_kwargs = default(encoder_kwargs, {"return_reg_log": True}) + + def encode(self, x) -> torch.Tensor: + assert ( + not self.training + ), f"{self.__class__.__name__} only supports inference currently" + _, log = self.model.encode(x, **self.encoder_kwargs) + assert isinstance(log, dict) + inds = log["min_encoding_indices"] + return rearrange(inds, "b ... -> b (...)") + + def decode( + self, inds: torch.Tensor, shape: Union[None, tuple, list] = None + ) -> torch.Tensor: + # expect inds shape (b, s) with s = h*w + shape = default(shape, self.shape) # Optional[(h, w)] + if shape is not None: + assert len(shape) == 2, f"Unhandeled shape {shape}" + inds = rearrange(inds, "b (h w) -> b h w", h=shape[0], w=shape[1]) + h = self.regularization.get_codebook_entry(inds) # (b, h, w, c) + h = rearrange(h, "b h w c -> b c h w") + return self.model.decode(h) + + +class AutoencoderKLModeOnly(AutoencodingEngineLegacy): + def __init__(self, **kwargs): + if "lossconfig" in kwargs: + kwargs["loss_config"] = kwargs.pop("lossconfig") + super().__init__( + regularizer_config={ + "target": ( + "sgm.modules.autoencoding.regularizers" + ".DiagonalGaussianRegularizer" + ), + "params": {"sample": False}, + }, + **kwargs, + ) diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..41a0f4a7c6a7ed49e2d2538879d47d18ede16cba --- /dev/null +++ b/sgm/models/diffusion.py @@ -0,0 +1,358 @@ +import math +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytorch_lightning as pl +import torch +from omegaconf import ListConfig, OmegaConf +from safetensors.torch import load_file as load_safetensors +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange + +from ..modules import UNCONDITIONAL_CONFIG +from ..modules.autoencoding.temporal_ae import VideoDecoder +from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER +from ..modules.ema import LitEma +from ..util import ( + default, + disabled_train, + get_obj_from_str, + instantiate_from_config, + log_txt_as_img, +) + + +class DiffusionEngine(pl.LightningModule): + def __init__( + self, + network_config, + denoiser_config, + first_stage_config, + conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, + sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, + scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, + network_wrapper: Union[None, str] = None, + ckpt_path: Union[None, str] = None, + use_ema: bool = False, + ema_decay_rate: float = 0.9999, + scale_factor: float = 1.0, + disable_first_stage_autocast=False, + input_key: str = "jpg", + log_keys: Union[List, None] = None, + no_cond_log: bool = False, + compile_model: bool = False, + en_and_decode_n_samples_a_time: Optional[int] = None, + ): + super().__init__() + self.log_keys = log_keys + self.input_key = input_key + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.AdamW"} + ) + model = instantiate_from_config(network_config) + self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( + model, compile_model=compile_model + ) + + self.denoiser = instantiate_from_config(denoiser_config) + self.sampler = ( + instantiate_from_config(sampler_config) + if sampler_config is not None + else None + ) + self.conditioner = instantiate_from_config( + default(conditioner_config, UNCONDITIONAL_CONFIG) + ) + self.scheduler_config = scheduler_config + self._init_first_stage(first_stage_config) + + self.loss_fn = ( + instantiate_from_config(loss_fn_config) + if loss_fn_config is not None + else None + ) + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model, decay=ema_decay_rate) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.scale_factor = scale_factor + self.disable_first_stage_autocast = disable_first_stage_autocast + self.no_cond_log = no_cond_log + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + + def init_from_ckpt( + self, + path: str, + ) -> None: + if path.endswith("ckpt"): + sd = torch.load(path, map_location="cpu")["state_dict"] + elif path.endswith("safetensors"): + sd = load_safetensors(path) + else: + raise NotImplementedError + + missing, unexpected = self.load_state_dict(sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + def _init_first_stage(self, config): + model = instantiate_from_config(config).eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + self.first_stage_model = model + + def get_input(self, batch): + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in bchw format + return batch[self.input_key] + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) + + n_rounds = math.ceil(z.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + if isinstance(self.first_stage_model.decoder, VideoDecoder): + kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} + else: + kwargs = {} + out = self.first_stage_model.decode( + z[n * n_samples : (n + 1) * n_samples], **kwargs + ) + all_out.append(out) + out = torch.cat(all_out, dim=0) + return out + + @torch.no_grad() + def encode_first_stage(self, x): + bs = x.shape[0] + is_video_input = False + if x.dim() == 5: + is_video_input = True + # for video diffusion + x = rearrange(x, "b t c h w -> (b t) c h w") + n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) + n_rounds = math.ceil(x.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + out = self.first_stage_model.encode( + x[n * n_samples : (n + 1) * n_samples] + ) + all_out.append(out) + z = torch.cat(all_out, dim=0) + z = self.scale_factor * z + + if is_video_input: + z = rearrange(z, "(b t) c h w -> b t c h w", b=bs) + + return z + + def forward(self, x, batch): + loss = self.loss_fn(self.model, self.denoiser, self.conditioner, x, batch) + loss_mean = loss.mean() + loss_dict = {"loss": loss_mean} + return loss_mean, loss_dict + + def shared_step(self, batch: Dict) -> Any: + x = self.get_input(batch) + breakpoint() + x = self.encode_first_stage(x) + batch["global_step"] = self.global_step + loss, loss_dict = self(x, batch) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + self.log( + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + if self.scheduler_config is not None: + lr = self.optimizers().param_groups[0]["lr"] + self.log( + "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + return loss + + def on_train_start(self, *args, **kwargs): + if self.sampler is None or self.loss_fn is None: + raise ValueError("Sampler and loss function need to be set for training.") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + for embedder in self.conditioner.embedders: + if embedder.is_trainable: + params = params + list(embedder.parameters()) + opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } + ] + return [opt], scheduler + return opt + + @torch.no_grad() + def sample( + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + **kwargs, + ): + randn = torch.randn(batch_size, *shape).to(self.device) + + denoiser = lambda input, sigma, c: self.denoiser( + self.model, input, sigma, c, **kwargs + ) + samples = self.sampler(denoiser, randn, cond, uc=uc) + return samples + + @torch.no_grad() + def log_conditionings(self, batch: Dict, n: int) -> Dict: + """ + Defines heuristics to log different conditionings. + These can be lists of strings (text-to-image), tensors, ints, ... + """ + image_h, image_w = batch[self.input_key].shape[2:] + log = dict() + + for embedder in self.conditioner.embedders: + if ( + (self.log_keys is None) or (embedder.input_key in self.log_keys) + ) and not self.no_cond_log: + x = batch[embedder.input_key][:n] + if isinstance(x, torch.Tensor): + if x.dim() == 1: + # class-conditional, convert integer to string + x = [str(x[i].item()) for i in range(x.shape[0])] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) + elif x.dim() == 2: + # size and crop cond and the like + x = [ + "x".join([str(xx) for xx in x[i].tolist()]) + for i in range(x.shape[0]) + ] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + elif isinstance(x, (List, ListConfig)): + if isinstance(x[0], str): + # strings + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + log[embedder.input_key] = xc + return log + + @torch.no_grad() + def log_images( + self, + batch: Dict, + N: int = 8, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: + conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] + if ucg_keys: + assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( + "Each defined ucg key for sampling must be in the provided conditioner input keys," + f"but we have {ucg_keys} vs. {conditioner_input_keys}" + ) + else: + ucg_keys = conditioner_input_keys + log = dict() + + x = self.get_input(batch) + + c, uc = self.conditioner.get_unconditional_conditioning( + batch, + force_uc_zero_embeddings=ucg_keys + if len(self.conditioner.embedders) > 0 + else [], + ) + + sampling_kwargs = {} + + N = min(x.shape[0], N) + x = x.to(self.device)[:N] + log["inputs"] = x + z = self.encode_first_stage(x) + log["reconstructions"] = self.decode_first_stage(z) + log.update(self.log_conditionings(batch, N)) + + for k in c: + if isinstance(c[k], torch.Tensor): + c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) + + if sample: + with self.ema_scope("Plotting"): + samples = self.sample( + c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs + ) + samples = self.decode_first_stage(samples) + log["samples"] = samples + return log diff --git a/sgm/models/video3d_diffusion.py b/sgm/models/video3d_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4f97ec0c975937f4686471b1fa5698af013197 --- /dev/null +++ b/sgm/models/video3d_diffusion.py @@ -0,0 +1,524 @@ +import re +import math +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +import torch +from omegaconf import ListConfig, OmegaConf +from safetensors.torch import load_file as load_safetensors +from torch.optim.lr_scheduler import LambdaLR +from torchvision.utils import make_grid +from einops import rearrange, repeat + +from ..modules import UNCONDITIONAL_CONFIG +from ..modules.autoencoding.temporal_ae import VideoDecoder +from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER +from ..modules.ema import LitEma +from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder +from ..util import ( + default, + disabled_train, + get_obj_from_str, + instantiate_from_config, + log_txt_as_img, + video_frames_as_grid, +) + + +def flatten_for_video(input): + return input.flatten() + + +class Video3DDiffusionEngine(pl.LightningModule): + def __init__( + self, + network_config, + denoiser_config, + first_stage_config, + conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, + sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, + scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, + network_wrapper: Union[None, str] = None, + ckpt_path: Union[None, str] = None, + use_ema: bool = False, + ema_decay_rate: float = 0.9999, + scale_factor: float = 1.0, + disable_first_stage_autocast=False, + input_key: str = "frames", # for video inputs + log_keys: Union[List, None] = None, + no_cond_log: bool = False, + compile_model: bool = False, + en_and_decode_n_samples_a_time: Optional[int] = None, + ): + super().__init__() + self.log_keys = log_keys + self.input_key = input_key + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.AdamW"} + ) + model = instantiate_from_config(network_config) + self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( + model, compile_model=compile_model + ) + + self.denoiser = instantiate_from_config(denoiser_config) + self.sampler = ( + instantiate_from_config(sampler_config) + if sampler_config is not None + else None + ) + self.conditioner = instantiate_from_config( + default(conditioner_config, UNCONDITIONAL_CONFIG) + ) + self.scheduler_config = scheduler_config + self._init_first_stage(first_stage_config) + + self.loss_fn = ( + instantiate_from_config(loss_fn_config) + if loss_fn_config is not None + else None + ) + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model, decay=ema_decay_rate) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.scale_factor = scale_factor + self.disable_first_stage_autocast = disable_first_stage_autocast + self.no_cond_log = no_cond_log + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path) + + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + + def _load_last_embedder(self, original_state_dict): + original_module_name = "conditioner.embedders.3" + state_dict = dict() + for k, v in original_state_dict.items(): + m = re.match(rf"^{original_module_name}\.(.*)$", k) + if m is None: + continue + state_dict[m.group(1)] = v + + idx = -1 + for i in range(len(self.conditioner.embedders)): + if isinstance( + self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder + ): + idx = i + + print(f"Embedder [{idx}] is the frame encoder, make sure this is expected") + + self.conditioner.embedders[idx].load_state_dict(state_dict) + + def init_from_ckpt( + self, + path: str, + ) -> None: + if path.endswith("ckpt"): + sd = torch.load(path, map_location="cpu")["state_dict"] + elif path.endswith("safetensors"): + sd = load_safetensors(path) + else: + raise NotImplementedError + + self_sd = self.state_dict() + input_keys = [ + "model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight", + ] + for input_key in input_keys: + if input_key not in sd or input_key not in self_sd: + continue + + input_weight = self_sd[input_key] + + if input_weight.shape != sd[input_key].shape: + print("Manual init: {}".format(input_key)) + input_weight.zero_() + input_weight[:, :8, :, :].copy_(sd[input_key]) + + deleted_keys = [] + for k, v in self.state_dict().items(): + # resolve shape dismatch + if k in sd: + if v.shape != sd[k].shape: + del sd[k] + deleted_keys.append(k) + + if len(deleted_keys) > 0: + print(f"Deleted Keys: {deleted_keys}") + + missing, unexpected = self.load_state_dict(sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + if len(deleted_keys) > 0: + print(f"Deleted Keys: {deleted_keys}") + + if len(missing) > 0 or len(unexpected) > 0: + # means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id) + print("Modified embedder to support 3d spiral video inputs") + try: + self._load_last_embedder(sd) + except: + print("Failed to load last embedder, make sure this is expected") + + def _init_first_stage(self, config): + model = instantiate_from_config(config).eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + self.first_stage_model = model + + def get_input(self, batch): + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in bchw format + return batch[self.input_key] + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + is_video_input = False + bs = z.shape[0] + if z.dim() == 5: + is_video_input = True + # for video diffusion + z = rearrange(z, "b t c h w -> (b t) c h w") + n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) + + n_rounds = math.ceil(z.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + if isinstance(self.first_stage_model.decoder, VideoDecoder): + kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} + else: + kwargs = {} + out = self.first_stage_model.decode( + z[n * n_samples : (n + 1) * n_samples], **kwargs + ) + all_out.append(out) + out = torch.cat(all_out, dim=0) + + if is_video_input: + out = rearrange(out, "(b t) c h w -> b t c h w", b=bs) + + return out + + @torch.no_grad() + def encode_first_stage(self, x): + if self.input_key == "latents": + return x + + bs = x.shape[0] + is_video_input = False + if x.dim() == 5: + is_video_input = True + # for video diffusion + x = rearrange(x, "b t c h w -> (b t) c h w") + n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) + n_rounds = math.ceil(x.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + out = self.first_stage_model.encode( + x[n * n_samples : (n + 1) * n_samples] + ) + all_out.append(out) + z = torch.cat(all_out, dim=0) + z = self.scale_factor * z + + # if is_video_input: + # z = rearrange(z, "(b t) c h w -> b t c h w", b=bs) + + return z + + def forward(self, x, batch): + loss, model_output = self.loss_fn( + self.model, + self.denoiser, + self.conditioner, + x, + batch, + return_model_output=True, + ) + loss_mean = loss.mean() + loss_dict = {"loss": loss_mean, "model_output": model_output} + return loss_mean, loss_dict + + def shared_step(self, batch: Dict) -> Any: + # TODO: move this shit to collate_fn in dataloader + # if "fps_id" in batch: + # batch["fps_id"] = flatten_for_video(batch["fps_id"]) + # if "motion_bucket_id" in batch: + # batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"]) + # if "cond_aug" in batch: + # batch["cond_aug"] = flatten_for_video(batch["cond_aug"]) + x = self.get_input(batch) + x = self.encode_first_stage(x) + # ## debug + # x_recon = self.decode_first_stage(x) + # video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg") + # video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg") + # ## debug + batch["global_step"] = self.global_step + loss, loss_dict = self(x, batch) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + with torch.no_grad(): + if "model_output" in loss_dict: + if batch_idx % 100 == 0: + if isinstance(self.logger, WandbLogger): + model_output = loss_dict["model_output"].detach()[ + : batch["num_video_frames"] + ] + recons = ( + (self.decode_first_stage(model_output) + 1.0) / 2.0 + ).clamp(0.0, 1.0) + recon_grid = make_grid(recons, nrow=4) + self.logger.log_image( + key=f"train/model_output_recon", + images=[recon_grid], + step=self.global_step, + ) + del loss_dict["model_output"] + + if torch.isnan(loss).any(): + print("Nan detected") + loss = None + + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + self.log( + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + if self.scheduler_config is not None: + lr = self.optimizers().param_groups[0]["lr"] + self.log( + "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + return loss + + def on_train_start(self, *args, **kwargs): + if self.sampler is None or self.loss_fn is None: + raise ValueError("Sampler and loss function need to be set for training.") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + for embedder in self.conditioner.embedders: + if embedder.is_trainable: + params = params + list(embedder.parameters()) + opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } + ] + return [opt], scheduler + return opt + + @torch.no_grad() + def sample( + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + **kwargs, + ): + randn = torch.randn(batch_size, *shape).to(self.device) + + denoiser = lambda input, sigma, c: self.denoiser( + self.model, input, sigma, c, **kwargs + ) + samples = self.sampler(denoiser, randn, cond, uc=uc) + return samples + + @torch.no_grad() + def log_conditionings(self, batch: Dict, n: int) -> Dict: + """ + Defines heuristics to log different conditionings. + These can be lists of strings (text-to-image), tensors, ints, ... + """ + image_h, image_w = batch[self.input_key].shape[-2:] + log = dict() + + for embedder in self.conditioner.embedders: + if ( + (self.log_keys is None) or (embedder.input_key in self.log_keys) + ) and not self.no_cond_log: + x = batch[embedder.input_key][:n] + if isinstance(x, torch.Tensor): + if x.dim() == 1: + # class-conditional, convert integer to string + x = [str(x[i].item()) for i in range(x.shape[0])] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) + elif x.dim() == 2: + # size and crop cond and the like + x = [ + "x".join([str(xx) for xx in x[i].tolist()]) + for i in range(x.shape[0]) + ] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + elif x.dim() == 4: + # image + xc = x + else: + raise NotImplementedError() + elif isinstance(x, (List, ListConfig)): + if isinstance(x[0], str): + # strings + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + log[embedder.input_key] = xc + + return log + + # for video diffusions will be logging frames of a video + @torch.no_grad() + def log_images( + self, + batch: Dict, + N: int = 1, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: + # # debug + # return {} + # # debug + assert "num_video_frames" in batch, "num_video_frames must be in batch" + num_video_frames = batch["num_video_frames"] + conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] + conditioner_input_keys = [] + for e in self.conditioner.embedders: + if e.input_key is not None: + conditioner_input_keys.append(e.input_key) + else: + conditioner_input_keys.extend(e.input_keys) + if ucg_keys: + assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( + "Each defined ucg key for sampling must be in the provided conditioner input keys," + f"but we have {ucg_keys} vs. {conditioner_input_keys}" + ) + else: + ucg_keys = conditioner_input_keys + log = dict() + + x = self.get_input(batch) + + c, uc = self.conditioner.get_unconditional_conditioning( + batch, + force_uc_zero_embeddings=ucg_keys + if len(self.conditioner.embedders) > 0 + else [], + ) + + sampling_kwargs = {"num_video_frames": num_video_frames} + n = min(x.shape[0] // num_video_frames, N) + sampling_kwargs["image_only_indicator"] = torch.cat( + [batch["image_only_indicator"][:n]] * 2 + ) + + N = min(x.shape[0] // num_video_frames, N) * num_video_frames + x = x.to(self.device)[:N] + # log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames) + log["inputs"] = x + z = self.encode_first_stage(x) + recon = self.decode_first_stage(z) + # log["reconstructions"] = rearrange( + # recon, "(b t) c h w -> b c h (t w)", t=num_video_frames + # ) + log["reconstructions"] = recon + log.update(self.log_conditionings(batch, N)) + log["pixelnerf_rgb"] = c["rgb"] + + for k in ["crossattn", "concat", "vector"]: + if k in c: + c[k] = c[k][:N] + uc[k] = uc[k][:N] + + # for k in c: + # if isinstance(c[k], torch.Tensor): + # if k == "vector": + # end = N + # else: + # end = n + # c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc)) + + # # for k in c: + # # print(c[k].shape) + + # breakpoint() + # for k in ["crossattn", "concat"]: + # c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames) + # c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames) + # uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames) + # uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames) + + # for k in c: + # print(c[k].shape) + if sample: + with self.ema_scope("Plotting"): + samples = self.sample( + c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs + ) + samples = self.decode_first_stage(samples) + log["samples"] = samples + return log diff --git a/sgm/models/video_diffusion.py b/sgm/models/video_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..5dbaa4a6d99e44fb2662f13e7cb5ca3ff9b0939e --- /dev/null +++ b/sgm/models/video_diffusion.py @@ -0,0 +1,503 @@ +import re +import math +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +import torch +from omegaconf import ListConfig, OmegaConf +from safetensors.torch import load_file as load_safetensors +from torch.optim.lr_scheduler import LambdaLR +from torchvision.utils import make_grid +from einops import rearrange, repeat + +from ..modules import UNCONDITIONAL_CONFIG +from ..modules.autoencoding.temporal_ae import VideoDecoder +from ..modules.diffusionmodules.wrappers import OPENAIUNETWRAPPER +from ..modules.ema import LitEma +from ..modules.encoders.modules import VideoPredictionEmbedderWithEncoder +from ..util import ( + default, + disabled_train, + get_obj_from_str, + instantiate_from_config, + log_txt_as_img, + video_frames_as_grid, +) + + +def flatten_for_video(input): + return input.flatten() + + +class DiffusionEngine(pl.LightningModule): + def __init__( + self, + network_config, + denoiser_config, + first_stage_config, + conditioner_config: Union[None, Dict, ListConfig, OmegaConf] = None, + sampler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + optimizer_config: Union[None, Dict, ListConfig, OmegaConf] = None, + scheduler_config: Union[None, Dict, ListConfig, OmegaConf] = None, + loss_fn_config: Union[None, Dict, ListConfig, OmegaConf] = None, + network_wrapper: Union[None, str] = None, + ckpt_path: Union[None, str] = None, + use_ema: bool = False, + ema_decay_rate: float = 0.9999, + scale_factor: float = 1.0, + disable_first_stage_autocast=False, + input_key: str = "frames", # for video inputs + log_keys: Union[List, None] = None, + no_cond_log: bool = False, + compile_model: bool = False, + en_and_decode_n_samples_a_time: Optional[int] = None, + load_last_embedder: bool = False, + from_scratch: bool = False, + ): + super().__init__() + self.log_keys = log_keys + self.input_key = input_key + self.optimizer_config = default( + optimizer_config, {"target": "torch.optim.AdamW"} + ) + model = instantiate_from_config(network_config) + self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))( + model, compile_model=compile_model + ) + + self.denoiser = instantiate_from_config(denoiser_config) + self.sampler = ( + instantiate_from_config(sampler_config) + if sampler_config is not None + else None + ) + self.conditioner = instantiate_from_config( + default(conditioner_config, UNCONDITIONAL_CONFIG) + ) + self.scheduler_config = scheduler_config + self._init_first_stage(first_stage_config) + + self.loss_fn = ( + instantiate_from_config(loss_fn_config) + if loss_fn_config is not None + else None + ) + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model, decay=ema_decay_rate) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.scale_factor = scale_factor + self.disable_first_stage_autocast = disable_first_stage_autocast + self.no_cond_log = no_cond_log + + self.load_last_embedder = load_last_embedder + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, from_scratch) + + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + + def _load_last_embedder(self, original_state_dict): + original_module_name = "conditioner.embedders.3" + state_dict = dict() + for k, v in original_state_dict.items(): + m = re.match(rf"^{original_module_name}\.(.*)$", k) + if m is None: + continue + state_dict[m.group(1)] = v + + idx = -1 + for i in range(len(self.conditioner.embedders)): + if isinstance( + self.conditioner.embedders[i], VideoPredictionEmbedderWithEncoder + ): + idx = i + + print(f"Embedder [{idx}] is the frame encoder, make sure this is expected") + + self.conditioner.embedders[idx].load_state_dict(state_dict) + + def init_from_ckpt( + self, + path: str, + from_scratch: bool = False, + ) -> None: + if path.endswith("ckpt"): + sd = torch.load(path, map_location="cpu")["state_dict"] + elif path.endswith("safetensors"): + sd = load_safetensors(path) + else: + raise NotImplementedError + + deleted_keys = [] + for k, v in self.state_dict().items(): + # resolve shape dismatch + if k in sd: + if v.shape != sd[k].shape: + del sd[k] + deleted_keys.append(k) + + if from_scratch: + new_sd = {} + for k in sd: + if "first_stage_model" in k: + new_sd[k] = sd[k] + sd = new_sd + print(sd.keys()) + + if len(deleted_keys) > 0: + print(f"Deleted Keys: {deleted_keys}") + + missing, unexpected = self.load_state_dict(sd, strict=False) + print( + f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" + ) + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + if len(deleted_keys) > 0: + print(f"Deleted Keys: {deleted_keys}") + + if (len(missing) > 0 or len(unexpected) > 0) and self.load_last_embedder: + # means we are loading from a checkpoint that has the old embedder (motion bucket id and fps id) + print("Modified embedder to support 3d spiral video inputs") + self._load_last_embedder(sd) + + def _init_first_stage(self, config): + model = instantiate_from_config(config).eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + self.first_stage_model = model + + def get_input(self, batch): + # assuming unified data format, dataloader returns a dict. + # image tensors should be scaled to -1 ... 1 and in bchw format + return batch[self.input_key] + + @torch.no_grad() + def decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + is_video_input = False + bs = z.shape[0] + if z.dim() == 5: + is_video_input = True + # for video diffusion + z = rearrange(z, "b t c h w -> (b t) c h w") + n_samples = default(self.en_and_decode_n_samples_a_time, z.shape[0]) + + n_rounds = math.ceil(z.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + if isinstance(self.first_stage_model.decoder, VideoDecoder): + kwargs = {"timesteps": len(z[n * n_samples : (n + 1) * n_samples])} + else: + kwargs = {} + out = self.first_stage_model.decode( + z[n * n_samples : (n + 1) * n_samples], **kwargs + ) + all_out.append(out) + out = torch.cat(all_out, dim=0) + + if is_video_input: + out = rearrange(out, "(b t) c h w -> b t c h w", b=bs) + + return out + + @torch.no_grad() + def encode_first_stage(self, x): + if self.input_key == "latents": + return x * self.scale_factor + + bs = x.shape[0] + is_video_input = False + if x.dim() == 5: + is_video_input = True + # for video diffusion + x = rearrange(x, "b t c h w -> (b t) c h w") + n_samples = default(self.en_and_decode_n_samples_a_time, x.shape[0]) + n_rounds = math.ceil(x.shape[0] / n_samples) + all_out = [] + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + for n in range(n_rounds): + out = self.first_stage_model.encode( + x[n * n_samples : (n + 1) * n_samples] + ) + all_out.append(out) + z = torch.cat(all_out, dim=0) + z = self.scale_factor * z + + # if is_video_input: + # z = rearrange(z, "(b t) c h w -> b t c h w", b=bs) + + return z + + def forward(self, x, batch): + loss, model_output = self.loss_fn( + self.model, + self.denoiser, + self.conditioner, + x, + batch, + return_model_output=True, + ) + loss_mean = loss.mean() + loss_dict = {"loss": loss_mean, "model_output": model_output} + return loss_mean, loss_dict + + def shared_step(self, batch: Dict) -> Any: + # TODO: move this shit to collate_fn in dataloader + # if "fps_id" in batch: + # batch["fps_id"] = flatten_for_video(batch["fps_id"]) + # if "motion_bucket_id" in batch: + # batch["motion_bucket_id"] = flatten_for_video(batch["motion_bucket_id"]) + # if "cond_aug" in batch: + # batch["cond_aug"] = flatten_for_video(batch["cond_aug"]) + x = self.get_input(batch) + x = self.encode_first_stage(x) + # ## debug + # x_recon = self.decode_first_stage(x) + # video_frames_as_grid((batch["frames"][0] + 1.0) / 2.0, "./tmp/origin.jpg") + # video_frames_as_grid((x_recon[0] + 1.0) / 2.0, "./tmp/recon.jpg") + # ## debug + batch["global_step"] = self.global_step + # breakpoint() + loss, loss_dict = self(x, batch) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + loss, loss_dict = self.shared_step(batch) + + with torch.no_grad(): + if "model_output" in loss_dict: + if batch_idx % 100 == 0: + if isinstance(self.logger, WandbLogger): + model_output = loss_dict["model_output"].detach()[ + : batch["num_video_frames"] + ] + recons = ( + (self.decode_first_stage(model_output) + 1.0) / 2.0 + ).clamp(0.0, 1.0) + recon_grid = make_grid(recons, nrow=4) + self.logger.log_image( + key=f"train/model_output_recon", + images=[recon_grid], + step=self.global_step, + ) + del loss_dict["model_output"] + + self.log_dict( + loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + self.log( + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, + ) + + if self.scheduler_config is not None: + lr = self.optimizers().param_groups[0]["lr"] + self.log( + "lr_abs", lr, prog_bar=True, logger=True, on_step=True, on_epoch=False + ) + + return loss + + def on_train_start(self, *args, **kwargs): + if self.sampler is None or self.loss_fn is None: + raise ValueError("Sampler and loss function need to be set for training.") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def instantiate_optimizer_from_config(self, params, lr, cfg): + return get_obj_from_str(cfg["target"])( + params, lr=lr, **cfg.get("params", dict()) + ) + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + for embedder in self.conditioner.embedders: + if embedder.is_trainable: + params = params + list(embedder.parameters()) + opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } + ] + return [opt], scheduler + return opt + + @torch.no_grad() + def sample( + self, + cond: Dict, + uc: Union[Dict, None] = None, + batch_size: int = 16, + shape: Union[None, Tuple, List] = None, + **kwargs, + ): + randn = torch.randn(batch_size, *shape).to(self.device) + + denoiser = lambda input, sigma, c: self.denoiser( + self.model, input, sigma, c, **kwargs + ) + samples = self.sampler(denoiser, randn, cond, uc=uc) + return samples + + @torch.no_grad() + def log_conditionings(self, batch: Dict, n: int) -> Dict: + """ + Defines heuristics to log different conditionings. + These can be lists of strings (text-to-image), tensors, ints, ... + """ + image_h, image_w = batch[self.input_key].shape[-2:] + log = dict() + + for embedder in self.conditioner.embedders: + if ( + (self.log_keys is None) or (embedder.input_key in self.log_keys) + ) and not self.no_cond_log: + x = batch[embedder.input_key][:n] + if isinstance(x, torch.Tensor): + if x.dim() == 1: + # class-conditional, convert integer to string + x = [str(x[i].item()) for i in range(x.shape[0])] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 4) + elif x.dim() == 2: + # size and crop cond and the like + x = [ + "x".join([str(xx) for xx in x[i].tolist()]) + for i in range(x.shape[0]) + ] + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + elif x.dim() == 4: + # image + xc = x + else: + pass + # breakpoint() + # raise NotImplementedError() + elif isinstance(x, (List, ListConfig)): + if isinstance(x[0], str): + # strings + xc = log_txt_as_img((image_h, image_w), x, size=image_h // 20) + else: + raise NotImplementedError() + else: + raise NotImplementedError() + log[embedder.input_key] = xc + return log + + # for video diffusions will be logging frames of a video + @torch.no_grad() + def log_images( + self, + batch: Dict, + N: int = 1, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: + # # debug + # return {} + # # debug + assert "num_video_frames" in batch, "num_video_frames must be in batch" + num_video_frames = batch["num_video_frames"] + conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] + if ucg_keys: + assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( + "Each defined ucg key for sampling must be in the provided conditioner input keys," + f"but we have {ucg_keys} vs. {conditioner_input_keys}" + ) + else: + ucg_keys = conditioner_input_keys + log = dict() + + x = self.get_input(batch) + + c, uc = self.conditioner.get_unconditional_conditioning( + batch, + force_uc_zero_embeddings=ucg_keys + if len(self.conditioner.embedders) > 0 + else [], + ) + + sampling_kwargs = {"num_video_frames": num_video_frames} + n = min(x.shape[0] // num_video_frames, N) + sampling_kwargs["image_only_indicator"] = torch.cat( + [batch["image_only_indicator"][:n]] * 2 + ) + + N = min(x.shape[0] // num_video_frames, N) * num_video_frames + x = x.to(self.device)[:N] + # log["inputs"] = rearrange(x, "(b t) c h w -> b c h (t w)", t=num_video_frames) + if self.input_key != "latents": + log["inputs"] = x + z = self.encode_first_stage(x) + recon = self.decode_first_stage(z) + # log["reconstructions"] = rearrange( + # recon, "(b t) c h w -> b c h (t w)", t=num_video_frames + # ) + log["reconstructions"] = recon + log.update(self.log_conditionings(batch, N)) + + for k in c: + if isinstance(c[k], torch.Tensor): + if k == "vector": + end = N + else: + end = n + c[k], uc[k] = map(lambda y: y[k][:end].to(self.device), (c, uc)) + + # for k in c: + # print(c[k].shape) + + for k in ["crossattn", "concat"]: + c[k] = repeat(c[k], "b ... -> b t ...", t=num_video_frames) + c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_video_frames) + uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_video_frames) + uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_video_frames) + + # for k in c: + # print(c[k].shape) + if sample: + with self.ema_scope("Plotting"): + samples = self.sample( + c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs + ) + samples = self.decode_first_stage(samples) + log["samples"] = samples + return log diff --git a/sgm/modules/__init__.py b/sgm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2aa9ad360acf32dab22989d81630b3eb7978abb1 --- /dev/null +++ b/sgm/modules/__init__.py @@ -0,0 +1,6 @@ +from .encoders.modules import GeneralConditioner, ExtraConditioner + +UNCONDITIONAL_CONFIG = { + "target": "sgm.modules.GeneralConditioner", + "params": {"emb_models": []}, +} diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b60cabce854b52527f6dee85ea4f0cb0951eb6 --- /dev/null +++ b/sgm/modules/attention.py @@ -0,0 +1,764 @@ +import logging +import math +from inspect import isfunction +from typing import Any, Optional +from functools import partial + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from packaging import version +from torch import nn + +# from torch.utils.checkpoint import checkpoint + +checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + + +logpy = logging.getLogger(__name__) + +if version.parse(torch.__version__) >= version.parse("2.0.0"): + SDP_IS_AVAILABLE = True + from torch.backends.cuda import SDPBackend, sdp_kernel + + BACKEND_MAP = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True}, + } +else: + from contextlib import nullcontext + + SDP_IS_AVAILABLE = False + sdp_kernel = nullcontext + BACKEND_MAP = {} + logpy.warn( + f"No SDP backend available, likely because you are running in pytorch " + f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. " + f"You might want to consider upgrading." + ) + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + logpy.warn("no module 'xformers'. Processing without...") + +# from .diffusionmodules.util import mixed_checkpoint as checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SelfAttention(nn.Module): + ATTENTION_MODES = ("xformers", "torch", "math") + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + attn_mode: str = "xformers", + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + assert attn_mode in self.ATTENTION_MODES + self.attn_mode = attn_mode + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, L, C = x.shape + + qkv = self.qkv(x) + if self.attn_mode == "torch": + qkv = rearrange( + qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads + ).float() + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "B H L D -> B L (H D)") + elif self.attn_mode == "xformers": + qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B L H D + x = xformers.ops.memory_efficient_attention(q, k, v) + x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads) + elif self.attn_mode == "math": + qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + q, k, v = qkv[0], qkv[1], qkv[2] # B H L D + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplemented + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + backend=None, + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.backend = backend + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + h = self.heads + + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + n_cp = x.shape[0] // n_times_crossframe_attn_in_self + k = repeat( + k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp + ) + v = repeat( + v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp + ) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + ## old + """ + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + """ + ## new + with sdp_kernel(**BACKEND_MAP[self.backend]): + # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape) + out = F.scaled_dot_product_attention( + q, k, v, attn_mask=mask + ) # scale is dim_head ** -0.5 per default + + del q, k, v + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__( + self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs + ): + super().__init__() + logpy.debug( + f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, " + f"context_dim is {context_dim} and using {heads} heads with a " + f"dimension of {dim_head}." + ) + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) + ) + self.attention_op: Optional[Any] = None + + def forward( + self, + x, + context=None, + mask=None, + additional_tokens=None, + n_times_crossframe_attn_in_self=0, + ): + if additional_tokens is not None: + # get the number of masked tokens at the beginning of the output sequence + n_tokens_to_mask = additional_tokens.shape[1] + # add additional token + x = torch.cat([additional_tokens, x], dim=1) + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + if n_times_crossframe_attn_in_self: + # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439 + assert x.shape[0] % n_times_crossframe_attn_in_self == 0 + # n_cp = x.shape[0]//n_times_crossframe_attn_in_self + k = repeat( + k[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + v = repeat( + v[::n_times_crossframe_attn_in_self], + "b ... -> (b n) ...", + n=n_times_crossframe_attn_in_self, + ) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + if version.parse(xformers.__version__) >= version.parse("0.0.21"): + # NOTE: workaround for + # https://github.com/facebookresearch/xformers/issues/845 + max_bs = 32768 + N = q.shape[0] + n_batches = math.ceil(N / max_bs) + out = list() + for i_batch in range(n_batches): + batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs) + out.append( + xformers.ops.memory_efficient_attention( + q[batch], + k[batch], + v[batch], + attn_bias=None, + op=self.attention_op, + ) + ) + out = torch.cat(out, 0) + else: + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + # TODO: Use this directly in the attention operation, as a bias + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + if additional_tokens is not None: + # remove additional token + out = out[:, n_tokens_to_mask:] + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention, # ampere + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False, + attn_mode="softmax", + sdp_backend=None, + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: + logpy.warn( + f"Attention mode '{attn_mode}' is not available. Falling " + f"back to native attention. This is not a problem in " + f"Pytorch >= 2.0. FYI, you are running with PyTorch " + f"version {torch.__version__}." + ) + attn_mode = "softmax" + elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: + logpy.warn( + "We do not support vanilla attention anymore, as it is too " + "expensive. Sorry." + ) + if not XFORMERS_IS_AVAILABLE: + assert ( + False + ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" + else: + logpy.info("Falling back to xformers efficient attention.") + attn_mode = "softmax-xformers" + attn_cls = self.ATTENTION_MODES[attn_mode] + if version.parse(torch.__version__) >= version.parse("2.0.0"): + assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) + else: + assert sdp_backend is None + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None, + backend=sdp_backend, + ) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + backend=sdp_backend, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + if self.checkpoint: + logpy.debug(f"{self.__class__.__name__} is using checkpointing") + + def forward( + self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 + ): + kwargs = {"x": x} + + if context is not None: + kwargs.update({"context": context}) + + if additional_tokens is not None: + kwargs.update({"additional_tokens": additional_tokens}) + + if n_times_crossframe_attn_in_self: + kwargs.update( + {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self} + ) + + # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) + if self.checkpoint: + # inputs = {"x": x, "context": context} + return checkpoint(self._forward, x, context) + # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) + else: + return self._forward(**kwargs) + + def _forward( + self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 + ): + x = ( + self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None, + additional_tokens=additional_tokens, + n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self + if not self.disable_self_attn + else 0, + ) + + x + ) + x = ( + self.attn2( + self.norm2(x), context=context, additional_tokens=additional_tokens + ) + + x + ) + x = self.ff(self.norm3(x)) + x + return x + + +class BasicTransformerSingleLayerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version + # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128]) + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + attn_mode="softmax", + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim, + ) + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + # inputs = {"x": x, "context": context} + # return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint) + return checkpoint(self._forward, x, context) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context) + x + x = self.ff(self.norm2(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + disable_self_attn=False, + use_linear=False, + attn_type="softmax", + use_checkpoint=True, + # sdp_backend=SDPBackend.FLASH_ATTENTION + sdp_backend=None, + ): + super().__init__() + logpy.debug( + f"constructing {self.__class__.__name__} of depth {depth} w/ " + f"{in_channels} channels and {n_heads} heads." + ) + + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + if exists(context_dim) and isinstance(context_dim, list): + if depth != len(context_dim): + logpy.warn( + f"{self.__class__.__name__}: Found context dims " + f"{context_dim} of depth {len(context_dim)}, which does not " + f"match the specified 'depth' of {depth}. Setting context_dim " + f"to {depth * [context_dim[0]]} now." + ) + # depth does not match context dims. + assert all( + map(lambda x: x == context_dim[0], context_dim) + ), "need homogenous context_dim to match depth automatically" + context_dim = depth * [context_dim[0]] + elif context_dim is None: + context_dim = [None] * depth + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + attn_mode=attn_type, + checkpoint=use_checkpoint, + sdp_backend=sdp_backend, + ) + for d in range(depth) + ] + ) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + else: + # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.proj_out = zero_module(nn.Linear(inner_dim, in_channels)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c").contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + if i > 0 and len(context) == 1: + i = 0 # use same context for each block + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class SimpleTransformer(nn.Module): + def __init__( + self, + dim: int, + depth: int, + heads: int, + dim_head: int, + context_dim: Optional[int] = None, + dropout: float = 0.0, + checkpoint: bool = True, + ): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + BasicTransformerBlock( + dim, + heads, + dim_head, + dropout=dropout, + context_dim=context_dim, + attn_mode="softmax-xformers", + checkpoint=checkpoint, + ) + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, context) + return x diff --git a/sgm/modules/autoencoding/__init__.py b/sgm/modules/autoencoding/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b316c7aa6ea1c5e31a58987aa3b37b2933eb7e2 --- /dev/null +++ b/sgm/modules/autoencoding/losses/__init__.py @@ -0,0 +1,7 @@ +__all__ = [ + "GeneralLPIPSWithDiscriminator", + "LatentLPIPS", +] + +from .discriminator_loss import GeneralLPIPSWithDiscriminator +from .lpips import LatentLPIPS diff --git a/sgm/modules/autoencoding/losses/discriminator_loss.py b/sgm/modules/autoencoding/losses/discriminator_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..09b6829267bf8e4d98c3f29abdc19e58dcbcbe64 --- /dev/null +++ b/sgm/modules/autoencoding/losses/discriminator_loss.py @@ -0,0 +1,306 @@ +from typing import Dict, Iterator, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torchvision +from einops import rearrange +from matplotlib import colormaps +from matplotlib import pyplot as plt + +from ....util import default, instantiate_from_config +from ..lpips.loss.lpips import LPIPS +from ..lpips.model.model import weights_init +from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss + + +class GeneralLPIPSWithDiscriminator(nn.Module): + def __init__( + self, + disc_start: int, + logvar_init: float = 0.0, + disc_num_layers: int = 3, + disc_in_channels: int = 3, + disc_factor: float = 1.0, + disc_weight: float = 1.0, + perceptual_weight: float = 1.0, + disc_loss: str = "hinge", + scale_input_to_tgt_size: bool = False, + dims: int = 2, + learn_logvar: bool = False, + regularization_weights: Union[None, Dict[str, float]] = None, + additional_log_keys: Optional[List[str]] = None, + discriminator_config: Optional[Dict] = None, + ): + super().__init__() + self.dims = dims + if self.dims > 2: + print( + f"running with dims={dims}. This means that for perceptual loss " + f"calculation, the LPIPS loss will be applied to each frame " + f"independently." + ) + self.scale_input_to_tgt_size = scale_input_to_tgt_size + assert disc_loss in ["hinge", "vanilla"] + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter( + torch.full((), logvar_init), requires_grad=learn_logvar + ) + self.learn_logvar = learn_logvar + + discriminator_config = default( + discriminator_config, + { + "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", + "params": { + "input_nc": disc_in_channels, + "n_layers": disc_num_layers, + "use_actnorm": False, + }, + }, + ) + + self.discriminator = instantiate_from_config(discriminator_config).apply( + weights_init + ) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.regularization_weights = default(regularization_weights, {}) + + self.forward_keys = [ + "optimizer_idx", + "global_step", + "last_layer", + "split", + "regularization_log", + ] + + self.additional_log_keys = set(default(additional_log_keys, [])) + self.additional_log_keys.update(set(self.regularization_weights.keys())) + + def get_trainable_parameters(self) -> Iterator[nn.Parameter]: + return self.discriminator.parameters() + + def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: + if self.learn_logvar: + yield self.logvar + yield from () + + @torch.no_grad() + def log_images( + self, inputs: torch.Tensor, reconstructions: torch.Tensor + ) -> Dict[str, torch.Tensor]: + # calc logits of real/fake + logits_real = self.discriminator(inputs.contiguous().detach()) + if len(logits_real.shape) < 4: + # Non patch-discriminator + return dict() + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + # -> (b, 1, h, w) + + # parameters for colormapping + high = max(logits_fake.abs().max(), logits_real.abs().max()).item() + cmap = colormaps["PiYG"] # diverging colormap + + def to_colormap(logits: torch.Tensor) -> torch.Tensor: + """(b, 1, ...) -> (b, 3, ...)""" + logits = (logits + high) / (2 * high) + logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel + # -> (b, 1, ..., 3) + logits = torch.from_numpy(logits_np).to(logits.device) + return rearrange(logits, "b 1 ... c -> b c ...") + + logits_real = torch.nn.functional.interpolate( + logits_real, + size=inputs.shape[-2:], + mode="nearest", + antialias=False, + ) + logits_fake = torch.nn.functional.interpolate( + logits_fake, + size=reconstructions.shape[-2:], + mode="nearest", + antialias=False, + ) + + # alpha value of logits for overlay + alpha_real = torch.abs(logits_real) / high + alpha_fake = torch.abs(logits_fake) / high + # -> (b, 1, h, w) in range [0, 0.5] + # alpha value of lines don't really matter, since the values are the same + # for both images and logits anyway + grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) + grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) + grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) + # -> (1, h, w) + # blend logits and images together + + # prepare logits for plotting + logits_real = to_colormap(logits_real) + logits_fake = to_colormap(logits_fake) + # resize logits + # -> (b, 3, h, w) + + # make some grids + # add all logits to one plot + logits_real = torchvision.utils.make_grid(logits_real, nrow=4) + logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) + # I just love how torchvision calls the number of columns `nrow` + grid_logits = torch.cat((logits_real, logits_fake), dim=1) + # -> (3, h, w) + + grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) + grid_images_fake = torchvision.utils.make_grid( + 0.5 * reconstructions + 0.5, nrow=4 + ) + grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) + # -> (3, h, w) in range [0, 1] + + grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images + + # Create labeled colorbar + dpi = 100 + height = 128 / dpi + width = grid_logits.shape[2] / dpi + fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) + img = ax.imshow(np.array([[-high, high]]), cmap=cmap) + plt.colorbar( + img, + cax=ax, + orientation="horizontal", + fraction=0.9, + aspect=width / height, + pad=0.0, + ) + img.set_visible(False) + fig.tight_layout() + fig.canvas.draw() + # manually convert figure to numpy + cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 + cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) + + # Add colorbar to plot + annotated_grid = torch.cat((grid_logits, cbar), dim=1) + blended_grid = torch.cat((grid_blend, cbar), dim=1) + return { + "vis_logits": 2 * annotated_grid[None, ...] - 1, + "vis_logits_blended": 2 * blended_grid[None, ...] - 1, + } + + def calculate_adaptive_weight( + self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor + ) -> torch.Tensor: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward( + self, + inputs: torch.Tensor, + reconstructions: torch.Tensor, + *, # added because I changed the order here + regularization_log: Dict[str, torch.Tensor], + optimizer_idx: int, + global_step: int, + last_layer: torch.Tensor, + split: str = "train", + weights: Union[None, float, torch.Tensor] = None, + ) -> Tuple[torch.Tensor, dict]: + if self.scale_input_to_tgt_size: + inputs = torch.nn.functional.interpolate( + inputs, reconstructions.shape[2:], mode="bicubic", antialias=True + ) + + if self.dims > 2: + inputs, reconstructions = map( + lambda x: rearrange(x, "b c t h w -> (b t) c h w"), + (inputs, reconstructions), + ) + + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss( + inputs.contiguous(), reconstructions.contiguous() + ) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if global_step >= self.discriminator_iter_start or not self.training: + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + if self.training: + d_weight = self.calculate_adaptive_weight( + nll_loss, g_loss, last_layer=last_layer + ) + else: + d_weight = torch.tensor(1.0) + else: + d_weight = torch.tensor(0.0) + g_loss = torch.tensor(0.0, requires_grad=True) + + loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss + log = dict() + for k in regularization_log: + if k in self.regularization_weights: + loss = loss + self.regularization_weights[k] * regularization_log[k] + if k in self.additional_log_keys: + log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() + + log.update( + { + f"{split}/loss/total": loss.clone().detach().mean(), + f"{split}/loss/nll": nll_loss.detach().mean(), + f"{split}/loss/rec": rec_loss.detach().mean(), + f"{split}/loss/g": g_loss.detach().mean(), + f"{split}/scalars/logvar": self.logvar.detach(), + f"{split}/scalars/d_weight": d_weight.detach(), + } + ) + + return loss, log + elif optimizer_idx == 1: + # second pass for discriminator update + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + + if global_step >= self.discriminator_iter_start or not self.training: + d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) + else: + d_loss = torch.tensor(0.0, requires_grad=True) + + log = { + f"{split}/loss/disc": d_loss.clone().detach().mean(), + f"{split}/logits/real": logits_real.detach().mean(), + f"{split}/logits/fake": logits_fake.detach().mean(), + } + return d_loss, log + else: + raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") + + def get_nll_loss( + self, + rec_loss: torch.Tensor, + weights: Optional[Union[float, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights * nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + + return nll_loss, weighted_nll_loss diff --git a/sgm/modules/autoencoding/losses/lpips.py b/sgm/modules/autoencoding/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..b329fcc2ee9477f0122aa7d066866cdfe71ce521 --- /dev/null +++ b/sgm/modules/autoencoding/losses/lpips.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn + +from ....util import default, instantiate_from_config +from ..lpips.loss.lpips import LPIPS + + +class LatentLPIPS(nn.Module): + def __init__( + self, + decoder_config, + perceptual_weight=1.0, + latent_weight=1.0, + scale_input_to_tgt_size=False, + scale_tgt_to_input_size=False, + perceptual_weight_on_inputs=0.0, + ): + super().__init__() + self.scale_input_to_tgt_size = scale_input_to_tgt_size + self.scale_tgt_to_input_size = scale_tgt_to_input_size + self.init_decoder(decoder_config) + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + self.latent_weight = latent_weight + self.perceptual_weight_on_inputs = perceptual_weight_on_inputs + + def init_decoder(self, config): + self.decoder = instantiate_from_config(config) + if hasattr(self.decoder, "encoder"): + del self.decoder.encoder + + def forward(self, latent_inputs, latent_predictions, image_inputs, split="train"): + log = dict() + loss = (latent_inputs - latent_predictions) ** 2 + log[f"{split}/latent_l2_loss"] = loss.mean().detach() + image_reconstructions = None + if self.perceptual_weight > 0.0: + image_reconstructions = self.decoder.decode(latent_predictions) + image_targets = self.decoder.decode(latent_inputs) + perceptual_loss = self.perceptual_loss( + image_targets.contiguous(), image_reconstructions.contiguous() + ) + loss = ( + self.latent_weight * loss.mean() + + self.perceptual_weight * perceptual_loss.mean() + ) + log[f"{split}/perceptual_loss"] = perceptual_loss.mean().detach() + + if self.perceptual_weight_on_inputs > 0.0: + image_reconstructions = default( + image_reconstructions, self.decoder.decode(latent_predictions) + ) + if self.scale_input_to_tgt_size: + image_inputs = torch.nn.functional.interpolate( + image_inputs, + image_reconstructions.shape[2:], + mode="bicubic", + antialias=True, + ) + elif self.scale_tgt_to_input_size: + image_reconstructions = torch.nn.functional.interpolate( + image_reconstructions, + image_inputs.shape[2:], + mode="bicubic", + antialias=True, + ) + + perceptual_loss2 = self.perceptual_loss( + image_inputs.contiguous(), image_reconstructions.contiguous() + ) + loss = loss + self.perceptual_weight_on_inputs * perceptual_loss2.mean() + log[f"{split}/perceptual_loss_on_inputs"] = perceptual_loss2.mean().detach() + return loss, log diff --git a/sgm/modules/autoencoding/lpips/__init__.py b/sgm/modules/autoencoding/lpips/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/autoencoding/lpips/loss/.gitignore b/sgm/modules/autoencoding/lpips/loss/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a92958a1cd4ffe005e1f5448ab3e6fd9c795a43a --- /dev/null +++ b/sgm/modules/autoencoding/lpips/loss/.gitignore @@ -0,0 +1 @@ +vgg.pth \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/loss/LICENSE b/sgm/modules/autoencoding/lpips/loss/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..924cfc85b8d63ef538f5676f830a2a8497932108 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/loss/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/loss/__init__.py b/sgm/modules/autoencoding/lpips/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/autoencoding/lpips/loss/lpips.py b/sgm/modules/autoencoding/lpips/loss/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..3e34f3d083674f675a5ca024e9bd27fb77e2b6b5 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/loss/lpips.py @@ -0,0 +1,147 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +from collections import namedtuple + +import torch +import torch.nn as nn +from torchvision import models + +from ..util import get_ckpt_path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") + self.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( + outs1[kk] + ) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [ + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns)) + ] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer( + "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + ) + self.register_buffer( + "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + ) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/sgm/modules/autoencoding/lpips/model/LICENSE b/sgm/modules/autoencoding/lpips/model/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..4b356e66b5aa689b339f1a80a9f1b5ba378003bb --- /dev/null +++ b/sgm/modules/autoencoding/lpips/model/LICENSE @@ -0,0 +1,58 @@ +Copyright (c) 2017, Jun-Yan Zhu and Taesung Park +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +--------------------------- LICENSE FOR pix2pix -------------------------------- +BSD License + +For pix2pix software +Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +----------------------------- LICENSE FOR DCGAN -------------------------------- +BSD License + +For dcgan.torch software + +Copyright (c) 2015, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/model/__init__.py b/sgm/modules/autoencoding/lpips/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/autoencoding/lpips/model/model.py b/sgm/modules/autoencoding/lpips/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..66357d4e627f9a69a5abbbad15546c96fcd758fe --- /dev/null +++ b/sgm/modules/autoencoding/lpips/model/model.py @@ -0,0 +1,88 @@ +import functools + +import torch.nn as nn + +from ..util import ActNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) diff --git a/sgm/modules/autoencoding/lpips/util.py b/sgm/modules/autoencoding/lpips/util.py new file mode 100644 index 0000000000000000000000000000000000000000..49c76e370bf16888ab61f42844b3c9f14ad9014c --- /dev/null +++ b/sgm/modules/autoencoding/lpips/util.py @@ -0,0 +1,128 @@ +import hashlib +import os + +import requests +import torch +import torch.nn as nn +from tqdm import tqdm + +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class ActNorm(nn.Module): + def __init__( + self, num_features, logdet=False, affine=True, allow_reverse_init=False + ): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/sgm/modules/autoencoding/lpips/vqperceptual.py b/sgm/modules/autoencoding/lpips/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..6195f0a6ed7ee6fd32c1bccea071e6075e95ee43 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/vqperceptual.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss diff --git a/sgm/modules/autoencoding/regularizers/__init__.py b/sgm/modules/autoencoding/regularizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2b1815a5ba88892375e8ec9bedacea49024113 --- /dev/null +++ b/sgm/modules/autoencoding/regularizers/__init__.py @@ -0,0 +1,31 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ....modules.distributions.distributions import \ + DiagonalGaussianDistribution +from .base import AbstractRegularizer + + +class DiagonalGaussianRegularizer(AbstractRegularizer): + def __init__(self, sample: bool = True): + super().__init__() + self.sample = sample + + def get_trainable_parameters(self) -> Any: + yield from () + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + log = dict() + posterior = DiagonalGaussianDistribution(z) + if self.sample: + z = posterior.sample() + else: + z = posterior.mode() + kl_loss = posterior.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + log["kl_loss"] = kl_loss + return z, log diff --git a/sgm/modules/autoencoding/regularizers/base.py b/sgm/modules/autoencoding/regularizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..fca681bb3c1f4818b57e956e31b98f76077ccb67 --- /dev/null +++ b/sgm/modules/autoencoding/regularizers/base.py @@ -0,0 +1,40 @@ +from abc import abstractmethod +from typing import Any, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + + +class AbstractRegularizer(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + raise NotImplementedError() + + @abstractmethod + def get_trainable_parameters(self) -> Any: + raise NotImplementedError() + + +class IdentityRegularizer(AbstractRegularizer): + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + return z, dict() + + def get_trainable_parameters(self) -> Any: + yield from () + + +def measure_perplexity( + predicted_indices: torch.Tensor, num_centroids: int +) -> Tuple[torch.Tensor, torch.Tensor]: + # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py + # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally + encodings = ( + F.one_hot(predicted_indices, num_centroids).float().reshape(-1, num_centroids) + ) + avg_probs = encodings.mean(0) + perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() + cluster_use = torch.sum(avg_probs > 0) + return perplexity, cluster_use diff --git a/sgm/modules/autoencoding/regularizers/quantize.py b/sgm/modules/autoencoding/regularizers/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..86a4dbdd10101b24f03bba134c4f8d2ab007f0db --- /dev/null +++ b/sgm/modules/autoencoding/regularizers/quantize.py @@ -0,0 +1,487 @@ +import logging +from abc import abstractmethod +from typing import Dict, Iterator, Literal, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import einsum + +from .base import AbstractRegularizer, measure_perplexity + +logpy = logging.getLogger(__name__) + + +class AbstractQuantizer(AbstractRegularizer): + def __init__(self): + super().__init__() + # Define these in your init + # shape (N,) + self.used: Optional[torch.Tensor] + self.re_embed: int + self.unknown_index: Union[Literal["random"], int] + + def remap_to_used(self, inds: torch.Tensor) -> torch.Tensor: + assert self.used is not None, "You need to define used indices for remap" + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to( + device=new.device + ) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.Tensor) -> torch.Tensor: + assert self.used is not None, "You need to define used indices for remap" + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + @abstractmethod + def get_codebook_entry( + self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None + ) -> torch.Tensor: + raise NotImplementedError() + + def get_trainable_parameters(self) -> Iterator[torch.nn.Parameter]: + yield from self.parameters() + + +class GumbelQuantizer(AbstractQuantizer): + """ + credit to @karpathy: + https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + + def __init__( + self, + num_hiddens: int, + embedding_dim: int, + n_embed: int, + straight_through: bool = True, + kl_weight: float = 5e-4, + temp_init: float = 1.0, + remap: Optional[str] = None, + unknown_index: str = "random", + loss_key: str = "loss/vq", + ) -> None: + super().__init__() + + self.loss_key = loss_key + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_embed + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + def forward( + self, z: torch.Tensor, temp: Optional[float] = None, return_logits: bool = False + ) -> Tuple[torch.Tensor, Dict]: + # force hard = True when we are in eval mode, as we must quantize. + # actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + out_dict = {} + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:, self.used, ...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:, self.used, ...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = ( + self.kl_weight + * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + ) + out_dict[self.loss_key] = diff + + ind = soft_one_hot.argmax(dim=1) + out_dict["indices"] = ind + if self.remap is not None: + ind = self.remap_to_used(ind) + + if return_logits: + out_dict["logits"] = logits + + return z_q, out_dict + + def get_codebook_entry(self, indices, shape): + # TODO: shape not yet optional + b, h, w, c = shape + assert b * h * w == indices.shape[0] + indices = rearrange(indices, "(b h w) -> b h w", b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = ( + F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + ) + z_q = einsum("b n h w, n d -> b d h w", one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer(AbstractQuantizer): + """ + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, + beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + def __init__( + self, + n_e: int, + e_dim: int, + beta: float = 0.25, + remap: Optional[str] = None, + unknown_index: str = "random", + sane_index_shape: bool = False, + log_perplexity: bool = False, + embedding_weight_norm: bool = False, + loss_key: str = "loss/vq", + ): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.loss_key = loss_key + + if not embedding_weight_norm: + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + else: + self.embedding = torch.nn.utils.weight_norm( + nn.Embedding(self.n_e, self.e_dim), dim=1 + ) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_e + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + self.sane_index_shape = sane_index_shape + self.log_perplexity = log_perplexity + + def forward( + self, + z: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict]: + do_reshape = z.ndim == 4 + if do_reshape: + # # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, "b c h w -> b h w c").contiguous() + + else: + assert z.ndim < 4, "No reshaping strategy for inputs > 4 dimensions defined" + z = z.contiguous() + + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = ( + torch.sum(z_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) + - 2 + * torch.einsum( + "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") + ) + ) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + loss_dict = {} + if self.log_perplexity: + perplexity, cluster_usage = measure_perplexity( + min_encoding_indices.detach(), self.n_e + ) + loss_dict.update({"perplexity": perplexity, "cluster_usage": cluster_usage}) + + # compute loss for embedding + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean( + (z_q - z.detach()) ** 2 + ) + loss_dict[self.loss_key] = loss + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + if do_reshape: + z_q = rearrange(z_q, "b h w c -> b c h w").contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape( + z.shape[0], -1 + ) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + if do_reshape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3] + ) + else: + min_encoding_indices = rearrange( + min_encoding_indices, "(b s) 1 -> b s", b=z_q.shape[0] + ) + + loss_dict["min_encoding_indices"] = min_encoding_indices + + return z_q, loss_dict + + def get_codebook_entry( + self, indices: torch.Tensor, shape: Optional[Tuple[int, ...]] = None + ) -> torch.Tensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + assert shape is not None, "Need to give shape for remap" + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad=False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_( + new_cluster_size, alpha=1 - self.decay + ) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + # normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(AbstractQuantizer): + def __init__( + self, + n_embed: int, + embedding_dim: int, + beta: float, + decay: float = 0.99, + eps: float = 1e-5, + remap: Optional[str] = None, + unknown_index: str = "random", + loss_key: str = "loss/vq", + ): + super().__init__() + self.codebook_dim = embedding_dim + self.num_tokens = n_embed + self.beta = beta + self.loss_key = loss_key + + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + else: + self.used = None + self.re_embed = n_embed + if unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + else: + assert unknown_index == "random" or isinstance( + unknown_index, int + ), "unknown index needs to be 'random', 'extra' or any integer" + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.remap is not None: + logpy.info( + f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + # reshape z -> (batch, height, width, channel) and flatten + # z, 'b c h w -> b h w c' + z = rearrange(z, "b c h w -> b h w c") + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = ( + z_flattened.pow(2).sum(dim=1, keepdim=True) + + self.embedding.weight.pow(2).sum(dim=1) + - 2 * torch.einsum("bd,nd->bn", z_flattened, self.embedding.weight) + ) # 'n d -> d n' + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + # EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + # EMA embedding average + embed_sum = encodings.transpose(0, 1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + # normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + # z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, "b h w c -> b c h w") + + out_dict = { + self.loss_key: loss, + "encodings": encodings, + "encoding_indices": encoding_indices, + "perplexity": perplexity, + } + + return z_q, out_dict + + +class VectorQuantizerWithInputProjection(VectorQuantizer): + def __init__( + self, + input_dim: int, + n_codes: int, + codebook_dim: int, + beta: float = 1.0, + output_dim: Optional[int] = None, + **kwargs, + ): + super().__init__(n_codes, codebook_dim, beta, **kwargs) + self.proj_in = nn.Linear(input_dim, codebook_dim) + self.output_dim = output_dim + if output_dim is not None: + self.proj_out = nn.Linear(codebook_dim, output_dim) + else: + self.proj_out = nn.Identity() + + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Dict]: + rearr = False + in_shape = z.shape + + if z.ndim > 3: + rearr = self.output_dim is not None + z = rearrange(z, "b c ... -> b (...) c") + z = self.proj_in(z) + z_q, loss_dict = super().forward(z) + + z_q = self.proj_out(z_q) + if rearr: + if len(in_shape) == 4: + z_q = rearrange(z_q, "b (h w) c -> b c h w ", w=in_shape[-1]) + elif len(in_shape) == 5: + z_q = rearrange( + z_q, "b (t h w) c -> b c t h w ", w=in_shape[-1], h=in_shape[-2] + ) + else: + raise NotImplementedError( + f"rearranging not available for {len(in_shape)}-dimensional input." + ) + + return z_q, loss_dict diff --git a/sgm/modules/autoencoding/temporal_ae.py b/sgm/modules/autoencoding/temporal_ae.py new file mode 100644 index 0000000000000000000000000000000000000000..374373e2e4330846ffef28d9061dcc64f70d2722 --- /dev/null +++ b/sgm/modules/autoencoding/temporal_ae.py @@ -0,0 +1,349 @@ +from typing import Callable, Iterable, Union + +import torch +from einops import rearrange, repeat + +from sgm.modules.diffusionmodules.model import ( + XFORMERS_IS_AVAILABLE, + AttnBlock, + Decoder, + MemoryEfficientAttnBlock, + ResnetBlock, +) +from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding +from sgm.modules.video_attention import VideoTransformerBlock +from sgm.util import partialclass + + +class VideoResBlock(ResnetBlock): + def __init__( + self, + out_channels, + *args, + dropout=0.0, + video_kernel_size=3, + alpha=0.0, + merge_strategy="learned", + **kwargs, + ): + super().__init__(out_channels=out_channels, dropout=dropout, *args, **kwargs) + if video_kernel_size is None: + video_kernel_size = [3, 1, 1] + self.time_stack = ResBlock( + channels=out_channels, + emb_channels=0, + dropout=dropout, + dims=3, + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=False, + skip_t_emb=True, + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, bs): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError() + + def forward(self, x, temb, skip_video=False, timesteps=None): + if timesteps is None: + timesteps = self.timesteps + + b, c, h, w = x.shape + + x = super().forward(x, temb) + + if not skip_video: + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + + x = self.time_stack(x, temb) + + alpha = self.get_alpha(bs=b // timesteps) + x = alpha * x + (1.0 - alpha) * x_mix + + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class AE3DConv(torch.nn.Conv2d): + def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs): + super().__init__(in_channels, out_channels, *args, **kwargs) + if isinstance(video_kernel_size, Iterable): + padding = [int(k // 2) for k in video_kernel_size] + else: + padding = int(video_kernel_size // 2) + + self.time_mix_conv = torch.nn.Conv3d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=video_kernel_size, + padding=padding, + ) + + def forward(self, input, timesteps, skip_video=False): + x = super().forward(input) + if skip_video: + return x + x = rearrange(x, "(b t) c h w -> b c t h w", t=timesteps) + x = self.time_mix_conv(x) + return rearrange(x, "b c t h w -> (b t) c h w") + + +class VideoBlock(AttnBlock): + def __init__( + self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" + ): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = VideoTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + attn_mode="softmax", + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + torch.nn.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + torch.nn.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps, skip_video=False): + if skip_video: + return super().forward(x) + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + +class MemoryEfficientVideoBlock(MemoryEfficientAttnBlock): + def __init__( + self, in_channels: int, alpha: float = 0, merge_strategy: str = "learned" + ): + super().__init__(in_channels) + # no context, single headed, as in base class + self.time_mix_block = VideoTransformerBlock( + dim=in_channels, + n_heads=1, + d_head=in_channels, + checkpoint=False, + ff_in=True, + attn_mode="softmax-xformers", + ) + + time_embed_dim = self.in_channels * 4 + self.video_time_embed = torch.nn.Sequential( + torch.nn.Linear(self.in_channels, time_embed_dim), + torch.nn.SiLU(), + torch.nn.Linear(time_embed_dim, self.in_channels), + ) + + self.merge_strategy = merge_strategy + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif self.merge_strategy == "learned": + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def forward(self, x, timesteps, skip_time_block=False): + if skip_time_block: + return super().forward(x) + + x_in = x + x = self.attention(x) + h, w = x.shape[2:] + x = rearrange(x, "b c h w -> b (h w) c") + + x_mix = x + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding(num_frames, self.in_channels, repeat_only=False) + emb = self.video_time_embed(t_emb) # b, n_channels + emb = emb[:, None, :] + x_mix = x_mix + emb + + alpha = self.get_alpha() + x_mix = self.time_mix_block(x_mix, timesteps=timesteps) + x = alpha * x + (1.0 - alpha) * x_mix # alpha merge + + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + + return x_in + x + + def get_alpha( + self, + ): + if self.merge_strategy == "fixed": + return self.mix_factor + elif self.merge_strategy == "learned": + return torch.sigmoid(self.mix_factor) + else: + raise NotImplementedError(f"unknown merge strategy {self.merge_strategy}") + + +def make_time_attn( + in_channels, + attn_type="vanilla", + attn_kwargs=None, + alpha: float = 0, + merge_strategy: str = "learned", +): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + ], f"attn_type {attn_type} not supported for spatio-temporal attention" + print( + f"making spatial and temporal attention of type '{attn_type}' with {in_channels} in_channels" + ) + if not XFORMERS_IS_AVAILABLE and attn_type == "vanilla-xformers": + print( + f"Attention mode '{attn_type}' is not available. Falling back to vanilla attention. " + f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" + ) + attn_type = "vanilla" + + if attn_type == "vanilla": + assert attn_kwargs is None + return partialclass( + VideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy + ) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return partialclass( + MemoryEfficientVideoBlock, + in_channels, + alpha=alpha, + merge_strategy=merge_strategy, + ) + else: + return NotImplementedError() + + +class Conv2DWrapper(torch.nn.Conv2d): + def forward(self, input: torch.Tensor, **kwargs) -> torch.Tensor: + return super().forward(input) + + +class VideoDecoder(Decoder): + available_time_modes = ["all", "conv-only", "attn-only"] + + def __init__( + self, + *args, + video_kernel_size: Union[int, list] = 3, + alpha: float = 0.0, + merge_strategy: str = "learned", + time_mode: str = "conv-only", + **kwargs, + ): + self.video_kernel_size = video_kernel_size + self.alpha = alpha + self.merge_strategy = merge_strategy + self.time_mode = time_mode + assert ( + self.time_mode in self.available_time_modes + ), f"time_mode parameter has to be in {self.available_time_modes}" + super().__init__(*args, **kwargs) + + def get_last_layer(self, skip_time_mix=False, **kwargs): + if self.time_mode == "attn-only": + raise NotImplementedError("TODO") + else: + return ( + self.conv_out.time_mix_conv.weight + if not skip_time_mix + else self.conv_out.weight + ) + + def _make_attn(self) -> Callable: + if self.time_mode not in ["conv-only", "only-last-conv"]: + return partialclass( + make_time_attn, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_attn() + + def _make_conv(self) -> Callable: + if self.time_mode != "attn-only": + return partialclass(AE3DConv, video_kernel_size=self.video_kernel_size) + else: + return Conv2DWrapper + + def _make_resblock(self) -> Callable: + if self.time_mode not in ["attn-only", "only-last-conv"]: + return partialclass( + VideoResBlock, + video_kernel_size=self.video_kernel_size, + alpha=self.alpha, + merge_strategy=self.merge_strategy, + ) + else: + return super()._make_resblock() diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/diffusionmodules/denoiser.py b/sgm/modules/diffusionmodules/denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..d86e7a262d1f036139e41f500d8579a2b95071ef --- /dev/null +++ b/sgm/modules/diffusionmodules/denoiser.py @@ -0,0 +1,75 @@ +from typing import Dict, Union + +import torch +import torch.nn as nn + +from ...util import append_dims, instantiate_from_config +from .denoiser_scaling import DenoiserScaling +from .discretizer import Discretization + + +class Denoiser(nn.Module): + def __init__(self, scaling_config: Dict): + super().__init__() + + self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) + + def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: + return sigma + + def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: + return c_noise + + def forward( + self, + network: nn.Module, + input: torch.Tensor, + sigma: torch.Tensor, + cond: Dict, + **additional_model_inputs, + ) -> torch.Tensor: + sigma = self.possibly_quantize_sigma(sigma) + sigma_shape = sigma.shape + sigma = append_dims(sigma, input.ndim) + c_skip, c_out, c_in, c_noise = self.scaling(sigma) + c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) + return ( + network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + + input * c_skip + ) + + +class DiscreteDenoiser(Denoiser): + def __init__( + self, + scaling_config: Dict, + num_idx: int, + discretization_config: Dict, + do_append_zero: bool = False, + quantize_c_noise: bool = True, + flip: bool = True, + ): + super().__init__(scaling_config) + self.discretization: Discretization = instantiate_from_config( + discretization_config + ) + sigmas = self.discretization(num_idx, do_append_zero=do_append_zero, flip=flip) + self.register_buffer("sigmas", sigmas) + self.quantize_c_noise = quantize_c_noise + self.num_idx = num_idx + + def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: + dists = sigma - self.sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape) + + def idx_to_sigma(self, idx: Union[torch.Tensor, int]) -> torch.Tensor: + return self.sigmas[idx] + + def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: + return self.idx_to_sigma(self.sigma_to_idx(sigma)) + + def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: + if self.quantize_c_noise: + return self.sigma_to_idx(c_noise) + else: + return c_noise diff --git a/sgm/modules/diffusionmodules/denoiser_scaling.py b/sgm/modules/diffusionmodules/denoiser_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..f4e287bfe8a82839a9a12fbd25c3446f43ab493b --- /dev/null +++ b/sgm/modules/diffusionmodules/denoiser_scaling.py @@ -0,0 +1,59 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch + + +class DenoiserScaling(ABC): + @abstractmethod + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + pass + + +class EDMScaling: + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 + c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise + + +class EpsScaling: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = torch.ones_like(sigma, device=sigma.device) + c_out = -sigma + c_in = 1 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScaling: + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = sigma.clone() + return c_skip, c_out, c_in, c_noise + + +class VScalingWithEDMcNoise(DenoiserScaling): + def __call__( + self, sigma: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + c_skip = 1.0 / (sigma**2 + 1.0) + c_out = -sigma / (sigma**2 + 1.0) ** 0.5 + c_in = 1.0 / (sigma**2 + 1.0) ** 0.5 + c_noise = 0.25 * sigma.log() + return c_skip, c_out, c_in, c_noise diff --git a/sgm/modules/diffusionmodules/denoiser_weighting.py b/sgm/modules/diffusionmodules/denoiser_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b03ca58f17ea3d7374f4bbb7bf1d2994755e00 --- /dev/null +++ b/sgm/modules/diffusionmodules/denoiser_weighting.py @@ -0,0 +1,24 @@ +import torch + + +class UnitWeighting: + def __call__(self, sigma): + return torch.ones_like(sigma, device=sigma.device) + + +class EDMWeighting: + def __init__(self, sigma_data=0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma): + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + +class VWeighting(EDMWeighting): + def __init__(self): + super().__init__(sigma_data=1.0) + + +class EpsWeighting: + def __call__(self, sigma): + return sigma**-2.0 diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py new file mode 100644 index 0000000000000000000000000000000000000000..02add6081c5e3164d4402619b44d5be235d3ec58 --- /dev/null +++ b/sgm/modules/diffusionmodules/discretizer.py @@ -0,0 +1,69 @@ +from abc import abstractmethod +from functools import partial + +import numpy as np +import torch + +from ...modules.diffusionmodules.util import make_beta_schedule +from ...util import append_zero + + +def generate_roughly_equally_spaced_steps( + num_substeps: int, max_step: int +) -> np.ndarray: + return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] + + +class Discretization: + def __call__(self, n, do_append_zero=True, device="cpu", flip=False): + sigmas = self.get_sigmas(n, device=device) + sigmas = append_zero(sigmas) if do_append_zero else sigmas + return sigmas if not flip else torch.flip(sigmas, (0,)) + + @abstractmethod + def get_sigmas(self, n, device): + pass + + +class EDMDiscretization(Discretization): + def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + + def get_sigmas(self, n, device="cpu"): + ramp = torch.linspace(0, 1, n, device=device) + min_inv_rho = self.sigma_min ** (1 / self.rho) + max_inv_rho = self.sigma_max ** (1 / self.rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho + return sigmas + + +class LegacyDDPMDiscretization(Discretization): + def __init__( + self, + linear_start=0.00085, + linear_end=0.0120, + num_timesteps=1000, + ): + super().__init__() + self.num_timesteps = num_timesteps + betas = make_beta_schedule( + "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end + ) + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.to_torch = partial(torch.tensor, dtype=torch.float32) + + def get_sigmas(self, n, device="cpu"): + if n < self.num_timesteps: + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) + alphas_cumprod = self.alphas_cumprod[timesteps] + elif n == self.num_timesteps: + alphas_cumprod = self.alphas_cumprod + else: + raise ValueError + + to_torch = partial(torch.tensor, dtype=torch.float32, device=device) + sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + return torch.flip(sigmas, (0,)) diff --git a/sgm/modules/diffusionmodules/guiders.py b/sgm/modules/diffusionmodules/guiders.py new file mode 100644 index 0000000000000000000000000000000000000000..63b5775b6ca857b4706f65f8cf3187cc8e4506d8 --- /dev/null +++ b/sgm/modules/diffusionmodules/guiders.py @@ -0,0 +1,146 @@ +import logging +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple, Union + +import torch +from einops import rearrange, repeat + +from ...util import append_dims, default + +logpy = logging.getLogger(__name__) + + +class Guider(ABC): + @abstractmethod + def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: + pass + + def prepare_inputs( + self, x: torch.Tensor, s: float, c: Dict, uc: Dict + ) -> Tuple[torch.Tensor, float, Dict]: + pass + + +class VanillaCFG(Guider): + def __init__(self, scale: float): + self.scale = scale + + def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + x_u, x_c = x.chunk(2) + x_pred = x_u + self.scale * (x_c - x_u) + return x_pred + + def prepare_inputs(self, x, s, c, uc): + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "concat"]: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class IdentityGuider(Guider): + def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor: + return x + + def prepare_inputs( + self, x: torch.Tensor, s: float, c: Dict, uc: Dict + ) -> Tuple[torch.Tensor, float, Dict]: + c_out = dict() + + for k in c: + c_out[k] = c[k] + + return x, s, c_out + + +class LinearPredictionGuider(Guider): + def __init__( + self, + max_scale: float, + num_frames: int, + min_scale: float = 1.0, + additional_cond_keys: Optional[Union[List[str], str]] = None, + ): + self.min_scale = min_scale + self.max_scale = max_scale + self.num_frames = num_frames + self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) + + additional_cond_keys = default(additional_cond_keys, []) + if isinstance(additional_cond_keys, str): + additional_cond_keys = [additional_cond_keys] + self.additional_cond_keys = additional_cond_keys + + def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + x_u, x_c = x.chunk(2) + + x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) + x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) + scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) + scale = append_dims(scale, x_u.ndim).to(x_u.device) + + return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") + + def prepare_inputs( + self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + if k == "rgb": + continue + assert c[k] == uc[k] + c_out[k] = c[k] + return torch.cat([x] * 2), torch.cat([s] * 2), c_out + + +class CentralPredictionGuider(Guider): + def __init__( + self, + max_scale: float, + num_frames: int, + min_scale: float = 1.0, + additional_cond_keys: Optional[Union[List[str], str]] = None, + ): + self.min_scale = min_scale + self.max_scale = max_scale + self.num_frames = num_frames + # self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0) + self.scale = torch.linspace(min_scale, 2 * max_scale, num_frames) + self.scale[num_frames // 2 :] = 2 * max_scale - self.scale[num_frames // 2 :] + self.scale = self.scale.unsqueeze(0) + + additional_cond_keys = default(additional_cond_keys, []) + if isinstance(additional_cond_keys, str): + additional_cond_keys = [additional_cond_keys] + self.additional_cond_keys = additional_cond_keys + + def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + x_u, x_c = x.chunk(2) + + x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames) + x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames) + scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0]) + scale = append_dims(scale, x_u.ndim).to(x_u.device) + + return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...") + + def prepare_inputs( + self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict + ) -> Tuple[torch.Tensor, torch.Tensor, dict]: + c_out = dict() + + for k in c: + if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys: + c_out[k] = torch.cat((uc[k], c[k]), 0) + else: + assert c[k] == uc[k] + c_out[k] = c[k] + return torch.cat([x] * 2), torch.cat([s] * 2), c_out diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9b2c437fab37bed10ea79c197560ade7bf511cad --- /dev/null +++ b/sgm/modules/diffusionmodules/loss.py @@ -0,0 +1,187 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from ...modules.autoencoding.lpips.loss.lpips import LPIPS +from ...modules.encoders.modules import GeneralConditioner +from ...util import append_dims, instantiate_from_config +from .denoiser import Denoiser + + +class StandardDiffusionLoss(nn.Module): + def __init__( + self, + sigma_sampler_config: dict, + loss_weighting_config: dict, + loss_type: str = "l2", + offset_noise_level: float = 0.0, + batch2model_keys: Optional[Union[str, List[str]]] = None, + ): + super().__init__() + + assert loss_type in ["l2", "l1", "lpips"] + + self.sigma_sampler = instantiate_from_config(sigma_sampler_config) + self.loss_weighting = instantiate_from_config(loss_weighting_config) + + self.loss_type = loss_type + self.offset_noise_level = offset_noise_level + + if loss_type == "lpips": + self.lpips = LPIPS().eval() + + if not batch2model_keys: + batch2model_keys = [] + + if isinstance(batch2model_keys, str): + batch2model_keys = [batch2model_keys] + + self.batch2model_keys = set(batch2model_keys) + + def get_noised_input( + self, sigmas_bc: torch.Tensor, noise: torch.Tensor, input: torch.Tensor + ) -> torch.Tensor: + noised_input = input + noise * sigmas_bc + return noised_input + + def forward( + self, + network: nn.Module, + denoiser: Denoiser, + conditioner: GeneralConditioner, + input: torch.Tensor, + batch: Dict, + return_model_output: bool = False, + ) -> torch.Tensor: + cond = conditioner(batch) + # for video diffusion + if "num_video_frames" in batch: + num_frames = batch["num_video_frames"] + for k in ["crossattn", "concat"]: + cond[k] = repeat(cond[k], "b ... -> b t ...", t=num_frames) + cond[k] = rearrange(cond[k], "b t ... -> (b t) ...", t=num_frames) + return self._forward(network, denoiser, cond, input, batch, return_model_output) + + def _forward( + self, + network: nn.Module, + denoiser: Denoiser, + cond: Dict, + input: torch.Tensor, + batch: Dict, + return_model_output: bool = False, + ) -> Tuple[torch.Tensor, Dict]: + additional_model_inputs = { + key: batch[key] for key in self.batch2model_keys.intersection(batch) + } + sigmas = self.sigma_sampler(input.shape[0]).to(input) + + noise = torch.randn_like(input) + if self.offset_noise_level > 0.0: + offset_shape = ( + (input.shape[0], 1, input.shape[2]) + if self.n_frames is not None + else (input.shape[0], input.shape[1]) + ) + noise = noise + self.offset_noise_level * append_dims( + torch.randn(offset_shape, device=input.device), + input.ndim, + ) + sigmas_bc = append_dims(sigmas, input.ndim) + noised_input = self.get_noised_input(sigmas_bc, noise, input) + + model_output = denoiser( + network, noised_input, sigmas, cond, **additional_model_inputs + ) + w = append_dims(self.loss_weighting(sigmas), input.ndim) + if not return_model_output: + return self.get_loss(model_output, input, w) + else: + return self.get_loss(model_output, input, w), model_output + + def get_loss(self, model_output, target, w): + if self.loss_type == "l2": + return torch.mean( + (w * (model_output - target) ** 2).reshape(target.shape[0], -1), 1 + ) + elif self.loss_type == "l1": + return torch.mean( + (w * (model_output - target).abs()).reshape(target.shape[0], -1), 1 + ) + elif self.loss_type == "lpips": + loss = self.lpips(model_output, target).reshape(-1) + return loss + else: + raise NotImplementedError(f"Unknown loss type {self.loss_type}") + + +class StandardDiffusionLossWithPixelNeRFLoss(StandardDiffusionLoss): + def __init__( + self, + sigma_sampler_config: Dict, + loss_weighting_config: Dict, + loss_type: str = "l2", + offset_noise_level: float = 0, + batch2model_keys: str | List[str] | None = None, + pixelnerf_loss_weight: float = 1.0, + pixelnerf_loss_type: str = "l2", + ): + super().__init__( + sigma_sampler_config, + loss_weighting_config, + loss_type, + offset_noise_level, + batch2model_keys, + ) + self.pixelnerf_loss_weight = pixelnerf_loss_weight + self.pixelnerf_loss_type = pixelnerf_loss_type + + def get_pixelnerf_loss(self, model_output, target): + if self.pixelnerf_loss_type == "l2": + return torch.mean( + ((model_output - target) ** 2).reshape(target.shape[0], -1), 1 + ) + elif self.pixelnerf_loss_type == "l1": + return torch.mean( + ((model_output - target).abs()).reshape(target.shape[0], -1), 1 + ) + elif self.pixelnerf_loss_type == "lpips": + loss = self.lpips(model_output, target).reshape(-1) + return loss + else: + raise NotImplementedError(f"Unknown loss type {self.loss_type}") + + def forward( + self, + network: nn.Module, + denoiser: Denoiser, + conditioner: GeneralConditioner, + input: torch.Tensor, + batch: Dict, + return_model_output: bool = False, + ) -> torch.Tensor: + cond = conditioner(batch) + return self._forward(network, denoiser, cond, input, batch, return_model_output) + + def _forward( + self, + network: nn.Module, + denoiser: Denoiser, + cond: Dict, + input: torch.Tensor, + batch: Dict, + return_model_output: bool = False, + ) -> Tuple[torch.Tensor | Dict]: + loss = super()._forward( + network, denoiser, cond, input, batch, return_model_output + ) + pixelnerf_loss = self.get_pixelnerf_loss( + cond["rgb"], batch["pixelnerf_input"]["rgb"] + ) + + if not return_model_output: + return loss + self.pixelnerf_loss_weight * pixelnerf_loss + else: + return loss[0] + self.pixelnerf_loss_weight * pixelnerf_loss, loss[1] diff --git a/sgm/modules/diffusionmodules/loss_weighting.py b/sgm/modules/diffusionmodules/loss_weighting.py new file mode 100644 index 0000000000000000000000000000000000000000..e12c0a76635435babd1af33969e82fa284525af8 --- /dev/null +++ b/sgm/modules/diffusionmodules/loss_weighting.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod + +import torch + + +class DiffusionLossWeighting(ABC): + @abstractmethod + def __call__(self, sigma: torch.Tensor) -> torch.Tensor: + pass + + +class UnitWeighting(DiffusionLossWeighting): + def __call__(self, sigma: torch.Tensor) -> torch.Tensor: + return torch.ones_like(sigma, device=sigma.device) + + +class EDMWeighting(DiffusionLossWeighting): + def __init__(self, sigma_data: float = 0.5): + self.sigma_data = sigma_data + + def __call__(self, sigma: torch.Tensor) -> torch.Tensor: + return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + +class VWeighting(EDMWeighting): + def __init__(self): + super().__init__(sigma_data=1.0) + + +class EpsWeighting(DiffusionLossWeighting): + def __call__(self, sigma: torch.Tensor) -> torch.Tensor: + return sigma**-2.0 diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..4cf9d92140dee8443a0ea6b5cf218f2879ad88f4 --- /dev/null +++ b/sgm/modules/diffusionmodules/model.py @@ -0,0 +1,748 @@ +# pytorch_diffusion + derived encoder decoder +import logging +import math +from typing import Any, Callable, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from packaging import version + +logpy = logging.getLogger(__name__) + +try: + import xformers + import xformers.ops + + XFORMERS_IS_AVAILABLE = True +except: + XFORMERS_IS_AVAILABLE = False + logpy.warning("no module 'xformers'. Processing without...") + +from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q, k, v = map( + lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) + ) + h_ = torch.nn.functional.scaled_dot_product_attention( + q, k, v + ) # scale is dim ** -0.5 per default + # compute attention + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.attention_op: Optional[Any] = None + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, "b c h w -> b (h w) c"), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention( + q, k, v, attn_bias=None, op=self.attention_op + ) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + return rearrange(out, "b (h w) c -> b c h w", b=B, h=H, w=W, c=C) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None, **unused_kwargs): + b, c, h, w = x.shape + x = rearrange(x, "b c h w -> b (h w) c") + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, "b (h w) c -> b c h w", h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in [ + "vanilla", + "vanilla-xformers", + "memory-efficient-cross-attn", + "linear", + "none", + ], f"attn_type {attn_type} unknown" + if ( + version.parse(torch.__version__) < version.parse("2.0.0") + and attn_type != "none" + ): + assert XFORMERS_IS_AVAILABLE, ( + f"We do not support vanilla attention in {torch.__version__} anymore, " + f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" + ) + attn_type = "vanilla-xformers" + logpy.info(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + logpy.info( + f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." + ) + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + logpy.info( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) + ) + + make_attn_cls = self._make_attn() + make_resblock_cls = self._make_resblock() + make_conv_cls = self._make_conv() + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type) + self.mid.block_2 = make_resblock_cls( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + make_resblock_cls( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn_cls(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = make_conv_cls( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def _make_attn(self) -> Callable: + return make_attn + + def _make_resblock(self) -> Callable: + return ResnetBlock + + def _make_conv(self) -> Callable: + return torch.nn.Conv2d + + def get_last_layer(self, **kwargs): + return self.conv_out.weight + + def forward(self, z, **kwargs): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb, **kwargs) + h = self.mid.attn_1(h, **kwargs) + h = self.mid.block_2(h, temb, **kwargs) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb, **kwargs) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h, **kwargs) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h, **kwargs) + if self.tanh_out: + h = torch.tanh(h) + return h diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..e762e6823540def71743e27131e284ea28cdb56e --- /dev/null +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,863 @@ +import logging +import math +from abc import abstractmethod +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch as th +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from functools import partial + +# from torch.utils.checkpoint import checkpoint + +checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + +from ...modules.attention import SpatialTransformer +from ...modules.diffusionmodules.util import ( + avg_pool_nd, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) +from ...modules.video_attention import SpatialVideoTransformer +from ...util import exists + +logpy = logging.getLogger(__name__) + + +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: Optional[int] = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x: th.Tensor) -> th.Tensor: + b, c, _ = x.shape + x = x.reshape(b, c, -1) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x: th.Tensor, emb: th.Tensor): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + context: Optional[th.Tensor] = None, + image_only_indicator: Optional[th.Tensor] = None, + time_context: Optional[int] = None, + num_video_frames: Optional[int] = None, + ): + from ...modules.diffusionmodules.video_model import VideoResBlock + + for layer in self: + module = layer + + if isinstance(module, TimestepBlock) and not isinstance( + module, VideoResBlock + ): + x = layer(x, emb) + elif isinstance(module, VideoResBlock): + x = layer(x, emb, num_video_frames, image_only_indicator) + elif isinstance(module, SpatialVideoTransformer): + x = layer( + x, + context, + time_context, + num_video_frames, + image_only_indicator, + ) + elif isinstance(module, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, + channels: int, + use_conv: bool, + dims: int = 2, + out_channels: Optional[int] = None, + padding: int = 1, + third_up: bool = False, + kernel_size: int = 3, + scale_factor: int = 2, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.third_up = third_up + self.scale_factor = scale_factor + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, kernel_size, padding=padding + ) + + def forward(self, x: th.Tensor) -> th.Tensor: + assert x.shape[1] == self.channels + + if self.dims == 3: + t_factor = 1 if not self.third_up else self.scale_factor + x = F.interpolate( + x, + ( + t_factor * x.shape[2], + x.shape[3] * self.scale_factor, + x.shape[4] * self.scale_factor, + ), + mode="nearest", + ) + else: + x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__( + self, + channels: int, + use_conv: bool, + dims: int = 2, + out_channels: Optional[int] = None, + padding: int = 1, + third_down: bool = False, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) + if use_conv: + logpy.info(f"Building a Downsample layer with {dims} dims.") + logpy.info( + f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " + f"kernel-size: 3, stride: {stride}, padding: {padding}" + ) + if dims == 3: + logpy.info(f" --> Downsampling third axis (time): {third_down}") + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x: th.Tensor) -> th.Tensor: + assert x.shape[1] == self.channels + + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + out_channels: Optional[int] = None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + kernel_size: int = 3, + exchange_temb_dims: bool = False, + skip_t_emb: bool = False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + if isinstance(kernel_size, Iterable): + padding = [k // 2 for k in kernel_size] + else: + padding = kernel_size // 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.skip_t_emb = skip_t_emb + self.emb_out_channels = ( + 2 * self.out_channels if use_scale_shift_norm else self.out_channels + ) + if self.skip_t_emb: + logpy.info(f"Skipping timestep embedding in {self.__class__.__name__}") + assert not self.use_scale_shift_norm + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + self.emb_out_channels, + ), + ) + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd( + dims, + self.out_channels, + self.out_channels, + kernel_size, + padding=padding, + ) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, kernel_size, padding=padding + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.use_checkpoint: + return checkpoint(self._forward, x, emb) + else: + return self._forward(x, emb) + + def _forward(self, x: th.Tensor, emb: th.Tensor) -> th.Tensor: + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + + if self.skip_t_emb: + emb_out = th.zeros_like(h) + else: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels: int, + num_heads: int = 1, + num_head_channels: int = -1, + use_checkpoint: bool = False, + use_new_attention_order: bool = False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x: th.Tensor, **kwargs) -> th.Tensor: + return checkpoint(self._forward, x) + + def _forward(self, x: th.Tensor) -> th.Tensor: + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads: int): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv: th.Tensor) -> th.Tensor: + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads: int): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv: th.Tensor) -> th.Tensor: + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + +class Timestep(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, t: th.Tensor) -> th.Tensor: + return timestep_embedding(t, self.dim) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks: int, + attention_resolutions: int, + dropout: float = 0.0, + channel_mult: Union[List, Tuple] = (1, 2, 4, 8), + conv_resample: bool = True, + dims: int = 2, + num_classes: Optional[Union[int, str]] = None, + use_checkpoint: bool = False, + num_heads: int = -1, + num_head_channels: int = -1, + num_heads_upsample: int = -1, + use_scale_shift_norm: bool = False, + resblock_updown: bool = False, + transformer_depth: int = 1, + context_dim: Optional[int] = None, + disable_self_attentions: Optional[List[bool]] = None, + num_attention_blocks: Optional[List[int]] = None, + disable_middle_self_attn: bool = False, + disable_middle_transformer: bool = False, + use_linear_in_transformer: bool = False, + spatial_transformer_attn_type: str = "softmax", + adm_in_channels: Optional[int] = None, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + transformer_depth_middle = transformer_depth[-1] + + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError( + "provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult" + ) + self.num_res_blocks = num_res_blocks + + if disable_self_attentions is not None: + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all( + map( + lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], + range(len(num_attention_blocks)), + ) + ) + logpy.info( + f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set." + ) + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + logpy.info("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if context_dim is not None and exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or nr < num_attention_blocks[level] + ): + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + if not disable_middle_transformer + else th.nn.Identity(), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if ( + not exists(num_attention_blocks) + or i < num_attention_blocks[level] + ): + layers.append( + SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_transformer, + attn_type=spatial_transformer_attn_type, + use_checkpoint=use_checkpoint, + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward( + self, + x: th.Tensor, + timesteps: Optional[th.Tensor] = None, + context: Optional[th.Tensor] = None, + y: Optional[th.Tensor] = None, + **kwargs, + ) -> th.Tensor: + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + + return self.out(h) diff --git a/sgm/modules/diffusionmodules/sampling.py b/sgm/modules/diffusionmodules/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..6346829c86a76ab549ed69431f1704e01379535a --- /dev/null +++ b/sgm/modules/diffusionmodules/sampling.py @@ -0,0 +1,365 @@ +""" + Partially ported from https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py +""" + + +from typing import Dict, Union + +import torch +from omegaconf import ListConfig, OmegaConf +from tqdm import tqdm + +from ...modules.diffusionmodules.sampling_utils import ( + get_ancestral_step, + linear_multistep_coeff, + to_d, + to_neg_log_sigma, + to_sigma, +) +from ...util import append_dims, default, instantiate_from_config + +DEFAULT_GUIDER = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + + +class BaseDiffusionSampler: + def __init__( + self, + discretization_config: Union[Dict, ListConfig, OmegaConf], + num_steps: Union[int, None] = None, + guider_config: Union[Dict, ListConfig, OmegaConf, None] = None, + verbose: bool = False, + device: str = "cuda", + ): + self.num_steps = num_steps + self.discretization = instantiate_from_config(discretization_config) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) + self.verbose = verbose + self.device = device + + def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None): + sigmas = self.discretization( + self.num_steps if num_steps is None else num_steps, device=self.device + ) + uc = default(uc, cond) + + x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) + num_sigmas = len(sigmas) + + s_in = x.new_ones([x.shape[0]]) + + return x, s_in, sigmas, num_sigmas, cond, uc + + def denoise(self, x, denoiser, sigma, cond, uc): + denoised = denoiser(*self.guider.prepare_inputs(x, sigma, cond, uc)) + denoised = self.guider(denoised, sigma) + return denoised + + def get_sigma_gen(self, num_sigmas): + sigma_generator = range(num_sigmas - 1) + if self.verbose: + print("#" * 30, " Sampling setting ", "#" * 30) + print(f"Sampler: {self.__class__.__name__}") + print(f"Discretization: {self.discretization.__class__.__name__}") + print(f"Guider: {self.guider.__class__.__name__}") + sigma_generator = tqdm( + sigma_generator, + total=num_sigmas, + desc=f"Sampling with {self.__class__.__name__} for {num_sigmas} steps", + ) + return sigma_generator + + +class SingleStepDiffusionSampler(BaseDiffusionSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc, *args, **kwargs): + raise NotImplementedError + + def euler_step(self, x, d, dt): + return x + dt * d + + +class EDMSampler(SingleStepDiffusionSampler): + def __init__( + self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs + ): + super().__init__(*args, **kwargs) + + self.s_churn = s_churn + self.s_tmin = s_tmin + self.s_tmax = s_tmax + self.s_noise = s_noise + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + sigma_hat = sigma * (gamma + 1.0) + if gamma > 0: + eps = torch.randn_like(x) * self.s_noise + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 + + denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + d = to_d(x, sigma_hat, denoised) + dt = append_dims(next_sigma - sigma_hat, x.ndim) + + euler_step = self.euler_step(x, d, dt) + x = self.possible_correction_step( + euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) + if self.s_tmin <= sigmas[i] <= self.s_tmax + else 0.0 + ) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, + ) + + return x + + +class AncestralSampler(SingleStepDiffusionSampler): + def __init__(self, eta=1.0, s_noise=1.0, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.eta = eta + self.s_noise = s_noise + self.noise_sampler = lambda x: torch.randn_like(x) + + def ancestral_euler_step(self, x, denoised, sigma, sigma_down): + d = to_d(x, sigma, denoised) + dt = append_dims(sigma_down - sigma, x.ndim) + + return self.euler_step(x, d, dt) + + def ancestral_step(self, x, sigma, next_sigma, sigma_up): + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, + x + self.noise_sampler(x) * self.s_noise * append_dims(sigma_up, x.ndim), + x, + ) + return x + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + for i in self.get_sigma_gen(num_sigmas): + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) + + return x + + +class LinearMultistepSampler(BaseDiffusionSampler): + def __init__( + self, + order=4, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.order = order + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + ds = [] + sigmas_cpu = sigmas.detach().cpu().numpy() + for i in self.get_sigma_gen(num_sigmas): + sigma = s_in * sigmas[i] + denoised = denoiser( + *self.guider.prepare_inputs(x, sigma, cond, uc), **kwargs + ) + denoised = self.guider(denoised, sigma) + d = to_d(x, sigma, denoised) + ds.append(d) + if len(ds) > self.order: + ds.pop(0) + cur_order = min(i + 1, self.order) + coeffs = [ + linear_multistep_coeff(cur_order, sigmas_cpu, i, j) + for j in range(cur_order) + ] + x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) + + return x + + +class EulerEDMSampler(EDMSampler): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): + return euler_step + + +class HeunEDMSampler(EDMSampler): + def possible_correction_step( + self, euler_step, x, d, dt, next_sigma, denoiser, cond, uc + ): + if torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 + return euler_step + else: + denoised = self.denoise(euler_step, denoiser, next_sigma, cond, uc) + d_new = to_d(euler_step, next_sigma, denoised) + d_prime = (d + d_new) / 2.0 + + # apply correction if noise level is not 0 + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x + d_prime * dt, euler_step + ) + return x + + +class EulerAncestralSampler(AncestralSampler): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + + return x + + +class DPMPP2SAncestralSampler(AncestralSampler): + def get_variables(self, sigma, sigma_down): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, sigma_down)] + h = t_next - t + s = t + 0.5 * h + return h, s, t, t_next + + def get_mult(self, h, s, t, t_next): + mult1 = to_sigma(s) / to_sigma(t) + mult2 = (-0.5 * h).expm1() + mult3 = to_sigma(t_next) / to_sigma(t) + mult4 = (-h).expm1() + + return mult1, mult2, mult3, mult4 + + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, **kwargs): + sigma_down, sigma_up = get_ancestral_step(sigma, next_sigma, eta=self.eta) + denoised = self.denoise(x, denoiser, sigma, cond, uc) + x_euler = self.ancestral_euler_step(x, denoised, sigma, sigma_down) + + if torch.sum(sigma_down) < 1e-14: + # Save a network evaluation if all noise levels are 0 + x = x_euler + else: + h, s, t, t_next = self.get_variables(sigma, sigma_down) + mult = [ + append_dims(mult, x.ndim) for mult in self.get_mult(h, s, t, t_next) + ] + + x2 = mult[0] * x - mult[1] * denoised + denoised2 = self.denoise(x2, denoiser, to_sigma(s), cond, uc) + x_dpmpp2s = mult[2] * x - mult[3] * denoised2 + + # apply correction if noise level is not 0 + x = torch.where(append_dims(sigma_down, x.ndim) > 0.0, x_dpmpp2s, x_euler) + + x = self.ancestral_step(x, sigma, next_sigma, sigma_up) + return x + + +class DPMPP2MSampler(BaseDiffusionSampler): + def get_variables(self, sigma, next_sigma, previous_sigma=None): + t, t_next = [to_neg_log_sigma(s) for s in (sigma, next_sigma)] + h = t_next - t + + if previous_sigma is not None: + h_last = t - to_neg_log_sigma(previous_sigma) + r = h_last / h + return h, r, t, t_next + else: + return h, None, t, t_next + + def get_mult(self, h, r, t, t_next, previous_sigma): + mult1 = to_sigma(t_next) / to_sigma(t) + mult2 = (-h).expm1() + + if previous_sigma is not None: + mult3 = 1 + 1 / (2 * r) + mult4 = 1 / (2 * r) + return mult1, mult2, mult3, mult4 + else: + return mult1, mult2 + + def sampler_step( + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, + ): + denoised = self.denoise(x, denoiser, sigma, cond, uc) + + h, r, t, t_next = self.get_variables(sigma, next_sigma, previous_sigma) + mult = [ + append_dims(mult, x.ndim) + for mult in self.get_mult(h, r, t, t_next, previous_sigma) + ] + + x_standard = mult[0] * x - mult[1] * denoised + if old_denoised is None or torch.sum(next_sigma) < 1e-14: + # Save a network evaluation if all noise levels are 0 or on the first step + return x_standard, denoised + else: + denoised_d = mult[2] * denoised - mult[3] * old_denoised + x_advanced = mult[0] * x - mult[1] * denoised_d + + # apply correction if noise level is not 0 and not first step + x = torch.where( + append_dims(next_sigma, x.ndim) > 0.0, x_advanced, x_standard + ) + + return x, denoised + + def __call__(self, denoiser, x, cond, uc=None, num_steps=None, **kwargs): + x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( + x, cond, uc, num_steps + ) + + old_denoised = None + for i in self.get_sigma_gen(num_sigmas): + x, old_denoised = self.sampler_step( + old_denoised, + None if i == 0 else s_in * sigmas[i - 1], + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc=uc, + ) + + return x diff --git a/sgm/modules/diffusionmodules/sampling_utils.py b/sgm/modules/diffusionmodules/sampling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ce78527ea9052a8bfd0856ed2278901516fb9130 --- /dev/null +++ b/sgm/modules/diffusionmodules/sampling_utils.py @@ -0,0 +1,43 @@ +import torch +from scipy import integrate + +from ...util import append_dims + + +def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): + if order - 1 > i: + raise ValueError(f"Order {order} too high for step {i}") + + def fn(tau): + prod = 1.0 + for k in range(order): + if j == k: + continue + prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) + return prod + + return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): + if not eta: + return sigma_to, 0.0 + sigma_up = torch.minimum( + sigma_to, + eta + * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, + ) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + return sigma_down, sigma_up + + +def to_d(x, sigma, denoised): + return (x - denoised) / append_dims(sigma, x.ndim) + + +def to_neg_log_sigma(sigma): + return sigma.log().neg() + + +def to_sigma(neg_log_sigma): + return neg_log_sigma.neg().exp() diff --git a/sgm/modules/diffusionmodules/sigma_sampling.py b/sgm/modules/diffusionmodules/sigma_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..d54724c6ef6a7b8067784a4192b0fe2f41123063 --- /dev/null +++ b/sgm/modules/diffusionmodules/sigma_sampling.py @@ -0,0 +1,31 @@ +import torch + +from ...util import default, instantiate_from_config + + +class EDMSampling: + def __init__(self, p_mean=-1.2, p_std=1.2): + self.p_mean = p_mean + self.p_std = p_std + + def __call__(self, n_samples, rand=None): + log_sigma = self.p_mean + self.p_std * default(rand, torch.randn((n_samples,))) + return log_sigma.exp() + + +class DiscreteSampling: + def __init__(self, discretization_config, num_idx, do_append_zero=False, flip=True): + self.num_idx = num_idx + self.sigmas = instantiate_from_config(discretization_config)( + num_idx, do_append_zero=do_append_zero, flip=flip + ) + + def idx_to_sigma(self, idx): + return self.sigmas[idx] + + def __call__(self, n_samples, rand=None): + idx = default( + rand, + torch.randint(0, self.num_idx, (n_samples,)), + ) + return self.idx_to_sigma(idx) diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..389f0e449367b1b628d61dca105343d066dbefff --- /dev/null +++ b/sgm/modules/diffusionmodules/util.py @@ -0,0 +1,369 @@ +""" +partially adopted from +https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +and +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +and +https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py + +thanks! +""" + +import math +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +def make_beta_schedule( + schedule, + n_timestep, + linear_start=1e-4, + linear_end=2e-2, +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + return betas.numpy() + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def mixed_checkpoint(func, inputs: dict, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. This differs from the original checkpoint function + borrowed from https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py in that + it also works with non-tensor inputs + :param func: the function to evaluate. + :param inputs: the argument dictionary to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + tensor_keys = [key for key in inputs if isinstance(inputs[key], torch.Tensor)] + tensor_inputs = [ + inputs[key] for key in inputs if isinstance(inputs[key], torch.Tensor) + ] + non_tensor_keys = [ + key for key in inputs if not isinstance(inputs[key], torch.Tensor) + ] + non_tensor_inputs = [ + inputs[key] for key in inputs if not isinstance(inputs[key], torch.Tensor) + ] + args = tuple(tensor_inputs) + tuple(non_tensor_inputs) + tuple(params) + return MixedCheckpointFunction.apply( + func, + len(tensor_inputs), + len(non_tensor_inputs), + tensor_keys, + non_tensor_keys, + *args, + ) + else: + return func(**inputs) + + +class MixedCheckpointFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + run_function, + length_tensors, + length_non_tensors, + tensor_keys, + non_tensor_keys, + *args, + ): + ctx.end_tensors = length_tensors + ctx.end_non_tensors = length_tensors + length_non_tensors + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + assert ( + len(tensor_keys) == length_tensors + and len(non_tensor_keys) == length_non_tensors + ) + + ctx.input_tensors = { + key: val for (key, val) in zip(tensor_keys, list(args[: ctx.end_tensors])) + } + ctx.input_non_tensors = { + key: val + for (key, val) in zip( + non_tensor_keys, list(args[ctx.end_tensors : ctx.end_non_tensors]) + ) + } + ctx.run_function = run_function + ctx.input_params = list(args[ctx.end_non_tensors :]) + + with torch.no_grad(): + output_tensors = ctx.run_function( + **ctx.input_tensors, **ctx.input_non_tensors + ) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + # additional_args = {key: ctx.input_tensors[key] for key in ctx.input_tensors if not isinstance(ctx.input_tensors[key],torch.Tensor)} + ctx.input_tensors = { + key: ctx.input_tensors[key].detach().requires_grad_(True) + for key in ctx.input_tensors + } + + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = { + key: ctx.input_tensors[key].view_as(ctx.input_tensors[key]) + for key in ctx.input_tensors + } + # shallow_copies.update(additional_args) + output_tensors = ctx.run_function(**shallow_copies, **ctx.input_non_tensors) + input_grads = torch.autograd.grad( + output_tensors, + list(ctx.input_tensors.values()) + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return ( + (None, None, None, None, None) + + input_grads[: ctx.end_tensors] + + (None,) * (ctx.end_non_tensors - ctx.end_tensors) + + input_grads[ctx.end_tensors :] + ) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class AlphaBlender(nn.Module): + strategies = ["learned", "fixed", "learned_with_images"] + + def __init__( + self, + alpha: float, + merge_strategy: str = "learned_with_images", + rearrange_pattern: str = "b t -> (b t) 1 1", + ): + super().__init__() + self.merge_strategy = merge_strategy + self.rearrange_pattern = rearrange_pattern + + assert ( + merge_strategy in self.strategies + ), f"merge_strategy needs to be in {self.strategies}" + + if self.merge_strategy == "fixed": + self.register_buffer("mix_factor", torch.Tensor([alpha])) + elif ( + self.merge_strategy == "learned" + or self.merge_strategy == "learned_with_images" + ): + self.register_parameter( + "mix_factor", torch.nn.Parameter(torch.Tensor([alpha])) + ) + else: + raise ValueError(f"unknown merge strategy {self.merge_strategy}") + + def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor: + if self.merge_strategy == "fixed": + alpha = self.mix_factor + elif self.merge_strategy == "learned": + alpha = torch.sigmoid(self.mix_factor) + elif self.merge_strategy == "learned_with_images": + assert image_only_indicator is not None, "need image_only_indicator ..." + alpha = torch.where( + image_only_indicator.bool(), + torch.ones(1, 1, device=image_only_indicator.device), + rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"), + ) + alpha = rearrange(alpha, self.rearrange_pattern) + else: + raise NotImplementedError + return alpha + + def forward( + self, + x_spatial: torch.Tensor, + x_temporal: torch.Tensor, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + alpha = self.get_alpha(image_only_indicator) + x = ( + alpha.to(x_spatial.dtype) * x_spatial + + (1.0 - alpha).to(x_spatial.dtype) * x_temporal + ) + return x diff --git a/sgm/modules/diffusionmodules/video_model.py b/sgm/modules/diffusionmodules/video_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2d077c7d0c7ed1c4a2c21f14105c266abc4926 --- /dev/null +++ b/sgm/modules/diffusionmodules/video_model.py @@ -0,0 +1,493 @@ +from functools import partial +from typing import List, Optional, Union + +from einops import rearrange + +from ...modules.diffusionmodules.openaimodel import * +from ...modules.video_attention import SpatialVideoTransformer +from ...util import default +from .util import AlphaBlender + + +class VideoResBlock(ResBlock): + def __init__( + self, + channels: int, + emb_channels: int, + dropout: float, + video_kernel_size: Union[int, List[int]] = 3, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + out_channels: Optional[int] = None, + use_conv: bool = False, + use_scale_shift_norm: bool = False, + dims: int = 2, + use_checkpoint: bool = False, + up: bool = False, + down: bool = False, + ): + super().__init__( + channels, + emb_channels, + dropout, + out_channels=out_channels, + use_conv=use_conv, + use_scale_shift_norm=use_scale_shift_norm, + dims=dims, + use_checkpoint=use_checkpoint, + up=up, + down=down, + ) + + self.time_stack = ResBlock( + default(out_channels, channels), + emb_channels, + dropout=dropout, + dims=3, + out_channels=default(out_channels, channels), + use_scale_shift_norm=False, + use_conv=False, + up=False, + down=False, + kernel_size=video_kernel_size, + use_checkpoint=use_checkpoint, + exchange_temb_dims=True, + ) + self.time_mixer = AlphaBlender( + alpha=merge_factor, + merge_strategy=merge_strategy, + rearrange_pattern="b t -> b 1 t 1 1", + ) + + def forward( + self, + x: th.Tensor, + emb: th.Tensor, + num_video_frames: int, + image_only_indicator: Optional[th.Tensor] = None, + ) -> th.Tensor: + x = super().forward(x, emb) + + x_mix = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + x = rearrange(x, "(b t) c h w -> b c t h w", t=num_video_frames) + + x = self.time_stack( + x, rearrange(emb, "(b t) ... -> b t ...", t=num_video_frames) + ) + x = self.time_mixer( + x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator + ) + x = rearrange(x, "b c t h w -> (b t) c h w") + return x + + +class VideoUNet(nn.Module): + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_res_blocks: int, + attention_resolutions: int, + dropout: float = 0.0, + channel_mult: List[int] = (1, 2, 4, 8), + conv_resample: bool = True, + dims: int = 2, + num_classes: Optional[int] = None, + use_checkpoint: bool = False, + num_heads: int = -1, + num_head_channels: int = -1, + num_heads_upsample: int = -1, + use_scale_shift_norm: bool = False, + resblock_updown: bool = False, + transformer_depth: Union[List[int], int] = 1, + transformer_depth_middle: Optional[int] = None, + context_dim: Optional[int] = None, + time_downup: bool = False, + time_context_dim: Optional[int] = None, + extra_ff_mix_layer: bool = False, + use_spatial_context: bool = False, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + spatial_transformer_attn_type: str = "softmax", + video_kernel_size: Union[int, List[int]] = 3, + use_linear_in_transformer: bool = False, + adm_in_channels: Optional[int] = None, + disable_temporal_crossattention: bool = False, + max_ddpm_temb_period: int = 10000, + ): + super().__init__() + assert context_dim is not None + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1 + + if num_head_channels == -1: + assert num_heads != -1 + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(transformer_depth, int): + transformer_depth = len(channel_mult) * [transformer_depth] + transformer_depth_middle = default( + transformer_depth_middle, transformer_depth[-1] + ) + + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "timestep": + self.label_emb = nn.Sequential( + Timestep(model_channels), + nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ), + ) + + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + + def get_attention_layer( + ch, + num_heads, + dim_head, + depth=1, + context_dim=None, + use_checkpoint=False, + disabled_sa=False, + ): + return SpatialVideoTransformer( + ch, + num_heads, + dim_head, + depth=depth, + context_dim=context_dim, + time_context_dim=time_context_dim, + dropout=dropout, + ff_in=extra_ff_mix_layer, + use_spatial_context=use_spatial_context, + merge_strategy=merge_strategy, + merge_factor=merge_factor, + checkpoint=use_checkpoint, + use_linear=use_linear_in_transformer, + attn_mode=spatial_transformer_attn_type, + disable_self_attn=disabled_sa, + disable_temporal_crossattention=disable_temporal_crossattention, + max_time_embed_period=max_ddpm_temb_period, + ) + + def get_resblock( + merge_factor, + merge_strategy, + video_kernel_size, + ch, + time_embed_dim, + dropout, + out_ch, + dims, + use_checkpoint, + use_scale_shift_norm, + down=False, + up=False, + ): + return VideoResBlock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + channels=ch, + emb_channels=time_embed_dim, + dropout=dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=down, + up=up, + ) + + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + layers.append( + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + use_checkpoint=use_checkpoint, + disabled_sa=False, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + ds *= 2 + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, + conv_resample, + dims=dims, + out_channels=out_ch, + third_down=time_downup, + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + self.middle_block = TimestepEmbedSequential( + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + out_ch=None, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth_middle, + context_dim=context_dim, + use_checkpoint=use_checkpoint, + ), + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + out_ch=None, + time_embed_dim=time_embed_dim, + dropout=dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch + ich, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + + layers.append( + get_attention_layer( + ch, + num_heads, + dim_head, + depth=transformer_depth[level], + context_dim=context_dim, + use_checkpoint=use_checkpoint, + disabled_sa=False, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + ds //= 2 + layers.append( + get_resblock( + merge_factor=merge_factor, + merge_strategy=merge_strategy, + video_kernel_size=video_kernel_size, + ch=ch, + time_embed_dim=time_embed_dim, + dropout=dropout, + out_ch=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample( + ch, + conv_resample, + dims=dims, + out_channels=out_ch, + third_up=time_downup, + ) + ) + + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + + def forward( + self, + x: th.Tensor, + timesteps: th.Tensor, + context: Optional[th.Tensor] = None, + y: Optional[th.Tensor] = None, + time_context: Optional[th.Tensor] = None, + num_video_frames: Optional[int] = None, + image_only_indicator: Optional[th.Tensor] = None, + ): + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional -> no, relax this TODO" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for module in self.input_blocks: + h = module( + h, + emb, + context=context, + image_only_indicator=image_only_indicator, + time_context=time_context, + num_video_frames=num_video_frames, + ) + hs.append(h) + h = self.middle_block( + h, + emb, + context=context, + image_only_indicator=image_only_indicator, + time_context=time_context, + num_video_frames=num_video_frames, + ) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module( + h, + emb, + context=context, + image_only_indicator=image_only_indicator, + time_context=time_context, + num_video_frames=num_video_frames, + ) + h = h.type(x.dtype) + return self.out(h) diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..37449ea63e992b9f89856f1f47c18ba68be8e334 --- /dev/null +++ b/sgm/modules/diffusionmodules/wrappers.py @@ -0,0 +1,34 @@ +import torch +import torch.nn as nn +from packaging import version + +OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper" + + +class IdentityWrapper(nn.Module): + def __init__(self, diffusion_model, compile_model: bool = False): + super().__init__() + compile = ( + torch.compile + if (version.parse(torch.__version__) >= version.parse("2.0.0")) + and compile_model + else lambda x: x + ) + self.diffusion_model = compile(diffusion_model) + + def forward(self, *args, **kwargs): + return self.diffusion_model(*args, **kwargs) + + +class OpenAIWrapper(IdentityWrapper): + def forward( + self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs + ) -> torch.Tensor: + x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1) + return self.diffusion_model( + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, + ) diff --git a/sgm/modules/distributions/__init__.py b/sgm/modules/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..016be35523187ea366db9ade391fe8ee276db60b --- /dev/null +++ b/sgm/modules/distributions/distributions.py @@ -0,0 +1,102 @@ +import numpy as np +import torch + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/sgm/modules/ema.py b/sgm/modules/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..97b5ae2b230f89b4dba57e44c4f851478ad86f68 --- /dev/null +++ b/sgm/modules/ema.py @@ -0,0 +1,86 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer("num_updates", torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/sgm/modules/encoders/__init__.py b/sgm/modules/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/sgm/modules/encoders/image_encoder.py b/sgm/modules/encoders/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..60d693245bc562987376b7d0fff80086fb936279 --- /dev/null +++ b/sgm/modules/encoders/image_encoder.py @@ -0,0 +1,349 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import importlib + + +def class_for_name(module_name, class_name): + # load the module, will raise ImportError if module cannot be loaded + m = importlib.import_module(module_name) + return getattr(m, class_name) + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + padding_mode="reflect", + ) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False, + padding_mode="reflect", + ) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # norm_layer = nn.InstanceNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__( + self, + inplanes, + planes, + stride=1, + downsample=None, + groups=1, + base_width=64, + dilation=1, + norm_layer=None, + ): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # norm_layer = nn.InstanceNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width, track_running_stats=False, affine=True) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width, track_running_stats=False, affine=True) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer( + planes * self.expansion, track_running_stats=False, affine=True + ) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class conv(nn.Module): + def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): + super(conv, self).__init__() + self.kernel_size = kernel_size + self.conv = nn.Conv2d( + num_in_layers, + num_out_layers, + kernel_size=kernel_size, + stride=stride, + padding=(self.kernel_size - 1) // 2, + padding_mode="reflect", + ) + # self.bn = nn.InstanceNorm2d( + # num_out_layers, track_running_stats=False, affine=True + # ) + self.bn = nn.BatchNorm2d(num_out_layers, track_running_stats=False, affine=True) + # self.bn = nn.LayerNorm(num_out_layers) + + def forward(self, x): + return F.elu(self.bn(self.conv(x)), inplace=True) + + +class upconv(nn.Module): + def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): + super(upconv, self).__init__() + self.scale = scale + self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) + + def forward(self, x): + x = nn.functional.interpolate( + x, scale_factor=self.scale, align_corners=True, mode="bilinear" + ) + return self.conv(x) + + +class ResUNet(nn.Module): + def __init__( + self, + encoder="resnet34", + coarse_out_ch=32, + fine_out_ch=32, + norm_layer=None, + coarse_only=False, + ): + super(ResUNet, self).__init__() + assert encoder in [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + ], "Incorrect encoder type" + if encoder in ["resnet18", "resnet34"]: + filters = [64, 128, 256, 512] + else: + filters = [256, 512, 1024, 2048] + self.coarse_only = coarse_only + if self.coarse_only: + fine_out_ch = 0 + self.coarse_out_ch = coarse_out_ch + self.fine_out_ch = fine_out_ch + out_ch = coarse_out_ch + fine_out_ch + + # original + layers = [3, 4, 6, 3] + if norm_layer is None: + norm_layer = nn.BatchNorm2d + # norm_layer = nn.InstanceNorm2d + self._norm_layer = norm_layer + self.dilation = 1 + block = BasicBlock + replace_stride_with_dilation = [False, False, False] + self.inplanes = 64 + self.groups = 1 + self.base_width = 64 + self.conv1 = nn.Conv2d( + 3, + self.inplanes, + kernel_size=7, + stride=2, + padding=3, + bias=False, + padding_mode="reflect", + ) + self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) + self.relu = nn.ReLU(inplace=True) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + + # decoder + self.upconv3 = upconv(filters[2], 128, 3, 2) + self.iconv3 = conv(filters[1] + 128, 128, 3, 1) + self.upconv2 = upconv(128, 64, 3, 2) + self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1) + + # fine-level conv + self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer( + planes * block.expansion, track_running_stats=False, affine=True + ), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + norm_layer, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def skipconnect(self, x1, x2): + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) + + # for padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + + x = torch.cat([x2, x1], dim=1) + return x + + def forward(self, x): + x = self.relu(self.bn1(self.conv1(x))) + + x1 = self.layer1(x) + x2 = self.layer2(x1) + x3 = self.layer3(x2) + + x = self.upconv3(x3) + x = self.skipconnect(x2, x) + x = self.iconv3(x) + + x = self.upconv2(x) + x = self.skipconnect(x1, x) + x = self.iconv2(x) + + x_out = self.out_conv(x) + + return x_out + + # if self.coarse_only: + # x_coarse = x_out + # x_fine = None + # else: + # x_coarse = x_out[:, : self.coarse_out_ch, :] + # x_fine = x_out[:, -self.fine_out_ch :, :] + # return x_coarse, x_fine diff --git a/sgm/modules/encoders/image_encoder_v2.py b/sgm/modules/encoders/image_encoder_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..72c782b3edee155fa4367e697a94d6b8b6b86b85 --- /dev/null +++ b/sgm/modules/encoders/image_encoder_v2.py @@ -0,0 +1,160 @@ +""" +UNet Network in PyTorch, modified from https://github.com/milesial/Pytorch-UNet +with architecture referenced from https://keras.io/examples/vision/depth_estimation +for monocular depth estimation from RGB images, i.e. one output channel. +""" + +import torch +from torch import nn + + +class UNet(nn.Module): + """ + The overall UNet architecture. + """ + + def __init__(self): + super().__init__() + + self.downscale_blocks = nn.ModuleList( + [ + DownBlock(16, 32), + DownBlock(32, 64), + DownBlock(64, 128), + DownBlock(128, 256), + ] + ) + self.upscale_blocks = nn.ModuleList( + [ + UpBlock(256, 128), + UpBlock(128, 64), + UpBlock(64, 32), + UpBlock(32, 16), + ] + ) + + self.input_conv = nn.Conv2d(3, 16, kernel_size=3, padding="same") + self.output_conv = nn.Conv2d(16, 1, kernel_size=1) + self.bridge = BottleNeckBlock(256) + self.activation = nn.Sigmoid() + + def forward(self, x): + x = self.input_conv(x) + + skip_features = [] + for block in self.downscale_blocks: + c, x = block(x) + skip_features.append(c) + + x = self.bridge(x) + + skip_features.reverse() + for block, skip in zip(self.upscale_blocks, skip_features): + x = block(x, skip) + + x = self.output_conv(x) + x = self.activation(x) + return x + + +class DownBlock(nn.Module): + """ + Module that performs downscaling with residual connections. + """ + + def __init__(self, in_channels, out_channels, padding="same", stride=1): + super().__init__() + self.conv1 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=False, + ) + self.conv2 = nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + self.relu = nn.LeakyReLU(0.2) + self.maxpool = nn.MaxPool2d(2) + + def forward(self, x): + d = self.conv1(x) + x = self.bn1(d) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = x + d + p = self.maxpool(x) + return x, p + + +class UpBlock(nn.Module): + """ + Module that performs upscaling after concatenation with skip connections. + """ + + def __init__(self, in_channels, out_channels, padding="same", stride=1): + super().__init__() + self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + self.conv1 = nn.Conv2d( + in_channels * 2, + in_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=False, + ) + self.conv2 = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=padding, + bias=False, + ) + self.bn1 = nn.BatchNorm2d(in_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + self.relu = nn.LeakyReLU(0.2) + + def forward(self, x, skip): + x = self.up(x) + x = torch.cat([x, skip], dim=1) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + return x + + +class BottleNeckBlock(nn.Module): + """ + BottleNeckBlock that serves as the UNet bridge. + """ + + def __init__(self, channels, padding="same", strides=1): + super().__init__() + self.conv1 = nn.Conv2d(channels, channels, 3, 1, "same") + self.conv2 = nn.Conv2d(channels, channels, 3, 1, "same") + self.relu = nn.LeakyReLU(0.2) + + def forward(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.relu(x) + return x \ No newline at end of file diff --git a/sgm/modules/encoders/math_utils.py b/sgm/modules/encoders/math_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..35e4f6f876ebb5878fa05e93bdf10488cb73e297 --- /dev/null +++ b/sgm/modules/encoders/math_utils.py @@ -0,0 +1,139 @@ +# MIT License + +# Copyright (c) 2022 Petr Kellnhofer + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch + + +def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: + """ + Left-multiplies MxM @ NxM. Returns NxM. + """ + res = torch.matmul(vectors4, matrix.T) + return res + + +def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: + """ + Normalize vector lengths. + """ + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + + +def torch_dot(x: torch.Tensor, y: torch.Tensor): + """ + Dot product of two tensors. + """ + return (x * y).sum(-1) + + +def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): + """ + Author: Petr Kellnhofer + Intersects rays with the [-1, 1] NDC volume. + Returns min and max distance of entry. + Returns -1 for no intersection. + https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection + """ + o_shape = rays_o.shape + rays_o = rays_o.detach().reshape(-1, 3) + rays_d = rays_d.detach().reshape(-1, 3) + + bb_min = [ + -1 * (box_side_length / 2), + -1 * (box_side_length / 2), + -1 * (box_side_length / 2), + ] + bb_max = [ + 1 * (box_side_length / 2), + 1 * (box_side_length / 2), + 1 * (box_side_length / 2), + ] + bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) + is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) + + # Precompute inverse for stability. + invdir = 1 / rays_d + sign = (invdir < 0).long() + + # Intersect with YZ plane. + tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[ + ..., 0 + ] + tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[ + ..., 0 + ] + + # Intersect with XZ plane. + tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[ + ..., 1 + ] + tymax = ( + bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1] + ) * invdir[..., 1] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tymin) + tmax = torch.min(tmax, tymax) + + # Intersect with XY plane. + tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[ + ..., 2 + ] + tzmax = ( + bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2] + ) * invdir[..., 2] + + # Resolve parallel rays. + is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False + + # Use the shortest intersection. + tmin = torch.max(tmin, tzmin) + tmax = torch.min(tmax, tzmax) + + # Mark invalid. + tmin[torch.logical_not(is_valid)] = -1 + tmax[torch.logical_not(is_valid)] = -2 + + return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) + + +def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): + """ + Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. + Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. + """ + # create a tensor of 'num' steps from 0 to 1 + steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) + + # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings + # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript + # "cannot statically infer the expected size of a list in this contex", hence the code below + for i in range(start.ndim): + steps = steps.unsqueeze(-1) + + # the output starts at 'start' and increments until 'stop' in each dimension + out = start[None] + steps * (stop - start)[None] + + return out diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..9860779362c766f4e9171d98c7411a2b178a842d --- /dev/null +++ b/sgm/modules/encoders/modules.py @@ -0,0 +1,1189 @@ +import math +from contextlib import nullcontext +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import kornia +import numpy as np +import open_clip +import torch +import torch.nn as nn +from einops import rearrange, repeat +from omegaconf import ListConfig + +# from torch.utils.checkpoint import checkpoint + +checkpoint = partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + +from transformers import ( + ByT5Tokenizer, + CLIPTextModel, + CLIPTokenizer, + T5EncoderModel, + T5Tokenizer, +) + +from ...modules.autoencoding.regularizers import DiagonalGaussianRegularizer +from ...modules.diffusionmodules.model import Encoder +from ...modules.diffusionmodules.openaimodel import Timestep +from ...modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ...modules.distributions.distributions import DiagonalGaussianDistribution +from ...util import ( + append_dims, + autocast, + count_params, + default, + disabled_train, + expand_dims_like, + instantiate_from_config, +) + + +class AbstractEmbModel(nn.Module): + def __init__(self): + super().__init__() + self._is_trainable = None + self._ucg_rate = None + self._input_key = None + + @property + def is_trainable(self) -> bool: + return self._is_trainable + + @property + def ucg_rate(self) -> Union[float, torch.Tensor]: + return self._ucg_rate + + @property + def input_key(self) -> str: + return self._input_key + + @is_trainable.setter + def is_trainable(self, value: bool): + self._is_trainable = value + + @ucg_rate.setter + def ucg_rate(self, value: Union[float, torch.Tensor]): + self._ucg_rate = value + + @input_key.setter + def input_key(self, value: str): + self._input_key = value + + @is_trainable.deleter + def is_trainable(self): + del self._is_trainable + + @ucg_rate.deleter + def ucg_rate(self): + del self._ucg_rate + + @input_key.deleter + def input_key(self): + del self._input_key + + +class GeneralConditioner(nn.Module): + OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} + KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} + + def __init__(self, emb_models: Union[List, ListConfig]): + super().__init__() + embedders = [] + for n, embconfig in enumerate(emb_models): + embedder = instantiate_from_config(embconfig) + assert isinstance( + embedder, AbstractEmbModel + ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" + embedder.is_trainable = embconfig.get("is_trainable", False) + embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) + if not embedder.is_trainable: + embedder.train = disabled_train + for param in embedder.parameters(): + param.requires_grad = False + embedder.eval() + print( + f"Initialized embedder #{n}: {embedder.__class__.__name__} " + f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" + ) + + if "input_key" in embconfig: + embedder.input_key = embconfig["input_key"] + elif "input_keys" in embconfig: + embedder.input_keys = embconfig["input_keys"] + else: + raise KeyError( + f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}" + ) + + embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) + if embedder.legacy_ucg_val is not None: + embedder.ucg_prng = np.random.RandomState() + + embedders.append(embedder) + self.embedders = nn.ModuleList(embedders) + + def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: Dict) -> Dict: + assert embedder.legacy_ucg_val is not None + p = embedder.ucg_rate + val = embedder.legacy_ucg_val + for i in range(len(batch[embedder.input_key])): + if embedder.ucg_prng.choice(2, p=[1 - p, p]): + batch[embedder.input_key][i] = val + return batch + + def forward( + self, batch: Dict, force_zero_embeddings: Optional[List] = None + ) -> Dict: + output = dict() + if force_zero_embeddings is None: + force_zero_embeddings = [] + for embedder in self.embedders: + embedding_context = nullcontext if embedder.is_trainable else torch.no_grad + with embedding_context(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + if embedder.legacy_ucg_val is not None: + batch = self.possibly_get_ucg_val(embedder, batch) + emb_out = embedder(batch[embedder.input_key]) + elif hasattr(embedder, "input_keys"): + emb_out = embedder(*[batch[k] for k in embedder.input_keys]) + assert isinstance( + emb_out, (torch.Tensor, list, tuple) + ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" + if not isinstance(emb_out, (list, tuple)): + emb_out = [emb_out] + for emb in emb_out: + out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: + emb = ( + expand_dims_like( + torch.bernoulli( + (1.0 - embedder.ucg_rate) + * torch.ones(emb.shape[0], device=emb.device) + ), + emb, + ) + * emb + ) + if ( + hasattr(embedder, "input_key") + and embedder.input_key in force_zero_embeddings + ): + emb = torch.zeros_like(emb) + if out_key in output: + output[out_key] = torch.cat( + (output[out_key], emb), self.KEY2CATDIM[out_key] + ) + else: + output[out_key] = emb + + # if "num_video_frames" in batch: + # num_frames = batch["num_video_frames"] + # for k in ["crossattn", "concat"]: + # output[k] = repeat(output[k], "b ... -> b t ...", t=num_frames) + # output[k] = rearrange(output[k], "b t ... -> (b t) ...", t=num_frames) + + return output + + def get_unconditional_conditioning( + self, + batch_c: Dict, + batch_uc: Optional[Dict] = None, + force_uc_zero_embeddings: Optional[List[str]] = None, + force_cond_zero_embeddings: Optional[List[str]] = None, + ): + if force_uc_zero_embeddings is None: + force_uc_zero_embeddings = [] + ucg_rates = list() + for embedder in self.embedders: + ucg_rates.append(embedder.ucg_rate) + embedder.ucg_rate = 0.0 + c = self(batch_c, force_cond_zero_embeddings) + uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) + + for embedder, rate in zip(self.embedders, ucg_rates): + embedder.ucg_rate = rate + return c, uc + + +class InceptionV3(nn.Module): + """Wrapper around the https://github.com/mseitzer/pytorch-fid inception + port with an additional squeeze at the end""" + + def __init__(self, normalize_input=False, **kwargs): + super().__init__() + from pytorch_fid import inception + + kwargs["resize_input"] = True + self.model = inception.InceptionV3(normalize_input=normalize_input, **kwargs) + + def forward(self, inp): + outp = self.model(inp) + + if len(outp) == 1: + return outp[0].squeeze() + + return outp + + +class IdentityEncoder(AbstractEmbModel): + def encode(self, x): + return x + + def forward(self, x): + return x + + +class ClassEmbedder(AbstractEmbModel): + def __init__(self, embed_dim, n_classes=1000, add_sequence_dim=False): + super().__init__() + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.add_sequence_dim = add_sequence_dim + + def forward(self, c): + c = self.embedding(c) + if self.add_sequence_dim: + c = c[:, None, :] + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = ( + self.n_classes - 1 + ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc.long()} + return uc + + +class ClassEmbedderForMultiCond(ClassEmbedder): + def forward(self, batch, key=None, disable_dropout=False): + out = batch + key = default(key, self.key) + islist = isinstance(batch[key], list) + if islist: + batch[key] = batch[key][0] + c_out = super().forward(batch, key, disable_dropout) + out[key] = [c_out] if islist else c_out + return out + + +class FrozenT5Embedder(AbstractEmbModel): + """Uses the T5 transformer encoder for text""" + + def __init__( + self, version="google/t5-v1_1-xxl", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("cuda", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenByT5Embedder(AbstractEmbModel): + """ + Uses the ByT5 transformer encoder for text. Is character-aware. + """ + + def __init__( + self, version="google/byt5-base", device="cuda", max_length=77, freeze=True + ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = ByT5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + with torch.autocast("cuda", enabled=False): + outputs = self.transformer(input_ids=tokens) + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEmbModel): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + max_length=77, + freeze=True, + layer="last", + layer_idx=None, + always_return_pooled=False, + ): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + self.return_pooled = always_return_pooled + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, text): + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer( + input_ids=tokens, output_hidden_states=self.layer == "hidden" + ) + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + if self.return_pooled: + return z, outputs.pooler_output + return z + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder2(AbstractEmbModel): + """ + Uses the OpenCLIP transformer encoder for text + """ + + LAYERS = ["pooled", "last", "penultimate"] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + always_return_pooled=False, + legacy=True, + ): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + self.return_pooled = always_return_pooled + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + self.legacy = legacy + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + if not self.return_pooled and self.legacy: + return z + if self.return_pooled: + assert not self.legacy + return z[self.layer], z["pooled"] + return z[self.layer] + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + if self.legacy: + x = x[self.layer] + x = self.model.ln_final(x) + return x + else: + # x is a dict and will stay a dict + o = x["last"] + o = self.model.ln_final(o) + pooled = self.pool(o, text) + x["pooled"] = pooled + return x + + def pool(self, x, text): + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = ( + x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + @ self.model.text_projection + ) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + outputs = {} + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - 1: + outputs["penultimate"] = x.permute(1, 0, 2) # LND -> NLD + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + outputs["last"] = x.permute(1, 0, 2) # LND -> NLD + return outputs + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPEmbedder(AbstractEmbModel): + LAYERS = [ + # "pooled", + "last", + "penultimate", + ] + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + layer="last", + ): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device("cpu"), + pretrained=version, + ) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if ( + self.model.transformer.grad_checkpointing + and not torch.jit.is_scripting() + ): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__( + self, + arch="ViT-H-14", + version="laion2b_s32b_b79k", + device="cuda", + max_length=77, + freeze=True, + antialias=True, + ucg_rate=0.0, + unsqueeze_dim=False, + repeat_to_max_len=False, + num_image_crops=0, + output_tokens=False, + init_device=None, + ): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms( + arch, + device=torch.device(default(init_device, "cpu")), + pretrained=version, + ) + del model.transformer + self.model = model + self.max_crops = num_image_crops + self.pad_to_max_len = self.max_crops > 0 + self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + + self.antialias = antialias + + self.register_buffer( + "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False + ) + self.register_buffer( + "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False + ) + self.ucg_rate = ucg_rate + self.unsqueeze_dim = unsqueeze_dim + self.stored_batch = None + self.model.visual.output_tokens = output_tokens + self.output_tokens = output_tokens + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize( + x, + (224, 224), + interpolation="bicubic", + align_corners=True, + antialias=self.antialias, + ) + x = (x + 1.0) / 2.0 + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + tokens = None + if self.output_tokens: + z, tokens = z[0], z[1] + z = z.to(image.dtype) + if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): + z = ( + torch.bernoulli( + (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) + )[:, None] + * z + ) + if tokens is not None: + tokens = ( + expand_dims_like( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(tokens.shape[0], device=tokens.device) + ), + tokens, + ) + * tokens + ) + if self.unsqueeze_dim: + z = z[:, None, :] + if self.output_tokens: + assert not self.repeat_to_max_len + assert not self.pad_to_max_len + return tokens, z + if self.repeat_to_max_len: + if z.dim() == 2: + z_ = z[:, None, :] + else: + z_ = z + return repeat(z_, "b 1 d -> b n d", n=self.max_length), z + elif self.pad_to_max_len: + assert z.dim() == 3 + z_pad = torch.cat( + ( + z, + torch.zeros( + z.shape[0], + self.max_length - z.shape[1], + z.shape[2], + device=z.device, + ), + ), + 1, + ) + return z_pad, z_pad[:, 0, ...] + return z + + def encode_with_vision_transformer(self, img): + # if self.max_crops > 0: + # img = self.preprocess_by_cropping(img) + if img.dim() == 5: + assert self.max_crops == img.shape[1] + img = rearrange(img, "b n c h w -> (b n) c h w") + img = self.preprocess(img) + if not self.output_tokens: + assert not self.model.visual.output_tokens + x = self.model.visual(img) + tokens = None + else: + assert self.model.visual.output_tokens + x, tokens = self.model.visual(img) + if self.max_crops > 0: + x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) + # drop out between 0 and all along the sequence axis + x = ( + torch.bernoulli( + (1.0 - self.ucg_rate) + * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) + ) + * x + ) + if tokens is not None: + tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) + print( + f"You are running very experimental token-concat in {self.__class__.__name__}. " + f"Check what you are doing, and then remove this message." + ) + if self.output_tokens: + return x, tokens + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEmbModel): + def __init__( + self, + clip_version="openai/clip-vit-large-patch14", + t5_version="google/t5-v1_1-xl", + device="cuda", + clip_max_length=77, + t5_max_length=77, + ): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder( + clip_version, device, max_length=clip_max_length + ) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print( + f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." + ) + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + +class SpatialRescaler(nn.Module): + def __init__( + self, + n_stages=1, + method="bilinear", + multiplier=0.5, + in_channels=3, + out_channels=None, + bias=False, + wrap_video=False, + kernel_size=1, + remap_output=False, + ): + super().__init__() + self.n_stages = n_stages + assert self.n_stages >= 0 + assert method in [ + "nearest", + "linear", + "bilinear", + "trilinear", + "bicubic", + "area", + ] + self.multiplier = multiplier + self.interpolator = partial(torch.nn.functional.interpolate, mode=method) + self.remap_output = out_channels is not None or remap_output + if self.remap_output: + print( + f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." + ) + self.channel_mapper = nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + bias=bias, + padding=kernel_size // 2, + ) + self.wrap_video = wrap_video + + def forward(self, x): + if self.wrap_video and x.ndim == 5: + B, C, T, H, W = x.shape + x = rearrange(x, "b c t h w -> b t c h w") + x = rearrange(x, "b t c h w -> (b t) c h w") + + for stage in range(self.n_stages): + x = self.interpolator(x, scale_factor=self.multiplier) + + if self.wrap_video: + x = rearrange(x, "(b t) c h w -> b t c h w", b=B, t=T, c=C) + x = rearrange(x, "b t c h w -> b c t h w") + if self.remap_output: + x = self.channel_mapper(x) + return x + + def encode(self, x): + return self(x) + + +class LowScaleEncoder(nn.Module): + def __init__( + self, + model_config, + linear_start, + linear_end, + timesteps=1000, + max_noise_level=250, + output_size=64, + scale_factor=1.0, + ): + super().__init__() + self.max_noise_level = max_noise_level + self.model = instantiate_from_config(model_config) + self.augmentation_schedule = self.register_schedule( + timesteps=timesteps, linear_start=linear_start, linear_end=linear_end + ) + self.out_size = output_size + self.scale_factor = scale_factor + + def register_schedule( + self, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x): + z = self.model.encode(x) + if isinstance(z, DiagonalGaussianDistribution): + z = z.sample() + z = z * self.scale_factor + noise_level = torch.randint( + 0, self.max_noise_level, (x.shape[0],), device=x.device + ).long() + z = self.q_sample(z, noise_level) + if self.out_size is not None: + z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") + return z, noise_level + + def decode(self, z): + z = z / self.scale_factor + return self.model.decode(z) + + +class ConcatTimestepEmbedderND(AbstractEmbModel): + """embeds each dimension independently and concatenates them""" + + def __init__(self, outdim): + super().__init__() + self.timestep = Timestep(outdim) + self.outdim = outdim + + def forward(self, x): + if x.ndim == 1: + x = x[:, None] + assert len(x.shape) == 2 + b, dims = x.shape[0], x.shape[1] + x = rearrange(x, "b d -> (b d)") + emb = self.timestep(x) + emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return emb + + +class GaussianEncoder(Encoder, AbstractEmbModel): + def __init__( + self, weight: float = 1.0, flatten_output: bool = True, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.posterior = DiagonalGaussianRegularizer() + self.weight = weight + self.flatten_output = flatten_output + + def forward(self, x) -> Tuple[Dict, torch.Tensor]: + z = super().forward(x) + z, log = self.posterior(z) + log["loss"] = log["kl_loss"] + log["weight"] = self.weight + if self.flatten_output: + z = rearrange(z, "b c h w -> b (h w ) c") + return log, z + + +class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): + def __init__( + self, + n_cond_frames: int, + n_copies: int, + encoder_config: dict, + sigma_sampler_config: Optional[dict] = None, + sigma_cond_config: Optional[dict] = None, + is_ae: bool = False, + scale_factor: float = 1.0, + disable_encoder_autocast: bool = False, + en_and_decode_n_samples_a_time: Optional[int] = None, + ): + super().__init__() + + self.n_cond_frames = n_cond_frames + self.n_copies = n_copies + self.encoder = instantiate_from_config(encoder_config) + self.sigma_sampler = ( + instantiate_from_config(sigma_sampler_config) + if sigma_sampler_config is not None + else None + ) + self.sigma_cond = ( + instantiate_from_config(sigma_cond_config) + if sigma_cond_config is not None + else None + ) + self.is_ae = is_ae + self.scale_factor = scale_factor + self.disable_encoder_autocast = disable_encoder_autocast + self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time + + def forward( + self, vid: torch.Tensor + ) -> Union[ + torch.Tensor, + Tuple[torch.Tensor, torch.Tensor], + Tuple[torch.Tensor, dict], + Tuple[Tuple[torch.Tensor, torch.Tensor], dict], + ]: + if self.sigma_sampler is not None: + b = vid.shape[0] // self.n_cond_frames + sigmas = self.sigma_sampler(b).to(vid.device) + if self.sigma_cond is not None: + sigma_cond = self.sigma_cond(sigmas) + sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) + sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) + noise = torch.randn_like(vid) + vid = vid + noise * append_dims(sigmas, vid.ndim) + + with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): + n_samples = ( + self.en_and_decode_n_samples_a_time + if self.en_and_decode_n_samples_a_time is not None + else vid.shape[0] + ) + n_rounds = math.ceil(vid.shape[0] / n_samples) + all_out = [] + for n in range(n_rounds): + if self.is_ae: + out = self.encoder.encode(vid[n * n_samples : (n + 1) * n_samples]) + else: + out = self.encoder(vid[n * n_samples : (n + 1) * n_samples]) + all_out.append(out) + + vid = torch.cat(all_out, dim=0) + vid *= self.scale_factor + + vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) + vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) + # modified for svd + # vid = repeat(vid, "b 1 c h w -> b t c h w", t=self.n_copies) + + return_val = (vid, sigma_cond) if self.sigma_cond is not None else vid + + return return_val + + +class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): + def __init__( + self, + open_clip_embedding_config: Dict, + n_cond_frames: int, + n_copies: int, + ): + super().__init__() + + self.n_cond_frames = n_cond_frames + self.n_copies = n_copies + self.open_clip = instantiate_from_config(open_clip_embedding_config) + + def forward(self, vid): + vid = self.open_clip(vid) + vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) + vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) + + return vid + + +class PixelNeRFEmbedder(AbstractEmbModel): + def __init__( + self, + image_encoder_config: dict, + pixelnerf_encoder_config: dict, + render_size: int, + num_video_frames: int, + ): + super().__init__() + self.render_size = render_size + self.num_video_frames = num_video_frames + self.image_encoder = instantiate_from_config(image_encoder_config) + self.pixelnerf_encoder = instantiate_from_config(pixelnerf_encoder_config) + + def forward(self, pixelnerf_input): + if "source_index" not in pixelnerf_input: + source_images = pixelnerf_input["frames"][:, 0] + image_feats = self.image_encoder(source_images) + image_feats = image_feats[:, None] + source_cameras = pixelnerf_input["cameras"][:, :1] + else: + # source_images = pixelnerf_input["frames"][ + # :, pixelnerf_input["source_index"] + # ] + source_images = pixelnerf_input["source_images"] + n_source_images = source_images.shape[1] + source_images = rearrange(source_images, "b t c h w -> (b t) c h w") + image_feats = self.image_encoder(source_images) + image_feats = rearrange( + image_feats, "(b t) c h w -> b t c h w", t=n_source_images + ) + source_cameras = pixelnerf_input["source_cameras"] + cameras = pixelnerf_input["cameras"] + target_cameras = cameras[:, :] + # source_images = source_images[:, None, ...] + source_c2ws = source_cameras[..., :16].reshape(*source_cameras.shape[:-1], 4, 4) + source_intrinsics = source_cameras[..., 16:].reshape( + *source_cameras.shape[:-1], 3, 3 + ) + target_c2ws = target_cameras[..., :16].reshape(*target_cameras.shape[:-1], 4, 4) + target_intrinsics = target_cameras[..., 16:].reshape( + *target_cameras.shape[:-1], 3, 3 + ) + + rgb, feats = self.pixelnerf_encoder( + image_feats, + source_c2ws, + source_intrinsics, + target_c2ws, + target_intrinsics, + self.render_size, + ) + + rgb = rearrange(rgb, "b t c h w -> (b t) c h w") + feats = rearrange(feats, "b t c h w -> (b t) c h w") + + return rgb, feats + + +class ExtraConditioner(GeneralConditioner): + def forward(self, batch: Dict, force_zero_embeddings: List | None = None) -> Dict: + bs = batch["frames"].shape[0] + num_frames = batch["num_video_frames"] + output = dict() + if force_zero_embeddings is None: + force_zero_embeddings = [] + for embedder in self.embedders: + embedding_context = nullcontext if embedder.is_trainable else torch.no_grad + with embedding_context(): + if hasattr(embedder, "input_key") and (embedder.input_key is not None): + if embedder.legacy_ucg_val is not None: + batch = self.possibly_get_ucg_val(embedder, batch) + emb_out = embedder(batch[embedder.input_key]) + elif hasattr(embedder, "input_keys"): + emb_out = embedder(*[batch[k] for k in embedder.input_keys]) + assert isinstance( + emb_out, (torch.Tensor, list, tuple) + ), f"encoder outputs must be tensors or a sequence, but got {type(emb_out)}" + if not isinstance(emb_out, (list, tuple)): + emb_out = [emb_out] + if isinstance(embedder, PixelNeRFEmbedder): + # a hack for pixelnerf input + output["rgb"] = emb_out[0] + emb_out = emb_out[1:] + for emb in emb_out: + out_key = self.OUTPUT_DIM2KEYS[emb.dim()] + if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: + emb = ( + expand_dims_like( + torch.bernoulli( + (1.0 - embedder.ucg_rate) + * torch.ones(emb.shape[0], device=emb.device) + ), + emb, + ) + * emb + ) + if ( + hasattr(embedder, "input_key") + and embedder.input_key in force_zero_embeddings + ): + emb = torch.zeros_like(emb) + if out_key in output: + output[out_key] = torch.cat( + (output[out_key], emb), self.KEY2CATDIM[out_key] + ) + else: + output[out_key] = emb + + if out_key in ["crossattn", "concat"]: + if output[out_key].shape[0] != bs: + output[out_key] = repeat( + output[out_key], "b ... -> (b t) ...", t=num_frames + ) + return output diff --git a/sgm/modules/encoders/pixelnerf.py b/sgm/modules/encoders/pixelnerf.py new file mode 100644 index 0000000000000000000000000000000000000000..515699c3aa52097e27ddde98c3491547c2e3a0b7 --- /dev/null +++ b/sgm/modules/encoders/pixelnerf.py @@ -0,0 +1,368 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.autograd.profiler as profiler +import numpy as np +from einops import rearrange, repeat, einsum + +from .math_utils import get_ray_limits_box, linspace + +from ...modules.diffusionmodules.openaimodel import Timestep + + +class ImageEncoder(nn.Module): + def __init__(self, output_dim: int = 64) -> None: + super().__init__() + self.output_dim = output_dim + + def forward(self, image): + return image + + +class PositionalEncoding(torch.nn.Module): + """ + Implement NeRF's positional encoding + """ + + def __init__(self, num_freqs=6, d_in=3, freq_factor=np.pi, include_input=True): + super().__init__() + self.num_freqs = num_freqs + self.d_in = d_in + self.freqs = freq_factor * 2.0 ** torch.arange(0, num_freqs) + self.d_out = self.num_freqs * 2 * d_in + self.include_input = include_input + if include_input: + self.d_out += d_in + # f1 f1 f2 f2 ... to multiply x by + self.register_buffer( + "_freqs", torch.repeat_interleave(self.freqs, 2).view(1, -1, 1) + ) + # 0 pi/2 0 pi/2 ... so that + # (sin(x + _phases[0]), sin(x + _phases[1]) ...) = (sin(x), cos(x)...) + _phases = torch.zeros(2 * self.num_freqs) + _phases[1::2] = np.pi * 0.5 + self.register_buffer("_phases", _phases.view(1, -1, 1)) + + def forward(self, x): + """ + Apply positional encoding (new implementation) + :param x (batch, self.d_in) + :return (batch, self.d_out) + """ + with profiler.record_function("positional_enc"): + # embed = x.unsqueeze(1).repeat(1, self.num_freqs * 2, 1) + embed = repeat(x, "... C -> ... N C", N=self.num_freqs * 2) + embed = torch.sin(torch.addcmul(self._phases, embed, self._freqs)) + embed = rearrange(embed, "... N C -> ... (N C)") + if self.include_input: + embed = torch.cat((x, embed), dim=-1) + return embed + + +class RayGenerator(torch.nn.Module): + """ + from camera pose and intrinsics to ray origins and directions + """ + + def __init__(self): + super().__init__() + ( + self.ray_origins_h, + self.ray_directions, + self.depths, + self.image_coords, + self.rendering_options, + ) = (None, None, None, None, None) + + def forward(self, cam2world_matrix, intrinsics, render_size): + """ + Create batches of rays and return origins and directions. + + cam2world_matrix: (N, 4, 4) + intrinsics: (N, 3, 3) + render_size: int + + ray_origins: (N, M, 3) + ray_dirs: (N, M, 2) + """ + + N, M = cam2world_matrix.shape[0], render_size**2 + cam_locs_world = cam2world_matrix[:, :3, 3] + fx = intrinsics[:, 0, 0] + fy = intrinsics[:, 1, 1] + cx = intrinsics[:, 0, 2] + cy = intrinsics[:, 1, 2] + sk = intrinsics[:, 0, 1] + + uv = torch.stack( + torch.meshgrid( + torch.arange( + render_size, dtype=torch.float32, device=cam2world_matrix.device + ), + torch.arange( + render_size, dtype=torch.float32, device=cam2world_matrix.device + ), + indexing="ij", + ) + ) + uv = uv.flip(0).reshape(2, -1).transpose(1, 0) + uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) + + x_cam = uv[:, :, 0].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) + y_cam = uv[:, :, 1].view(N, -1) * (1.0 / render_size) + (0.5 / render_size) + z_cam = torch.ones((N, M), device=cam2world_matrix.device) + + x_lift = ( + ( + x_cam + - cx.unsqueeze(-1) + + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) + - sk.unsqueeze(-1) * y_cam / fy.unsqueeze(-1) + ) + / fx.unsqueeze(-1) + * z_cam + ) + y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam + + cam_rel_points = torch.stack( + (x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1 + ) + + # NOTE: this should be named _blender2opencv + _opencv2blender = ( + torch.tensor( + [ + [1, 0, 0, 0], + [0, -1, 0, 0], + [0, 0, -1, 0], + [0, 0, 0, 1], + ], + dtype=torch.float32, + device=cam2world_matrix.device, + ) + .unsqueeze(0) + .repeat(N, 1, 1) + ) + + cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) + + world_rel_points = torch.bmm( + cam2world_matrix, cam_rel_points.permute(0, 2, 1) + ).permute(0, 2, 1)[:, :, :3] + + ray_dirs = world_rel_points - cam_locs_world[:, None, :] + ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2) + + ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) + + return ray_origins, ray_dirs + + +class RaySampler(torch.nn.Module): + def __init__( + self, + num_samples_per_ray, + bbox_length=1.0, + near=0.5, + far=10000.0, + disparity=False, + ): + super().__init__() + self.num_samples_per_ray = num_samples_per_ray + self.bbox_length = bbox_length + self.near = near + self.far = far + self.disparity = disparity + + def forward(self, ray_origins, ray_directions): + if not self.disparity: + t_start, t_end = get_ray_limits_box( + ray_origins, ray_directions, 2 * self.bbox_length + ) + else: + t_start = torch.full_like(ray_origins, self.near) + t_end = torch.full_like(ray_origins, self.far) + is_ray_valid = t_end > t_start + if torch.any(is_ray_valid).item(): + t_start[~is_ray_valid] = t_start[is_ray_valid].min() + t_end[~is_ray_valid] = t_start[is_ray_valid].max() + + if not self.disparity: + depths = linspace(t_start, t_end, self.num_samples_per_ray) + depths += ( + torch.rand_like(depths) + * (t_end - t_start) + / (self.num_samples_per_ray - 1) + ) + else: + step = 1.0 / self.num_samples_per_ray + z_steps = torch.linspace( + 0, 1 - step, self.num_samples_per_ray, device=ray_origins.device + ) + z_steps += torch.rand_like(z_steps) * step + depths = 1 / (1 / self.near * (1 - z_steps) + 1 / self.far * z_steps) + depths = depths[..., None, None, None] + + return ray_origins[None] + ray_directions[None] * depths + + +class PixelNeRF(torch.nn.Module): + def __init__( + self, + num_samples_per_ray: int = 128, + feature_dim: int = 64, + interp: str = "bilinear", + padding: str = "border", + disparity: bool = False, + near: float = 0.5, + far: float = 10000.0, + use_feats_std: bool = False, + use_pos_emb: bool = False, + ) -> None: + super().__init__() + # self.positional_encoder = Timestep(3) # TODO + self.num_samples_per_ray = num_samples_per_ray + self.ray_generator = RayGenerator() + self.ray_sampler = RaySampler( + num_samples_per_ray, near=near, far=far, disparity=disparity + ) # TODO + self.interp = interp + self.padding = padding + + self.positional_encoder = PositionalEncoding() + + # self.feature_aggregator = nn.Linear(128, 129) # TODO + self.use_feats_std = use_feats_std + self.use_pos_emb = use_pos_emb + d_in = feature_dim + if use_feats_std: + d_in += feature_dim + if use_pos_emb: + d_in += self.positional_encoder.d_out + self.feature_aggregator = nn.Sequential( + nn.Linear(d_in, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 129), + ) + + # self.decoder = nn.Linear(128, 131) # TODO + self.decoder = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 131), + ) + + def project(self, ray_samples, source_c2ws, source_instrincs): + # TODO: implement + # S for number of source cameras + # ray_samples: [B, N, H * W, N_sample, 3] + # source_c2ws: [B, S, 4, 4] + # source_intrinsics: [B, S, 3, 3] + # return [B, S, N, H * W, N_sample, 2] + S = source_c2ws.shape[1] + B = ray_samples.shape[0] + N = ray_samples.shape[1] + HW = ray_samples.shape[2] + ray_samples = repeat( + ray_samples, + "B N HW N_sample C -> B S N HW N_sample C", + S=source_c2ws.shape[1], + ) + padding = torch.ones((B, S, N, HW, self.num_samples_per_ray, 1)).to(ray_samples) + ray_samples_homo = torch.cat([ray_samples, padding], dim=-1) + source_c2ws = repeat(source_c2ws, "B S C1 C2 -> B S N 1 1 C1 C2", N=N) + source_instrincs = repeat(source_instrincs, "B S C1 C2 -> B S N 1 1 C1 C2", N=N) + source_w2c = source_c2ws.inverse() + projected_samples = einsum( + source_w2c, ray_samples_homo, "... i j, ... j -> ... i" + )[..., :3] + # NOTE: assumes opengl convention + projected_samples = -1 * projected_samples[..., :2] / projected_samples[..., 2:] + # NOTE: intrinsics here are normalized by resolution + fx = source_instrincs[..., 0, 0] + fy = source_instrincs[..., 1, 1] + cx = source_instrincs[..., 0, 2] + cy = source_instrincs[..., 1, 2] + x = projected_samples[..., 0] * fx + cx + # negative sign here is caused by opengl, F.grid_sample is consistent with openCV convention + y = -projected_samples[..., 1] * fy + cy + + return torch.stack([x, y], dim=-1) + + def forward( + self, image_feats, source_c2ws, source_intrinsics, c2ws, intrinsics, render_size + ): + # image_feats: [B S C H W] + B = c2ws.shape[0] + T = c2ws.shape[1] + ray_origins, ray_directions = self.ray_generator( + c2ws.reshape(-1, 4, 4), intrinsics.reshape(-1, 3, 3), render_size + ) # [B * N, H * W, 3] + # breakpoint() + + ray_samples = self.ray_sampler( + ray_origins, ray_directions + ) # [N_sample, B * N, H * W, 3] + ray_samples = rearrange(ray_samples, "Ns (B N) HW C -> B N HW Ns C", B=B) + + projected_samples = self.project(ray_samples, source_c2ws, source_intrinsics) + # # debug + # p = projected_samples[:, :, 0, :, 0, :] + # p = p.reshape(p.shape[0] * p.shape[1], *p.shape[2:]) + + # breakpoint() + + # image_feats = repeat(image_feats, "B S C H W -> (B S N) C H W", N=T) + image_feats = rearrange(image_feats, "B S C H W -> (B S) C H W") + projected_samples = rearrange( + projected_samples, "B S N HW Ns xy -> (B S) (N Ns) HW xy" + ) + # make sure the projected samples are in the range of [-1, 1], as required by F.grid_sample + joint = F.grid_sample( + image_feats, + projected_samples * 2.0 - 1.0, + padding_mode=self.padding, + mode=self.interp, + align_corners=True, + ) + # print("image_feats", image_feats.max(), image_feats.min()) + # print("samples", projected_samples.max(), projected_samples.min()) + joint = rearrange( + joint, + "(B S) C (N Ns) HW -> B S N HW Ns C", + B=B, + Ns=self.num_samples_per_ray, + ) + + reduced = torch.mean(joint, dim=1) # reduce on source dimension + if self.use_feats_std: + if not joint.shape[1] == 1: + reduced = torch.cat((reduced, joint.std(dim=1)), dim=-1) + else: + reduced = torch.cat((reduced, torch.zeros_like(reduced)), dim=-1) + + if self.use_pos_emb: + reduced = torch.cat((reduced, self.positional_encoder(ray_samples)), dim=-1) + reduced = self.feature_aggregator(reduced) + + feats, weights = reduced.split([reduced.shape[-1] - 1, 1], dim=-1) + # feats: [B, N, H * W, N_samples, N_c] + # weights: [B, N, H * W, N_samples, 1] + weights = F.softmax(weights, dim=-2) + + feats = torch.sum(feats * weights, dim=-2) + + rgb, feats = self.decoder(feats).split([3, 128], dim=-1) + + rgb = F.sigmoid(rgb) + rgb = rearrange(rgb, "B N (H W) C -> B N C H W", H=render_size) + feats = rearrange(feats, "B N (H W) C -> B N C H W", H=render_size) + + # print(rgb.max(), rgb.min()) + # print(feats.max(), feats.min()) + + return rgb, feats diff --git a/sgm/modules/video_attention.py b/sgm/modules/video_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..783395aa554144936766b57380f35dab29c093c3 --- /dev/null +++ b/sgm/modules/video_attention.py @@ -0,0 +1,301 @@ +import torch + +from ..modules.attention import * +from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding + + +class TimeMixSequential(nn.Sequential): + def forward(self, x, context=None, timesteps=None): + for layer in self: + x = layer(x, context, timesteps) + + return x + + +class VideoTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, + "softmax-xformers": MemoryEfficientCrossAttention, + } + + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + timesteps=None, + ff_in=False, + inner_dim=None, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + switch_temporal_ca_to_sa=False, + ): + super().__init__() + + attn_cls = self.ATTENTION_MODES[attn_mode] + + self.ff_in = ff_in or inner_dim is not None + if inner_dim is None: + inner_dim = dim + + assert int(n_heads * d_head) == inner_dim + + self.is_res = inner_dim == dim + + if self.ff_in: + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward( + dim, dim_out=inner_dim, dropout=dropout, glu=gated_ff + ) + + self.timesteps = timesteps + self.disable_self_attn = disable_self_attn + if self.disable_self_attn: + self.attn1 = attn_cls( + query_dim=inner_dim, + heads=n_heads, + dim_head=d_head, + context_dim=context_dim, + dropout=dropout, + ) # is a cross-attention + else: + self.attn1 = attn_cls( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + + self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff) + + if disable_temporal_crossattention: + if switch_temporal_ca_to_sa: + raise ValueError + else: + self.attn2 = None + else: + self.norm2 = nn.LayerNorm(inner_dim) + if switch_temporal_ca_to_sa: + self.attn2 = attn_cls( + query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + else: + self.attn2 = attn_cls( + query_dim=inner_dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + + self.norm1 = nn.LayerNorm(inner_dim) + self.norm3 = nn.LayerNorm(inner_dim) + self.switch_temporal_ca_to_sa = switch_temporal_ca_to_sa + + self.checkpoint = checkpoint + if self.checkpoint: + print(f"{self.__class__.__name__} is using checkpointing") + + def forward( + self, x: torch.Tensor, context: torch.Tensor = None, timesteps: int = None + ) -> torch.Tensor: + if self.checkpoint: + return checkpoint(self._forward, x, context, timesteps) + else: + return self._forward(x, context, timesteps=timesteps) + + def _forward(self, x, context=None, timesteps=None): + assert self.timesteps or timesteps + assert not (self.timesteps and timesteps) or self.timesteps == timesteps + timesteps = self.timesteps or timesteps + B, S, C = x.shape + x = rearrange(x, "(b t) s c -> (b s) t c", t=timesteps) + + if self.ff_in: + x_skip = x + x = self.ff_in(self.norm_in(x)) + if self.is_res: + x += x_skip + + if self.disable_self_attn: + x = self.attn1(self.norm1(x), context=context) + x + else: + x = self.attn1(self.norm1(x)) + x + + if self.attn2 is not None: + if self.switch_temporal_ca_to_sa: + x = self.attn2(self.norm2(x)) + x + else: + x = self.attn2(self.norm2(x), context=context) + x + x_skip = x + x = self.ff(self.norm3(x)) + if self.is_res: + x += x_skip + + x = rearrange( + x, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps + ) + return x + + def get_last_layer(self): + return self.ff.net[-1].weight + + +class SpatialVideoTransformer(SpatialTransformer): + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + attn_mode="softmax", + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + ): + super().__init__( + in_channels, + n_heads, + d_head, + depth=depth, + dropout=dropout, + attn_type=attn_mode, + use_checkpoint=checkpoint, + context_dim=context_dim, + use_linear=use_linear, + disable_self_attn=disable_self_attn, + ) + self.time_depth = time_depth + self.depth = depth + self.max_time_embed_period = max_time_embed_period + + time_mix_d_head = d_head + n_time_mix_heads = n_heads + + time_mix_inner_dim = int(time_mix_d_head * n_time_mix_heads) + + inner_dim = n_heads * d_head + if use_spatial_context: + time_context_dim = context_dim + + self.time_stack = nn.ModuleList( + [ + VideoTransformerBlock( + inner_dim, + n_time_mix_heads, + time_mix_d_head, + dropout=dropout, + context_dim=time_context_dim, + timesteps=timesteps, + checkpoint=checkpoint, + ff_in=ff_in, + inner_dim=time_mix_inner_dim, + attn_mode=attn_mode, + disable_self_attn=disable_self_attn, + disable_temporal_crossattention=disable_temporal_crossattention, + ) + for _ in range(self.depth) + ] + ) + + assert len(self.time_stack) == len(self.transformer_blocks) + + self.use_spatial_context = use_spatial_context + self.in_channels = in_channels + + time_embed_dim = self.in_channels * 4 + self.time_pos_embed = nn.Sequential( + linear(self.in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, self.in_channels), + ) + + self.time_mixer = AlphaBlender( + alpha=merge_factor, merge_strategy=merge_strategy + ) + + def forward( + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + _, _, h, w = x.shape + x_in = x + spatial_context = None + if exists(context): + spatial_context = context + + if self.use_spatial_context: + assert ( + context.ndim == 3 + ), f"n dims of spatial context should be 3 but are {context.ndim}" + + time_context = context + time_context_first_timestep = time_context[::timesteps] + time_context = repeat( + time_context_first_timestep, "b ... -> (b n) ...", n=h * w + ) + elif time_context is not None and not self.use_spatial_context: + time_context = repeat(time_context, "b ... -> (b n) ...", n=h * w) + if time_context.ndim == 2: + time_context = rearrange(time_context, "b c -> b 1 c") + + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + if self.use_linear: + x = self.proj_in(x) + + num_frames = torch.arange(timesteps, device=x.device) + num_frames = repeat(num_frames, "t -> b t", b=x.shape[0] // timesteps) + num_frames = rearrange(num_frames, "b t -> (b t)") + t_emb = timestep_embedding( + num_frames, + self.in_channels, + repeat_only=False, + max_period=self.max_time_embed_period, + ) + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + for it_, (block, mix_block) in enumerate( + zip(self.transformer_blocks, self.time_stack) + ): + x = block( + x, + context=spatial_context, + ) + + x_mix = x + x_mix = x_mix + emb + + x_mix = mix_block(x_mix, context=time_context, timesteps=timesteps) + x = self.time_mixer( + x_spatial=x, + x_temporal=x_mix, + image_only_indicator=image_only_indicator, + ) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + if not self.use_linear: + x = self.proj_out(x) + out = x + x_in + return out diff --git a/sgm/sampling/hier.py b/sgm/sampling/hier.py new file mode 100644 index 0000000000000000000000000000000000000000..375261c89b9f2fb38b2b853af8872ef4f0f500af --- /dev/null +++ b/sgm/sampling/hier.py @@ -0,0 +1 @@ +# hierachical sampling, (autogressive sampling like GeNVS) diff --git a/sgm/util.py b/sgm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..49cc0df0e14326087e1adaf515b76137c2977fbe --- /dev/null +++ b/sgm/util.py @@ -0,0 +1,310 @@ +import functools +import importlib +import os +from functools import partial +from inspect import isfunction + +import fsspec +import numpy as np +import torch +from PIL import Image, ImageDraw, ImageFont +from safetensors.torch import load_file as load_safetensors +from einops import rearrange +from mediapy import write_image + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def get_string_from_tuple(s): + try: + # Check if the string starts and ends with parentheses + if s[0] == "(" and s[-1] == ")": + # Convert the string to a tuple + t = eval(s) + # Check if the type of t is tuple + if type(t) == tuple: + return t[0] + else: + pass + except: + pass + return s + + +def is_power_of_two(n): + """ + chat.openai.com/chat + Return True if n is a power of 2, otherwise return False. + + The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False. + The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False. + If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise. + Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False. + + """ + if n <= 0: + return False + return (n & (n - 1)) == 0 + + +def autocast(f, enabled=True): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast( + enabled=enabled, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled(), + ): + return f(*args, **kwargs) + + return do_autocast + + +def load_partial_from_config(config): + return partial(get_obj_from_str(config["target"]), **config.get("params", dict())) + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype("data/DejaVuSans.ttf", size=size) + nc = int(40 * (wh[0] / 256)) + if isinstance(xc[bi], list): + text_seq = xc[bi][0] + else: + text_seq = xc[bi] + lines = "\n".join( + text_seq[start : start + nc] for start in range(0, len(text_seq), nc) + ) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +def make_path_absolute(path): + fs, p = fsspec.core.url_to_fs(path) + if fs.protocol == "file": + return os.path.abspath(p) + return path + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def isheatmap(x): + if not isinstance(x, torch.Tensor): + return False + + return x.ndim == 2 + + +def isneighbors(x): + if not isinstance(x, torch.Tensor): + return False + return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1) + + +def exists(x): + return x is not None + + +def expand_dims_like(x, y): + while x.dim() != y.dim(): + x = x.unsqueeze(-1) + return x + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False, invalidate_cache=True): + module, cls = string.rsplit(".", 1) + if invalidate_cache: + importlib.invalidate_caches() + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def append_zero(x): + return torch.cat([x, x.new_zeros([1])]) + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +def load_model_from_config(config, ckpt, verbose=True, freeze=True): + print(f"Loading model from {ckpt}") + if ckpt.endswith("ckpt"): + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + elif ckpt.endswith("safetensors"): + sd = load_safetensors(ckpt) + else: + raise NotImplementedError + + model = instantiate_from_config(config.model) + + m, u = model.load_state_dict(sd, strict=False) + + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if freeze: + for param in model.parameters(): + param.requires_grad = False + + model.eval() + return model + + +def get_configs_path() -> str: + """ + Get the `configs` directory. + For a working copy, this is the one in the root of the repository, + but for an installed copy, it's in the `sgm` package (see pyproject.toml). + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "configs"), + os.path.join(this_dir, "..", "configs"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM configs in {candidates}") + + +def get_nested_attribute(obj, attribute_path, depth=None, return_key=False): + """ + Will return the result of a recursive get attribute call. + E.g.: + a.b.c + = getattr(getattr(a, "b"), "c") + = get_nested_attribute(a, "b.c") + If any part of the attribute call is an integer x with current obj a, will + try to call a[x] instead of a.x first. + """ + attributes = attribute_path.split(".") + if depth is not None and depth > 0: + attributes = attributes[:depth] + assert len(attributes) > 0, "At least one attribute should be selected" + current_attribute = obj + current_key = None + for level, attribute in enumerate(attributes): + current_key = ".".join(attributes[: level + 1]) + try: + id_ = int(attribute) + current_attribute = current_attribute[id_] + except ValueError: + current_attribute = getattr(current_attribute, attribute) + + return (current_attribute, current_key) if return_key else current_attribute + + +def video_frames_as_grid(frames, save_path): + # frames: [T, C, H, W] + frames = frames.detach().cpu() + frames = rearrange(frames, "t c h w -> h (t w) c") + write_image(save_path, frames) + + +def server_safe_call(keep_trying: bool = False): + """Decorator for server calls. If the call fails, it will keep trying until it succeeds. + + Args: + keep_trying (bool, optional): whether to call again if the first try fails. Defaults to False. + """ + + def decorator(func): + def wrapper(*args, **kwargs): + success = False + while not success: + try: + ret = func(*args, **kwargs) + success = True + except KeyboardInterrupt: + raise + except: + if not keep_trying: + break + return ret + + return wrapper + + return decorator