diff --git a/configs/aligned_shape_latents/shapevae-256.yaml b/configs/aligned_shape_latents/shapevae-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..87af1842a4d9b876d40047e452b972703ccf9bcc --- /dev/null +++ b/configs/aligned_shape_latents/shapevae-256.yaml @@ -0,0 +1,46 @@ +model: + target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule + params: + shape_module_cfg: + target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver + params: + num_latents: 256 + embed_dim: 64 + point_feats: 3 # normal + num_freqs: 8 + include_pi: false + heads: 12 + width: 768 + num_encoder_layers: 8 + num_decoder_layers: 16 + use_ln_post: true + init_scale: 0.25 + qkv_bias: false + use_checkpoint: true + aligned_module_cfg: + target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule + params: + clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" + + loss_cfg: + target: michelangelo.models.tsal.loss.ContrastKLNearFar + params: + contrast_weight: 0.1 + near_weight: 0.1 + kl_weight: 0.001 + + optimizer_cfg: + optimizer: + target: torch.optim.AdamW + params: + betas: [0.9, 0.99] + eps: 1.e-6 + weight_decay: 1.e-2 + + scheduler: + target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler + params: + warm_up_steps: 5000 + f_start: 1.e-6 + f_min: 1.e-3 + f_max: 1.0 diff --git a/configs/deploy/clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000_deploy.yaml b/configs/deploy/clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000_deploy.yaml new file mode 100644 index 0000000000000000000000000000000000000000..16491c2db374dcf1df3329ac33074ed4c9ff7862 --- /dev/null +++ b/configs/deploy/clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000_deploy.yaml @@ -0,0 +1,181 @@ +name: "0630_clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000" +#wandb: +# project: "image_diffuser" +# offline: false + + +training: + steps: 500000 + use_amp: true + ckpt_path: "" + base_lr: 1.e-4 + gradient_clip_val: 5.0 + gradient_clip_algorithm: "norm" + every_n_train_steps: 5000 + val_check_interval: 1024 + limit_val_batches: 16 + +dataset: + target: michelangelo.data.asl_webdataset.MultiAlignedShapeLatentModule + params: + batch_size: 38 + num_workers: 4 + val_num_workers: 4 + buffer_size: 256 + return_normal: true + random_crop: false + surface_sampling: true + pc_size: &pc_size 4096 + image_size: 384 + mean: &mean [0.5, 0.5, 0.5] + std: &std [0.5, 0.5, 0.5] + cond_stage_key: "image" + + meta_info: + 3D-FUTURE: + render_folder: "/root/workspace/cq_workspace/datasets/3D-FUTURE/renders" + tar_folder: "/root/workspace/datasets/make_tars/3D-FUTURE" + + ABO: + render_folder: "/root/workspace/cq_workspace/datasets/ABO/renders" + tar_folder: "/root/workspace/datasets/make_tars/ABO" + + GSO: + render_folder: "/root/workspace/cq_workspace/datasets/GSO/renders" + tar_folder: "/root/workspace/datasets/make_tars/GSO" + + TOYS4K: + render_folder: "/root/workspace/cq_workspace/datasets/TOYS4K/TOYS4K/renders" + tar_folder: "/root/workspace/datasets/make_tars/TOYS4K" + + 3DCaricShop: + render_folder: "/root/workspace/cq_workspace/datasets/3DCaricShop/renders" + tar_folder: "/root/workspace/datasets/make_tars/3DCaricShop" + + Thingi10K: + render_folder: "/root/workspace/cq_workspace/datasets/Thingi10K/renders" + tar_folder: "/root/workspace/datasets/make_tars/Thingi10K" + + shapenet: + render_folder: "/root/workspace/cq_workspace/datasets/shapenet/renders" + tar_folder: "/root/workspace/datasets/make_tars/shapenet" + + pokemon: + render_folder: "/root/workspace/cq_workspace/datasets/pokemon/renders" + tar_folder: "/root/workspace/datasets/make_tars/pokemon" + + objaverse: + render_folder: "/root/workspace/cq_workspace/datasets/objaverse/renders" + tar_folder: "/root/workspace/datasets/make_tars/objaverse" + +model: + target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser + params: + first_stage_config: + target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule + params: + shape_module_cfg: + target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver + params: + num_latents: &num_latents 256 + embed_dim: &embed_dim 64 + point_feats: 3 # normal + num_freqs: 8 + include_pi: false + heads: 12 + width: 768 + num_encoder_layers: 8 + num_decoder_layers: 16 + use_ln_post: true + init_scale: 0.25 + qkv_bias: false + use_checkpoint: false + aligned_module_cfg: + target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule + params: + clip_model_version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14" + # clip_model_version: "/root/workspace/checkpoints/clip/clip-vit-large-patch14" + + loss_cfg: + target: torch.nn.Identity + + cond_stage_config: + target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder + params: + version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14" + # version: "/root/workspace/checkpoints/clip/clip-vit-large-patch14" + zero_embedding_radio: 0.1 + + first_stage_key: "surface" + cond_stage_key: "image" + scale_by_std: false + + denoiser_cfg: + target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser + params: + input_channels: *embed_dim + output_channels: *embed_dim + n_ctx: *num_latents + width: 768 + layers: 6 # 2 * 6 + 1 = 13 + heads: 12 + context_dim: 1024 + init_scale: 1.0 + skip_ln: true + use_checkpoint: true + + scheduler_cfg: + guidance_scale: 7.5 + num_inference_steps: 50 + eta: 0.0 + + noise: + target: diffusers.schedulers.DDPMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + variance_type: "fixed_small" + clip_sample: false + denoise: + target: diffusers.schedulers.DDIMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + clip_sample: false # clip sample to -1~1 + set_alpha_to_one: false + steps_offset: 1 + + optimizer_cfg: + optimizer: + target: torch.optim.AdamW + params: + betas: [0.9, 0.99] + eps: 1.e-6 + weight_decay: 1.e-2 + + scheduler: + target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler + params: + warm_up_steps: 5000 + f_start: 1.e-6 + f_min: 1.e-3 + f_max: 1.0 + + loss_cfg: + loss_type: "mse" + +logger: + target: michelangelo.utils.trainings.mesh_log_callback.ImageConditionalASLDiffuserLogger + params: + step_frequency: 2000 + num_samples: 4 + sample_times: 4 + mean: *mean + std: *std + bounds: [-1.1, -1.1, -1.1, 1.1, 1.1, 1.1] + octree_depth: 7 + num_chunks: 10000 diff --git a/configs/deploy/clip_sp+pk_aslperceiver=256_01_4096_8_udt=03.yaml b/configs/deploy/clip_sp+pk_aslperceiver=256_01_4096_8_udt=03.yaml new file mode 100644 index 0000000000000000000000000000000000000000..acebd9b2e868b81ac9ef5a86ab317d2b4b518a75 --- /dev/null +++ b/configs/deploy/clip_sp+pk_aslperceiver=256_01_4096_8_udt=03.yaml @@ -0,0 +1,180 @@ +name: "0428_clip_subsp+pk_sal_perceiver=256_01_4096_8_udt=03" +#wandb: +# project: "image_diffuser" +# offline: false + +training: + steps: 500000 + use_amp: true + ckpt_path: "" + base_lr: 1.e-4 + gradient_clip_val: 5.0 + gradient_clip_algorithm: "norm" + every_n_train_steps: 5000 + val_check_interval: 1024 + limit_val_batches: 16 + +# dataset +dataset: + target: michelangelo.data.asl_torch_dataset.MultiAlignedShapeImageTextModule + params: + batch_size: 38 + num_workers: 4 + val_num_workers: 4 + buffer_size: 256 + return_normal: true + random_crop: false + surface_sampling: true + pc_size: &pc_size 4096 + image_size: 384 + mean: &mean [0.5, 0.5, 0.5] + std: &std [0.5, 0.5, 0.5] + + cond_stage_key: "text" + + meta_info: + 3D-FUTURE: + render_folder: "/root/workspace/cq_workspace/datasets/3D-FUTURE/renders" + tar_folder: "/root/workspace/datasets/make_tars/3D-FUTURE" + + ABO: + render_folder: "/root/workspace/cq_workspace/datasets/ABO/renders" + tar_folder: "/root/workspace/datasets/make_tars/ABO" + + GSO: + render_folder: "/root/workspace/cq_workspace/datasets/GSO/renders" + tar_folder: "/root/workspace/datasets/make_tars/GSO" + + TOYS4K: + render_folder: "/root/workspace/cq_workspace/datasets/TOYS4K/TOYS4K/renders" + tar_folder: "/root/workspace/datasets/make_tars/TOYS4K" + + 3DCaricShop: + render_folder: "/root/workspace/cq_workspace/datasets/3DCaricShop/renders" + tar_folder: "/root/workspace/datasets/make_tars/3DCaricShop" + + Thingi10K: + render_folder: "/root/workspace/cq_workspace/datasets/Thingi10K/renders" + tar_folder: "/root/workspace/datasets/make_tars/Thingi10K" + + shapenet: + render_folder: "/root/workspace/cq_workspace/datasets/shapenet/renders" + tar_folder: "/root/workspace/datasets/make_tars/shapenet" + + pokemon: + render_folder: "/root/workspace/cq_workspace/datasets/pokemon/renders" + tar_folder: "/root/workspace/datasets/make_tars/pokemon" + + objaverse: + render_folder: "/root/workspace/cq_workspace/datasets/objaverse/renders" + tar_folder: "/root/workspace/datasets/make_tars/objaverse" + +model: + target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser + params: + first_stage_config: + target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule + params: + # ckpt_path: "/root/workspace/cq_workspace/michelangelo/experiments/aligned_shape_latents/clip_aslperceiver_sp+pk_01_01/ckpt/ckpt-step=00230000.ckpt" + shape_module_cfg: + target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver + params: + num_latents: &num_latents 256 + embed_dim: &embed_dim 64 + point_feats: 3 # normal + num_freqs: 8 + include_pi: false + heads: 12 + width: 768 + num_encoder_layers: 8 + num_decoder_layers: 16 + use_ln_post: true + init_scale: 0.25 + qkv_bias: false + use_checkpoint: true + aligned_module_cfg: + target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule + params: + clip_model_version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14" + + loss_cfg: + target: torch.nn.Identity + + cond_stage_config: + target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder + params: + version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14" + zero_embedding_radio: 0.1 + max_length: 77 + + first_stage_key: "surface" + cond_stage_key: "text" + scale_by_std: false + + denoiser_cfg: + target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser + params: + input_channels: *embed_dim + output_channels: *embed_dim + n_ctx: *num_latents + width: 768 + layers: 8 # 2 * 6 + 1 = 13 + heads: 12 + context_dim: 768 + init_scale: 1.0 + skip_ln: true + use_checkpoint: true + + scheduler_cfg: + guidance_scale: 7.5 + num_inference_steps: 50 + eta: 0.0 + + noise: + target: diffusers.schedulers.DDPMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + variance_type: "fixed_small" + clip_sample: false + denoise: + target: diffusers.schedulers.DDIMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + clip_sample: false # clip sample to -1~1 + set_alpha_to_one: false + steps_offset: 1 + + optimizer_cfg: + optimizer: + target: torch.optim.AdamW + params: + betas: [0.9, 0.99] + eps: 1.e-6 + weight_decay: 1.e-2 + + scheduler: + target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler + params: + warm_up_steps: 5000 + f_start: 1.e-6 + f_min: 1.e-3 + f_max: 1.0 + + loss_cfg: + loss_type: "mse" + +logger: + target: michelangelo.utils.trainings.mesh_log_callback.TextConditionalASLDiffuserLogger + params: + step_frequency: 1000 + num_samples: 4 + sample_times: 4 + bounds: [-1.1, -1.1, -1.1, 1.1, 1.1, 1.1] + octree_depth: 7 + num_chunks: 10000 diff --git a/configs/image_cond_diffuser_asl/image-ASLDM-256.yaml b/configs/image_cond_diffuser_asl/image-ASLDM-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0a0b47e06ddfaad8fd888c99340e935397b204b0 --- /dev/null +++ b/configs/image_cond_diffuser_asl/image-ASLDM-256.yaml @@ -0,0 +1,97 @@ +model: + target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser + params: + first_stage_config: + target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule + params: + shape_module_cfg: + target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver + params: + num_latents: &num_latents 256 + embed_dim: &embed_dim 64 + point_feats: 3 # normal + num_freqs: 8 + include_pi: false + heads: 12 + width: 768 + num_encoder_layers: 8 + num_decoder_layers: 16 + use_ln_post: true + init_scale: 0.25 + qkv_bias: false + use_checkpoint: false + aligned_module_cfg: + target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule + params: + clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" + + loss_cfg: + target: torch.nn.Identity + + cond_stage_config: + target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder + params: + version: "./checkpoints/clip/clip-vit-large-patch14" + zero_embedding_radio: 0.1 + + first_stage_key: "surface" + cond_stage_key: "image" + scale_by_std: false + + denoiser_cfg: + target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser + params: + input_channels: *embed_dim + output_channels: *embed_dim + n_ctx: *num_latents + width: 768 + layers: 6 # 2 * 6 + 1 = 13 + heads: 12 + context_dim: 1024 + init_scale: 1.0 + skip_ln: true + use_checkpoint: true + + scheduler_cfg: + guidance_scale: 7.5 + num_inference_steps: 50 + eta: 0.0 + + noise: + target: diffusers.schedulers.DDPMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + variance_type: "fixed_small" + clip_sample: false + denoise: + target: diffusers.schedulers.DDIMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + clip_sample: false # clip sample to -1~1 + set_alpha_to_one: false + steps_offset: 1 + + optimizer_cfg: + optimizer: + target: torch.optim.AdamW + params: + betas: [0.9, 0.99] + eps: 1.e-6 + weight_decay: 1.e-2 + + scheduler: + target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler + params: + warm_up_steps: 5000 + f_start: 1.e-6 + f_min: 1.e-3 + f_max: 1.0 + + loss_cfg: + loss_type: "mse" diff --git a/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml b/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1888ab21483ec1c76cae914d308962792587074d --- /dev/null +++ b/configs/text_cond_diffuser_asl/text-ASLDM-256.yaml @@ -0,0 +1,98 @@ +model: + target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser + params: + first_stage_config: + target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule + params: + shape_module_cfg: + target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver + params: + num_latents: &num_latents 256 + embed_dim: &embed_dim 64 + point_feats: 3 # normal + num_freqs: 8 + include_pi: false + heads: 12 + width: 768 + num_encoder_layers: 8 + num_decoder_layers: 16 + use_ln_post: true + init_scale: 0.25 + qkv_bias: false + use_checkpoint: true + aligned_module_cfg: + target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule + params: + clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" + + loss_cfg: + target: torch.nn.Identity + + cond_stage_config: + target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder + params: + version: "./checkpoints/clip/clip-vit-large-patch14" + zero_embedding_radio: 0.1 + max_length: 77 + + first_stage_key: "surface" + cond_stage_key: "text" + scale_by_std: false + + denoiser_cfg: + target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser + params: + input_channels: *embed_dim + output_channels: *embed_dim + n_ctx: *num_latents + width: 768 + layers: 8 # 2 * 6 + 1 = 13 + heads: 12 + context_dim: 768 + init_scale: 1.0 + skip_ln: true + use_checkpoint: true + + scheduler_cfg: + guidance_scale: 7.5 + num_inference_steps: 50 + eta: 0.0 + + noise: + target: diffusers.schedulers.DDPMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + variance_type: "fixed_small" + clip_sample: false + denoise: + target: diffusers.schedulers.DDIMScheduler + params: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + clip_sample: false # clip sample to -1~1 + set_alpha_to_one: false + steps_offset: 1 + + optimizer_cfg: + optimizer: + target: torch.optim.AdamW + params: + betas: [0.9, 0.99] + eps: 1.e-6 + weight_decay: 1.e-2 + + scheduler: + target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler + params: + warm_up_steps: 5000 + f_start: 1.e-6 + f_min: 1.e-3 + f_max: 1.0 + + loss_cfg: + loss_type: "mse" \ No newline at end of file diff --git a/example_data/image/car.jpg b/example_data/image/car.jpg new file mode 100644 index 0000000000000000000000000000000000000000..be302971eb9829a0e1784a152987f8627d7bebb8 Binary files /dev/null and b/example_data/image/car.jpg differ diff --git a/example_data/surface/surface.npz b/example_data/surface/surface.npz new file mode 100644 index 0000000000000000000000000000000000000000..590a2793c899b983de838fef4352923e01ea64d6 --- /dev/null +++ b/example_data/surface/surface.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0893e44d82ada683baa656a718beaf6ec19fc28b6816b451f56645530d5bb962 +size 1201024 diff --git a/gradio_cached_dir/example/img_example/airplane.jpg b/gradio_cached_dir/example/img_example/airplane.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cad243b686ef46dfdb0114338ba173b88dfbfc87 Binary files /dev/null and b/gradio_cached_dir/example/img_example/airplane.jpg differ diff --git a/gradio_cached_dir/example/img_example/alita.jpg b/gradio_cached_dir/example/img_example/alita.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6baaa50f20418156431438b26c37bde3c2f18ad4 Binary files /dev/null and b/gradio_cached_dir/example/img_example/alita.jpg differ diff --git a/gradio_cached_dir/example/img_example/bag.jpg b/gradio_cached_dir/example/img_example/bag.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6bd414617f81fb772b00b3a7030ab1ec73053bb8 Binary files /dev/null and b/gradio_cached_dir/example/img_example/bag.jpg differ diff --git a/gradio_cached_dir/example/img_example/bench.jpg b/gradio_cached_dir/example/img_example/bench.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7ae09b212e7db87bd405490e215e5ca85f53d254 Binary files /dev/null and b/gradio_cached_dir/example/img_example/bench.jpg differ diff --git a/gradio_cached_dir/example/img_example/building.jpg b/gradio_cached_dir/example/img_example/building.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9d4fb7946d8b89fda2faea098f9d16b91d1a3d18 Binary files /dev/null and b/gradio_cached_dir/example/img_example/building.jpg differ diff --git a/gradio_cached_dir/example/img_example/burger.jpg b/gradio_cached_dir/example/img_example/burger.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c3c71d73d53b13fd75f5ea33ece4688eb5ddae97 Binary files /dev/null and b/gradio_cached_dir/example/img_example/burger.jpg differ diff --git a/gradio_cached_dir/example/img_example/car.jpg b/gradio_cached_dir/example/img_example/car.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f318fd1619b2c94778a3f571397f68ba0baf2d6c Binary files /dev/null and b/gradio_cached_dir/example/img_example/car.jpg differ diff --git a/gradio_cached_dir/example/img_example/loopy.jpg b/gradio_cached_dir/example/img_example/loopy.jpg new file mode 100644 index 0000000000000000000000000000000000000000..34496b9ace26ec5e11ca97193790893485a7b74b Binary files /dev/null and b/gradio_cached_dir/example/img_example/loopy.jpg differ diff --git a/gradio_cached_dir/example/img_example/mario.jpg b/gradio_cached_dir/example/img_example/mario.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a73effa1d53a7c09172b4a0e4379b85895684322 Binary files /dev/null and b/gradio_cached_dir/example/img_example/mario.jpg differ diff --git a/gradio_cached_dir/example/img_example/ship.jpg b/gradio_cached_dir/example/img_example/ship.jpg new file mode 100644 index 0000000000000000000000000000000000000000..122d8a3391e171ea97693ac4eed148ee9dcda19f Binary files /dev/null and b/gradio_cached_dir/example/img_example/ship.jpg differ diff --git a/michelangelo/__init__.py b/michelangelo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/michelangelo/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/michelangelo/__pycache__/__init__.cpython-39.pyc b/michelangelo/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ded5a28b63f0346c976ec9fdd2bde11de30248b4 Binary files /dev/null and b/michelangelo/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/data/__init__.py b/michelangelo/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/michelangelo/data/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/michelangelo/data/__pycache__/__init__.cpython-39.pyc b/michelangelo/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..012b2b0e3ea5d606900b4cbc11e4be4bc54a903c Binary files /dev/null and b/michelangelo/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/data/__pycache__/asl_webdataset.cpython-39.pyc b/michelangelo/data/__pycache__/asl_webdataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e5418d9fba15f411e4e18ee6b66a953fc851727 Binary files /dev/null and b/michelangelo/data/__pycache__/asl_webdataset.cpython-39.pyc differ diff --git a/michelangelo/data/__pycache__/tokenizer.cpython-39.pyc b/michelangelo/data/__pycache__/tokenizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19c42c9492a0481ccecc23e6e4c917801b423211 Binary files /dev/null and b/michelangelo/data/__pycache__/tokenizer.cpython-39.pyc differ diff --git a/michelangelo/data/__pycache__/transforms.cpython-39.pyc b/michelangelo/data/__pycache__/transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a34923a320447270d95ad85c7b5563a03bbe6ff Binary files /dev/null and b/michelangelo/data/__pycache__/transforms.cpython-39.pyc differ diff --git a/michelangelo/data/__pycache__/utils.cpython-39.pyc b/michelangelo/data/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2b687235248f623cab697c277171c66901d5d16 Binary files /dev/null and b/michelangelo/data/__pycache__/utils.cpython-39.pyc differ diff --git a/michelangelo/data/templates.json b/michelangelo/data/templates.json new file mode 100644 index 0000000000000000000000000000000000000000..f1a355f6bb61a98b4e229e7aece7b0035c3b0592 --- /dev/null +++ b/michelangelo/data/templates.json @@ -0,0 +1,69 @@ +{ + "shape": [ + "a point cloud model of {}.", + "There is a {} in the scene.", + "There is the {} in the scene.", + "a photo of a {} in the scene.", + "a photo of the {} in the scene.", + "a photo of one {} in the scene.", + "itap of a {}.", + "itap of my {}.", + "itap of the {}.", + "a photo of a {}.", + "a photo of my {}.", + "a photo of the {}.", + "a photo of one {}.", + "a photo of many {}.", + "a good photo of a {}.", + "a good photo of the {}.", + "a bad photo of a {}.", + "a bad photo of the {}.", + "a photo of a nice {}.", + "a photo of the nice {}.", + "a photo of a cool {}.", + "a photo of the cool {}.", + "a photo of a weird {}.", + "a photo of the weird {}.", + "a photo of a small {}.", + "a photo of the small {}.", + "a photo of a large {}.", + "a photo of the large {}.", + "a photo of a clean {}.", + "a photo of the clean {}.", + "a photo of a dirty {}.", + "a photo of the dirty {}.", + "a bright photo of a {}.", + "a bright photo of the {}.", + "a dark photo of a {}.", + "a dark photo of the {}.", + "a photo of a hard to see {}.", + "a photo of the hard to see {}.", + "a low resolution photo of a {}.", + "a low resolution photo of the {}.", + "a cropped photo of a {}.", + "a cropped photo of the {}.", + "a close-up photo of a {}.", + "a close-up photo of the {}.", + "a jpeg corrupted photo of a {}.", + "a jpeg corrupted photo of the {}.", + "a blurry photo of a {}.", + "a blurry photo of the {}.", + "a pixelated photo of a {}.", + "a pixelated photo of the {}.", + "a black and white photo of the {}.", + "a black and white photo of a {}", + "a plastic {}.", + "the plastic {}.", + "a toy {}.", + "the toy {}.", + "a plushie {}.", + "the plushie {}.", + "a cartoon {}.", + "the cartoon {}.", + "an embroidered {}.", + "the embroidered {}.", + "a painting of the {}.", + "a painting of a {}." + ] + +} \ No newline at end of file diff --git a/michelangelo/data/transforms.py b/michelangelo/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..452e0ee93481932c91cf35dc5eb0b90658a7823e --- /dev/null +++ b/michelangelo/data/transforms.py @@ -0,0 +1,407 @@ +# -*- coding: utf-8 -*- +import os +import time +import numpy as np +import warnings +import random +from omegaconf.listconfig import ListConfig +from webdataset import pipelinefilter +import torch +import torchvision.transforms.functional as TVF +from torchvision.transforms import InterpolationMode +from torchvision.transforms.transforms import _interpolation_modes_from_int +from typing import Sequence + +from michelangelo.utils import instantiate_from_config + + +def _uid_buffer_pick(buf_dict, rng): + uid_keys = list(buf_dict.keys()) + selected_uid = rng.choice(uid_keys) + buf = buf_dict[selected_uid] + + k = rng.randint(0, len(buf) - 1) + sample = buf[k] + buf[k] = buf[-1] + buf.pop() + + if len(buf) == 0: + del buf_dict[selected_uid] + + return sample + + +def _add_to_buf_dict(buf_dict, sample): + key = sample["__key__"] + uid, uid_sample_id = key.split("_") + if uid not in buf_dict: + buf_dict[uid] = [] + buf_dict[uid].append(sample) + + return buf_dict + + +def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): + """Shuffle the data in the stream. + + This uses a buffer of size `bufsize`. Shuffling at + startup is less random; this is traded off against + yielding samples quickly. + + data: iterator + bufsize: buffer size for shuffling + returns: iterator + rng: either random module or random.Random instance + + """ + if rng is None: + rng = random.Random(int((os.getpid() + time.time()) * 1e9)) + initial = min(initial, bufsize) + buf_dict = dict() + current_samples = 0 + for sample in data: + _add_to_buf_dict(buf_dict, sample) + current_samples += 1 + + if current_samples < bufsize: + try: + _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708 + current_samples += 1 + except StopIteration: + pass + + if current_samples >= initial: + current_samples -= 1 + yield _uid_buffer_pick(buf_dict, rng) + + while current_samples > 0: + current_samples -= 1 + yield _uid_buffer_pick(buf_dict, rng) + + +uid_shuffle = pipelinefilter(_uid_shuffle) + + +class RandomSample(object): + def __init__(self, + num_volume_samples: int = 1024, + num_near_samples: int = 1024): + + super().__init__() + + self.num_volume_samples = num_volume_samples + self.num_near_samples = num_near_samples + + def __call__(self, sample): + rng = np.random.default_rng() + + # 1. sample surface input + total_surface = sample["surface"] + ind = rng.choice(total_surface.shape[0], replace=False) + surface = total_surface[ind] + + # 2. sample volume/near geometric points + vol_points = sample["vol_points"] + vol_label = sample["vol_label"] + near_points = sample["near_points"] + near_label = sample["near_label"] + + ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) + vol_points = vol_points[ind] + vol_label = vol_label[ind] + vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) + + ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) + near_points = near_points[ind] + near_label = near_label[ind] + near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) + + # concat sampled volume and near points + geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) + + sample = { + "surface": surface, + "geo_points": geo_points + } + + return sample + + +class SplitRandomSample(object): + def __init__(self, + use_surface_sample: bool = False, + num_surface_samples: int = 4096, + num_volume_samples: int = 1024, + num_near_samples: int = 1024): + + super().__init__() + + self.use_surface_sample = use_surface_sample + self.num_surface_samples = num_surface_samples + self.num_volume_samples = num_volume_samples + self.num_near_samples = num_near_samples + + def __call__(self, sample): + + rng = np.random.default_rng() + + # 1. sample surface input + surface = sample["surface"] + + if self.use_surface_sample: + replace = surface.shape[0] < self.num_surface_samples + ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace) + surface = surface[ind] + + # 2. sample volume/near geometric points + vol_points = sample["vol_points"] + vol_label = sample["vol_label"] + near_points = sample["near_points"] + near_label = sample["near_label"] + + ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) + vol_points = vol_points[ind] + vol_label = vol_label[ind] + vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) + + ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) + near_points = near_points[ind] + near_label = near_label[ind] + near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) + + # concat sampled volume and near points + geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) + + sample = { + "surface": surface, + "geo_points": geo_points + } + + return sample + + +class FeatureSelection(object): + + VALID_SURFACE_FEATURE_DIMS = { + "none": [0, 1, 2], # xyz + "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal + "normal": [0, 1, 2, 6, 7, 8] + } + + def __init__(self, surface_feature_type: str): + + self.surface_feature_type = surface_feature_type + self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type] + + def __call__(self, sample): + sample["surface"] = sample["surface"][:, self.surface_dims] + return sample + + +class AxisScaleTransform(object): + def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): + assert isinstance(interval, (tuple, list, ListConfig)) + self.interval = interval + self.min_val = interval[0] + self.max_val = interval[1] + self.inter_size = interval[1] - interval[0] + self.jitter = jitter + self.jitter_scale = jitter_scale + + def __call__(self, sample): + + surface = sample["surface"][..., 0:3] + geo_points = sample["geo_points"][..., 0:3] + + scaling = torch.rand(1, 3) * self.inter_size + self.min_val + # print(scaling) + surface = surface * scaling + geo_points = geo_points * scaling + + scale = (1 / torch.abs(surface).max().item()) * 0.999999 + surface *= scale + geo_points *= scale + + if self.jitter: + surface += self.jitter_scale * torch.randn_like(surface) + surface.clamp_(min=-1.015, max=1.015) + + sample["surface"][..., 0:3] = surface + sample["geo_points"][..., 0:3] = geo_points + + return sample + + +class ToTensor(object): + + def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")): + self.tensor_keys = tensor_keys + + def __call__(self, sample): + for key in self.tensor_keys: + if key not in sample: + continue + + sample[key] = torch.tensor(sample[key], dtype=torch.float32) + + return sample + + +class AxisScale(object): + def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): + assert isinstance(interval, (tuple, list, ListConfig)) + self.interval = interval + self.jitter = jitter + self.jitter_scale = jitter_scale + + def __call__(self, surface, *args): + scaling = torch.rand(1, 3) * 0.5 + 0.75 + # print(scaling) + surface = surface * scaling + scale = (1 / torch.abs(surface).max().item()) * 0.999999 + surface *= scale + + args_outputs = [] + for _arg in args: + _arg = _arg * scaling * scale + args_outputs.append(_arg) + + if self.jitter: + surface += self.jitter_scale * torch.randn_like(surface) + surface.clamp_(min=-1, max=1) + + if len(args) == 0: + return surface + else: + return surface, *args_outputs + + +class RandomResize(torch.nn.Module): + """Apply randomly Resize with a given probability.""" + + def __init__( + self, + size, + resize_radio=(0.5, 1), + allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR), + interpolation=InterpolationMode.BICUBIC, + max_size=None, + antialias=None, + ): + super().__init__() + if not isinstance(size, (int, Sequence)): + raise TypeError(f"Size should be int or sequence. Got {type(size)}") + if isinstance(size, Sequence) and len(size) not in (1, 2): + raise ValueError("If size is a sequence, it should have 1 or 2 values") + + self.size = size + self.max_size = max_size + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " + "Please use InterpolationMode enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + self.interpolation = interpolation + self.antialias = antialias + + self.resize_radio = resize_radio + self.allow_resize_interpolations = allow_resize_interpolations + + def random_resize_params(self): + radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0] + + if isinstance(self.size, int): + size = int(self.size * radio) + elif isinstance(self.size, Sequence): + size = list(self.size) + size = (int(size[0] * radio), int(size[1] * radio)) + else: + raise RuntimeError() + + interpolation = self.allow_resize_interpolations[ + torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,)) + ] + return size, interpolation + + def forward(self, img): + size, interpolation = self.random_resize_params() + img = TVF.resize(img, size, interpolation, self.max_size, self.antialias) + img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) + return img + + def __repr__(self) -> str: + detail = f"(size={self.size}, interpolation={self.interpolation.value}," + detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}" + return f"{self.__class__.__name__}{detail}" + + +class Compose(object): + """Composes several transforms together. This transform does not support torchscript. + Please, see the note below. + + Args: + transforms (list of ``Transform`` objects): list of transforms to compose. + + Example: + >>> transforms.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + + .. note:: + In order to script the transformations, please use ``torch.nn.Sequential`` as below. + + >>> transforms = torch.nn.Sequential( + >>> transforms.CenterCrop(10), + >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), + >>> ) + >>> scripted_transforms = torch.jit.script(transforms) + + Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require + `lambda` functions or ``PIL.Image``. + + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, *args): + for t in self.transforms: + args = t(*args) + return args + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.transforms: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +def identity(*args, **kwargs): + if len(args) == 1: + return args[0] + else: + return args + + +def build_transforms(cfg): + + if cfg is None: + return identity + + transforms = [] + + for transform_name, cfg_instance in cfg.items(): + transform_instance = instantiate_from_config(cfg_instance) + transforms.append(transform_instance) + print(f"Build transform: {transform_instance}") + + transforms = Compose(transforms) + + return transforms + diff --git a/michelangelo/data/utils.py b/michelangelo/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af06ed0c8849819a5d2b72ece805e8ec26079ea9 --- /dev/null +++ b/michelangelo/data/utils.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +import torch +import numpy as np + + +def worker_init_fn(_): + worker_info = torch.utils.data.get_worker_info() + worker_id = worker_info.id + + # dataset = worker_info.dataset + # split_size = dataset.num_records // worker_info.num_workers + # # reset num_records to the true number to retain reliable length information + # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size] + # current_id = np.random.choice(len(np.random.get_state()[1]), 1) + # return np.random.seed(np.random.get_state()[1][current_id] + worker_id) + + return np.random.seed(np.random.get_state()[1][0] + worker_id) + + +def collation_fn(samples, combine_tensors=True, combine_scalars=True): + """ + + Args: + samples (list[dict]): + combine_tensors: + combine_scalars: + + Returns: + + """ + + result = {} + + keys = samples[0].keys() + + for key in keys: + result[key] = [] + + for sample in samples: + for key in keys: + val = sample[key] + result[key].append(val) + + for key in keys: + val_list = result[key] + if isinstance(val_list[0], (int, float)): + if combine_scalars: + result[key] = np.array(result[key]) + + elif isinstance(val_list[0], torch.Tensor): + if combine_tensors: + result[key] = torch.stack(val_list) + + elif isinstance(val_list[0], np.ndarray): + if combine_tensors: + result[key] = np.stack(val_list) + + return result diff --git a/michelangelo/graphics/__init__.py b/michelangelo/graphics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/michelangelo/graphics/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/michelangelo/graphics/__pycache__/__init__.cpython-39.pyc b/michelangelo/graphics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed330fa9bac2e2f45495d63c220ddc9af996a10 Binary files /dev/null and b/michelangelo/graphics/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/graphics/primitives/__init__.py b/michelangelo/graphics/primitives/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb910878f98a83209a41b562d339d12d39f42e89 --- /dev/null +++ b/michelangelo/graphics/primitives/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .volume import generate_dense_grid_points + +from .mesh import ( + MeshOutput, + save_obj, + savemeshtes2 +) diff --git a/michelangelo/graphics/primitives/__pycache__/__init__.cpython-39.pyc b/michelangelo/graphics/primitives/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48b7127c6a9a5ff7b51c5808a65977871d8671ab Binary files /dev/null and b/michelangelo/graphics/primitives/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/graphics/primitives/__pycache__/extract_texture_map.cpython-39.pyc b/michelangelo/graphics/primitives/__pycache__/extract_texture_map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761e57da19c8cbbbe5e315bd21515d56daf6bf30 Binary files /dev/null and b/michelangelo/graphics/primitives/__pycache__/extract_texture_map.cpython-39.pyc differ diff --git a/michelangelo/graphics/primitives/__pycache__/mesh.cpython-39.pyc b/michelangelo/graphics/primitives/__pycache__/mesh.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b4e89a313d8d16485dc281474e220c2c8d4956c Binary files /dev/null and b/michelangelo/graphics/primitives/__pycache__/mesh.cpython-39.pyc differ diff --git a/michelangelo/graphics/primitives/__pycache__/volume.cpython-39.pyc b/michelangelo/graphics/primitives/__pycache__/volume.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9840c9429b8f5c1e54b3bf9ce12ae39bf21d5a84 Binary files /dev/null and b/michelangelo/graphics/primitives/__pycache__/volume.cpython-39.pyc differ diff --git a/michelangelo/graphics/primitives/mesh.py b/michelangelo/graphics/primitives/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5e8a551378b8e86d041967736cacaf904dbf54 --- /dev/null +++ b/michelangelo/graphics/primitives/mesh.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- + +import os +import cv2 +import numpy as np +import PIL.Image +from typing import Optional + +import trimesh + + +def save_obj(pointnp_px3, facenp_fx3, fname): + fid = open(fname, "w") + write_str = "" + for pidx, p in enumerate(pointnp_px3): + pp = p + write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2]) + + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2]) + fid.write(write_str) + fid.close() + return + + +def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname): + fol, na = os.path.split(fname) + na, _ = os.path.splitext(na) + + matname = "%s/%s.mtl" % (fol, na) + fid = open(matname, "w") + fid.write("newmtl material_0\n") + fid.write("Kd 1 1 1\n") + fid.write("Ka 0 0 0\n") + fid.write("Ks 0.4 0.4 0.4\n") + fid.write("Ns 10\n") + fid.write("illum 2\n") + fid.write("map_Kd %s.png\n" % na) + fid.close() + #### + + fid = open(fname, "w") + fid.write("mtllib %s.mtl\n" % na) + + for pidx, p in enumerate(pointnp_px3): + pp = p + fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2])) + + for pidx, p in enumerate(tcoords_px2): + pp = p + fid.write("vt %f %f\n" % (pp[0], pp[1])) + + fid.write("usemtl material_0\n") + for i, f in enumerate(facenp_fx3): + f1 = f + 1 + f2 = facetex_fx3[i] + 1 + fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) + fid.close() + + PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save( + os.path.join(fol, "%s.png" % na)) + + return + + +class MeshOutput(object): + + def __init__(self, + mesh_v: np.ndarray, + mesh_f: np.ndarray, + vertex_colors: Optional[np.ndarray] = None, + uvs: Optional[np.ndarray] = None, + mesh_tex_idx: Optional[np.ndarray] = None, + tex_map: Optional[np.ndarray] = None): + + self.mesh_v = mesh_v + self.mesh_f = mesh_f + self.vertex_colors = vertex_colors + self.uvs = uvs + self.mesh_tex_idx = mesh_tex_idx + self.tex_map = tex_map + + def contain_uv_texture(self): + return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None) + + def contain_vertex_colors(self): + return self.vertex_colors is not None + + def export(self, fname): + + if self.contain_uv_texture(): + savemeshtes2( + self.mesh_v, + self.uvs, + self.mesh_f, + self.mesh_tex_idx, + self.tex_map, + fname + ) + + elif self.contain_vertex_colors(): + mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors) + mesh_obj.export(fname) + + else: + save_obj( + self.mesh_v, + self.mesh_f, + fname + ) + + + diff --git a/michelangelo/graphics/primitives/volume.py b/michelangelo/graphics/primitives/volume.py new file mode 100644 index 0000000000000000000000000000000000000000..e8cb1d3f41fd00d18af5a6c751d49c68770fe04a --- /dev/null +++ b/michelangelo/graphics/primitives/volume.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- + +import numpy as np + + +def generate_dense_grid_points(bbox_min: np.ndarray, + bbox_max: np.ndarray, + octree_depth: int, + indexing: str = "ij"): + length = bbox_max - bbox_min + num_cells = np.exp2(octree_depth) + x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) + y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) + z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) + [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) + xyz = np.stack((xs, ys, zs), axis=-1) + xyz = xyz.reshape(-1, 3) + grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + + return xyz, grid_size, length + diff --git a/michelangelo/models/__init__.py b/michelangelo/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/michelangelo/models/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/michelangelo/models/__pycache__/__init__.cpython-39.pyc b/michelangelo/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd656db92c763f471415cb8c0fe3f56dff610fac Binary files /dev/null and b/michelangelo/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/models/asl_diffusion/__init__.py b/michelangelo/models/asl_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/michelangelo/models/asl_diffusion/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/michelangelo/models/asl_diffusion/__pycache__/__init__.cpython-39.pyc b/michelangelo/models/asl_diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c3d11328753cffe2f1107577c490a996a6a5dfe Binary files /dev/null and b/michelangelo/models/asl_diffusion/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/models/asl_diffusion/__pycache__/asl_udt.cpython-39.pyc b/michelangelo/models/asl_diffusion/__pycache__/asl_udt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb6f3d9c1e04878c510129bbbfce7741f7dffac8 Binary files /dev/null and b/michelangelo/models/asl_diffusion/__pycache__/asl_udt.cpython-39.pyc differ diff --git a/michelangelo/models/asl_diffusion/__pycache__/clip_asl_diffuser_pl_module.cpython-39.pyc b/michelangelo/models/asl_diffusion/__pycache__/clip_asl_diffuser_pl_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cbba4aea42ca4564766f095d32cf35659acbec8 Binary files /dev/null and b/michelangelo/models/asl_diffusion/__pycache__/clip_asl_diffuser_pl_module.cpython-39.pyc differ diff --git a/michelangelo/models/asl_diffusion/__pycache__/inference_utils.cpython-39.pyc b/michelangelo/models/asl_diffusion/__pycache__/inference_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac8dfc7fcab2e6b1b8f2e0b09fd631dbb079395a Binary files /dev/null and b/michelangelo/models/asl_diffusion/__pycache__/inference_utils.cpython-39.pyc differ diff --git a/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py b/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..558521196ebb97ea3e4c94b8f0710483eb068784 --- /dev/null +++ b/michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py @@ -0,0 +1,483 @@ +# -*- coding: utf-8 -*- + +from omegaconf import DictConfig +from typing import List, Tuple, Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import lr_scheduler +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only + +from einops import rearrange + +from diffusers.schedulers import ( + DDPMScheduler, + DDIMScheduler, + KarrasVeScheduler, + DPMSolverMultistepScheduler +) + +from michelangelo.utils import instantiate_from_config +# from michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule +from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule +from michelangelo.models.asl_diffusion.inference_utils import ddim_sample + +SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class ASLDiffuser(pl.LightningModule): + first_stage_model: Optional[AlignedShapeAsLatentPLModule] + # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] + model: nn.Module + + def __init__(self, *, + first_stage_config, + denoiser_cfg, + scheduler_cfg, + optimizer_cfg, + loss_cfg, + first_stage_key: str = "surface", + cond_stage_key: str = "image", + cond_stage_trainable: bool = True, + scale_by_std: bool = False, + z_scale_factor: float = 1.0, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + self.first_stage_key = first_stage_key + self.cond_stage_key = cond_stage_key + self.cond_stage_trainable = cond_stage_trainable + + # 1. initialize first stage. + # Note: the condition model contained in the first stage model. + self.first_stage_config = first_stage_config + self.first_stage_model = None + # self.instantiate_first_stage(first_stage_config) + + # 2. initialize conditional stage + # self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_model = { + "image": self.encode_image, + "image_unconditional_embedding": self.empty_img_cond, + "text": self.encode_text, + "text_unconditional_embedding": self.empty_text_cond, + "surface": self.encode_surface, + "surface_unconditional_embedding": self.empty_surface_cond, + } + + # 3. diffusion model + self.model = instantiate_from_config( + denoiser_cfg, device=None, dtype=None + ) + + self.optimizer_cfg = optimizer_cfg + + # 4. scheduling strategy + self.scheduler_cfg = scheduler_cfg + + self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) + self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) + + # 5. loss configures + self.loss_cfg = loss_cfg + + self.scale_by_std = scale_by_std + if scale_by_std: + self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) + else: + self.z_scale_factor = z_scale_factor + + self.ckpt_path = ckpt_path + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + self.first_stage_model = self.first_stage_model.to(self.device) + + # def instantiate_cond_stage(self, config): + # if not self.cond_stage_trainable: + # if config == "__is_first_stage__": + # print("Using first stage also as cond stage.") + # self.cond_stage_model = self.first_stage_model + # elif config == "__is_unconditional__": + # print(f"Training {self.__class__.__name__} as an unconditional model.") + # self.cond_stage_model = None + # # self.be_unconditional = True + # else: + # model = instantiate_from_config(config) + # self.cond_stage_model = model.eval() + # self.cond_stage_model.train = disabled_train + # for param in self.cond_stage_model.parameters(): + # param.requires_grad = False + # else: + # assert config != "__is_first_stage__" + # assert config != "__is_unconditional__" + # model = instantiate_from_config(config) + # self.cond_stage_model = model + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def configure_optimizers(self) -> Tuple[List, List]: + + lr = self.learning_rate + + trainable_parameters = list(self.model.parameters()) + # if the conditional encoder is trainable + + # if self.cond_stage_trainable: + # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad] + # trainable_parameters += conditioner_params + # print(f"number of trainable conditional parameters: {len(conditioner_params)}.") + + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + @torch.no_grad() + def encode_text(self, text): + + b = text.shape[0] + text_tokens = rearrange(text, "b t l -> (b t) l") + text_embed = self.first_stage_model.model.encode_text_embed(text_tokens) + text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) + text_embed = text_embed.mean(dim=1) + text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + + return text_embed + + @torch.no_grad() + def encode_image(self, img): + + return self.first_stage_model.model.encode_image_embed(img) + + @torch.no_grad() + def encode_surface(self, surface): + + return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False) + + @torch.no_grad() + def empty_text_cond(self, cond): + + return torch.zeros_like(cond, device=cond.device) + + @torch.no_grad() + def empty_img_cond(self, cond): + + return torch.zeros_like(cond, device=cond.device) + + @torch.no_grad() + def empty_surface_cond(self, cond): + + return torch.zeros_like(cond, device=cond.device) + + @torch.no_grad() + def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): + + z_q = self.first_stage_model.encode(surface, sample_posterior) + z_q = self.z_scale_factor * z_q + + return z_q + + @torch.no_grad() + def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): + + z_q = 1. / self.z_scale_factor * z_q + latents = self.first_stage_model.decode(z_q, **kwargs) + return latents + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ + and batch_idx == 0 and self.ckpt_path is None: + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + + z_q = self.encode_first_stage(batch[self.first_stage_key]) + z = z_q.detach() + + del self.z_scale_factor + self.register_buffer("z_scale_factor", 1. / z.flatten().std()) + print(f"setting self.z_scale_factor to {self.z_scale_factor}") + + print("### USING STD-RESCALING ###") + + def compute_loss(self, model_outputs, split): + """ + + Args: + model_outputs (dict): + - x_0: + - noise: + - noise_prior: + - noise_pred: + - noise_pred_prior: + + split (str): + + Returns: + + """ + + pred = model_outputs["pred"] + + if self.noise_scheduler.prediction_type == "epsilon": + target = model_outputs["noise"] + elif self.noise_scheduler.prediction_type == "sample": + target = model_outputs["x_0"] + else: + raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") + + if self.loss_cfg.loss_type == "l1": + simple = F.l1_loss(pred, target, reduction="mean") + elif self.loss_cfg.loss_type in ["mse", "l2"]: + simple = F.mse_loss(pred, target, reduction="mean") + else: + raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") + + total_loss = simple + + loss_dict = { + f"{split}/total_loss": total_loss.clone().detach(), + f"{split}/simple": simple.detach(), + } + + return total_loss, loss_dict + + def forward(self, batch): + """ + + Args: + batch: + + Returns: + + """ + + if self.first_stage_model is None: + self.instantiate_first_stage(self.first_stage_config) + + latents = self.encode_first_stage(batch[self.first_stage_key]) + + # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) + + conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1) + + mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1 + conditions = conditions * mask.to(conditions) + + # Sample noise that we"ll add to the latents + # [batch_size, n_token, latent_dim] + noise = torch.randn_like(latents) + bs = latents.shape[0] + # Sample a random timestep for each motion + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bs,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) + + # diffusion model forward + noise_pred = self.model(noisy_z, timesteps, conditions) + + diffusion_outputs = { + "x_0": noisy_z, + "noise": noise, + "pred": noise_pred + } + + return diffusion_outputs + + def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface (torch.FloatTensor): + - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] + - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] + - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] + - text (list of str): + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + diffusion_outputs = self(batch) + + loss, loss_dict = self.compute_loss(diffusion_outputs, "train") + self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) + + return loss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface_pc (torch.FloatTensor): [n_pts, 4] + - surface_feats (torch.FloatTensor): [n_pts, c] + - text (list of str): + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + diffusion_outputs = self(batch) + + loss, loss_dict = self.compute_loss(diffusion_outputs, "val") + self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) + + return loss + + @torch.no_grad() + def sample(self, + batch: Dict[str, Union[torch.FloatTensor, List[str]]], + sample_times: int = 1, + steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + eta: float = 0.0, + return_intermediates: bool = False, **kwargs): + + if self.first_stage_model is None: + self.instantiate_first_stage(self.first_stage_config) + + if steps is None: + steps = self.scheduler_cfg.num_inference_steps + + if guidance_scale is None: + guidance_scale = self.scheduler_cfg.guidance_scale + do_classifier_free_guidance = guidance_scale > 0 + + # conditional encode + xc = batch[self.cond_stage_key] + # cond = self.cond_stage_model[self.cond_stage_key](xc) + cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1) + + if do_classifier_free_guidance: + """ + Note: There are two kinds of uncond for text. + 1: using "" as uncond text; (in SAL diffusion) + 2: zeros_like(cond) as uncond text; (in MDM) + """ + # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) + un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond) + # un_cond = torch.zeros_like(cond, device=cond.device) + cond = torch.cat([un_cond, cond], dim=0) + + outputs = [] + latents = None + + if not return_intermediates: + for _ in range(sample_times): + sample_loop = ddim_sample( + self.denoise_scheduler, + self.model, + shape=self.first_stage_model.latent_shape, + cond=cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=not self.zero_rank + ) + for sample, t in sample_loop: + latents = sample + outputs.append(self.decode_first_stage(latents, **kwargs)) + else: + + sample_loop = ddim_sample( + self.denoise_scheduler, + self.model, + shape=self.first_stage_model.latent_shape, + cond=cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=not self.zero_rank + ) + + iter_size = steps // sample_times + i = 0 + for sample, t in sample_loop: + latents = sample + if i % iter_size == 0 or i == steps - 1: + outputs.append(self.decode_first_stage(latents, **kwargs)) + i += 1 + + return outputs diff --git a/michelangelo/models/asl_diffusion/asl_udt.py b/michelangelo/models/asl_diffusion/asl_udt.py new file mode 100644 index 0000000000000000000000000000000000000000..83a02341035d02c268f92ce089267f93eeae888b --- /dev/null +++ b/michelangelo/models/asl_diffusion/asl_udt.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from typing import Optional +from diffusers.models.embeddings import Timesteps +import math + +from michelangelo.models.modules.transformer_blocks import MLP +from michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer + + +class ConditionalASLUDTDenoiser(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + input_channels: int, + output_channels: int, + n_ctx: int, + width: int, + layers: int, + heads: int, + context_dim: int, + context_ln: bool = True, + skip_ln: bool = False, + init_scale: float = 0.25, + flip_sin_to_cos: bool = False, + use_checkpoint: bool = False): + super().__init__() + + self.use_checkpoint = use_checkpoint + + init_scale = init_scale * math.sqrt(1.0 / width) + + self.backbone = UNetDiffusionTransformer( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + layers=layers, + heads=heads, + skip_ln=skip_ln, + init_scale=init_scale, + use_checkpoint=use_checkpoint + ) + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype) + + # timestep embedding + self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0) + self.time_proj = MLP( + device=device, dtype=dtype, width=width, init_scale=init_scale + ) + + self.context_embed = nn.Sequential( + nn.LayerNorm(context_dim, device=device, dtype=dtype), + nn.Linear(context_dim, width, device=device, dtype=dtype), + ) + + if context_ln: + self.context_embed = nn.Sequential( + nn.LayerNorm(context_dim, device=device, dtype=dtype), + nn.Linear(context_dim, width, device=device, dtype=dtype), + ) + else: + self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype) + + def forward(self, + model_input: torch.FloatTensor, + timestep: torch.LongTensor, + context: torch.FloatTensor): + + r""" + Args: + model_input (torch.FloatTensor): [bs, n_data, c] + timestep (torch.LongTensor): [bs,] + context (torch.FloatTensor): [bs, context_tokens, c] + + Returns: + sample (torch.FloatTensor): [bs, n_data, c] + + """ + + _, n_data, _ = model_input.shape + + # 1. time + t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1) + + # 2. conditions projector + context = self.context_embed(context) + + # 3. denoiser + x = self.input_proj(model_input) + x = torch.cat([t_emb, context, x], dim=1) + x = self.backbone(x) + x = self.ln_post(x) + x = x[:, -n_data:] + sample = self.output_proj(x) + + return sample + + diff --git a/michelangelo/models/asl_diffusion/base.py b/michelangelo/models/asl_diffusion/base.py new file mode 100644 index 0000000000000000000000000000000000000000..a979197ae9990929aecbca42ce081a2b1aa1f465 --- /dev/null +++ b/michelangelo/models/asl_diffusion/base.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn + + +class BaseDenoiser(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, t, context): + raise NotImplementedError diff --git a/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py b/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..5755c1f42bd7c8e69a5a8136925da79ea39a84b5 --- /dev/null +++ b/michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py @@ -0,0 +1,393 @@ +# -*- coding: utf-8 -*- + +from omegaconf import DictConfig +from typing import List, Tuple, Dict, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import lr_scheduler +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only + +from diffusers.schedulers import ( + DDPMScheduler, + DDIMScheduler, + KarrasVeScheduler, + DPMSolverMultistepScheduler +) + +from michelangelo.utils import instantiate_from_config +from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule +from michelangelo.models.asl_diffusion.inference_utils import ddim_sample + +SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler] + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class ClipASLDiffuser(pl.LightningModule): + first_stage_model: Optional[AlignedShapeAsLatentPLModule] + cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]] + model: nn.Module + + def __init__(self, *, + first_stage_config, + cond_stage_config, + denoiser_cfg, + scheduler_cfg, + optimizer_cfg, + loss_cfg, + first_stage_key: str = "surface", + cond_stage_key: str = "image", + scale_by_std: bool = False, + z_scale_factor: float = 1.0, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + self.first_stage_key = first_stage_key + self.cond_stage_key = cond_stage_key + + # 1. lazy initialize first stage + self.instantiate_first_stage(first_stage_config) + + # 2. initialize conditional stage + self.instantiate_cond_stage(cond_stage_config) + + # 3. diffusion model + self.model = instantiate_from_config( + denoiser_cfg, device=None, dtype=None + ) + + self.optimizer_cfg = optimizer_cfg + + # 4. scheduling strategy + self.scheduler_cfg = scheduler_cfg + + self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise) + self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise) + + # 5. loss configures + self.loss_cfg = loss_cfg + + self.scale_by_std = scale_by_std + if scale_by_std: + self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor)) + else: + self.z_scale_factor = z_scale_factor + + self.ckpt_path = ckpt_path + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def instantiate_non_trainable_model(self, config): + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + for param in model.parameters(): + param.requires_grad = False + + return model + + def instantiate_first_stage(self, first_stage_config): + self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config) + self.first_stage_model.set_shape_model_only() + + def instantiate_cond_stage(self, cond_stage_config): + self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config) + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def configure_optimizers(self) -> Tuple[List, List]: + + lr = self.learning_rate + + trainable_parameters = list(self.model.parameters()) + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + @torch.no_grad() + def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True): + + z_q = self.first_stage_model.encode(surface, sample_posterior) + z_q = self.z_scale_factor * z_q + + return z_q + + @torch.no_grad() + def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs): + + z_q = 1. / self.z_scale_factor * z_q + latents = self.first_stage_model.decode(z_q, **kwargs) + return latents + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \ + and batch_idx == 0 and self.ckpt_path is None: + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + + z_q = self.encode_first_stage(batch[self.first_stage_key]) + z = z_q.detach() + + del self.z_scale_factor + self.register_buffer("z_scale_factor", 1. / z.flatten().std()) + print(f"setting self.z_scale_factor to {self.z_scale_factor}") + + print("### USING STD-RESCALING ###") + + def compute_loss(self, model_outputs, split): + """ + + Args: + model_outputs (dict): + - x_0: + - noise: + - noise_prior: + - noise_pred: + - noise_pred_prior: + + split (str): + + Returns: + + """ + + pred = model_outputs["pred"] + + if self.noise_scheduler.prediction_type == "epsilon": + target = model_outputs["noise"] + elif self.noise_scheduler.prediction_type == "sample": + target = model_outputs["x_0"] + else: + raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.") + + if self.loss_cfg.loss_type == "l1": + simple = F.l1_loss(pred, target, reduction="mean") + elif self.loss_cfg.loss_type in ["mse", "l2"]: + simple = F.mse_loss(pred, target, reduction="mean") + else: + raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.") + + total_loss = simple + + loss_dict = { + f"{split}/total_loss": total_loss.clone().detach(), + f"{split}/simple": simple.detach(), + } + + return total_loss, loss_dict + + def forward(self, batch): + """ + + Args: + batch: + + Returns: + + """ + + latents = self.encode_first_stage(batch[self.first_stage_key]) + conditions = self.cond_stage_model.encode(batch[self.cond_stage_key]) + + # Sample noise that we"ll add to the latents + # [batch_size, n_token, latent_dim] + noise = torch.randn_like(latents) + bs = latents.shape[0] + # Sample a random timestep for each motion + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (bs,), + device=latents.device, + ) + timesteps = timesteps.long() + # Add noise to the latents according to the noise magnitude at each timestep + noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps) + + # diffusion model forward + noise_pred = self.model(noisy_z, timesteps, conditions) + + diffusion_outputs = { + "x_0": noisy_z, + "noise": noise, + "pred": noise_pred + } + + return diffusion_outputs + + def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface (torch.FloatTensor): + - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1] + - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1] + - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1] + - text (list of str): + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + diffusion_outputs = self(batch) + + loss, loss_dict = self.compute_loss(diffusion_outputs, "train") + self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) + + return loss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface_pc (torch.FloatTensor): [n_pts, 4] + - surface_feats (torch.FloatTensor): [n_pts, c] + - text (list of str): + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + diffusion_outputs = self(batch) + + loss, loss_dict = self.compute_loss(diffusion_outputs, "val") + self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True) + + return loss + + @torch.no_grad() + def sample(self, + batch: Dict[str, Union[torch.FloatTensor, List[str]]], + sample_times: int = 1, + steps: Optional[int] = None, + guidance_scale: Optional[float] = None, + eta: float = 0.0, + return_intermediates: bool = False, **kwargs): + + if steps is None: + steps = self.scheduler_cfg.num_inference_steps + + if guidance_scale is None: + guidance_scale = self.scheduler_cfg.guidance_scale + do_classifier_free_guidance = guidance_scale > 0 + + # conditional encode + xc = batch[self.cond_stage_key] + + # print(self.first_stage_model.device, self.cond_stage_model.device, self.device) + + cond = self.cond_stage_model(xc) + + if do_classifier_free_guidance: + un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc)) + cond = torch.cat([un_cond, cond], dim=0) + + outputs = [] + latents = None + + if not return_intermediates: + for _ in range(sample_times): + sample_loop = ddim_sample( + self.denoise_scheduler, + self.model, + shape=self.first_stage_model.latent_shape, + cond=cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=not self.zero_rank + ) + for sample, t in sample_loop: + latents = sample + outputs.append(self.decode_first_stage(latents, **kwargs)) + else: + + sample_loop = ddim_sample( + self.denoise_scheduler, + self.model, + shape=self.first_stage_model.latent_shape, + cond=cond, + steps=steps, + guidance_scale=guidance_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + device=self.device, + eta=eta, + disable_prog=not self.zero_rank + ) + + iter_size = steps // sample_times + i = 0 + for sample, t in sample_loop: + latents = sample + if i % iter_size == 0 or i == steps - 1: + outputs.append(self.decode_first_stage(latents, **kwargs)) + i += 1 + + return outputs diff --git a/michelangelo/models/asl_diffusion/inference_utils.py b/michelangelo/models/asl_diffusion/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..967d5c52a8e33a6759d1c4891b0d21d1c9f95442 --- /dev/null +++ b/michelangelo/models/asl_diffusion/inference_utils.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +import torch +from tqdm import tqdm +from typing import Tuple, List, Union, Optional +from diffusers.schedulers import DDIMScheduler + + +__all__ = ["ddim_sample"] + + +def ddim_sample(ddim_scheduler: DDIMScheduler, + diffusion_model: torch.nn.Module, + shape: Union[List[int], Tuple[int]], + cond: torch.FloatTensor, + steps: int, + eta: float = 0.0, + guidance_scale: float = 3.0, + do_classifier_free_guidance: bool = True, + generator: Optional[torch.Generator] = None, + device: torch.device = "cuda:0", + disable_prog: bool = True): + + assert steps > 0, f"{steps} must > 0." + + # init latents + bsz = cond.shape[0] + if do_classifier_free_guidance: + bsz = bsz // 2 + + latents = torch.randn( + (bsz, *shape), + generator=generator, + device=cond.device, + dtype=cond.dtype, + ) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * ddim_scheduler.init_noise_sigma + # set timesteps + ddim_scheduler.set_timesteps(steps) + timesteps = ddim_scheduler.timesteps.to(device) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, and between [0, 1] + extra_step_kwargs = { + "eta": eta, + "generator": generator + } + + # reverse + for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if do_classifier_free_guidance + else latents + ) + # latent_model_input = scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + timestep_tensor = torch.tensor([t], dtype=torch.long, device=device) + timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0]) + noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + # text_embeddings_for_guidance = encoder_hidden_states.chunk( + # 2)[1] if do_classifier_free_guidance else encoder_hidden_states + # compute the previous noisy sample x_t -> x_t-1 + latents = ddim_scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + yield latents, t + + +def karra_sample(): + pass diff --git a/michelangelo/models/conditional_encoders/__init__.py b/michelangelo/models/conditional_encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f644ce0eac101dbd60ffdb0225a7560a5dc25735 --- /dev/null +++ b/michelangelo/models/conditional_encoders/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from .clip import CLIPEncoder diff --git a/michelangelo/models/conditional_encoders/__pycache__/__init__.cpython-39.pyc b/michelangelo/models/conditional_encoders/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e05a9136857bc4455fbcb650e784c336c2e32e6 Binary files /dev/null and b/michelangelo/models/conditional_encoders/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/models/conditional_encoders/__pycache__/clip.cpython-39.pyc b/michelangelo/models/conditional_encoders/__pycache__/clip.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1226ce7dd72985bb84c251b607055cd07ab58cd1 Binary files /dev/null and b/michelangelo/models/conditional_encoders/__pycache__/clip.cpython-39.pyc differ diff --git a/michelangelo/models/conditional_encoders/__pycache__/encoder_factory.cpython-39.pyc b/michelangelo/models/conditional_encoders/__pycache__/encoder_factory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39df0b1fc91424b3fbfc77c348b9a2993bfd14d8 Binary files /dev/null and b/michelangelo/models/conditional_encoders/__pycache__/encoder_factory.cpython-39.pyc differ diff --git a/michelangelo/models/conditional_encoders/clip.py b/michelangelo/models/conditional_encoders/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..099b237d543981cca70f92ccbbb0c1c560aa0f2a --- /dev/null +++ b/michelangelo/models/conditional_encoders/clip.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- + +import torch +import numpy as np +from PIL import Image +from dataclasses import dataclass +from torchvision.transforms import Normalize +from transformers import CLIPModel, CLIPTokenizer +from transformers.utils import ModelOutput +from typing import Iterable, Optional, Union, List + + +ImageType = Union[np.ndarray, torch.Tensor, Image.Image] + + +@dataclass +class CLIPEmbedOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + embeds: torch.FloatTensor = None + + +class CLIPEncoder(torch.nn.Module): + + def __init__(self, model_path="openai/clip-vit-base-patch32"): + + super().__init__() + + # Load the CLIP model and processor + self.model: CLIPModel = CLIPModel.from_pretrained(model_path) + self.tokenizer = CLIPTokenizer.from_pretrained(model_path) + self.image_preprocess = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + self.model.training = False + for p in self.model.parameters(): + p.requires_grad = False + + @torch.no_grad() + def encode_image(self, images: Iterable[Optional[ImageType]]): + pixel_values = self.image_preprocess(images) + + vision_outputs = self.model.vision_model(pixel_values=pixel_values) + + pooler_output = vision_outputs[1] # pooled_output + image_features = self.model.visual_projection(pooler_output) + + visual_embeds = CLIPEmbedOutput( + last_hidden_state=vision_outputs.last_hidden_state, + pooler_output=pooler_output, + embeds=image_features + ) + + return visual_embeds + + @torch.no_grad() + def encode_text(self, texts: List[str]): + text_inputs = self.tokenizer(texts, padding=True, return_tensors="pt") + + text_outputs = self.model.text_model(input_ids=text_inputs) + + pooler_output = text_outputs[1] # pooled_output + text_features = self.model.text_projection(pooler_output) + + text_embeds = CLIPEmbedOutput( + last_hidden_state=text_outputs.last_hidden_state, + pooler_output=pooler_output, + embeds=text_features + ) + + return text_embeds + + def forward(self, + images: Iterable[Optional[ImageType]], + texts: List[str]): + + visual_embeds = self.encode_image(images) + text_embeds = self.encode_text(texts) + + return visual_embeds, text_embeds + + + + + + + + + + diff --git a/michelangelo/models/conditional_encoders/encoder_factory.py b/michelangelo/models/conditional_encoders/encoder_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..8e6da9d354817bfa8e4eea43b2dacf784488af73 --- /dev/null +++ b/michelangelo/models/conditional_encoders/encoder_factory.py @@ -0,0 +1,562 @@ +# -*- coding: utf-8 -*- +import os + +import torch +import torch.nn as nn +from torchvision import transforms +from transformers import CLIPModel, CLIPTokenizer +from collections import OrderedDict + +from michelangelo.data.transforms import RandomResize + + +class AbstractEncoder(nn.Module): + embedding_dim: int + + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key="class"): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + + def forward(self, batch, key=None): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + c = self.embedding(c) + return c + + +class FrozenCLIPTextEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, + version="openai/clip-vit-large-patch14", + tokenizer_version=None, + device="cuda", + max_length=77, + zero_embedding_radio: float = 0.1, + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version) + + self.device = device + self.max_length = max_length + self.zero_embedding_radio = zero_embedding_radio + + self.clip_dict = OrderedDict() + self.clip_name = os.path.split(version)[-1] + + transformer = CLIPModel.from_pretrained(version).text_model + + for param in transformer.parameters(): + param.requires_grad = False + self.clip_dict[self.clip_name] = transformer + + self._move_flag = False + + @property + def clip(self): + return self.clip_dict[self.clip_name] + + def move(self): + if self._move_flag: + return + + self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device) + self._move_flag = True + + def unconditional_embedding(self, batch_size): + empty_text = [""] * batch_size + empty_z = self.forward(empty_text) + return empty_z + + def forward(self, text): + self.move() + + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.clip(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + batch_size = len(text) + batch_mask = torch.rand((batch_size,)) + for i in range(batch_size): + if batch_mask[i] < self.zero_embedding_radio: + text[i] = "" + + return self(text) + +class FrozenAlignedCLIPTextEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, + version="openai/clip-vit-large-patch14", + tokenizer_version=None, + device="cuda", + max_length=77, + zero_embedding_radio: float = 0.1, + ): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_version or version) + + self.device = device + self.max_length = max_length + self.zero_embedding_radio = zero_embedding_radio + + self.clip_dict = OrderedDict() + self.clip_name = os.path.split(version)[-1] + + transformer = CLIPModel.from_pretrained(version).text_model + + for param in transformer.parameters(): + param.requires_grad = False + self.clip_dict[self.clip_name] = transformer + + self._move_flag = False + + @property + def clip(self): + return self.clip_dict[self.clip_name] + + def move(self): + if self._move_flag: + return + + self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device) + self._move_flag = True + + def unconditional_embedding(self, batch_size): + empty_text = [""] * batch_size + empty_z = self.forward(empty_text) + return empty_z + + def forward(self, text): + self.move() + + batch_encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding="max_length", + return_tensors="pt", + ) + + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.clip(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + batch_size = len(text) + batch_mask = torch.rand((batch_size,)) + for i in range(batch_size): + if batch_mask[i] < self.zero_embedding_radio: + text[i] = "" + + return self(text) + + +class FrozenCLIPImageEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + zero_embedding_radio=0.1, + normalize_embedding=True, + num_projection_vector=0, + linear_mapping_bias=True, + reverse_visual_projection=False, + ): + super().__init__() + + self.device = device + + self.clip_dict = OrderedDict() + self.clip_name = os.path.split(version)[-1] + + clip_model = CLIPModel.from_pretrained(version) + clip_model.text_model = None + clip_model.text_projection = None + clip_model = clip_model.eval() + for param in self.parameters(): + param.requires_grad = False + self.clip_dict[self.clip_name] = clip_model + + self.transform = transforms.Compose( + [ + transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.CenterCrop(224), # crop a (224, 224) square + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + self.zero_embedding_radio = zero_embedding_radio + + self.num_projection_vector = num_projection_vector + self.reverse_visual_projection = reverse_visual_projection + self.normalize_embedding = normalize_embedding + + embedding_dim = ( + clip_model.visual_projection.in_features + if reverse_visual_projection + else clip_model.visual_projection.out_features + ) + self.embedding_dim = embedding_dim + if self.num_projection_vector > 0: + self.projection = nn.Linear( + embedding_dim, + clip_model.visual_projection.out_features * num_projection_vector, + bias=linear_mapping_bias, + ) + nn.init.normal_(self.projection.weight, std=embedding_dim ** -0.5) + + self._move_flag = False + + @property + def clip(self): + return self.clip_dict[self.clip_name] + + def unconditional_embedding(self, batch_size): + zero = torch.zeros( + batch_size, + 1, + self.embedding_dim, + device=self.device, + dtype=self.clip.visual_projection.weight.dtype, + ) + if self.num_projection_vector > 0: + zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1) + return zero + + def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0): + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype) + + if self.reverse_visual_projection: + z = self.clip.vision_model(self.transform(image))[1] + else: + z = self.clip.get_image_features(self.transform(image)) + + if self.normalize_embedding: + z = z / z.norm(dim=-1, keepdim=True) + if z.ndim == 2: + z = z.unsqueeze(dim=-2) + + if zero_embedding_radio > 0: + mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) < zero_embedding_radio + z = z * mask.to(z) + + if self.num_projection_vector > 0: + z = self.projection(z).view(len(image), self.num_projection_vector, -1) + + return z + + def move(self): + if self._move_flag: + return + + self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device) + self._move_flag = True + + def encode(self, image): + self.move() + return self(image, zero_embedding_radio=self.zero_embedding_radio) + + +class FrozenCLIPImageGridEmbedder(AbstractEncoder): + + def __init__( + self, + version="openai/clip-vit-large-patch14", + device="cuda", + zero_embedding_radio=0.1, + ): + super().__init__() + + self.device = device + + self.clip_dict = OrderedDict() + self.clip_name = os.path.split(version)[-1] + + clip_model: CLIPModel = CLIPModel.from_pretrained(version) + clip_model.text_model = None + clip_model.text_projection = None + clip_model = clip_model.eval() + for param in self.parameters(): + param.requires_grad = False + self.clip_dict[self.clip_name] = clip_model + + self.transform = transforms.Compose( + [ + transforms.Resize(224, transforms.InterpolationMode.BILINEAR, antialias=True), + transforms.CenterCrop(224), # crop a (224, 224) square + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + self.zero_embedding_radio = zero_embedding_radio + self.embedding_dim = clip_model.vision_embed_dim + + self._move_flag = False + + @property + def clip(self): + return self.clip_dict[self.clip_name] + + def move(self): + if self._move_flag: + return + + self.clip_dict[self.clip_name] = self.clip_dict[self.clip_name].to(self.device) + self._move_flag = True + + def unconditional_embedding(self, batch_size): + zero = torch.zeros( + batch_size, + self.clip.vision_model.embeddings.num_positions, + self.embedding_dim, + device=self.device, + dtype=self.clip.visual_projection.weight.dtype, + ) + return zero + + def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0): + self.move() + + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = image.to(self.device, dtype=self.clip.visual_projection.weight.dtype) + + z = self.clip.vision_model(self.transform(image)).last_hidden_state + + if zero_embedding_radio > 0: + mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio + z = z * mask.to(z) + + return z + + def encode(self, image): + return self(image, zero_embedding_radio=self.zero_embedding_radio) + + +class MoECLIPImageEncoder(nn.Module): + def __init__( + self, + versions, + hidden_state_dim, + num_projection_vector=8, + zero_embedding_radio=0.1, + device="cuda", + precision="fp16", + normalize=False, + clip_max=0, + transform_type="base", + argument_p=0.2, + ): + super().__init__() + + self.device = torch.device(device) + self.hidden_state_dim = hidden_state_dim + self.zero_embedding_radio = zero_embedding_radio + self.num_projection_vector = num_projection_vector + self.dtype = dict(fp16=torch.float16, fp32=torch.float32, bf16=torch.bfloat16)[precision] + self.normalize = normalize + self.clip_max = clip_max + + if transform_type == "base": + self.transform = transforms.Compose( + [ + transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.CenterCrop(224), # crop a (224, 224) square + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + elif transform_type == "crop_blur_resize": + self.transform = transforms.Compose( + [ + transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), + transforms.CenterCrop(224), # crop a (224, 224) square + transforms.RandomApply( + transforms=[ + transforms.RandomResizedCrop( + size=224, + scale=(0.8, 1.0), + ratio=(0.99, 1.01), + interpolation=transforms.InterpolationMode.BICUBIC, + ), + ], + p=argument_p, + ), + transforms.RandomApply( + transforms=[ + transforms.GaussianBlur(kernel_size=9, sigma=(0.1, 5)), + ], + p=argument_p, + ), + transforms.RandomApply( + transforms=[ + RandomResize(size=224, resize_radio=(0.2, 1)), + ], + p=argument_p, + ), + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711], + ), + ] + ) + else: + raise ValueError(f"invalid {transform_type=}") + + if isinstance(versions, str): + versions = (versions,) + + # 如果直接把clips定位为当前类的子module,1. 会在保存ckp时存无用的多个权重。 2. pl会调用to,导致layer_norm的权重也被转换成fp16 + clips = OrderedDict() + + for v in versions: + # 因为clips不是子module,直接指定device="cuda"会错误地导致clip模型权重都被放到cuda:0上。 + clips[v], _ = clip.load(name=v, device="cpu", jit=False, download_root=None) + delattr(clips[v], "transformer") + clips[v].eval() + clips[v].requires_grad_(False) + + self.clips_hidden_dim = sum(clips[v].ln_final.weight.size(0) for v in clips) + + if self.num_projection_vector == 0: + self.projection = nn.Identity() + else: + self.projection = nn.Linear(self.clips_hidden_dim, hidden_state_dim * self.num_projection_vector, bias=True) + self.projection.to(dtype=self.dtype) + nn.init.normal_(self.projection.weight, std=self.clips_hidden_dim ** -0.5) + + self.clips = clips + + self._move_flag = False + + def move(self): + if self._move_flag: + return + + def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.type(self.dtype) + if l.bias is not None: + l.bias.data = l.bias.data.type(self.dtype) + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.type(self.dtype) + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.type(self.dtype) + + model.apply(_convert_weights_to_fp16) + + for k in self.clips: + self.clips[k].to(self.device) + convert_weights(self.clips[k]) # fp32 -> self.dtype + self._move_flag = True + + def unconditional_embedding(self, batch_size=None): + zero = torch.zeros( + batch_size, + self.clips_hidden_dim, + device=self.device, + dtype=self.dtype, + ) + if self.num_projection_vector > 0: + zero = self.projection(zero).view(batch_size, self.num_projection_vector, -1) + return zero + + def convert_embedding(self, z): + if self.num_projection_vector > 0: + z = self.projection(z.type(self.projection.weight.dtype)).view(len(z), self.num_projection_vector, -1) + return z + + def forward(self, image, value_range=(-1, 1), zero_embedding_radio=0): + if value_range is not None: + low, high = value_range + image = (image - low) / (high - low) + + image = self.transform(image) + + with torch.no_grad(): + embs = [] + for v in self.clips: + x = self.clips[v].encode_image(image) + if self.normalize: + x = x / x.norm(p=2, dim=-1, keepdim=True) * (x.size(-1) ** 0.5) + # clip_max only works with normalization + if self.clip_max > 0: + x = x.clamp(-self.clip_max, self.clip_max) + embs.append(x) + + z = torch.cat(embs, dim=-1) + if self.normalize: + z /= z.size(-1) ** 0.5 + + if zero_embedding_radio > 0: + mask = torch.rand((len(image), 1, 1), device=z.device, dtype=z.dtype) >= zero_embedding_radio + z = z + mask.to(z) + + if self.num_projection_vector > 0: + z = self.projection(z).view(len(image), self.num_projection_vector, -1) + return z + + def encode(self, image): + self.move() + return self(image, zero_embedding_radio=self.zero_embedding_radio) diff --git a/michelangelo/models/modules/__init__.py b/michelangelo/models/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0729b49eadf964584d3524d9c0f6adec3f04a6a9 --- /dev/null +++ b/michelangelo/models/modules/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from .checkpoint import checkpoint diff --git a/michelangelo/models/modules/__pycache__/__init__.cpython-39.pyc b/michelangelo/models/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..348900e920a85fe0d57351a17c39749033bbb410 Binary files /dev/null and b/michelangelo/models/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/models/modules/__pycache__/checkpoint.cpython-39.pyc b/michelangelo/models/modules/__pycache__/checkpoint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b08696fe6915b54cd3eae60de3bc2b7913aad097 Binary files /dev/null and b/michelangelo/models/modules/__pycache__/checkpoint.cpython-39.pyc differ diff --git a/michelangelo/models/modules/__pycache__/diffusion_transformer.cpython-39.pyc b/michelangelo/models/modules/__pycache__/diffusion_transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e93c113e6db8b8cc772447f18f8af9ed7112eea2 Binary files /dev/null and b/michelangelo/models/modules/__pycache__/diffusion_transformer.cpython-39.pyc differ diff --git a/michelangelo/models/modules/__pycache__/distributions.cpython-39.pyc b/michelangelo/models/modules/__pycache__/distributions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3007bb32e0994a8a6d2339b0422a39ed72c020b5 Binary files /dev/null and b/michelangelo/models/modules/__pycache__/distributions.cpython-39.pyc differ diff --git a/michelangelo/models/modules/__pycache__/embedder.cpython-39.pyc b/michelangelo/models/modules/__pycache__/embedder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1d8c035e958efe1b3c4dfc01115cafc58869551 Binary files /dev/null and b/michelangelo/models/modules/__pycache__/embedder.cpython-39.pyc differ diff --git a/michelangelo/models/modules/__pycache__/transformer_blocks.cpython-39.pyc b/michelangelo/models/modules/__pycache__/transformer_blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f8c267a1d5e3fae972d72e51fb7c9444ad80743 Binary files /dev/null and b/michelangelo/models/modules/__pycache__/transformer_blocks.cpython-39.pyc differ diff --git a/michelangelo/models/modules/checkpoint.py b/michelangelo/models/modules/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2e5481ea2768afd53b54619cd33aa936ff6afc11 --- /dev/null +++ b/michelangelo/models/modules/checkpoint.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +""" +Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 +""" + +import torch +from typing import Callable, Iterable, Sequence, Union + + +def checkpoint( + func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], + inputs: Sequence[torch.Tensor], + params: Iterable[torch.Tensor], + flag: bool, + use_deepspeed: bool = False +): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + :param use_deepspeed: if True, use deepspeed + """ + if flag: + if use_deepspeed: + return deepspeed.checkpointing.checkpoint(func, *inputs) + + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads diff --git a/michelangelo/models/modules/diffusion_transformer.py b/michelangelo/models/modules/diffusion_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0a8345410bbddaf7b483418190a243fc5abb2ea5 --- /dev/null +++ b/michelangelo/models/modules/diffusion_transformer.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +from typing import Optional + +from michelangelo.models.modules.checkpoint import checkpoint +from michelangelo.models.modules.transformer_blocks import ( + init_linear, + MLP, + MultiheadCrossAttention, + MultiheadAttention, + ResidualAttentionBlock +) + + +class AdaLayerNorm(nn.Module): + def __init__(self, + device: torch.device, + dtype: torch.dtype, + width: int): + + super().__init__() + + self.silu = nn.SiLU(inplace=True) + self.linear = nn.Linear(width, width * 2, device=device, dtype=dtype) + self.layernorm = nn.LayerNorm(width, elementwise_affine=False, device=device, dtype=dtype) + + def forward(self, x, timestep): + emb = self.linear(timestep) + scale, shift = torch.chunk(emb, 2, dim=2) + x = self.layernorm(x) * (1 + scale) + shift + return x + + +class DitBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + context_dim: int, + qkv_bias: bool = False, + init_scale: float = 1.0, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias + ) + self.ln_1 = AdaLayerNorm(device, dtype, width) + + if context_dim is not None: + self.ln_2 = AdaLayerNorm(device, dtype, width) + self.cross_attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + width=width, + heads=heads, + data_width=context_dim, + init_scale=init_scale, + qkv_bias=qkv_bias + ) + + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_3 = AdaLayerNorm(device, dtype, width) + + def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): + return checkpoint(self._forward, (x, t, context), self.parameters(), self.use_checkpoint) + + def _forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): + x = x + self.attn(self.ln_1(x, t)) + if context is not None: + x = x + self.cross_attn(self.ln_2(x, t), context) + x = x + self.mlp(self.ln_3(x, t)) + return x + + +class DiT(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + context_dim: int, + init_scale: float = 0.25, + qkv_bias: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList( + [ + DitBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + context_dim=context_dim, + qkv_bias=qkv_bias, + init_scale=init_scale, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, t: torch.Tensor, context: Optional[torch.Tensor] = None): + for block in self.resblocks: + x = block(x, t, context) + return x + + +class UNetDiffusionTransformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = False, + skip_ln: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + + self.n_ctx = n_ctx + self.width = width + self.layers = layers + + self.encoder = nn.ModuleList() + for _ in range(layers): + resblock = ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + self.encoder.append(resblock) + + self.middle_block = ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + + self.decoder = nn.ModuleList() + for _ in range(layers): + resblock = ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + linear = nn.Linear(width * 2, width, device=device, dtype=dtype) + init_linear(linear, init_scale) + + layer_norm = nn.LayerNorm(width, device=device, dtype=dtype) if skip_ln else None + + self.decoder.append(nn.ModuleList([resblock, linear, layer_norm])) + + def forward(self, x: torch.Tensor): + + enc_outputs = [] + for block in self.encoder: + x = block(x) + enc_outputs.append(x) + + x = self.middle_block(x) + + for i, (resblock, linear, layer_norm) in enumerate(self.decoder): + x = torch.cat([enc_outputs.pop(), x], dim=-1) + x = linear(x) + + if layer_norm is not None: + x = layer_norm(x) + + x = resblock(x) + + return x + + + diff --git a/michelangelo/models/modules/distributions.py b/michelangelo/models/modules/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..cf1cdcd53f1eb534b55d92ae1bd0b9854f6b890c --- /dev/null +++ b/michelangelo/models/modules/distributions.py @@ -0,0 +1,100 @@ +import torch +import numpy as np +from typing import Union, List + + +class AbstractDistribution(object): + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): + self.feat_dim = feat_dim + self.parameters = parameters + + if isinstance(parameters, list): + self.mean = parameters[0] + self.logvar = parameters[1] + else: + self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean) + + def sample(self): + x = self.mean + self.std * torch.randn_like(self.mean) + return x + + def kl(self, other=None, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.mean(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=dims) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=dims) + + def nll(self, sample, dims=(1, 2, 3)): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/michelangelo/models/modules/embedder.py b/michelangelo/models/modules/embedder.py new file mode 100644 index 0000000000000000000000000000000000000000..223de828f44903a3ce96b59d1cc5621e0989b535 --- /dev/null +++ b/michelangelo/models/modules/embedder.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import torch +import torch.nn as nn +import math + +VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] + + +class FourierEmbedder(nn.Module): + """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts + each feature dimension of `x[..., i]` into: + [ + sin(x[..., i]), + sin(f_1*x[..., i]), + sin(f_2*x[..., i]), + ... + sin(f_N * x[..., i]), + cos(x[..., i]), + cos(f_1*x[..., i]), + cos(f_2*x[..., i]), + ... + cos(f_N * x[..., i]), + x[..., i] # only present if include_input is True. + ], here f_i is the frequency. + + Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. + If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; + Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. + + Args: + num_freqs (int): the number of frequencies, default is 6; + logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; + input_dim (int): the input dimension, default is 3; + include_input (bool): include the input tensor or not, default is True. + + Attributes: + frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], + otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); + + out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), + otherwise, it is input_dim * num_freqs * 2. + + """ + + def __init__(self, + num_freqs: int = 6, + logspace: bool = True, + input_dim: int = 3, + include_input: bool = True, + include_pi: bool = True) -> None: + + """The initialization""" + + super().__init__() + + if logspace: + frequencies = 2.0 ** torch.arange( + num_freqs, + dtype=torch.float32 + ) + else: + frequencies = torch.linspace( + 1.0, + 2.0 ** (num_freqs - 1), + num_freqs, + dtype=torch.float32 + ) + + if include_pi: + frequencies *= torch.pi + + self.register_buffer("frequencies", frequencies, persistent=False) + self.include_input = include_input + self.num_freqs = num_freqs + + self.out_dim = self.get_dims(input_dim) + + def get_dims(self, input_dim): + temp = 1 if self.include_input or self.num_freqs == 0 else 0 + out_dim = input_dim * (self.num_freqs * 2 + temp) + + return out_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ Forward process. + + Args: + x: tensor of shape [..., dim] + + Returns: + embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] + where temp is 1 if include_input is True and 0 otherwise. + """ + + if self.num_freqs > 0: + embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) + if self.include_input: + return torch.cat((x, embed.sin(), embed.cos()), dim=-1) + else: + return torch.cat((embed.sin(), embed.cos()), dim=-1) + else: + return x + + +class LearnedFourierEmbedder(nn.Module): + """ following @crowsonkb "s lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, in_channels, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + per_channel_dim = half_dim // in_channels + self.weights = nn.Parameter(torch.randn(per_channel_dim)) + + def forward(self, x): + """ + + Args: + x (torch.FloatTensor): [..., c] + + Returns: + x (torch.FloatTensor): [..., d] + """ + + # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] + freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) + fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) + return fouriered + + +class TriplaneLearnedFourierEmbedder(nn.Module): + def __init__(self, in_channels, dim): + super().__init__() + + self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) + + self.out_dim = in_channels + dim + + def forward(self, x): + + yz_embed = self.yz_plane_embedder(x) + xz_embed = self.xz_plane_embedder(x) + xy_embed = self.xy_plane_embedder(x) + + embed = yz_embed + xz_embed + xy_embed + + return embed + + +def sequential_pos_embed(num_len, embed_dim): + assert embed_dim % 2 == 0 + + pos = torch.arange(num_len, dtype=torch.float32) + omega = torch.arange(embed_dim // 2, dtype=torch.float32) + omega /= embed_dim / 2. + omega = 1. / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) + + return embeddings + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].to(timesteps.dtype) * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, + num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, + log2_hashmap_size=19, desired_resolution=None): + if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): + return nn.Identity(), input_dim + + elif embed_type == "fourier": + embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, + logspace=True, include_input=True) + return embedder_obj, embedder_obj.out_dim + + elif embed_type == "hashgrid": + raise NotImplementedError + + elif embed_type == "sphere_harmonic": + raise NotImplementedError + + else: + raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") diff --git a/michelangelo/models/modules/transformer_blocks.py b/michelangelo/models/modules/transformer_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..c77dd3720d54aa1e83274c8d35e6cae95d28ba9d --- /dev/null +++ b/michelangelo/models/modules/transformer_blocks.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional + +from michelangelo.models.modules.checkpoint import checkpoint + + +def init_linear(l, stddev): + nn.init.normal_(l.weight, std=stddev) + if l.bias is not None: + nn.init.constant_(l.bias, 0.0) + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool, + flash: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) + init_linear(self.c_qkv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + self.flash = flash + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + + if self.flash: + out = F.scaled_dot_product_attention(q, k, v) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + init_scale: float = 1.0, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + width: int, + heads: int, + init_scale: float, + qkv_bias: bool = True, + flash: bool = False, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash + ) + init_linear(self.c_q, init_scale) + init_linear(self.c_kv, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, + flash: bool = False, n_data: Optional[int] = None): + + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + self.flash = flash + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + + if self.flash: + out = F.scaled_dot_product_attention(q, k, v) + else: + weight = torch.einsum( + "bthc,bshc->bhts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + return out + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_data: Optional[int] = None, + width: int, + heads: int, + data_width: Optional[int] = None, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class MLP(nn.Module): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + width: int, + init_scale: float): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + init_linear(self.c_fc, init_scale) + init_linear(self.c_proj, init_scale) + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class Transformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x diff --git a/michelangelo/models/modules/transformer_vit.py b/michelangelo/models/modules/transformer_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..a1999b1abd2442511a351ecade6c6aa59c10d716 --- /dev/null +++ b/michelangelo/models/modules/transformer_vit.py @@ -0,0 +1,310 @@ +# -*- coding: utf-8 -*- + +import math +import torch +import torch.nn as nn +from typing import Optional +import warnings + +from michelangelo.models.modules.checkpoint import checkpoint + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor | nn.Parameter, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are + applied while sampling the normal with mean/std applied, therefore a, b args + should be adjusted to match the range of mean, std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + with torch.no_grad(): + return _trunc_normal_(tensor, mean, std, a, b) + + +def init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + +class MultiheadAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + qkv_bias: bool + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx) + + def forward(self, x): + x = self.c_qkv(x) + x = checkpoint(self.attention, (x,), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_ctx = n_ctx + + def forward(self, qkv): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.heads // 3 + scale = 1 / math.sqrt(attn_ch) + qkv = qkv.view(bs, n_ctx, self.heads, -1) + q, k, v = torch.split(qkv, attn_ch, dim=-1) + weight = torch.einsum("bthc,bshc->bhts", q, k) * scale + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + n_ctx: int, + width: int, + heads: int, + qkv_bias: bool = True, + use_checkpoint: bool = False + ): + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.attn = MultiheadAttention( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + qkv_bias=qkv_bias + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width) + self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) + + def _forward(self, x: torch.Tensor): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + def forward(self, x: torch.Tensor): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + +class MultiheadCrossAttention(nn.Module): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + width: int, + heads: int, + qkv_bias: bool = True, + n_data: Optional[int] = None, + data_width: Optional[int] = None, + ): + super().__init__() + self.n_data = n_data + self.width = width + self.heads = heads + self.data_width = width if data_width is None else data_width + self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) + self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) + self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) + self.attention = QKVMultiheadCrossAttention( + device=device, dtype=dtype, heads=heads, n_data=n_data + ) + + def forward(self, x, data): + x = self.c_q(x) + data = self.c_kv(data) + x = checkpoint(self.attention, (x, data), (), True) + x = self.c_proj(x) + return x + + +class QKVMultiheadCrossAttention(nn.Module): + def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_data: Optional[int] = None): + super().__init__() + self.device = device + self.dtype = dtype + self.heads = heads + self.n_data = n_data + + def forward(self, q, kv): + _, n_ctx, _ = q.shape + bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 + scale = 1 / math.sqrt(attn_ch) + q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) + k, v = torch.split(kv, attn_ch, dim=-1) + weight = torch.einsum("bthc,bshc->bhts", q, k) * scale + wdtype = weight.dtype + weight = torch.softmax(weight.float(), dim=-1).type(wdtype) + return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualCrossAttentionBlock(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_data: Optional[int] = None, + width: int, + heads: int, + data_width: Optional[int] = None, + qkv_bias: bool = True + ): + super().__init__() + + if data_width is None: + data_width = width + + self.attn = MultiheadCrossAttention( + device=device, + dtype=dtype, + n_data=n_data, + width=width, + heads=heads, + data_width=data_width, + qkv_bias=qkv_bias + ) + self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) + self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) + self.mlp = MLP(device=device, dtype=dtype, width=width) + self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, data: torch.Tensor): + x = x + self.attn(self.ln_1(x), self.ln_2(data)) + x = x + self.mlp(self.ln_3(x)) + return x + + +class MLP(nn.Module): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + width: int): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) + self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) + self.gelu = nn.GELU() + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class Transformer(nn.Module): + def __init__( + self, + *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + n_ctx: int, + width: int, + layers: int, + heads: int, + qkv_bias: bool = True, + use_checkpoint: bool = False + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + device=device, + dtype=dtype, + n_ctx=n_ctx, + width=width, + heads=heads, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + for _ in range(layers) + ] + ) + + self.apply(init_weights) + + def forward(self, x: torch.Tensor): + for block in self.resblocks: + x = block(x) + return x + + diff --git a/michelangelo/models/tsal/__init__.py b/michelangelo/models/tsal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/michelangelo/models/tsal/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/michelangelo/models/tsal/__pycache__/__init__.cpython-39.pyc b/michelangelo/models/tsal/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e104d0d3907e3d20673006beef7ae3876f44e15 Binary files /dev/null and b/michelangelo/models/tsal/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/models/tsal/__pycache__/asl_pl_module.cpython-39.pyc b/michelangelo/models/tsal/__pycache__/asl_pl_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96ce4cc02229edbe40da383ced267c0c077801a Binary files /dev/null and b/michelangelo/models/tsal/__pycache__/asl_pl_module.cpython-39.pyc differ diff --git a/michelangelo/models/tsal/__pycache__/clip_asl_module.cpython-39.pyc b/michelangelo/models/tsal/__pycache__/clip_asl_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d82be61c419ff6be1e7c24c30a62b3a7c7fdb97e Binary files /dev/null and b/michelangelo/models/tsal/__pycache__/clip_asl_module.cpython-39.pyc differ diff --git a/michelangelo/models/tsal/__pycache__/inference_utils.cpython-39.pyc b/michelangelo/models/tsal/__pycache__/inference_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3899ea895f31301e57a8f57cf11724458bd18708 Binary files /dev/null and b/michelangelo/models/tsal/__pycache__/inference_utils.cpython-39.pyc differ diff --git a/michelangelo/models/tsal/__pycache__/loss.cpython-39.pyc b/michelangelo/models/tsal/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77b11131cae438de9687f56990ff97d970c18e04 Binary files /dev/null and b/michelangelo/models/tsal/__pycache__/loss.cpython-39.pyc differ diff --git a/michelangelo/models/tsal/__pycache__/sal_perceiver.cpython-39.pyc b/michelangelo/models/tsal/__pycache__/sal_perceiver.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b599a267b76f456bed55259d30db4b4b6edbc28 Binary files /dev/null and b/michelangelo/models/tsal/__pycache__/sal_perceiver.cpython-39.pyc differ diff --git a/michelangelo/models/tsal/__pycache__/tsal_base.cpython-39.pyc b/michelangelo/models/tsal/__pycache__/tsal_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92cfefdc031200fb79eeb5806754f70efb20fe95 Binary files /dev/null and b/michelangelo/models/tsal/__pycache__/tsal_base.cpython-39.pyc differ diff --git a/michelangelo/models/tsal/asl_pl_module.py b/michelangelo/models/tsal/asl_pl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..4d6523b6bf6d86cf385fea2d44d9b7d36ebc9f77 --- /dev/null +++ b/michelangelo/models/tsal/asl_pl_module.py @@ -0,0 +1,354 @@ +# -*- coding: utf-8 -*- + +from typing import List, Tuple, Dict, Optional +from omegaconf import DictConfig + +import torch +import torch.nn.functional as F +from torch.optim import lr_scheduler +import pytorch_lightning as pl +from typing import Union +from functools import partial + +from michelangelo.utils import instantiate_from_config + +from .inference_utils import extract_geometry +from .tsal_base import ( + AlignedShapeAsLatentModule, + ShapeAsLatentModule, + Latent2MeshOutput, + AlignedMeshOutput +) + + +class AlignedShapeAsLatentPLModule(pl.LightningModule): + + def __init__(self, *, + shape_module_cfg, + aligned_module_cfg, + loss_cfg, + optimizer_cfg: Optional[DictConfig] = None, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + shape_model: ShapeAsLatentModule = instantiate_from_config( + shape_module_cfg, device=None, dtype=None + ) + self.model: AlignedShapeAsLatentModule = instantiate_from_config( + aligned_module_cfg, shape_model=shape_model + ) + + self.loss = instantiate_from_config(loss_cfg) + + self.optimizer_cfg = optimizer_cfg + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + self.save_hyperparameters() + + def set_shape_model_only(self): + self.model.set_shape_model_only() + + @property + def latent_shape(self): + return self.model.shape_model.latent_shape + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def configure_optimizers(self) -> Tuple[List, List]: + lr = self.learning_rate + + trainable_parameters = list(self.model.parameters()) + + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + def forward(self, + surface: torch.FloatTensor, + image: torch.FloatTensor, + text: torch.FloatTensor, + volume_queries: torch.FloatTensor): + + """ + + Args: + surface (torch.FloatTensor): + image (torch.FloatTensor): + text (torch.FloatTensor): + volume_queries (torch.FloatTensor): + + Returns: + + """ + + embed_outputs, shape_z = self.model(surface, image, text) + + shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) + latents = self.model.shape_model.decode(shape_zq) + logits = self.model.shape_model.query_geometry(volume_queries, latents) + + return embed_outputs, logits, posterior + + def encode(self, surface: torch.FloatTensor, sample_posterior=True): + + pc = surface[..., 0:3] + feats = surface[..., 3:6] + + shape_embed, shape_zq, posterior = self.model.shape_model.encode( + pc=pc, feats=feats, sample_posterior=sample_posterior + ) + + return shape_zq + + def decode(self, + z_q, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim] + outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) + + return outputs + + def training_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] + - image (torch.FloatTensor): [bs, 3, 224, 224] + - text (torch.FloatTensor): [bs, num_templates, 77] + - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + surface = batch["surface"] + image = batch["image"] + text = batch["text"] + + volume_queries = batch["geo_points"][..., 0:3] + shape_labels = batch["geo_points"][..., -1] + + embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) + + aeloss, log_dict_ae = self.loss( + **embed_outputs, + posteriors=posteriors, + shape_logits=shape_logits, + shape_labels=shape_labels, + split="train" + ) + + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: + + surface = batch["surface"] + image = batch["image"] + text = batch["text"] + + volume_queries = batch["geo_points"][..., 0:3] + shape_labels = batch["geo_points"][..., -1] + + embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) + + aeloss, log_dict_ae = self.loss( + **embed_outputs, + posteriors=posteriors, + shape_logits=shape_logits, + shape_labels=shape_labels, + split="val" + ) + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def visual_alignment(self, + surface: torch.FloatTensor, + image: torch.FloatTensor, + text: torch.FloatTensor, + description: Optional[List[str]] = None, + bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth: int = 7, + num_chunks: int = 10000) -> List[AlignedMeshOutput]: + + """ + + Args: + surface: + image: + text: + description: + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + device = surface.device + bs = surface.shape[0] + + embed_outputs, shape_z = self.model(surface, image, text) + + # calculate the similarity + image_embed = embed_outputs["image_embed"] + text_embed = embed_outputs["text_embed"] + shape_embed = embed_outputs["shape_embed"] + + # normalized features + shape_embed = F.normalize(shape_embed, dim=-1, p=2) + text_embed = F.normalize(text_embed, dim=-1, p=2) + image_embed = F.normalize(image_embed, dim=-1, p=2) + + # B x B + shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1) + + # B x B + shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1) + + # shape reconstruction + shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) + latents = self.model.shape_model.decode(shape_zq) + geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) + + # 2. decode geometry + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=bs, + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = AlignedMeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + out.surface = surface[i].cpu().numpy() + out.image = image[i].cpu().numpy() + if description is not None: + out.text = description[i] + out.shape_text_similarity = shape_text_similarity[i, i] + out.shape_image_similarity = shape_image_similarity[i, i] + + outputs.append(out) + + return outputs + + def latent2mesh(self, + latents: torch.FloatTensor, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + """ + + Args: + latents: [bs, num_latents, dim] + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[MeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) + + # 2. decode geometry + device = latents.device + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=len(latents), + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = Latent2MeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + + outputs.append(out) + + return outputs + diff --git a/michelangelo/models/tsal/clip_asl_module.py b/michelangelo/models/tsal/clip_asl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..64f938a632a286a006180092c81b84e78c7032d5 --- /dev/null +++ b/michelangelo/models/tsal/clip_asl_module.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- + +import torch +from torch import nn +from einops import rearrange +from transformers import CLIPModel + +from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule + + +class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): + + def __init__(self, *, + shape_model, + clip_model_version: str = "openai/clip-vit-large-patch14"): + + super().__init__() + + self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) + for params in self.clip_model.parameters(): + params.requires_grad = False + + self.shape_model = shape_model + self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.clip_model.projection_dim)) + nn.init.normal_(self.shape_projection, std=self.clip_model.projection_dim ** -0.5) + + def set_shape_model_only(self): + self.clip_model = None + + def encode_shape_embed(self, surface, return_latents: bool = False): + """ + + Args: + surface (torch.FloatTensor): [bs, n, 3 + c] + return_latents (bool): + + Returns: + x (torch.FloatTensor): [bs, projection_dim] + shape_latents (torch.FloatTensor): [bs, m, d] + """ + + pc = surface[..., 0:3] + feats = surface[..., 3:] + + shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) + x = shape_embed @ self.shape_projection + + if return_latents: + return x, shape_latents + else: + return x + + def encode_image_embed(self, image): + """ + + Args: + image (torch.FloatTensor): [bs, 3, h, w] + + Returns: + x (torch.FloatTensor): [bs, projection_dim] + """ + + x = self.clip_model.get_image_features(image) + + return x + + def encode_text_embed(self, text): + x = self.clip_model.get_text_features(text) + return x + + def forward(self, surface, image, text): + """ + + Args: + surface (torch.FloatTensor): + image (torch.FloatTensor): [bs, 3, 224, 224] + text (torch.LongTensor): [bs, num_templates, 77] + + Returns: + embed_outputs (dict): the embedding outputs, and it contains: + - image_embed (torch.FloatTensor): + - text_embed (torch.FloatTensor): + - shape_embed (torch.FloatTensor): + - logit_scale (float): + """ + + # # text embedding + # text_embed_all = [] + # for i in range(text.shape[0]): + # text_for_one_sample = text[i] + # text_embed = self.encode_text_embed(text_for_one_sample) + # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + # text_embed = text_embed.mean(dim=0) + # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + # text_embed_all.append(text_embed) + # text_embed_all = torch.stack(text_embed_all) + + b = text.shape[0] + text_tokens = rearrange(text, "b t l -> (b t) l") + text_embed = self.encode_text_embed(text_tokens) + text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) + text_embed = text_embed.mean(dim=1) + text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) + + # image embedding + image_embed = self.encode_image_embed(image) + + # shape embedding + shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) + + embed_outputs = { + "image_embed": image_embed, + "text_embed": text_embed, + "shape_embed": shape_embed, + "logit_scale": self.clip_model.logit_scale.exp() + } + + return embed_outputs, shape_latents diff --git a/michelangelo/models/tsal/inference_utils.py b/michelangelo/models/tsal/inference_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b55b5bfe22c04931c3788a70ea2ea350e021d2 --- /dev/null +++ b/michelangelo/models/tsal/inference_utils.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- + +import torch +from tqdm import tqdm +from einops import repeat +import numpy as np +from typing import Callable, Tuple, List, Union, Optional +from skimage import measure + +from michelangelo.graphics.primitives import generate_dense_grid_points + + +@torch.no_grad() +def extract_geometry(geometric_func: Callable, + device: torch.device, + batch_size: int = 1, + bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth: int = 7, + num_chunks: int = 10000, + disable: bool = True): + """ + + Args: + geometric_func: + device: + bounds: + octree_depth: + batch_size: + num_chunks: + disable: + + Returns: + + """ + + if isinstance(bounds, float): + bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] + + bbox_min = np.array(bounds[0:3]) + bbox_max = np.array(bounds[3:6]) + bbox_size = bbox_max - bbox_min + + xyz_samples, grid_size, length = generate_dense_grid_points( + bbox_min=bbox_min, + bbox_max=bbox_max, + octree_depth=octree_depth, + indexing="ij" + ) + xyz_samples = torch.FloatTensor(xyz_samples) + + batch_logits = [] + for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), + desc="Implicit Function:", disable=disable, leave=False): + queries = xyz_samples[start: start + num_chunks, :].to(device) + batch_queries = repeat(queries, "p c -> b p c", b=batch_size) + + logits = geometric_func(batch_queries) + batch_logits.append(logits.cpu()) + + grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() + + mesh_v_f = [] + has_surface = np.zeros((batch_size,), dtype=np.bool_) + for i in range(batch_size): + try: + vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") + vertices = vertices / grid_size * bbox_size + bbox_min + # vertices[:, [0, 1]] = vertices[:, [1, 0]] + mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) + has_surface[i] = True + + except ValueError: + mesh_v_f.append((None, None)) + has_surface[i] = False + + except RuntimeError: + mesh_v_f.append((None, None)) + has_surface[i] = False + + return mesh_v_f, has_surface diff --git a/michelangelo/models/tsal/loss.py b/michelangelo/models/tsal/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e99553532eb679030ccb0c9a67c8ae8448aa676a --- /dev/null +++ b/michelangelo/models/tsal/loss.py @@ -0,0 +1,303 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Dict + +from michelangelo.models.modules.distributions import DiagonalGaussianDistribution +from michelangelo.utils.eval import compute_psnr +from michelangelo.utils import misc + + +class KLNearFar(nn.Module): + def __init__(self, + near_weight: float = 0.1, + kl_weight: float = 1.0, + num_near_samples: Optional[int] = None): + + super().__init__() + + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + posteriors: Optional[DiagonalGaussianDistribution], + logits: torch.FloatTensor, + labels: torch.FloatTensor, + split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: + + """ + + Args: + posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): + logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; + labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; + split (str): + **kwargs: + + Returns: + loss (torch.Tensor): (,) + log (dict): + + """ + + if self.num_near_samples is None: + num_vol = logits.shape[1] // 2 + else: + num_vol = logits.shape[1] - self.num_near_samples + + vol_logits = logits[:, 0:num_vol] + vol_labels = labels[:, 0:num_vol] + + near_logits = logits[:, num_vol:] + near_labels = labels[:, num_vol:] + + # occupancy loss + # vol_bce = self.geo_criterion(vol_logits, vol_labels) + # near_bce = self.geo_criterion(near_logits, near_labels) + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + + with torch.no_grad(): + preds = logits >= 0 + accuracy = (preds == labels).float() + accuracy = accuracy.mean() + pos_ratio = torch.mean(labels) + + log = { + "{}/total_loss".format(split): loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/accuracy".format(split): accuracy, + "{}/pos_ratio".format(split): pos_ratio + } + + if posteriors is not None: + log[f"{split}/mean"] = posteriors.mean.mean().detach() + log[f"{split}/std_mean"] = posteriors.std.mean().detach() + log[f"{split}/std_max"] = posteriors.std.max().detach() + + return loss, log + + +class KLNearFarColor(nn.Module): + def __init__(self, + near_weight: float = 0.1, + kl_weight: float = 1.0, + color_weight: float = 1.0, + color_criterion: str = "mse", + num_near_samples: Optional[int] = None): + + super().__init__() + + self.color_weight = color_weight + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + + if color_criterion == "mse": + self.color_criterion = nn.MSELoss() + + elif color_criterion == "l1": + self.color_criterion = nn.L1Loss() + + else: + raise ValueError(f"{color_criterion} must be [`mse`, `l1`].") + + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + posteriors: Optional[DiagonalGaussianDistribution], + logits: torch.FloatTensor, + labels: torch.FloatTensor, + pred_colors: torch.FloatTensor, + gt_colors: torch.FloatTensor, + split: Optional[str] = "train", **kwargs) -> Tuple[torch.FloatTensor, Dict[str, float]]: + + """ + + Args: + posteriors (DiagonalGaussianDistribution or torch.distributions.Normal): + logits (torch.FloatTensor): [B, 2*N], logits[:, 0:N] is the volume points; logits[:, N:2N] is the near points; + labels (torch.FloatTensor): [B, 2*N], labels[:, 0:N] is the volume points; labels[:, N:2N] is the near points; + pred_colors (torch.FloatTensor): [B, M, 3] + gt_colors (torch.FloatTensor): [B, M, 3] + split (str): + **kwargs: + + Returns: + loss (torch.Tensor): (,) + log (dict): + + """ + + if self.num_near_samples is None: + num_vol = logits.shape[1] // 2 + else: + num_vol = logits.shape[1] - self.num_near_samples + + vol_logits = logits[:, 0:num_vol] + vol_labels = labels[:, 0:num_vol] + + near_logits = logits[:, num_vol:] + near_labels = labels[:, num_vol:] + + # occupancy loss + # vol_bce = self.geo_criterion(vol_logits, vol_labels) + # near_bce = self.geo_criterion(near_logits, near_labels) + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + # surface color loss + color = self.color_criterion(pred_colors, gt_colors) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=pred_colors.dtype, device=pred_colors.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + color * self.color_weight + kl_loss * self.kl_weight + + with torch.no_grad(): + preds = logits >= 0 + accuracy = (preds == labels).float() + accuracy = accuracy.mean() + psnr = compute_psnr(pred_colors, gt_colors) + + log = { + "{}/total_loss".format(split): loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/color".format(split): color.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/psnr".format(split): psnr.detach(), + "{}/accuracy".format(split): accuracy + } + + return loss, log + + +class ContrastKLNearFar(nn.Module): + def __init__(self, + contrast_weight: float = 1.0, + near_weight: float = 0.1, + kl_weight: float = 1.0, + num_near_samples: Optional[int] = None): + + super().__init__() + + self.labels = None + self.last_local_batch_size = None + + self.contrast_weight = contrast_weight + self.near_weight = near_weight + self.kl_weight = kl_weight + self.num_near_samples = num_near_samples + self.geo_criterion = nn.BCEWithLogitsLoss() + + def forward(self, + shape_embed: torch.FloatTensor, + text_embed: torch.FloatTensor, + image_embed: torch.FloatTensor, + logit_scale: torch.FloatTensor, + posteriors: Optional[DiagonalGaussianDistribution], + shape_logits: torch.FloatTensor, + shape_labels: torch.FloatTensor, + split: Optional[str] = "train", **kwargs): + + local_batch_size = shape_embed.size(0) + + if local_batch_size != self.last_local_batch_size: + self.labels = local_batch_size * misc.get_rank() + torch.arange( + local_batch_size, device=shape_embed.device + ).long() + self.last_local_batch_size = local_batch_size + + # normalized features + shape_embed = F.normalize(shape_embed, dim=-1, p=2) + text_embed = F.normalize(text_embed, dim=-1, p=2) + image_embed = F.normalize(image_embed, dim=-1, p=2) + + # gather features from all GPUs + shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( + [shape_embed, text_embed, image_embed] + ) + + # cosine similarity as logits + logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() + logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() + logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() + logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() + contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + + F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ + (F.cross_entropy(logits_per_shape_image, self.labels) + + F.cross_entropy(logits_per_image_shape, self.labels)) / 2 + + # shape reconstruction + if self.num_near_samples is None: + num_vol = shape_logits.shape[1] // 2 + else: + num_vol = shape_logits.shape[1] - self.num_near_samples + + vol_logits = shape_logits[:, 0:num_vol] + vol_labels = shape_labels[:, 0:num_vol] + + near_logits = shape_logits[:, num_vol:] + near_labels = shape_labels[:, num_vol:] + + # occupancy loss + vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) + near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) + + if posteriors is None: + kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) + else: + kl_loss = posteriors.kl(dims=(1, 2)) + kl_loss = torch.mean(kl_loss) + + loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight + + # compute accuracy + with torch.no_grad(): + pred = torch.argmax(logits_per_shape_text, dim=-1) + correct = pred.eq(self.labels).sum() + shape_text_acc = 100 * correct / local_batch_size + + pred = torch.argmax(logits_per_shape_image, dim=-1) + correct = pred.eq(self.labels).sum() + shape_image_acc = 100 * correct / local_batch_size + + preds = shape_logits >= 0 + accuracy = (preds == shape_labels).float() + accuracy = accuracy.mean() + + log = { + "{}/contrast".format(split): contrast_loss.clone().detach(), + "{}/near".format(split): near_bce.detach(), + "{}/far".format(split): vol_bce.detach(), + "{}/kl".format(split): kl_loss.detach(), + "{}/shape_text_acc".format(split): shape_text_acc, + "{}/shape_image_acc".format(split): shape_image_acc, + "{}/total_loss".format(split): loss.clone().detach(), + "{}/accuracy".format(split): accuracy, + } + + if posteriors is not None: + log[f"{split}/mean"] = posteriors.mean.mean().detach() + log[f"{split}/std_mean"] = posteriors.std.mean().detach() + log[f"{split}/std_max"] = posteriors.std.max().detach() + + return loss, log diff --git a/michelangelo/models/tsal/sal_perceiver.py b/michelangelo/models/tsal/sal_perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..786f978755e8f3ec72df9b879a83f7bf26b4446a --- /dev/null +++ b/michelangelo/models/tsal/sal_perceiver.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from typing import Optional +from einops import repeat +import math + +from michelangelo.models.modules import checkpoint +from michelangelo.models.modules.embedder import FourierEmbedder +from michelangelo.models.modules.distributions import DiagonalGaussianDistribution +from michelangelo.models.modules.transformer_blocks import ( + ResidualCrossAttentionBlock, + Transformer +) + +from .tsal_base import ShapeAsLatentModule + + +class CrossAttentionEncoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + fourier_embedder: FourierEmbedder, + point_feats: int, + width: int, + heads: int, + layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.num_latents = num_latents + + self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02) + + self.fourier_embedder = fourier_embedder + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) + self.cross_attn = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + ) + + self.self_attn = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=False + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) + else: + self.ln_post = None + + def _forward(self, pc, feats): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + + """ + + bs = pc.shape[0] + + data = self.fourier_embedder(pc) + if feats is not None: + data = torch.cat([data, feats], dim=-1) + data = self.input_proj(data) + + query = repeat(self.query, "m c -> b m c", b=bs) + latents = self.cross_attn(query, data) + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents, pc + + def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + dict + """ + + return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) + + +class CrossAttentionDecoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.fourier_embedder = fourier_embedder + + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) + + self.cross_attn_decoder = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + n_data=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash + ) + + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) + + def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + queries = self.query_proj(self.fourier_embedder(queries)) + x = self.cross_attn_decoder(queries, latents) + x = self.ln_post(x) + x = self.output_proj(x) + return x + + def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) + + +class ShapeAsLatentPerceiver(ShapeAsLatentModule): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.num_latents = num_latents + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + init_scale = init_scale * math.sqrt(1.0 / width) + self.encoder = CrossAttentionEncoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + point_feats=point_feats, + width=width, + heads=heads, + layers=num_encoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint + ) + + self.embed_dim = embed_dim + if embed_dim > 0: + # VAE embed + self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) + self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype) + self.latent_shape = (num_latents, embed_dim) + else: + self.latent_shape = (num_latents, width) + + self.transformer = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=num_decoder_layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + + # geometry decoder + self.geo_decoder = CrossAttentionDecoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + out_channels=1, + num_latents=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_checkpoint=use_checkpoint + ) + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + sample_posterior (bool): + + Returns: + latents (torch.FloatTensor) + center_pos (torch.FloatTensor or None): + posterior (DiagonalGaussianDistribution or None): + """ + + latents, center_pos = self.encoder(pc, feats) + + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + latents = posterior.sample() + else: + latents = posterior.mode() + + return latents, center_pos, posterior + + def decode(self, latents: torch.FloatTensor): + latents = self.post_kl(latents) + return self.transformer(latents) + + def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + logits = self.geo_decoder(queries, latents).squeeze(-1) + return logits + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + volume_queries (torch.FloatTensor): [B, P, 3] + sample_posterior (bool): + + Returns: + logits (torch.FloatTensor): [B, P] + center_pos (torch.FloatTensor): [B, M, 3] + posterior (DiagonalGaussianDistribution or None). + + """ + + latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(latents) + logits = self.query_geometry(volume_queries, latents) + + return logits, center_pos, posterior + + +class AlignedShapeLatentPerceiver(ShapeAsLatentPerceiver): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + heads: int, + num_encoder_layers: int, + num_decoder_layers: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + flash: bool = False, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__( + device=device, + dtype=dtype, + num_latents=1 + num_latents, + point_feats=point_feats, + embed_dim=embed_dim, + num_freqs=num_freqs, + include_pi=include_pi, + width=width, + heads=heads, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + init_scale=init_scale, + qkv_bias=qkv_bias, + flash=flash, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint + ) + + self.width = width + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, c] + sample_posterior (bool): + + Returns: + shape_embed (torch.FloatTensor) + kl_embed (torch.FloatTensor): + posterior (DiagonalGaussianDistribution or None): + """ + + shape_embed, latents = self.encode_latents(pc, feats) + kl_embed, posterior = self.encode_kl_embed(latents, sample_posterior) + + return shape_embed, kl_embed, posterior + + def encode_latents(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None): + + x, _ = self.encoder(pc, feats) + + shape_embed = x[:, 0] + latents = x[:, 1:] + + return shape_embed, latents + + def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + kl_embed = posterior.sample() + else: + kl_embed = posterior.mode() + else: + kl_embed = latents + + return kl_embed, posterior + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + volume_queries (torch.FloatTensor): [B, P, 3] + sample_posterior (bool): + + Returns: + shape_embed (torch.FloatTensor): [B, projection_dim] + logits (torch.FloatTensor): [B, M] + posterior (DiagonalGaussianDistribution or None). + + """ + + shape_embed, kl_embed, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(kl_embed) + logits = self.query_geometry(volume_queries, latents) + + return shape_embed, logits, posterior diff --git a/michelangelo/models/tsal/sal_pl_module.py b/michelangelo/models/tsal/sal_pl_module.py new file mode 100644 index 0000000000000000000000000000000000000000..52680aaa1cbe0eac018eeeb81426bb79a273a647 --- /dev/null +++ b/michelangelo/models/tsal/sal_pl_module.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- + +from typing import List, Tuple, Dict, Optional +from omegaconf import DictConfig + +import torch +from torch.optim import lr_scheduler +import pytorch_lightning as pl +from typing import Union +from functools import partial + +from michelangelo.utils import instantiate_from_config + +from .inference_utils import extract_geometry +from .tsal_base import ( + ShapeAsLatentModule, + Latent2MeshOutput, + Point2MeshOutput +) + + +class ShapeAsLatentPLModule(pl.LightningModule): + + def __init__(self, *, + module_cfg, + loss_cfg, + optimizer_cfg: Optional[DictConfig] = None, + ckpt_path: Optional[str] = None, + ignore_keys: Union[Tuple[str], List[str]] = ()): + + super().__init__() + + self.sal: ShapeAsLatentModule = instantiate_from_config(module_cfg, device=None, dtype=None) + + self.loss = instantiate_from_config(loss_cfg) + + self.optimizer_cfg = optimizer_cfg + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + self.save_hyperparameters() + + @property + def latent_shape(self): + return self.sal.latent_shape + + @property + def zero_rank(self): + if self._trainer: + zero_rank = self.trainer.local_rank == 0 + else: + zero_rank = True + + return zero_rank + + def init_from_ckpt(self, path, ignore_keys=()): + state_dict = torch.load(path, map_location="cpu")["state_dict"] + + keys = list(state_dict.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del state_dict[k] + + missing, unexpected = self.load_state_dict(state_dict, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def configure_optimizers(self) -> Tuple[List, List]: + lr = self.learning_rate + + # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-4)] + # optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + + if self.optimizer_cfg is None: + optimizers = [torch.optim.AdamW(self.sal.parameters(), lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] + schedulers = [] + else: + optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=self.sal.parameters()) + scheduler_func = instantiate_from_config( + self.optimizer_cfg.scheduler, + max_decay_steps=self.trainer.max_steps, + lr_max=lr + ) + scheduler = { + "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), + "interval": "step", + "frequency": 1 + } + optimizers = [optimizer] + schedulers = [scheduler] + + return optimizers, schedulers + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor): + + logits, center_pos, posterior = self.sal(pc, feats, volume_queries) + + return posterior, logits + + def encode(self, surface: torch.FloatTensor, sample_posterior=True): + + pc = surface[..., 0:3] + feats = surface[..., 3:6] + + latents, center_pos, posterior = self.sal.encode( + pc=pc, feats=feats, sample_posterior=sample_posterior + ) + + return latents + + def decode(self, + z_q, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + latents = self.sal.decode(z_q) # latents: [bs, num_latents, dim] + outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) + + return outputs + + def training_step(self, batch: Dict[str, torch.FloatTensor], + batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: + """ + + Args: + batch (dict): the batch sample, and it contains: + - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] + - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] + + batch_idx (int): + + optimizer_idx (int): + + Returns: + loss (torch.FloatTensor): + + """ + + pc = batch["surface"][..., 0:3] + feats = batch["surface"][..., 3:] + + volume_queries = batch["geo_points"][..., 0:3] + volume_labels = batch["geo_points"][..., -1] + + posterior, logits = self( + pc=pc, feats=feats, volume_queries=volume_queries + ) + aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="train") + + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: + + pc = batch["surface"][..., 0:3] + feats = batch["surface"][..., 3:] + + volume_queries = batch["geo_points"][..., 0:3] + volume_labels = batch["geo_points"][..., -1] + + posterior, logits = self( + pc=pc, feats=feats, volume_queries=volume_queries, + ) + aeloss, log_dict_ae = self.loss(posterior, logits, volume_labels, split="val") + + self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=logits.shape[0], + sync_dist=False, rank_zero_only=True) + + return aeloss + + def point2mesh(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Point2MeshOutput]: + + """ + + Args: + pc: + feats: + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[MeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + device = pc.device + bs = pc.shape[0] + + # 1. point encoder + latents transformer + latents, center_pos, posterior = self.sal.encode(pc, feats) + latents = self.sal.decode(latents) # latents: [bs, num_latents, dim] + + geometric_func = partial(self.sal.query_geometry, latents=latents) + + # 2. decode geometry + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=bs, + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = Point2MeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + out.pc = torch.cat([pc[i], feats[i]], dim=-1).cpu().numpy() + + if center_pos is not None: + out.center = center_pos[i].cpu().numpy() + + outputs.append(out) + + return outputs + + def latent2mesh(self, + latents: torch.FloatTensor, + bounds: Union[Tuple[float], List[float], float] = 1.1, + octree_depth: int = 7, + num_chunks: int = 10000) -> List[Latent2MeshOutput]: + + """ + + Args: + latents: [bs, num_latents, dim] + bounds: + octree_depth: + num_chunks: + + Returns: + mesh_outputs (List[MeshOutput]): the mesh outputs list. + + """ + + outputs = [] + + geometric_func = partial(self.sal.query_geometry, latents=latents) + + # 2. decode geometry + device = latents.device + mesh_v_f, has_surface = extract_geometry( + geometric_func=geometric_func, + device=device, + batch_size=len(latents), + bounds=bounds, + octree_depth=octree_depth, + num_chunks=num_chunks, + disable=not self.zero_rank + ) + + # 3. decode texture + for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): + if not is_surface: + outputs.append(None) + continue + + out = Latent2MeshOutput() + out.mesh_v = mesh_v + out.mesh_f = mesh_f + + outputs.append(out) + + return outputs diff --git a/michelangelo/models/tsal/sal_transformer.py b/michelangelo/models/tsal/sal_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a117b51aa9db1ddeb50d7e1f832287863f6ef476 --- /dev/null +++ b/michelangelo/models/tsal/sal_transformer.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- + +import torch +import torch.nn as nn +from torch_cluster import fps +from typing import Optional +import math + +from michelangelo.models.modules import checkpoint +from michelangelo.models.modules.embedder import FourierEmbedder +from michelangelo.models.modules.distributions import DiagonalGaussianDistribution +from michelangelo.models.modules.transformer_blocks import ( + ResidualCrossAttentionBlock, + Transformer +) + +from .tsal_base import ShapeAsLatentModule + + +class CrossAttentionEncoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + fourier_embedder: FourierEmbedder, + point_feats: int, + width: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.num_latents = num_latents + self.fourier_embedder = fourier_embedder + + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype) + self.cross_attn_encoder = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias + ) + if use_ln_post: + self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device) + else: + self.ln_post = None + + def _forward(self, pc, feats): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + + """ + + B, N, _ = pc.shape + batch = torch.arange(B).to(pc.device) + batch = torch.repeat_interleave(batch, N) + + data = self.fourier_embedder(pc) + if feats is not None: + data = torch.cat([data, feats], dim=-1) + data = self.input_proj(data) + + ratio = self.num_latents / N + flatten_pos = pc.view(B * N, -1) # [B * N, 3] + flatten_data = data.view(B * N, -1) # [B * N, C] + idx = fps(flatten_pos, batch, ratio=ratio) + center_pos = flatten_pos[idx].view(B, self.num_latents, -1) + query = flatten_data[idx].view(B, self. num_latents, -1) + + latents = self.cross_attn_encoder(query, data) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents, center_pos + + def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + + Returns: + dict + """ + + return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint) + + +class CrossAttentionDecoder(nn.Module): + + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + out_channels: int, + fourier_embedder: FourierEmbedder, + width: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + self.fourier_embedder = fourier_embedder + + self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype) + + self.cross_attn_decoder = ResidualCrossAttentionBlock( + device=device, + dtype=dtype, + n_data=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias + ) + + self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype) + self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype) + + def _forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + queries = self.query_proj(self.fourier_embedder(queries)) + x = self.cross_attn_decoder(queries, latents) + x = self.ln_post(x) + x = self.output_proj(x) + return x + + def forward(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + return checkpoint(self._forward, (queries, latents), self.parameters(), self.use_checkpoint) + + +class ShapeAsLatentTransformer(ShapeAsLatentModule): + def __init__(self, *, + device: Optional[torch.device], + dtype: Optional[torch.dtype], + num_latents: int, + point_feats: int = 0, + embed_dim: int = 0, + num_freqs: int = 8, + include_pi: bool = True, + width: int, + layers: int, + heads: int, + init_scale: float = 0.25, + qkv_bias: bool = True, + use_ln_post: bool = False, + use_checkpoint: bool = False): + + super().__init__() + + self.use_checkpoint = use_checkpoint + + self.num_latents = num_latents + self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + + init_scale = init_scale * math.sqrt(1.0 / width) + self.encoder = CrossAttentionEncoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + num_latents=num_latents, + point_feats=point_feats, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_ln_post=use_ln_post, + use_checkpoint=use_checkpoint + ) + + self.embed_dim = embed_dim + if embed_dim > 0: + # VAE embed + self.pre_kl = nn.Linear(width, embed_dim * 2, device=device, dtype=dtype) + self.post_kl = nn.Linear(embed_dim, width, device=device, dtype=dtype) + self.latent_shape = (num_latents, embed_dim) + else: + self.latent_shape = (num_latents, width) + + self.transformer = Transformer( + device=device, + dtype=dtype, + n_ctx=num_latents, + width=width, + layers=layers, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + + # geometry decoder + self.geo_decoder = CrossAttentionDecoder( + device=device, + dtype=dtype, + fourier_embedder=self.fourier_embedder, + out_channels=1, + num_latents=num_latents, + width=width, + heads=heads, + init_scale=init_scale, + qkv_bias=qkv_bias, + use_checkpoint=use_checkpoint + ) + + def encode(self, + pc: torch.FloatTensor, + feats: Optional[torch.FloatTensor] = None, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + sample_posterior (bool): + + Returns: + latents (torch.FloatTensor) + center_pos (torch.FloatTensor): + posterior (DiagonalGaussianDistribution or None): + """ + + latents, center_pos = self.encoder(pc, feats) + + posterior = None + if self.embed_dim > 0: + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) + + if sample_posterior: + latents = posterior.sample() + else: + latents = posterior.mode() + + return latents, center_pos, posterior + + def decode(self, latents: torch.FloatTensor): + latents = self.post_kl(latents) + return self.transformer(latents) + + def query_geometry(self, queries: torch.FloatTensor, latents: torch.FloatTensor): + logits = self.geo_decoder(queries, latents).squeeze(-1) + return logits + + def forward(self, + pc: torch.FloatTensor, + feats: torch.FloatTensor, + volume_queries: torch.FloatTensor, + sample_posterior: bool = True): + """ + + Args: + pc (torch.FloatTensor): [B, N, 3] + feats (torch.FloatTensor or None): [B, N, C] + volume_queries (torch.FloatTensor): [B, P, 3] + sample_posterior (bool): + + Returns: + logits (torch.FloatTensor): [B, P] + center_pos (torch.FloatTensor): [B, M, 3] + posterior (DiagonalGaussianDistribution or None). + + """ + + latents, center_pos, posterior = self.encode(pc, feats, sample_posterior=sample_posterior) + + latents = self.decode(latents) + logits = self.query_geometry(volume_queries, latents) + + return logits, center_pos, posterior diff --git a/michelangelo/models/tsal/tsal_base.py b/michelangelo/models/tsal/tsal_base.py new file mode 100644 index 0000000000000000000000000000000000000000..233a8afbdd0eb24024a6f915e770a286361cf0fe --- /dev/null +++ b/michelangelo/models/tsal/tsal_base.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- + +import torch.nn as nn +from typing import Tuple, List, Optional +import pytorch_lightning as pl + + +class Point2MeshOutput(object): + def __init__(self): + self.mesh_v = None + self.mesh_f = None + self.center = None + self.pc = None + + +class Latent2MeshOutput(object): + + def __init__(self): + self.mesh_v = None + self.mesh_f = None + + +class AlignedMeshOutput(object): + + def __init__(self): + self.mesh_v = None + self.mesh_f = None + self.surface = None + self.image = None + self.text: Optional[str] = None + self.shape_text_similarity: Optional[float] = None + self.shape_image_similarity: Optional[float] = None + + +class ShapeAsLatentPLModule(pl.LightningModule): + latent_shape: Tuple[int] + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +class ShapeAsLatentModule(nn.Module): + latent_shape: Tuple[int, int] + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + +class AlignedShapeAsLatentPLModule(pl.LightningModule): + latent_shape: Tuple[int] + + def set_shape_model_only(self): + raise NotImplementedError + + def encode(self, surface, *args, **kwargs): + raise NotImplementedError + + def decode(self, z_q, *args, **kwargs): + raise NotImplementedError + + def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: + raise NotImplementedError + + def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: + raise NotImplementedError + + +class AlignedShapeAsLatentModule(nn.Module): + shape_model: ShapeAsLatentModule + latent_shape: Tuple[int, int] + + def __init__(self, *args, **kwargs): + super().__init__() + + def set_shape_model_only(self): + raise NotImplementedError + + def encode_image_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_text_embed(self, *args, **kwargs): + raise NotImplementedError + + def encode_shape_embed(self, *args, **kwargs): + raise NotImplementedError + + +class TexturedShapeAsLatentModule(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + def decode(self, *args, **kwargs): + raise NotImplementedError + + def query_geometry(self, *args, **kwargs): + raise NotImplementedError + + def query_color(self, *args, **kwargs): + raise NotImplementedError diff --git a/michelangelo/utils/__init__.py b/michelangelo/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76d2dd39781034eaa33293a2243ebee3b3c982c6 --- /dev/null +++ b/michelangelo/utils/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from .misc import get_config_from_file +from .misc import instantiate_from_config diff --git a/michelangelo/utils/__pycache__/__init__.cpython-39.pyc b/michelangelo/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9174eb0235c972cc30ad044d199e9933833db62 Binary files /dev/null and b/michelangelo/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/utils/__pycache__/eval.cpython-39.pyc b/michelangelo/utils/__pycache__/eval.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c04731139086293b04a91917de0cef083ba0c2cd Binary files /dev/null and b/michelangelo/utils/__pycache__/eval.cpython-39.pyc differ diff --git a/michelangelo/utils/__pycache__/io.cpython-39.pyc b/michelangelo/utils/__pycache__/io.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4212398bf91ccb7d285477818cbe48c830ef21df Binary files /dev/null and b/michelangelo/utils/__pycache__/io.cpython-39.pyc differ diff --git a/michelangelo/utils/__pycache__/misc.cpython-39.pyc b/michelangelo/utils/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..662cf6f524b2773c8f88db948370659a8f2a819f Binary files /dev/null and b/michelangelo/utils/__pycache__/misc.cpython-39.pyc differ diff --git a/michelangelo/utils/eval.py b/michelangelo/utils/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..954b9ae2643c8adb6c9af6141ede2b38a329db22 --- /dev/null +++ b/michelangelo/utils/eval.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- + +import torch + + +def compute_psnr(x, y, data_range: float = 2, eps: float = 1e-7): + + mse = torch.mean((x - y) ** 2) + psnr = 10 * torch.log10(data_range / (mse + eps)) + + return psnr + diff --git a/michelangelo/utils/io.py b/michelangelo/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..e651e5a8750ab485b5fbd59a70b38e339b6ed79b --- /dev/null +++ b/michelangelo/utils/io.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- + +import os +import io +import tarfile +import json +import numpy as np +import numpy.lib.format + + +def mkdir(path): + os.makedirs(path, exist_ok=True) + return path + + +def npy_loads(data): + stream = io.BytesIO(data) + return np.lib.format.read_array(stream) + + +def npz_loads(data): + return np.load(io.BytesIO(data)) + + +def json_loads(data): + return json.loads(data) + + +def load_json(filepath): + with open(filepath, "r") as f: + data = json.load(f) + return data + + +def write_json(filepath, data): + with open(filepath, "w") as f: + json.dump(data, f, indent=2) + + +def extract_tar(tar_path, tar_cache_folder): + + with tarfile.open(tar_path, "r") as tar: + tar.extractall(path=tar_cache_folder) + + tar_uids = sorted(os.listdir(tar_cache_folder)) + print(f"extract tar: {tar_path} to {tar_cache_folder}") + return tar_uids diff --git a/michelangelo/utils/misc.py b/michelangelo/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..bbef357bc7c63d3c7f33d048aec68dda2b0e3992 --- /dev/null +++ b/michelangelo/utils/misc.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +import importlib +from omegaconf import OmegaConf, DictConfig, ListConfig + +import torch +import torch.distributed as dist +from typing import Union + + +def get_config_from_file(config_file: str) -> Union[DictConfig, ListConfig]: + config_file = OmegaConf.load(config_file) + + if 'base_config' in config_file.keys(): + if config_file['base_config'] == "default_base": + base_config = OmegaConf.create() + # base_config = get_default_config() + elif config_file['base_config'].endswith(".yaml"): + base_config = get_config_from_file(config_file['base_config']) + else: + raise ValueError(f"{config_file} must be `.yaml` file or it contains `base_config` key.") + + config_file = {key: value for key, value in config_file if key != "base_config"} + + return OmegaConf.merge(base_config, config_file) + + return config_file + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def get_obj_from_config(config): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + return get_obj_from_str(config["target"]) + + +def instantiate_from_config(config, **kwargs): + if "target" not in config: + raise KeyError("Expected key `target` to instantiate.") + + cls = get_obj_from_str(config["target"]) + + params = config.get("params", dict()) + # params.update(kwargs) + # instance = cls(**params) + kwargs.update(params) + instance = cls(**kwargs) + + return instance + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def all_gather_batch(tensors): + """ + Performs all_gather operation on the provided tensors. + """ + # Queue the gathered tensors + world_size = get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + tensor_list = [] + output_tensor = [] + for tensor in tensors: + tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] + dist.all_gather( + tensor_all, + tensor, + async_op=False # performance opt + ) + + tensor_list.append(tensor_all) + + for tensor_all in tensor_list: + output_tensor.append(torch.cat(tensor_all, dim=0)) + return output_tensor diff --git a/michelangelo/utils/visualizers/__init__.py b/michelangelo/utils/visualizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/michelangelo/utils/visualizers/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/michelangelo/utils/visualizers/__pycache__/__init__.cpython-39.pyc b/michelangelo/utils/visualizers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02a8fe81750ca84f7f04312d1f48c452b76495fa Binary files /dev/null and b/michelangelo/utils/visualizers/__pycache__/__init__.cpython-39.pyc differ diff --git a/michelangelo/utils/visualizers/__pycache__/color_util.cpython-39.pyc b/michelangelo/utils/visualizers/__pycache__/color_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..469115bdef4e58c0483c8becc599c8c5e1cae849 Binary files /dev/null and b/michelangelo/utils/visualizers/__pycache__/color_util.cpython-39.pyc differ diff --git a/michelangelo/utils/visualizers/__pycache__/html_util.cpython-39.pyc b/michelangelo/utils/visualizers/__pycache__/html_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3b407a4fc92c856501c7c3c9f2d22e438e49215 Binary files /dev/null and b/michelangelo/utils/visualizers/__pycache__/html_util.cpython-39.pyc differ diff --git a/michelangelo/utils/visualizers/__pycache__/pythreejs_viewer.cpython-39.pyc b/michelangelo/utils/visualizers/__pycache__/pythreejs_viewer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc843978c58e21870048dbd288bc40b306ebb705 Binary files /dev/null and b/michelangelo/utils/visualizers/__pycache__/pythreejs_viewer.cpython-39.pyc differ diff --git a/michelangelo/utils/visualizers/color_util.py b/michelangelo/utils/visualizers/color_util.py new file mode 100644 index 0000000000000000000000000000000000000000..7983243fd37f5fee47bc51475dc58c460a067830 --- /dev/null +++ b/michelangelo/utils/visualizers/color_util.py @@ -0,0 +1,43 @@ +import numpy as np +import matplotlib.pyplot as plt + + +# Helper functions +def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): + colormap = plt.cm.get_cmap(colormap) + if normalize: + vmin = np.min(inp) + vmax = np.max(inp) + + norm = plt.Normalize(vmin, vmax) + return colormap(norm(inp))[:, :3] + + +def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): + # tex dims need to be power of two. + array = np.ones((width, height, 3), dtype='float32') + + # width in texels of each checker + checker_w = width / n_checkers_x + checker_h = height / n_checkers_y + + for y in range(height): + for x in range(width): + color_key = int(x / checker_w) + int(y / checker_h) + if color_key % 2 == 0: + array[x, y, :] = [1., 0.874, 0.0] + else: + array[x, y, :] = [0., 0., 0.] + return array + + +def gen_circle(width=256, height=256): + xx, yy = np.mgrid[:width, :height] + circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 + array = np.ones((width, height, 4), dtype='float32') + array[:, :, 0] = (circle <= width) + array[:, :, 1] = (circle <= width) + array[:, :, 2] = (circle <= width) + array[:, :, 3] = circle <= width + return array + diff --git a/michelangelo/utils/visualizers/html_util.py b/michelangelo/utils/visualizers/html_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f90fe6cfefe6108655b48c36d60db537589993d5 --- /dev/null +++ b/michelangelo/utils/visualizers/html_util.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +import io +import base64 +import numpy as np +from PIL import Image + + +def to_html_frame(content): + + html_frame = f""" + + + {content} + + + """ + + return html_frame + + +def to_single_row_table(caption: str, content: str): + + table_html = f""" + + + + + +
{caption}
{content}
+ """ + + return table_html + + +def to_image_embed_tag(image: np.ndarray): + + # Convert np.ndarray to bytes + img = Image.fromarray(image) + raw_bytes = io.BytesIO() + img.save(raw_bytes, "PNG") + + # Encode bytes to base64 + image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8") + + image_tag = f""" + Embedded Image + """ + + return image_tag diff --git a/michelangelo/utils/visualizers/pythreejs_viewer.py b/michelangelo/utils/visualizers/pythreejs_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..b3ce0f88f26fcd5e007fde2cec4816901a74ad33 --- /dev/null +++ b/michelangelo/utils/visualizers/pythreejs_viewer.py @@ -0,0 +1,534 @@ +import numpy as np +from ipywidgets import embed +import pythreejs as p3s +import uuid + +from .color_util import get_colors, gen_circle, gen_checkers + + +EMBED_URL = "https://cdn.jsdelivr.net/npm/@jupyter-widgets/html-manager@1.0.1/dist/embed-amd.js" + + +class PyThreeJSViewer(object): + + def __init__(self, settings, render_mode="WEBSITE"): + self.render_mode = render_mode + self.__update_settings(settings) + self._light = p3s.DirectionalLight(color='white', position=[0, 0, 1], intensity=0.6) + self._light2 = p3s.AmbientLight(intensity=0.5) + self._cam = p3s.PerspectiveCamera(position=[0, 0, 1], lookAt=[0, 0, 0], fov=self.__s["fov"], + aspect=self.__s["width"] / self.__s["height"], children=[self._light]) + self._orbit = p3s.OrbitControls(controlling=self._cam) + self._scene = p3s.Scene(children=[self._cam, self._light2], background=self.__s["background"]) # "#4c4c80" + self._renderer = p3s.Renderer(camera=self._cam, scene=self._scene, controls=[self._orbit], + width=self.__s["width"], height=self.__s["height"], + antialias=self.__s["antialias"]) + + self.__objects = {} + self.__cnt = 0 + + def jupyter_mode(self): + self.render_mode = "JUPYTER" + + def offline(self): + self.render_mode = "OFFLINE" + + def website(self): + self.render_mode = "WEBSITE" + + def __get_shading(self, shading): + shad = {"flat": True, "wireframe": False, "wire_width": 0.03, "wire_color": "black", + "side": 'DoubleSide', "colormap": "viridis", "normalize": [None, None], + "bbox": False, "roughness": 0.5, "metalness": 0.25, "reflectivity": 1.0, + "line_width": 1.0, "line_color": "black", + "point_color": "red", "point_size": 0.01, "point_shape": "circle", + "text_color": "red" + } + for k in shading: + shad[k] = shading[k] + return shad + + def __update_settings(self, settings={}): + sett = {"width": 600, "height": 600, "antialias": True, "scale": 1.5, "background": "#ffffff", + "fov": 30} + for k in settings: + sett[k] = settings[k] + self.__s = sett + + def __add_object(self, obj, parent=None): + if not parent: # Object is added to global scene and objects dict + self.__objects[self.__cnt] = obj + self.__cnt += 1 + self._scene.add(obj["mesh"]) + else: # Object is added to parent object and NOT to objects dict + parent.add(obj["mesh"]) + + self.__update_view() + + if self.render_mode == "JUPYTER": + return self.__cnt - 1 + elif self.render_mode == "WEBSITE": + return self + + def __add_line_geometry(self, lines, shading, obj=None): + lines = lines.astype("float32", copy=False) + mi = np.min(lines, axis=0) + ma = np.max(lines, axis=0) + + geometry = p3s.LineSegmentsGeometry(positions=lines.reshape((-1, 2, 3))) + material = p3s.LineMaterial(linewidth=shading["line_width"], color=shading["line_color"]) + # , vertexColors='VertexColors'), + lines = p3s.LineSegments2(geometry=geometry, material=material) # type='LinePieces') + line_obj = {"geometry": geometry, "mesh": lines, "material": material, + "max": ma, "min": mi, "type": "Lines", "wireframe": None} + + if obj: + return self.__add_object(line_obj, obj), line_obj + else: + return self.__add_object(line_obj) + + def __update_view(self): + if len(self.__objects) == 0: + return + ma = np.zeros((len(self.__objects), 3)) + mi = np.zeros((len(self.__objects), 3)) + for r, obj in enumerate(self.__objects): + ma[r] = self.__objects[obj]["max"] + mi[r] = self.__objects[obj]["min"] + ma = np.max(ma, axis=0) + mi = np.min(mi, axis=0) + diag = np.linalg.norm(ma - mi) + mean = ((ma - mi) / 2 + mi).tolist() + scale = self.__s["scale"] * (diag) + self._orbit.target = mean + self._cam.lookAt(mean) + self._cam.position = [mean[0], mean[1], mean[2] + scale] + self._light.position = [mean[0], mean[1], mean[2] + scale] + + self._orbit.exec_three_obj_method('update') + self._cam.exec_three_obj_method('updateProjectionMatrix') + + def __get_bbox(self, v): + m = np.min(v, axis=0) + M = np.max(v, axis=0) + + # Corners of the bounding box + v_box = np.array([[m[0], m[1], m[2]], [M[0], m[1], m[2]], [M[0], M[1], m[2]], [m[0], M[1], m[2]], + [m[0], m[1], M[2]], [M[0], m[1], M[2]], [M[0], M[1], M[2]], [m[0], M[1], M[2]]]) + + f_box = np.array([[0, 1], [1, 2], [2, 3], [3, 0], [4, 5], [5, 6], [6, 7], [7, 4], + [0, 4], [1, 5], [2, 6], [7, 3]], dtype=np.uint32) + return v_box, f_box + + def __get_colors(self, v, f, c, sh): + coloring = "VertexColors" + if type(c) == np.ndarray and c.size == 3: # Single color + colors = np.ones_like(v) + colors[:, 0] = c[0] + colors[:, 1] = c[1] + colors[:, 2] = c[2] + # print("Single colors") + elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[1] == 3: # Color values for + if c.shape[0] == f.shape[0]: # faces + colors = np.hstack([c, c, c]).reshape((-1, 3)) + coloring = "FaceColors" + # print("Face color values") + elif c.shape[0] == v.shape[0]: # vertices + colors = c + # print("Vertex color values") + else: # Wrong size, fallback + print("Invalid color array given! Supported are numpy arrays.", type(c)) + colors = np.ones_like(v) + colors[:, 0] = 1.0 + colors[:, 1] = 0.874 + colors[:, 2] = 0.0 + elif type(c) == np.ndarray and c.size == f.shape[0]: # Function values for faces + normalize = sh["normalize"][0] != None and sh["normalize"][1] != None + cc = get_colors(c, sh["colormap"], normalize=normalize, + vmin=sh["normalize"][0], vmax=sh["normalize"][1]) + # print(cc.shape) + colors = np.hstack([cc, cc, cc]).reshape((-1, 3)) + coloring = "FaceColors" + # print("Face function values") + elif type(c) == np.ndarray and c.size == v.shape[0]: # Function values for vertices + normalize = sh["normalize"][0] != None and sh["normalize"][1] != None + colors = get_colors(c, sh["colormap"], normalize=normalize, + vmin=sh["normalize"][0], vmax=sh["normalize"][1]) + # print("Vertex function values") + + else: + colors = np.ones_like(v) + colors[:, 0] = 1.0 + colors[:, 1] = 0.874 + colors[:, 2] = 0.0 + + # No color + if c is not None: + print("Invalid color array given! Supported are numpy arrays.", type(c)) + + return colors, coloring + + def __get_point_colors(self, v, c, sh): + v_color = True + if c is None: # No color given, use global color + # conv = mpl.colors.ColorConverter() + colors = sh["point_color"] # np.array(conv.to_rgb(sh["point_color"])) + v_color = False + elif isinstance(c, str): # No color given, use global color + # conv = mpl.colors.ColorConverter() + colors = c # np.array(conv.to_rgb(c)) + v_color = False + elif type(c) == np.ndarray and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] == 3: + # Point color + colors = c.astype("float32", copy=False) + + elif isinstance(c, np.ndarray) and len(c.shape) == 2 and c.shape[0] == v.shape[0] and c.shape[1] != 3: + # Function values for vertices, but the colors are features + c_norm = np.linalg.norm(c, ord=2, axis=-1) + normalize = sh["normalize"][0] != None and sh["normalize"][1] != None + colors = get_colors(c_norm, sh["colormap"], normalize=normalize, + vmin=sh["normalize"][0], vmax=sh["normalize"][1]) + colors = colors.astype("float32", copy=False) + + elif type(c) == np.ndarray and c.size == v.shape[0]: # Function color + normalize = sh["normalize"][0] != None and sh["normalize"][1] != None + colors = get_colors(c, sh["colormap"], normalize=normalize, + vmin=sh["normalize"][0], vmax=sh["normalize"][1]) + colors = colors.astype("float32", copy=False) + # print("Vertex function values") + + else: + print("Invalid color array given! Supported are numpy arrays.", type(c)) + colors = sh["point_color"] + v_color = False + + return colors, v_color + + def add_mesh(self, v, f, c=None, uv=None, n=None, shading={}, texture_data=None, **kwargs): + shading.update(kwargs) + sh = self.__get_shading(shading) + mesh_obj = {} + + # it is a tet + if v.shape[1] == 3 and f.shape[1] == 4: + f_tmp = np.ndarray([f.shape[0] * 4, 3], dtype=f.dtype) + for i in range(f.shape[0]): + f_tmp[i * 4 + 0] = np.array([f[i][1], f[i][0], f[i][2]]) + f_tmp[i * 4 + 1] = np.array([f[i][0], f[i][1], f[i][3]]) + f_tmp[i * 4 + 2] = np.array([f[i][1], f[i][2], f[i][3]]) + f_tmp[i * 4 + 3] = np.array([f[i][2], f[i][0], f[i][3]]) + f = f_tmp + + if v.shape[1] == 2: + v = np.append(v, np.zeros([v.shape[0], 1]), 1) + + # Type adjustment vertices + v = v.astype("float32", copy=False) + + # Color setup + colors, coloring = self.__get_colors(v, f, c, sh) + + # Type adjustment faces and colors + c = colors.astype("float32", copy=False) + + # Material and geometry setup + ba_dict = {"color": p3s.BufferAttribute(c)} + if coloring == "FaceColors": + verts = np.zeros((f.shape[0] * 3, 3), dtype="float32") + for ii in range(f.shape[0]): + # print(ii*3, f[ii]) + verts[ii * 3] = v[f[ii, 0]] + verts[ii * 3 + 1] = v[f[ii, 1]] + verts[ii * 3 + 2] = v[f[ii, 2]] + v = verts + else: + f = f.astype("uint32", copy=False).ravel() + ba_dict["index"] = p3s.BufferAttribute(f, normalized=False) + + ba_dict["position"] = p3s.BufferAttribute(v, normalized=False) + + if uv is not None: + uv = (uv - np.min(uv)) / (np.max(uv) - np.min(uv)) + if texture_data is None: + texture_data = gen_checkers(20, 20) + tex = p3s.DataTexture(data=texture_data, format="RGBFormat", type="FloatType") + material = p3s.MeshStandardMaterial(map=tex, reflectivity=sh["reflectivity"], side=sh["side"], + roughness=sh["roughness"], metalness=sh["metalness"], + flatShading=sh["flat"], + polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5) + ba_dict["uv"] = p3s.BufferAttribute(uv.astype("float32", copy=False)) + else: + material = p3s.MeshStandardMaterial(vertexColors=coloring, reflectivity=sh["reflectivity"], + side=sh["side"], roughness=sh["roughness"], metalness=sh["metalness"], + flatShading=sh["flat"], + polygonOffset=True, polygonOffsetFactor=1, polygonOffsetUnits=5) + + if type(n) != type(None) and coloring == "VertexColors": # TODO: properly handle normals for FaceColors as well + ba_dict["normal"] = p3s.BufferAttribute(n.astype("float32", copy=False), normalized=True) + + geometry = p3s.BufferGeometry(attributes=ba_dict) + + if coloring == "VertexColors" and type(n) == type(None): + geometry.exec_three_obj_method('computeVertexNormals') + elif coloring == "FaceColors" and type(n) == type(None): + geometry.exec_three_obj_method('computeFaceNormals') + + # Mesh setup + mesh = p3s.Mesh(geometry=geometry, material=material) + + # Wireframe setup + mesh_obj["wireframe"] = None + if sh["wireframe"]: + wf_geometry = p3s.WireframeGeometry(mesh.geometry) # WireframeGeometry + wf_material = p3s.LineBasicMaterial(color=sh["wire_color"], linewidth=sh["wire_width"]) + wireframe = p3s.LineSegments(wf_geometry, wf_material) + mesh.add(wireframe) + mesh_obj["wireframe"] = wireframe + + # Bounding box setup + if sh["bbox"]: + v_box, f_box = self.__get_bbox(v) + _, bbox = self.add_edges(v_box, f_box, sh, mesh) + mesh_obj["bbox"] = [bbox, v_box, f_box] + + # Object setup + mesh_obj["max"] = np.max(v, axis=0) + mesh_obj["min"] = np.min(v, axis=0) + mesh_obj["geometry"] = geometry + mesh_obj["mesh"] = mesh + mesh_obj["material"] = material + mesh_obj["type"] = "Mesh" + mesh_obj["shading"] = sh + mesh_obj["coloring"] = coloring + mesh_obj["arrays"] = [v, f, c] # TODO replays with proper storage or remove if not needed + + return self.__add_object(mesh_obj) + + def add_lines(self, beginning, ending, shading={}, obj=None, **kwargs): + shading.update(kwargs) + if len(beginning.shape) == 1: + if len(beginning) == 2: + beginning = np.array([[beginning[0], beginning[1], 0]]) + else: + if beginning.shape[1] == 2: + beginning = np.append( + beginning, np.zeros([beginning.shape[0], 1]), 1) + if len(ending.shape) == 1: + if len(ending) == 2: + ending = np.array([[ending[0], ending[1], 0]]) + else: + if ending.shape[1] == 2: + ending = np.append( + ending, np.zeros([ending.shape[0], 1]), 1) + + sh = self.__get_shading(shading) + lines = np.hstack([beginning, ending]) + lines = lines.reshape((-1, 3)) + return self.__add_line_geometry(lines, sh, obj) + + def add_edges(self, vertices, edges, shading={}, obj=None, **kwargs): + shading.update(kwargs) + if vertices.shape[1] == 2: + vertices = np.append( + vertices, np.zeros([vertices.shape[0], 1]), 1) + sh = self.__get_shading(shading) + lines = np.zeros((edges.size, 3)) + cnt = 0 + for e in edges: + lines[cnt, :] = vertices[e[0]] + lines[cnt + 1, :] = vertices[e[1]] + cnt += 2 + return self.__add_line_geometry(lines, sh, obj) + + def add_points(self, points, c=None, shading={}, obj=None, **kwargs): + shading.update(kwargs) + if len(points.shape) == 1: + if len(points) == 2: + points = np.array([[points[0], points[1], 0]]) + else: + if points.shape[1] == 2: + points = np.append( + points, np.zeros([points.shape[0], 1]), 1) + sh = self.__get_shading(shading) + points = points.astype("float32", copy=False) + mi = np.min(points, axis=0) + ma = np.max(points, axis=0) + + g_attributes = {"position": p3s.BufferAttribute(points, normalized=False)} + m_attributes = {"size": sh["point_size"]} + + if sh["point_shape"] == "circle": # Plot circles + tex = p3s.DataTexture(data=gen_circle(16, 16), format="RGBAFormat", type="FloatType") + m_attributes["map"] = tex + m_attributes["alphaTest"] = 0.5 + m_attributes["transparency"] = True + else: # Plot squares + pass + + colors, v_colors = self.__get_point_colors(points, c, sh) + if v_colors: # Colors per point + m_attributes["vertexColors"] = 'VertexColors' + g_attributes["color"] = p3s.BufferAttribute(colors, normalized=False) + + else: # Colors for all points + m_attributes["color"] = colors + + material = p3s.PointsMaterial(**m_attributes) + geometry = p3s.BufferGeometry(attributes=g_attributes) + points = p3s.Points(geometry=geometry, material=material) + point_obj = {"geometry": geometry, "mesh": points, "material": material, + "max": ma, "min": mi, "type": "Points", "wireframe": None} + + if obj: + return self.__add_object(point_obj, obj), point_obj + else: + return self.__add_object(point_obj) + + def remove_object(self, obj_id): + if obj_id not in self.__objects: + print("Invalid object id. Valid ids are: ", list(self.__objects.keys())) + return + self._scene.remove(self.__objects[obj_id]["mesh"]) + del self.__objects[obj_id] + self.__update_view() + + def reset(self): + for obj_id in list(self.__objects.keys()).copy(): + self._scene.remove(self.__objects[obj_id]["mesh"]) + del self.__objects[obj_id] + self.__update_view() + + def update_object(self, oid=0, vertices=None, colors=None, faces=None): + obj = self.__objects[oid] + if type(vertices) != type(None): + if obj["coloring"] == "FaceColors": + f = obj["arrays"][1] + verts = np.zeros((f.shape[0] * 3, 3), dtype="float32") + for ii in range(f.shape[0]): + # print(ii*3, f[ii]) + verts[ii * 3] = vertices[f[ii, 0]] + verts[ii * 3 + 1] = vertices[f[ii, 1]] + verts[ii * 3 + 2] = vertices[f[ii, 2]] + v = verts + + else: + v = vertices.astype("float32", copy=False) + obj["geometry"].attributes["position"].array = v + # self.wireframe.attributes["position"].array = v # Wireframe updates? + obj["geometry"].attributes["position"].needsUpdate = True + # obj["geometry"].exec_three_obj_method('computeVertexNormals') + if type(colors) != type(None): + colors, coloring = self.__get_colors(obj["arrays"][0], obj["arrays"][1], colors, obj["shading"]) + colors = colors.astype("float32", copy=False) + obj["geometry"].attributes["color"].array = colors + obj["geometry"].attributes["color"].needsUpdate = True + if type(faces) != type(None): + if obj["coloring"] == "FaceColors": + print("Face updates are currently only possible in vertex color mode.") + return + f = faces.astype("uint32", copy=False).ravel() + print(obj["geometry"].attributes) + obj["geometry"].attributes["index"].array = f + # self.wireframe.attributes["position"].array = v # Wireframe updates? + obj["geometry"].attributes["index"].needsUpdate = True + # obj["geometry"].exec_three_obj_method('computeVertexNormals') + # self.mesh.geometry.verticesNeedUpdate = True + # self.mesh.geometry.elementsNeedUpdate = True + # self.update() + if self.render_mode == "WEBSITE": + return self + + # def update(self): + # self.mesh.exec_three_obj_method('update') + # self.orbit.exec_three_obj_method('update') + # self.cam.exec_three_obj_method('updateProjectionMatrix') + # self.scene.exec_three_obj_method('update') + + def add_text(self, text, shading={}, **kwargs): + shading.update(kwargs) + sh = self.__get_shading(shading) + tt = p3s.TextTexture(string=text, color=sh["text_color"]) + sm = p3s.SpriteMaterial(map=tt) + text = p3s.Sprite(material=sm, scaleToTexture=True) + self._scene.add(text) + + # def add_widget(self, widget, callback): + # self.widgets.append(widget) + # widget.observe(callback, names='value') + + # def add_dropdown(self, options, default, desc, cb): + # widget = widgets.Dropdown(options=options, value=default, description=desc) + # self.__widgets.append(widget) + # widget.observe(cb, names="value") + # display(widget) + + # def add_button(self, text, cb): + # button = widgets.Button(description=text) + # self.__widgets.append(button) + # button.on_click(cb) + # display(button) + + def to_html(self, imports=True, html_frame=True): + # Bake positions (fixes centering bug in offline rendering) + if len(self.__objects) == 0: + return + ma = np.zeros((len(self.__objects), 3)) + mi = np.zeros((len(self.__objects), 3)) + for r, obj in enumerate(self.__objects): + ma[r] = self.__objects[obj]["max"] + mi[r] = self.__objects[obj]["min"] + ma = np.max(ma, axis=0) + mi = np.min(mi, axis=0) + diag = np.linalg.norm(ma - mi) + mean = (ma - mi) / 2 + mi + for r, obj in enumerate(self.__objects): + v = self.__objects[obj]["geometry"].attributes["position"].array + v -= mean + v += np.array([0.0, .9, 0.0]) #! to move the obj to the center of window + + scale = self.__s["scale"] * (diag) + self._orbit.target = [0.0, 0.0, 0.0] + self._cam.lookAt([0.0, 0.0, 0.0]) + # self._cam.position = [0.0, 0.0, scale] + self._cam.position = [0.0, 0.5, scale * 1.3] #! show four complete meshes in the window + self._light.position = [0.0, 0.0, scale] + + state = embed.dependency_state(self._renderer) + + # Somehow these entries are missing when the state is exported in python. + # Exporting from the GUI works, so we are inserting the missing entries. + for k in state: + if state[k]["model_name"] == "OrbitControlsModel": + state[k]["state"]["maxAzimuthAngle"] = "inf" + state[k]["state"]["maxDistance"] = "inf" + state[k]["state"]["maxZoom"] = "inf" + state[k]["state"]["minAzimuthAngle"] = "-inf" + + tpl = embed.load_requirejs_template + if not imports: + embed.load_requirejs_template = "" + + s = embed.embed_snippet(self._renderer, state=state, embed_url=EMBED_URL) + # s = embed.embed_snippet(self.__w, state=state) + embed.load_requirejs_template = tpl + + if html_frame: + s = "\n\n" + s + "\n\n" + + # Revert changes + for r, obj in enumerate(self.__objects): + v = self.__objects[obj]["geometry"].attributes["position"].array + v += mean + self.__update_view() + + return s + + def save(self, filename=""): + if filename == "": + uid = str(uuid.uuid4()) + ".html" + else: + filename = filename.replace(".html", "") + uid = filename + '.html' + with open(uid, "w") as f: + f.write(self.to_html()) + print("Plot saved to file %s." % uid)