diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..5be3c327f9b882eb1139fc1e7fadc9aa294c8f00 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +example/boots.mp4 filter=lfs diff=lfs merge=lfs -text +example/Donut.mp4 filter=lfs diff=lfs merge=lfs -text +example/durian.mp4 filter=lfs diff=lfs merge=lfs -text +example/pillow_huskies.mp4 filter=lfs diff=lfs merge=lfs -text +example/wooden_car.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..99df8a8719bba2a40379faad3f762129150f1733 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +*.pyc +.vscode +output +build +output/ +point_e_model_cache/ \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..c39b2bab5591b8b7282bbd38ca23410285a8409b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "submodules/diff-gaussian-rasterization"] + path = submodules/diff-gaussian-rasterization + url = https://github.com/YixunLiang/diff-gaussian-rasterization.git +[submodule "submodules/simple-knn"] + path = submodules/simple-knn + url = https://github.com/YixunLiang/simple-knn.git diff --git a/GAUSSIAN_SPLATTING_LICENSE.md b/GAUSSIAN_SPLATTING_LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..ef71f1375508a836c6550547b34e8310ebea7d8a --- /dev/null +++ b/GAUSSIAN_SPLATTING_LICENSE.md @@ -0,0 +1,83 @@ +Gaussian-Splatting License +=========================== + +**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**. +The *Software* is in the process of being registered with the Agence pour la Protection des +Programmes (APP). + +The *Software* is still being developed by the *Licensor*. + +*Licensor*'s goal is to allow the research community to use, test and evaluate +the *Software*. + +## 1. Definitions + +*Licensee* means any person or entity that uses the *Software* and distributes +its *Work*. + +*Licensor* means the owners of the *Software*, i.e Inria and MPII + +*Software* means the original work of authorship made available under this +License ie gaussian-splatting. + +*Work* means the *Software* and any additions to or derivative works of the +*Software* that are made available under this License. + + +## 2. Purpose +This license is intended to define the rights granted to the *Licensee* by +Licensors under the *Software*. + +## 3. Rights granted + +For the above reasons Licensors have decided to distribute the *Software*. +Licensors grant non-exclusive rights to use the *Software* for research purposes +to research users (both academic and industrial), free of charge, without right +to sublicense.. The *Software* may be used "non-commercially", i.e., for research +and/or evaluation purposes only. + +Subject to the terms and conditions of this License, you are granted a +non-exclusive, royalty-free, license to reproduce, prepare derivative works of, +publicly display, publicly perform and distribute its *Work* and any resulting +derivative works in any form. + +## 4. Limitations + +**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do +so under this License, (b) you include a complete copy of this License with +your distribution, and (c) you retain without modification any copyright, +patent, trademark, or attribution notices that are present in the *Work*. + +**4.2 Derivative Works.** You may specify that additional or different terms apply +to the use, reproduction, and distribution of your derivative works of the *Work* +("Your Terms") only if (a) Your Terms provide that the use limitation in +Section 2 applies to your derivative works, and (b) you identify the specific +derivative works that are subject to Your Terms. Notwithstanding Your Terms, +this License (including the redistribution requirements in Section 3.1) will +continue to apply to the *Work* itself. + +**4.3** Any other use without of prior consent of Licensors is prohibited. Research +users explicitly acknowledge having received from Licensors all information +allowing to appreciate the adequacy between of the *Software* and their needs and +to undertake all necessary precautions for its execution and use. + +**4.4** The *Software* is provided both as a compiled library file and as source +code. In case of using the *Software* for a publication or other results obtained +through the use of the *Software*, users are strongly encouraged to cite the +corresponding publications as explained in the documentation of the *Software*. + +## 5. Disclaimer + +THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES +WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY +UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL +CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES +OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL +USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR +ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE +AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION) +HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR +IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*. diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..0668abea143aabf2b1e3070c40302841fe583e9b --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 dreamgaussian + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/arguments/__init__.py b/arguments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f02c780966039a8d7d3579fc86cdc981f64739a --- /dev/null +++ b/arguments/__init__.py @@ -0,0 +1,258 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from argparse import ArgumentParser, Namespace +import sys +import os + +class GroupParams: + pass + +class ParamGroup: + def __init__(self, parser: ArgumentParser, name : str, fill_none = False): + group = parser.add_argument_group(name) + for key, value in vars(self).items(): + shorthand = False + if key.startswith("_"): + shorthand = True + key = key[1:] + t = type(value) + value = value if not fill_none else None + if shorthand: + if t == bool: + group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true") + else: + group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t) + else: + if t == bool: + group.add_argument("--" + key, default=value, action="store_true") + else: + group.add_argument("--" + key, default=value, type=t) + + def extract(self, args): + group = GroupParams() + for arg in vars(args).items(): + if arg[0] in vars(self) or ("_" + arg[0]) in vars(self): + setattr(group, arg[0], arg[1]) + return group + + def load_yaml(self, opts=None): + if opts is None: + return + else: + for key, value in opts.items(): + try: + setattr(self, key, value) + except: + raise Exception(f'Unknown attribute {key}') + +class GuidanceParams(ParamGroup): + def __init__(self, parser, opts=None): + self.guidance = "SD" + self.g_device = "cuda" + + self.model_key = None + self.is_safe_tensor = False + self.base_model_key = None + + self.controlnet_model_key = None + + self.perpneg = True + self.negative_w = -2. + self.front_decay_factor = 2. + self.side_decay_factor = 10. + + self.vram_O = False + self.fp16 = True + self.hf_key = None + self.t_range = [0.02, 0.5] + self.max_t_range = 0.98 + + self.scheduler_type = 'DDIM' + self.num_train_timesteps = None + + self.sds = False + self.fix_noise = False + self.noise_seed = 0 + + self.ddim_inv = False + self.delta_t = 80 + self.delta_t_start = 100 + self.annealing_intervals = True + self.text = '' + self.inverse_text = '' + self.textual_inversion_path = None + self.LoRA_path = None + self.controlnet_ratio = 0.5 + self.negative = "" + self.guidance_scale = 7.5 + self.denoise_guidance_scale = 1.0 + self.lambda_guidance = 1. + + self.xs_delta_t = 200 + self.xs_inv_steps = 5 + self.xs_eta = 0.0 + + # multi-batch + self.C_batch_size = 1 + + self.vis_interval = 100 + + super().__init__(parser, "Guidance Model Parameters") + + +class ModelParams(ParamGroup): + def __init__(self, parser, sentinel=False, opts=None): + self.sh_degree = 0 + self._source_path = "" + self._model_path = "" + self.pretrained_model_path = None + self._images = "images" + self.workspace = "debug" + self.batch = 10 + self._resolution = -1 + self._white_background = True + self.data_device = "cuda" + self.eval = False + self.opt_path = None + + # augmentation + self.sh_deg_aug_ratio = 0.1 + self.bg_aug_ratio = 0.5 + self.shs_aug_ratio = 0.0 + self.scale_aug_ratio = 1.0 + super().__init__(parser, "Loading Parameters", sentinel) + + def extract(self, args): + g = super().extract(args) + g.source_path = os.path.abspath(g.source_path) + return g + + +class PipelineParams(ParamGroup): + def __init__(self, parser, opts=None): + self.convert_SHs_python = False + self.compute_cov3D_python = False + self.debug = False + super().__init__(parser, "Pipeline Parameters") + + +class OptimizationParams(ParamGroup): + def __init__(self, parser, opts=None): + self.iterations = 5000# 10_000 + self.position_lr_init = 0.00016 + self.position_lr_final = 0.0000016 + self.position_lr_delay_mult = 0.01 + self.position_lr_max_steps = 30_000 + self.feature_lr = 0.0050 + self.feature_lr_final = 0.0030 + + self.opacity_lr = 0.05 + self.scaling_lr = 0.005 + self.rotation_lr = 0.001 + + + self.geo_iter = 0 + self.as_latent_ratio = 0.2 + # dense + + self.resnet_lr = 1e-4 + self.resnet_lr_init = 2e-3 + self.resnet_lr_final = 5e-5 + + + self.scaling_lr_final = 0.001 + self.rotation_lr_final = 0.0002 + + self.percent_dense = 0.003 + self.densify_grad_threshold = 0.00075 + + self.lambda_tv = 1.0 # 0.1 + self.lambda_bin = 10.0 + self.lambda_scale = 1.0 + self.lambda_sat = 1.0 + self.lambda_radius = 1.0 + self.densification_interval = 100 + self.opacity_reset_interval = 300 + self.densify_from_iter = 100 + self.densify_until_iter = 30_00 + + self.use_control_net_iter = 10000000 + self.warmup_iter = 1500 + + self.use_progressive = False + self.save_process = True + self.pro_frames_num = 600 + self.pro_render_45 = False + self.progressive_view_iter = 500 + self.progressive_view_init_ratio = 0.2 + + self.scale_up_cameras_iter = 500 + self.scale_up_factor = 0.95 + self.fovy_scale_up_factor = [0.75, 1.1] + self.phi_scale_up_factor = 1.5 + super().__init__(parser, "Optimization Parameters") + + +class GenerateCamParams(ParamGroup): + def __init__(self, parser): + self.init_shape = 'sphere' + self.init_prompt = '' + self.use_pointe_rgb = False + self.radius_range = [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + self.max_radius_range = [3.5, 5.0] + self.default_radius = 3.5 + self.theta_range = [45, 105] + self.max_theta_range = [45, 105] + self.phi_range = [-180, 180] + self.max_phi_range = [-180, 180] + self.fovy_range = [0.32, 0.60] #[0.3, 1.5] #[0.5, 0.8] #[10, 30] + self.max_fovy_range = [0.16, 0.60] + self.rand_cam_gamma = 1.0 + self.angle_overhead = 30 + self.angle_front =60 + self.render_45 = True + self.uniform_sphere_rate = 0 + self.image_w = 512 + self.image_h = 512 # 512 + self.SSAA = 1 + self.init_num_pts = 100_000 + self.default_polar = 90 + self.default_azimuth = 0 + self.default_fovy = 0.55 #20 + self.jitter_pose = True + self.jitter_center = 0.05 + self.jitter_target = 0.05 + self.jitter_up = 0.01 + self.device = "cuda" + super().__init__(parser, "Generate Cameras Parameters") + +def get_combined_args(parser : ArgumentParser): + cmdlne_string = sys.argv[1:] + cfgfile_string = "Namespace()" + args_cmdline = parser.parse_args(cmdlne_string) + + try: + cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args") + print("Looking for config file in", cfgfilepath) + with open(cfgfilepath) as cfg_file: + print("Config file found: {}".format(cfgfilepath)) + cfgfile_string = cfg_file.read() + except TypeError: + print("Config file not found at") + pass + args_cfgfile = eval(cfgfile_string) + + merged_dict = vars(args_cfgfile).copy() + for k,v in vars(args_cmdline).items(): + if v != None: + merged_dict[k] = v + return Namespace(**merged_dict) diff --git a/configs/axe.yaml b/configs/axe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bfdc9a78f6544db27c5b9cf7d14bd846079cf503 --- /dev/null +++ b/configs/axe.yaml @@ -0,0 +1,76 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: viking_axe + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'Viking axe, fantasy, weapon, blender, 8k, HDR.' + negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.' + inverse_text: '' + perpneg: false + C_batch_size: 4 + + t_range: [0.02, 0.5] + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + noise_seed: 0 + + ddim_inv: true + accum: false + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 25 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'pointe' + init_prompt: 'A flag.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-180, 180] + max_phi_range: [-180, 180] + rand_cam_gamma: 1. + + theta_range: [45, 105] + max_theta_range: [45, 105] + + radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 #1500 + opacity_reset_interval: 300 #500 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/configs/bagel.yaml b/configs/bagel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dabb640a7217b63c3c80b604edd0e3058aca14bd --- /dev/null +++ b/configs/bagel.yaml @@ -0,0 +1,74 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: bagel + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'a DSLR photo of a bagel filled with cream cheese and lox.' + negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.' + inverse_text: '' + perpneg: false + C_batch_size: 4 + t_range: [0.02, 0.5] + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + noise_seed: 0 + + ddim_inv: true + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 80 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'pointe' + init_prompt: 'a bagel.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-180, 180] + max_phi_range: [-180, 180] + rand_cam_gamma: 1. + + theta_range: [45, 105] + max_theta_range: [45, 105] + + radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 #1500 + opacity_reset_interval: 300 #500 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/configs/cat_armor.yaml b/configs/cat_armor.yaml new file mode 100644 index 0000000000000000000000000000000000000000..10f0e4700ddf880a772b08553786949cb8e7372b --- /dev/null +++ b/configs/cat_armor.yaml @@ -0,0 +1,74 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: cat_armor + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'a DSLR photo of a cat wearing armor.' + negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.' + inverse_text: '' + perpneg: true + C_batch_size: 4 + t_range: [0.02, 0.5] + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + noise_seed: 0 + + ddim_inv: true + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 80 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'pointe' + init_prompt: 'a cat.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-180, 180] + max_phi_range: [-180, 180] + rand_cam_gamma: 1.5 + + theta_range: [60, 90] + max_theta_range: [60, 90] + + radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 #1500 + opacity_reset_interval: 300 #500 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/configs/crown.yaml b/configs/crown.yaml new file mode 100644 index 0000000000000000000000000000000000000000..61c16dfc84b4225cc85ee9c55e8718728135664b --- /dev/null +++ b/configs/crown.yaml @@ -0,0 +1,74 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: crown + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'a DSLR photo of the Imperial State Crown of England.' + negative: 'unrealistic, blurry, low quality.' + inverse_text: '' + perpneg: false + C_batch_size: 4 + t_range: [0.02, 0.5] + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + noise_seed: 0 + + ddim_inv: true + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 80 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'pointe' + init_prompt: 'the Imperial State Crown of England.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-180, 180] + max_phi_range: [-180, 180] + rand_cam_gamma: 1. + + theta_range: [45, 105] + max_theta_range: [45, 105] + + radius_range: [5.2, 5.5] + max_radius_range: [3.5, 5.0] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 + opacity_reset_interval: 300 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/configs/football_helmet.yaml b/configs/football_helmet.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51d3557ebb36220a17e4fa49ceb382fb77230bd4 --- /dev/null +++ b/configs/football_helmet.yaml @@ -0,0 +1,75 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: football_helmet + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'a DSLR photo of a football helmet.' + negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.' + inverse_text: '' + perpneg: false + C_batch_size: 4 + t_range: [0.02, 0.5] + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + + noise_seed: 0 + + ddim_inv: true + accum: false + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 50 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'pointe' + init_prompt: 'a football helmet.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-180, 180] + max_phi_range: [-180, 180] + + theta_range: [45, 90] + max_theta_range: [45, 90] + + radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 #1500 + opacity_reset_interval: 300 #500 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/configs/hamburger.yaml b/configs/hamburger.yaml new file mode 100644 index 0000000000000000000000000000000000000000..626c1d07f2db7770b0cba86c5b9619b4a3be72ac --- /dev/null +++ b/configs/hamburger.yaml @@ -0,0 +1,75 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: hamburger + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'A delicious hamburger.' + negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.' + inverse_text: '' + perpneg: false + C_batch_size: 4 + t_range: [0.02, 0.5] + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + + noise_seed: 0 + + ddim_inv: true + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 50 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'sphere' + init_prompt: '.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-180, 180] + max_phi_range: [-180, 180] + rand_cam_gamma: 1. + + theta_range: [45, 105] + max_theta_range: [45, 105] + + radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 #1500 + opacity_reset_interval: 300 #500 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/configs/ts_lora.yaml b/configs/ts_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3f22c98b2ea99e92d10d339b573445b35f5b82cb --- /dev/null +++ b/configs/ts_lora.yaml @@ -0,0 +1,76 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: TS_lora + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'A wearing sunglasses.' + LoRA_path: "./custom_example/lora/Taylor_Swift/step_inv_1000.safetensors" + negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.' + inverse_text: '' + perpneg: true + C_batch_size: 4 + t_range: [0.02, 0.5] + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + + noise_seed: 0 + + ddim_inv: true + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 80 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'pointe' + init_prompt: 'a girl head.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-80, 80] + max_phi_range: [-180, 180] + rand_cam_gamma: 1.5 + + theta_range: [60, 120] + max_theta_range: [60, 120] + + radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 #1500 + opacity_reset_interval: 300 #500 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/configs/white_hair_ironman.yaml b/configs/white_hair_ironman.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17b80caaed6b2af3becf08622c94e7d7f70f6f34 --- /dev/null +++ b/configs/white_hair_ironman.yaml @@ -0,0 +1,73 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: white_hair_IRONMAN + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'A portrait of IRONMAN, white hair, head, photorealistic, 8K, HDR.' + negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution.' + inverse_text: '' + perpneg: true + C_batch_size: 4 + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + noise_seed: 0 + + ddim_inv: true + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 50 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'pointe' + init_prompt: 'a man head.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-80, 80] + max_phi_range: [-180, 180] + rand_cam_gamma: 1.5 + + theta_range: [45, 90] + max_theta_range: [45, 90] + + radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 #1500 + opacity_reset_interval: 300 #500 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/configs/zombie_joker.yaml b/configs/zombie_joker.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b5c2627973232c954818de72071b23eb703c98b2 --- /dev/null +++ b/configs/zombie_joker.yaml @@ -0,0 +1,75 @@ +port: 2355 +save_video: true +seed: 0 + +PipelineParams: + convert_SHs_python: False #true = using direct rgb +ModelParams: + workspace: zombie_joker + sh_degree: 0 + bg_aug_ratio: 0.66 + +GuidanceParams: + model_key: 'stabilityai/stable-diffusion-2-1-base' + text: 'Zombie JOKER, head, photorealistic, 8K, HDR.' + negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.' + inverse_text: '' + perpneg: true + C_batch_size: 4 + + t_range: [0.02, 0.5] + max_t_range: 0.98 + lambda_guidance: 0.1 + guidance_scale: 7.5 + denoise_guidance_scale: 1.0 + noise_seed: 0 + + ddim_inv: true + annealing_intervals: true + + xs_delta_t: 200 + xs_inv_steps: 5 + xs_eta: 0.0 + + delta_t: 50 + delta_t_start: 100 + +GenerateCamParams: + init_shape: 'pointe' + init_prompt: 'a man head.' + use_pointe_rgb: false + init_num_pts: 100_000 + phi_range: [-80, 80] + max_phi_range: [-180, 180] + rand_cam_gamma: 1.5 + + theta_range: [45, 90] + max_theta_range: [45, 90] + + radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5] + max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5] + default_radius: 3.5 + + default_fovy: 0.55 + fovy_range: [0.32, 0.60] + max_fovy_range: [0.16, 0.60] + +OptimizationParams: + iterations: 5000 + save_process: True + pro_frames_num: 600 + pro_render_45: False + warmup_iter: 1500 # 2500 + + as_latent_ratio : 0.2 + geo_iter : 0 + densify_from_iter: 100 + densify_until_iter: 3000 + percent_dense: 0.003 + densify_grad_threshold: 0.00075 + progressive_view_iter: 500 #1500 + opacity_reset_interval: 300 #500 + + scale_up_cameras_iter: 500 + fovy_scale_up_factor: [0.75, 1.1] + phi_scale_up_factor: 1.5 \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..d1224d3e9c27c6f9672b676e9c25f29a350d245b --- /dev/null +++ b/environment.yml @@ -0,0 +1,29 @@ +name: LucidDreamer +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - cudatoolkit=11.6 + - plyfile=0.8.1 + - python=3.9 + - pip=22.3.1 + - pytorch=1.12.1 + - torchaudio=0.12.1 + - torchvision=0.15.2 + - tqdm + - pip: + - mediapipe + - Pillow + - diffusers==0.18.2 + - xformers==0.0.20 + - transformers==4.30.2 + - fire==0.5.0 + - huggingface_hub==0.16.4 + - imageio==2.31.1 + - imageio-ffmpeg + - PyYAML + - safetensors + - wandb + - accelerate + - triton \ No newline at end of file diff --git a/example/Donut.mp4 b/example/Donut.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ac2dd74e0dd1dde4da0810a774c563fd0cc7042d --- /dev/null +++ b/example/Donut.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4633e31ff1ff161e0bd015c166c507cad140e14aef616eecef95c32da5dd1902 +size 2264633 diff --git a/example/boots.mp4 b/example/boots.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..8ea41360ed494ea5ed2ad16f4c4a9228e2647073 --- /dev/null +++ b/example/boots.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f117d721a095ae913d17072ee5ed4373c95f1a8851ca6e9e254bf5efeaf56cb +size 5358683 diff --git a/example/durian.mp4 b/example/durian.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9d5b86cf69dc207008f4392eaf6a5c9ee0def67b --- /dev/null +++ b/example/durian.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:da35c90e1212627da08180fcb513a0d402dc189c613c00d18a3b3937c992b47d +size 9316285 diff --git a/example/pillow_huskies.mp4 b/example/pillow_huskies.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ffb0dbb3a82fd782a8321af7042ed1e450407f41 --- /dev/null +++ b/example/pillow_huskies.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fc53845fdca59e413765833aed51a9a93e2962719a63f4471e9fa7943e217cf6 +size 3586741 diff --git a/example/wooden_car.mp4 b/example/wooden_car.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9d6432016300e254a3f091ba5ca044c3a6f39cdc --- /dev/null +++ b/example/wooden_car.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c4e3b6b31a1d2c9e3791c4c2c7b278d1eb3ae209c94b5d5f3834b4ea5d6d3c16 +size 1660564 diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8a85c4b0a6bfcb9044f78b45912644e766195b --- /dev/null +++ b/gaussian_renderer/__init__.py @@ -0,0 +1,168 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer +from scene.gaussian_model import GaussianModel +from utils.sh_utils import eval_sh, SH2RGB +from utils.graphics_utils import fov2focal +import random + + +def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, black_video = False, + override_color = None, sh_deg_aug_ratio = 0.1, bg_aug_ratio = 0.3, shs_aug_ratio=1.0, scale_aug_ratio=1.0, test = False): + """ + Render the scene. + + Background tensor (bg_color) must be on GPU! + """ + + # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means + screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 + try: + screenspace_points.retain_grad() + except: + pass + + if black_video: + bg_color = torch.zeros_like(bg_color) + #Aug + if random.random() < sh_deg_aug_ratio and not test: + act_SH = 0 + else: + act_SH = pc.active_sh_degree + + if random.random() < bg_aug_ratio and not test: + if random.random() < 0.5: + bg_color = torch.rand_like(bg_color) + else: + bg_color = torch.zeros_like(bg_color) + # bg_color = torch.zeros_like(bg_color) + + #bg_color = torch.zeros_like(bg_color) + # Set up rasterization configuration + tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) + tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) + try: + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=act_SH, + campos=viewpoint_camera.camera_center, + prefiltered=False + ) + except TypeError as e: + raster_settings = GaussianRasterizationSettings( + image_height=int(viewpoint_camera.image_height), + image_width=int(viewpoint_camera.image_width), + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=bg_color, + scale_modifier=scaling_modifier, + viewmatrix=viewpoint_camera.world_view_transform, + projmatrix=viewpoint_camera.full_proj_transform, + sh_degree=act_SH, + campos=viewpoint_camera.camera_center, + prefiltered=False, + debug=False + ) + + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means3D = pc.get_xyz + means2D = screenspace_points + opacity = pc.get_opacity + + # If precomputed 3d covariance is provided, use it. If not, then it will be computed from + # scaling / rotation by the rasterizer. + scales = None + rotations = None + cov3D_precomp = None + if pipe.compute_cov3D_python: + cov3D_precomp = pc.get_covariance(scaling_modifier) + else: + scales = pc.get_scaling + rotations = pc.get_rotation + + # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors + # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. + shs = None + colors_precomp = None + if colors_precomp is None: + if pipe.convert_SHs_python: + raw_rgb = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2).squeeze()[:,:3] + rgb = torch.sigmoid(raw_rgb) + colors_precomp = rgb + else: + shs = pc.get_features + else: + colors_precomp = override_color + + if random.random() < shs_aug_ratio and not test: + variance = (0.2 ** 0.5) * shs + shs = shs + (torch.randn_like(shs) * variance) + + # add noise to scales + if random.random() < scale_aug_ratio and not test: + variance = (0.2 ** 0.5) * scales / 4 + scales = torch.clamp(scales + (torch.randn_like(scales) * variance), 0.0) + + # Rasterize visible Gaussians to image, obtain their radii (on screen). + + rendered_image, radii, depth_alpha = rasterizer( + means3D = means3D, + means2D = means2D, + shs = shs, + colors_precomp = colors_precomp, + opacities = opacity, + scales = scales, + rotations = rotations, + cov3D_precomp = cov3D_precomp) + depth, alpha = torch.chunk(depth_alpha, 2) + # bg_train = pc.get_background + # rendered_image = bg_train*alpha.repeat(3,1,1) + rendered_image +# focal = 1 / (2 * math.tan(viewpoint_camera.FoVx / 2)) #torch.tan(torch.tensor(viewpoint_camera.FoVx) / 2) * (2. / 2 +# disparity = focal / (depth + 1e-9) +# max_disp = torch.max(disparity) +# min_disp = torch.min(disparity[disparity > 0]) +# norm_disparity = (disparity - min_disp) / (max_disp - min_disp) +# # Those Gaussians that were frustum culled or had a radius of 0 were not visible. +# # They will be excluded from value updates used in the splitting criteria. +# return {"render": rendered_image, +# "depth": norm_disparity, + + focal = 1 / (2 * math.tan(viewpoint_camera.FoVx / 2)) + disp = focal / (depth + (alpha * 10) + 1e-5) + + try: + min_d = disp[alpha <= 0.1].min() + except Exception: + min_d = disp.min() + + disp = torch.clamp((disp - min_d) / (disp.max() - min_d), 0.0, 1.0) + + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. + # They will be excluded from value updates used in the splitting criteria. + return {"render": rendered_image, + "depth": disp, + "alpha": alpha, + "viewspace_points": screenspace_points, + "visibility_filter" : radii > 0, + "radii": radii, + "scales": scales} diff --git a/gaussian_renderer/network_gui.py b/gaussian_renderer/network_gui.py new file mode 100644 index 0000000000000000000000000000000000000000..abb01ed3d5c0b8762c9bd78d2cd3a4478f85c827 --- /dev/null +++ b/gaussian_renderer/network_gui.py @@ -0,0 +1,95 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import traceback +import socket +import json +from scene.cameras import MiniCam + +host = "127.0.0.1" +port = 6009 + +conn = None +addr = None + +listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + +def init(wish_host, wish_port): + global host, port, listener + host = wish_host + port = wish_port + cnt = 0 + while True: + try: + listener.bind((host, port)) + break + except: + if cnt == 10: + break + cnt += 1 + port += 1 + listener.listen() + listener.settimeout(0) + +def try_connect(): + global conn, addr, listener + try: + conn, addr = listener.accept() + print(f"\nConnected by {addr}") + conn.settimeout(None) + except Exception as inst: + pass + +def read(): + global conn + messageLength = conn.recv(4) + messageLength = int.from_bytes(messageLength, 'little') + message = conn.recv(messageLength) + return json.loads(message.decode("utf-8")) + +def send(message_bytes, verify): + global conn + if message_bytes != None: + conn.sendall(message_bytes) + conn.sendall(len(verify).to_bytes(4, 'little')) + conn.sendall(bytes(verify, 'ascii')) + +def receive(): + message = read() + + width = message["resolution_x"] + height = message["resolution_y"] + + if width != 0 and height != 0: + try: + do_training = bool(message["train"]) + fovy = message["fov_y"] + fovx = message["fov_x"] + znear = message["z_near"] + zfar = message["z_far"] + do_shs_python = bool(message["shs_python"]) + do_rot_scale_python = bool(message["rot_scale_python"]) + keep_alive = bool(message["keep_alive"]) + scaling_modifier = message["scaling_modifier"] + world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() + world_view_transform[:,1] = -world_view_transform[:,1] + world_view_transform[:,2] = -world_view_transform[:,2] + full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() + full_proj_transform[:,1] = -full_proj_transform[:,1] + custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) + except Exception as e: + print("") + traceback.print_exc() + raise e + return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier + else: + return None, None, None, None, None, None \ No newline at end of file diff --git a/gradio_demo.py b/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..cec71cb95a7147a9549f0614320b11ff456a66fd --- /dev/null +++ b/gradio_demo.py @@ -0,0 +1,62 @@ +import gradio as gr +import numpy as np +from train import * + +example_inputs = [[ + "A DSLR photo of a Rugged, vintage-inspired hiking boots with a weathered leather finish, best quality, 4K, HD.", + "Rugged, vintage-inspired hiking boots with a weathered leather finish." + ], [ + "a DSLR photo of a Cream Cheese Donut.", + "a Donut." + ], [ + "A durian, 8k, HDR.", + "A durian" + ], [ + "A pillow with huskies printed on it", + "A pillow" + ], [ + "A DSLR photo of a wooden car, super detailed, best quality, 4K, HD.", + "a wooden car." +]] +example_outputs = [ + gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/boots.mp4'), autoplay=True), + gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/Donut.mp4'), autoplay=True), + gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/durian.mp4'), autoplay=True), + gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/pillow_huskies.mp4'), autoplay=True), + gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/wooden_car.mp4'), autoplay=True) +] + +def main(prompt, init_prompt, negative_prompt, num_iter, CFG, seed): + if [prompt, init_prompt] in example_inputs: + return example_outputs[example_inputs.index([prompt, init_prompt])] + args, lp, op, pp, gcp, gp = args_parser(default_opt=os.path.join(os.path.dirname(__file__), 'configs/white_hair_ironman.yaml')) + gp.text = prompt + gp.negative = negative_prompt + if len(init_prompt) > 1: + gcp.init_shape = 'pointe' + gcp.init_prompt = init_prompt + else: + gcp.init_shape = 'sphere' + gcp.init_prompt = '.' + op.iterations = num_iter + gp.guidance_scale = CFG + gp.noise_seed = int(seed) + lp.workspace = 'gradio_demo' + video_path = start_training(args, lp, op, pp, gcp, gp) + return gr.Video(value=video_path, autoplay=True) + +with gr.Blocks() as demo: + gr.Markdown("#
LucidDreamer: Towards High-Fidelity Text-to-3D Generation via Interval Score Matching
") + gr.Markdown("
Yixun Liang*, Xin Yang*, Jiantao Lin, Haodong Li, Xiaogang Xu, Yingcong Chen**
") + gr.Markdown("
*: Equal contribution. **: Corresponding author.
") + gr.Markdown("We present a text-to-3D generation framework, named the *LucidDreamer*, to distill high-fidelity textures and shapes from pretrained 2D diffusion models.") + gr.Markdown("
CLICK for the full abstractThe recent advancements in text-to-3D generation mark a significant milestone in generative models, unlocking new possibilities for creating imaginative 3D assets across various real-world scenarios. While recent advancements in text-to-3D generation have shown promise, they often fall short in rendering detailed and high-quality 3D models. This problem is especially prevalent as many methods base themselves on Score Distillation Sampling (SDS). This paper identifies a notable deficiency in SDS, that it brings inconsistent and low-quality updating direction for the 3D model, causing the over-smoothing effect. To address this, we propose a novel approach called Interval Score Matching (ISM). ISM employs deterministic diffusing trajectories and utilizes interval-based score matching to counteract over-smoothing. Furthermore, we incorporate 3D Gaussian Splatting into our text-to-3D generation pipeline. Extensive experiments show that our model largely outperforms the state-of-the-art in quality and training efficiency.
") + gr.Interface(fn=main, inputs=[gr.Textbox(lines=2, value="A portrait of IRONMAN, white hair, head, photorealistic, 8K, HDR.", label="Your prompt"), + gr.Textbox(lines=1, value="a man head.", label="Point-E init prompt (optional)"), + gr.Textbox(lines=2, value="unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution.", label="Negative prompt (optional)"), + gr.Slider(1000, 5000, value=5000, label="Number of iterations"), + gr.Slider(7.5, 100, value=7.5, label="CFG"), + gr.Number(value=0, label="Seed")], + outputs="playable_video", + examples=example_inputs) +demo.launch() diff --git a/guidance/perpneg_utils.py b/guidance/perpneg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..79f4f98dd49284d7d616bec7ff4e635d08421c44 --- /dev/null +++ b/guidance/perpneg_utils.py @@ -0,0 +1,48 @@ +import torch + +# Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm +def get_perpendicular_component(x, y): + assert x.shape == y.shape + return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y + + +def batch_get_perpendicular_component(x, y): + assert x.shape == y.shape + result = [] + for i in range(x.shape[0]): + result.append(get_perpendicular_component(x[i], y[i])) + return torch.stack(result) + + +def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size): + """ + Notes: + - weights: an array with the weights for combining the noise predictions + - delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir + """ + delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64] + weights = weights.split(batch_size, dim=0) # K x [B] + # print(f"{weights[0].shape = } {weights = }") + + assert torch.all(weights[0] == 1.0) + + main_positive = delta_noise_preds[0] # [B, 4, 64, 64] + + accumulated_output = torch.zeros_like(main_positive) + for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1): + # print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n") + + idx_non_zero = torch.abs(weights[i]) > 1e-4 + + # print(f"{idx_non_zero.shape = }, {idx_non_zero = }") + # print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }") + # print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }") + # print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }") + if sum(idx_non_zero) == 0: + continue + accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero]) + + #assert accumulated_output.shape == main_positive.shape,# f"{accumulated_output.shape = }, {main_positive.shape = }" + + + return accumulated_output + main_positive \ No newline at end of file diff --git a/guidance/sd_step.py b/guidance/sd_step.py new file mode 100644 index 0000000000000000000000000000000000000000..a196c3df34efe7f1f13a0db75e9b6a0de4677c45 --- /dev/null +++ b/guidance/sd_step.py @@ -0,0 +1,264 @@ +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import StableDiffusionPipeline, DiffusionPipeline, DDPMScheduler, DDIMScheduler, EulerDiscreteScheduler, \ + EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ControlNetModel, \ + DDIMInverseScheduler +from diffusers.utils import BaseOutput, deprecate + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T + +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass + +from diffusers.utils import BaseOutput, randn_tensor + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise +def ddim_add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, +) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.step +def ddim_step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + delta_timestep: int = None, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + **kwargs +) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.FloatTensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + + if delta_timestep is None: + # 1. get previous step value (=t+1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + else: + prev_timestep = timestep - delta_timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + # if prev_timestep < timestep: + # else: + # variance = abs(self._get_variance(prev_timestep, timestep)) + + variance = abs(self._get_variance(timestep, prev_timestep)) + + std_dev_t = eta * variance + std_dev_t = min((1 - alpha_prod_t_prev) / 2, std_dev_t) ** 0.5 + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + prev_sample = torch.nan_to_num(prev_sample) + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + +def pred_original( + self, + model_output: torch.FloatTensor, + timesteps: int, + sample: torch.FloatTensor, +): + if isinstance(self, DDPMScheduler) or isinstance(self, DDIMScheduler): + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + # 1. compute alphas, betas + alpha_prod_t = alphas_cumprod[timesteps] + while len(alpha_prod_t.shape) < len(sample.shape): + alpha_prod_t = alpha_prod_t.unsqueeze(-1) + + beta_prod_t = 1 - alpha_prod_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + elif isinstance(self, EulerAncestralDiscreteScheduler) or isinstance(self, EulerDiscreteScheduler): + timestep = timesteps.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index].to(device=sample.device, dtype=sample.dtype) + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + else: + raise NotImplementedError + + return pred_original_sample \ No newline at end of file diff --git a/guidance/sd_utils.py b/guidance/sd_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41aafe779e56b95469bddccf8eb05fed4f08a4a0 --- /dev/null +++ b/guidance/sd_utils.py @@ -0,0 +1,487 @@ +from audioop import mul +from transformers import CLIPTextModel, CLIPTokenizer, logging +from diffusers import StableDiffusionPipeline, DiffusionPipeline, DDPMScheduler, DDIMScheduler, EulerDiscreteScheduler, \ + EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ControlNetModel, \ + DDIMInverseScheduler, UNet2DConditionModel +from diffusers.utils.import_utils import is_xformers_available +from os.path import isfile +from pathlib import Path +import os +import random + +import torchvision.transforms as T +# suppress partial model loading warning +logging.set_verbosity_error() + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as T +from torchvision.utils import save_image +from torch.cuda.amp import custom_bwd, custom_fwd +from .perpneg_utils import weighted_perpendicular_aggregator + +from .sd_step import * + +def rgb2sat(img, T=None): + max_ = torch.max(img, dim=1, keepdim=True).values + 1e-5 + min_ = torch.min(img, dim=1, keepdim=True).values + sat = (max_ - min_) / max_ + if T is not None: + sat = (1 - T) * sat + return sat + +class SpecifyGradient(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, input_tensor, gt_grad): + ctx.save_for_backward(gt_grad) + # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward. + return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, grad_scale): + gt_grad, = ctx.saved_tensors + gt_grad = gt_grad * grad_scale + return gt_grad, None + +def seed_everything(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = True + +class StableDiffusion(nn.Module): + def __init__(self, device, fp16, vram_O, t_range=[0.02, 0.98], max_t_range=0.98, num_train_timesteps=None, + ddim_inv=False, use_control_net=False, textual_inversion_path = None, + LoRA_path = None, guidance_opt=None): + super().__init__() + + self.device = device + self.precision_t = torch.float16 if fp16 else torch.float32 + + print(f'[INFO] loading stable diffusion...') + + model_key = guidance_opt.model_key + assert model_key is not None + + is_safe_tensor = guidance_opt.is_safe_tensor + base_model_key = "stabilityai/stable-diffusion-v1-5" if guidance_opt.base_model_key is None else guidance_opt.base_model_key # for finetuned model only + + if is_safe_tensor: + pipe = StableDiffusionPipeline.from_single_file(model_key, use_safetensors=True, torch_dtype=self.precision_t, load_safety_checker=False) + else: + pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t) + + self.ism = not guidance_opt.sds + self.scheduler = DDIMScheduler.from_pretrained(model_key if not is_safe_tensor else base_model_key, subfolder="scheduler", torch_dtype=self.precision_t) + self.sche_func = ddim_step + + if use_control_net: + controlnet_model_key = guidance_opt.controlnet_model_key + self.controlnet_depth = ControlNetModel.from_pretrained(controlnet_model_key,torch_dtype=self.precision_t).to(device) + + if vram_O: + pipe.enable_sequential_cpu_offload() + pipe.enable_vae_slicing() + pipe.unet.to(memory_format=torch.channels_last) + pipe.enable_attention_slicing(1) + pipe.enable_model_cpu_offload() + + pipe.enable_xformers_memory_efficient_attention() + + pipe = pipe.to(self.device) + if textual_inversion_path is not None: + pipe.load_textual_inversion(textual_inversion_path) + print("load textual inversion in:.{}".format(textual_inversion_path)) + + if LoRA_path is not None: + from lora_diffusion import tune_lora_scale, patch_pipe + print("load lora in:.{}".format(LoRA_path)) + patch_pipe( + pipe, + LoRA_path, + patch_text=True, + patch_ti=True, + patch_unet=True, + ) + tune_lora_scale(pipe.unet, 1.00) + tune_lora_scale(pipe.text_encoder, 1.00) + + self.pipe = pipe + self.vae = pipe.vae + self.tokenizer = pipe.tokenizer + self.text_encoder = pipe.text_encoder + self.unet = pipe.unet + + self.num_train_timesteps = num_train_timesteps if num_train_timesteps is not None else self.scheduler.config.num_train_timesteps + self.scheduler.set_timesteps(self.num_train_timesteps, device=device) + + self.timesteps = torch.flip(self.scheduler.timesteps, dims=(0, )) + self.min_step = int(self.num_train_timesteps * t_range[0]) + self.max_step = int(self.num_train_timesteps * t_range[1]) + self.warmup_step = int(self.num_train_timesteps*(max_t_range-t_range[1])) + + self.noise_temp = None + self.noise_gen = torch.Generator(self.device) + self.noise_gen.manual_seed(guidance_opt.noise_seed) + + self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience + self.rgb_latent_factors = torch.tensor([ + # R G B + [ 0.298, 0.207, 0.208], + [ 0.187, 0.286, 0.173], + [-0.158, 0.189, 0.264], + [-0.184, -0.271, -0.473] + ], device=self.device) + + + print(f'[INFO] loaded stable diffusion!') + + def augmentation(self, *tensors): + augs = T.Compose([ + T.RandomHorizontalFlip(p=0.5), + ]) + + channels = [ten.shape[1] for ten in tensors] + tensors_concat = torch.concat(tensors, dim=1) + tensors_concat = augs(tensors_concat) + + results = [] + cur_c = 0 + for i in range(len(channels)): + results.append(tensors_concat[:, cur_c:cur_c + channels[i], ...]) + cur_c += channels[i] + return (ten for ten in results) + + def add_noise_with_cfg(self, latents, noise, + ind_t, ind_prev_t, + text_embeddings=None, cfg=1.0, + delta_t=1, inv_steps=1, + is_noisy_latent=False, + eta=0.0): + + text_embeddings = text_embeddings.to(self.precision_t) + if cfg <= 1.0: + uncond_text_embedding = text_embeddings.reshape(2, -1, text_embeddings.shape[-2], text_embeddings.shape[-1])[1] + + unet = self.unet + + if is_noisy_latent: + prev_noisy_lat = latents + else: + prev_noisy_lat = self.scheduler.add_noise(latents, noise, self.timesteps[ind_prev_t]) + + cur_ind_t = ind_prev_t + cur_noisy_lat = prev_noisy_lat + + pred_scores = [] + + for i in range(inv_steps): + # pred noise + cur_noisy_lat_ = self.scheduler.scale_model_input(cur_noisy_lat, self.timesteps[cur_ind_t]).to(self.precision_t) + + if cfg > 1.0: + latent_model_input = torch.cat([cur_noisy_lat_, cur_noisy_lat_]) + timestep_model_input = self.timesteps[cur_ind_t].reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1) + unet_output = unet(latent_model_input, timestep_model_input, + encoder_hidden_states=text_embeddings).sample + + uncond, cond = torch.chunk(unet_output, chunks=2) + + unet_output = cond + cfg * (uncond - cond) # reverse cfg to enhance the distillation + else: + timestep_model_input = self.timesteps[cur_ind_t].reshape(1, 1).repeat(cur_noisy_lat_.shape[0], 1).reshape(-1) + unet_output = unet(cur_noisy_lat_, timestep_model_input, + encoder_hidden_states=uncond_text_embedding).sample + + pred_scores.append((cur_ind_t, unet_output)) + + next_ind_t = min(cur_ind_t + delta_t, ind_t) + cur_t, next_t = self.timesteps[cur_ind_t], self.timesteps[next_ind_t] + delta_t_ = next_t-cur_t if isinstance(self.scheduler, DDIMScheduler) else next_ind_t-cur_ind_t + + cur_noisy_lat = self.sche_func(self.scheduler, unet_output, cur_t, cur_noisy_lat, -delta_t_, eta).prev_sample + cur_ind_t = next_ind_t + + del unet_output + torch.cuda.empty_cache() + + if cur_ind_t == ind_t: + break + + return prev_noisy_lat, cur_noisy_lat, pred_scores[::-1] + + + @torch.no_grad() + def get_text_embeds(self, prompt, resolution=(512, 512)): + inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt') + embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0] + return embeddings + + def train_step_perpneg(self, text_embeddings, pred_rgb, pred_depth=None, pred_alpha=None, + grad_scale=1,use_control_net=False, + save_folder:Path=None, iteration=0, warm_up_rate = 0, weights = 0, + resolution=(512, 512), guidance_opt=None,as_latent=False, embedding_inverse = None): + + + # flip aug + pred_rgb, pred_depth, pred_alpha = self.augmentation(pred_rgb, pred_depth, pred_alpha) + + B = pred_rgb.shape[0] + K = text_embeddings.shape[0] - 1 + + if as_latent: + latents,_ = self.encode_imgs(pred_depth.repeat(1,3,1,1).to(self.precision_t)) + else: + latents,_ = self.encode_imgs(pred_rgb.to(self.precision_t)) + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + + weights = weights.reshape(-1) + noise = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1) + + inverse_text_embeddings = embedding_inverse.unsqueeze(1).repeat(1, B, 1, 1).reshape(-1, embedding_inverse.shape[-2], embedding_inverse.shape[-1]) + + text_embeddings = text_embeddings.reshape(-1, text_embeddings.shape[-2], text_embeddings.shape[-1]) # make it k+1, c * t, ... + + if guidance_opt.annealing_intervals: + current_delta_t = int(guidance_opt.delta_t + np.ceil((warm_up_rate)*(guidance_opt.delta_t_start - guidance_opt.delta_t))) + else: + current_delta_t = guidance_opt.delta_t + + ind_t = torch.randint(self.min_step, self.max_step + int(self.warmup_step*warm_up_rate), (1, ), dtype=torch.long, generator=self.noise_gen, device=self.device)[0] + ind_prev_t = max(ind_t - current_delta_t, torch.ones_like(ind_t) * 0) + + t = self.timesteps[ind_t] + prev_t = self.timesteps[ind_prev_t] + + with torch.no_grad(): + # step unroll via ddim inversion + if not self.ism: + prev_latents_noisy = self.scheduler.add_noise(latents, noise, prev_t) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + target = noise + else: + # Step 1: sample x_s with larger steps + xs_delta_t = guidance_opt.xs_delta_t if guidance_opt.xs_delta_t is not None else current_delta_t + xs_inv_steps = guidance_opt.xs_inv_steps if guidance_opt.xs_inv_steps is not None else int(np.ceil(ind_prev_t / xs_delta_t)) + starting_ind = max(ind_prev_t - xs_delta_t * xs_inv_steps, torch.ones_like(ind_t) * 0) + + _, prev_latents_noisy, pred_scores_xs = self.add_noise_with_cfg(latents, noise, ind_prev_t, starting_ind, inverse_text_embeddings, + guidance_opt.denoise_guidance_scale, xs_delta_t, xs_inv_steps, eta=guidance_opt.xs_eta) + # Step 2: sample x_t + _, latents_noisy, pred_scores_xt = self.add_noise_with_cfg(prev_latents_noisy, noise, ind_t, ind_prev_t, inverse_text_embeddings, + guidance_opt.denoise_guidance_scale, current_delta_t, 1, is_noisy_latent=True) + + pred_scores = pred_scores_xt + pred_scores_xs + target = pred_scores[0][1] + + + with torch.no_grad(): + latent_model_input = latents_noisy[None, :, ...].repeat(1 + K, 1, 1, 1, 1).reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ) + tt = t.reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, tt[0]) + if use_control_net: + pred_depth_input = pred_depth_input[None, :, ...].repeat(1 + K, 1, 3, 1, 1).reshape(-1, 3, 512, 512).half() + down_block_res_samples, mid_block_res_sample = self.controlnet_depth( + latent_model_input, + tt, + encoder_hidden_states=text_embeddings, + controlnet_cond=pred_depth_input, + return_dict=False, + ) + unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample).sample + else: + unet_output = self.unet(latent_model_input.to(self.precision_t), tt.to(self.precision_t), encoder_hidden_states=text_embeddings.to(self.precision_t)).sample + + unet_output = unet_output.reshape(1 + K, -1, 4, resolution[0] // 8, resolution[1] // 8, ) + noise_pred_uncond, noise_pred_text = unet_output[:1].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ), unet_output[1:].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ) + delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1) + delta_DSD = weighted_perpendicular_aggregator(delta_noise_preds,\ + weights,\ + B) + + pred_noise = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD + w = lambda alphas: (((1 - alphas) / alphas) ** 0.5) + + grad = w(self.alphas[t]) * (pred_noise - target) + + grad = torch.nan_to_num(grad_scale * grad) + loss = SpecifyGradient.apply(latents, grad) + + if iteration % guidance_opt.vis_interval == 0: + noise_pred_post = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD + lat2rgb = lambda x: torch.clip((x.permute(0,2,3,1) @ self.rgb_latent_factors.to(x.dtype)).permute(0,3,1,2), 0., 1.) + save_path_iter = os.path.join(save_folder,"iter_{}_step_{}.jpg".format(iteration,prev_t.item())) + with torch.no_grad(): + pred_x0_latent_sp = pred_original(self.scheduler, noise_pred_uncond, prev_t, prev_latents_noisy) + pred_x0_latent_pos = pred_original(self.scheduler, noise_pred_post, prev_t, prev_latents_noisy) + pred_x0_pos = self.decode_latents(pred_x0_latent_pos.type(self.precision_t)) + pred_x0_sp = self.decode_latents(pred_x0_latent_sp.type(self.precision_t)) + + grad_abs = torch.abs(grad.detach()) + norm_grad = F.interpolate((grad_abs / grad_abs.max()).mean(dim=1,keepdim=True), (resolution[0], resolution[1]), mode='bilinear', align_corners=False).repeat(1,3,1,1) + + latents_rgb = F.interpolate(lat2rgb(latents), (resolution[0], resolution[1]), mode='bilinear', align_corners=False) + latents_sp_rgb = F.interpolate(lat2rgb(pred_x0_latent_sp), (resolution[0], resolution[1]), mode='bilinear', align_corners=False) + + viz_images = torch.cat([pred_rgb, + pred_depth.repeat(1, 3, 1, 1), + pred_alpha.repeat(1, 3, 1, 1), + rgb2sat(pred_rgb, pred_alpha).repeat(1, 3, 1, 1), + latents_rgb, latents_sp_rgb, + norm_grad, + pred_x0_sp, pred_x0_pos],dim=0) + save_image(viz_images, save_path_iter) + + + return loss + + + def train_step(self, text_embeddings, pred_rgb, pred_depth=None, pred_alpha=None, + grad_scale=1,use_control_net=False, + save_folder:Path=None, iteration=0, warm_up_rate = 0, + resolution=(512, 512), guidance_opt=None,as_latent=False, embedding_inverse = None): + + pred_rgb, pred_depth, pred_alpha = self.augmentation(pred_rgb, pred_depth, pred_alpha) + + B = pred_rgb.shape[0] + K = text_embeddings.shape[0] - 1 + + if as_latent: + latents,_ = self.encode_imgs(pred_depth.repeat(1,3,1,1).to(self.precision_t)) + else: + latents,_ = self.encode_imgs(pred_rgb.to(self.precision_t)) + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + + if self.noise_temp is None: + self.noise_temp = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1) + + if guidance_opt.fix_noise: + noise = self.noise_temp + else: + noise = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1) + + text_embeddings = text_embeddings[:, :, ...] + text_embeddings = text_embeddings.reshape(-1, text_embeddings.shape[-2], text_embeddings.shape[-1]) # make it k+1, c * t, ... + + inverse_text_embeddings = embedding_inverse.unsqueeze(1).repeat(1, B, 1, 1).reshape(-1, embedding_inverse.shape[-2], embedding_inverse.shape[-1]) + + if guidance_opt.annealing_intervals: + current_delta_t = int(guidance_opt.delta_t + (warm_up_rate)*(guidance_opt.delta_t_start - guidance_opt.delta_t)) + else: + current_delta_t = guidance_opt.delta_t + + ind_t = torch.randint(self.min_step, self.max_step + int(self.warmup_step*warm_up_rate), (1, ), dtype=torch.long, generator=self.noise_gen, device=self.device)[0] + ind_prev_t = max(ind_t - current_delta_t, torch.ones_like(ind_t) * 0) + + t = self.timesteps[ind_t] + prev_t = self.timesteps[ind_prev_t] + + with torch.no_grad(): + # step unroll via ddim inversion + if not self.ism: + prev_latents_noisy = self.scheduler.add_noise(latents, noise, prev_t) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + target = noise + else: + # Step 1: sample x_s with larger steps + xs_delta_t = guidance_opt.xs_delta_t if guidance_opt.xs_delta_t is not None else current_delta_t + xs_inv_steps = guidance_opt.xs_inv_steps if guidance_opt.xs_inv_steps is not None else int(np.ceil(ind_prev_t / xs_delta_t)) + starting_ind = max(ind_prev_t - xs_delta_t * xs_inv_steps, torch.ones_like(ind_t) * 0) + + _, prev_latents_noisy, pred_scores_xs = self.add_noise_with_cfg(latents, noise, ind_prev_t, starting_ind, inverse_text_embeddings, + guidance_opt.denoise_guidance_scale, xs_delta_t, xs_inv_steps, eta=guidance_opt.xs_eta) + # Step 2: sample x_t + _, latents_noisy, pred_scores_xt = self.add_noise_with_cfg(prev_latents_noisy, noise, ind_t, ind_prev_t, inverse_text_embeddings, + guidance_opt.denoise_guidance_scale, current_delta_t, 1, is_noisy_latent=True) + + pred_scores = pred_scores_xt + pred_scores_xs + target = pred_scores[0][1] + + + with torch.no_grad(): + latent_model_input = latents_noisy[None, :, ...].repeat(2, 1, 1, 1, 1).reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ) + tt = t.reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1) + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, tt[0]) + if use_control_net: + pred_depth_input = pred_depth_input[None, :, ...].repeat(1 + K, 1, 3, 1, 1).reshape(-1, 3, 512, 512).half() + down_block_res_samples, mid_block_res_sample = self.controlnet_depth( + latent_model_input, + tt, + encoder_hidden_states=text_embeddings, + controlnet_cond=pred_depth_input, + return_dict=False, + ) + unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample).sample + else: + unet_output = self.unet(latent_model_input.to(self.precision_t), tt.to(self.precision_t), encoder_hidden_states=text_embeddings.to(self.precision_t)).sample + + unet_output = unet_output.reshape(2, -1, 4, resolution[0] // 8, resolution[1] // 8, ) + noise_pred_uncond, noise_pred_text = unet_output[:1].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ), unet_output[1:].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ) + delta_DSD = noise_pred_text - noise_pred_uncond + + pred_noise = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD + + w = lambda alphas: (((1 - alphas) / alphas) ** 0.5) + + grad = w(self.alphas[t]) * (pred_noise - target) + + grad = torch.nan_to_num(grad_scale * grad) + loss = SpecifyGradient.apply(latents, grad) + + if iteration % guidance_opt.vis_interval == 0: + noise_pred_post = noise_pred_uncond + 7.5* delta_DSD + lat2rgb = lambda x: torch.clip((x.permute(0,2,3,1) @ self.rgb_latent_factors.to(x.dtype)).permute(0,3,1,2), 0., 1.) + save_path_iter = os.path.join(save_folder,"iter_{}_step_{}.jpg".format(iteration,prev_t.item())) + with torch.no_grad(): + pred_x0_latent_sp = pred_original(self.scheduler, noise_pred_uncond, prev_t, prev_latents_noisy) + pred_x0_latent_pos = pred_original(self.scheduler, noise_pred_post, prev_t, prev_latents_noisy) + pred_x0_pos = self.decode_latents(pred_x0_latent_pos.type(self.precision_t)) + pred_x0_sp = self.decode_latents(pred_x0_latent_sp.type(self.precision_t)) + # pred_x0_uncond = pred_x0_sp[:1, ...] + + grad_abs = torch.abs(grad.detach()) + norm_grad = F.interpolate((grad_abs / grad_abs.max()).mean(dim=1,keepdim=True), (resolution[0], resolution[1]), mode='bilinear', align_corners=False).repeat(1,3,1,1) + + latents_rgb = F.interpolate(lat2rgb(latents), (resolution[0], resolution[1]), mode='bilinear', align_corners=False) + latents_sp_rgb = F.interpolate(lat2rgb(pred_x0_latent_sp), (resolution[0], resolution[1]), mode='bilinear', align_corners=False) + + viz_images = torch.cat([pred_rgb, + pred_depth.repeat(1, 3, 1, 1), + pred_alpha.repeat(1, 3, 1, 1), + rgb2sat(pred_rgb, pred_alpha).repeat(1, 3, 1, 1), + latents_rgb, latents_sp_rgb, norm_grad, + pred_x0_sp, pred_x0_pos],dim=0) + save_image(viz_images, save_path_iter) + + return loss + + def decode_latents(self, latents): + target_dtype = latents.dtype + latents = latents / self.vae.config.scaling_factor + + imgs = self.vae.decode(latents.to(self.vae.dtype)).sample + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs.to(target_dtype) + + def encode_imgs(self, imgs): + target_dtype = imgs.dtype + # imgs: [B, 3, H, W] + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs.to(self.vae.dtype)).latent_dist + kl_divergence = posterior.kl() + + latents = posterior.sample() * self.vae.config.scaling_factor + + return latents.to(target_dtype), kl_divergence \ No newline at end of file diff --git a/lora_diffusion/__init__.py b/lora_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..286e4fecaaf7190bba2739bbc6e826a8969fece9 --- /dev/null +++ b/lora_diffusion/__init__.py @@ -0,0 +1,5 @@ +from .lora import * +from .dataset import * +from .utils import * +from .preprocess_files import * +from .lora_manager import * diff --git a/lora_diffusion/cli_lora_add.py b/lora_diffusion/cli_lora_add.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7f7e4ace4253ba63a1631a304cf935a23b034a --- /dev/null +++ b/lora_diffusion/cli_lora_add.py @@ -0,0 +1,187 @@ +from typing import Literal, Union, Dict +import os +import shutil +import fire +from diffusers import StableDiffusionPipeline +from safetensors.torch import safe_open, save_file + +import torch +from .lora import ( + tune_lora_scale, + patch_pipe, + collapse_lora, + monkeypatch_remove_lora, +) +from .lora_manager import lora_join +from .to_ckpt_v2 import convert_to_ckpt + + +def _text_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) + + +def add( + path_1: str, + path_2: str, + output_path: str, + alpha_1: float = 0.5, + alpha_2: float = 0.5, + mode: Literal[ + "lpl", + "upl", + "upl-ckpt-v2", + ] = "lpl", + with_text_lora: bool = False, +): + print("Lora Add, mode " + mode) + if mode == "lpl": + if path_1.endswith(".pt") and path_2.endswith(".pt"): + for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + ( + [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")] + if with_text_lora + else [] + ): + print("Loading", _path_1, _path_2) + out_list = [] + if opt == "text_encoder": + if not os.path.exists(_path_1): + print(f"No text encoder found in {_path_1}, skipping...") + continue + if not os.path.exists(_path_2): + print(f"No text encoder found in {_path_1}, skipping...") + continue + + l1 = torch.load(_path_1) + l2 = torch.load(_path_2) + + l1pairs = zip(l1[::2], l1[1::2]) + l2pairs = zip(l2[::2], l2[1::2]) + + for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs): + # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape) + x1.data = alpha_1 * x1.data + alpha_2 * x2.data + y1.data = alpha_1 * y1.data + alpha_2 * y2.data + + out_list.append(x1) + out_list.append(y1) + + if opt == "unet": + + print("Saving merged UNET to", output_path) + torch.save(out_list, output_path) + + elif opt == "text_encoder": + print("Saving merged text encoder to", _text_lora_path(output_path)) + torch.save( + out_list, + _text_lora_path(output_path), + ) + + elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"): + safeloras_1 = safe_open(path_1, framework="pt", device="cpu") + safeloras_2 = safe_open(path_2, framework="pt", device="cpu") + + metadata = dict(safeloras_1.metadata()) + metadata.update(dict(safeloras_2.metadata())) + + ret_tensor = {} + + for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())): + if keys.startswith("text_encoder") or keys.startswith("unet"): + + tens1 = safeloras_1.get_tensor(keys) + tens2 = safeloras_2.get_tensor(keys) + + tens = alpha_1 * tens1 + alpha_2 * tens2 + ret_tensor[keys] = tens + else: + if keys in safeloras_1.keys(): + + tens1 = safeloras_1.get_tensor(keys) + else: + tens1 = safeloras_2.get_tensor(keys) + + ret_tensor[keys] = tens1 + + save_file(ret_tensor, output_path, metadata) + + elif mode == "upl": + + print( + f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}." + ) + + loaded_pipeline = StableDiffusionPipeline.from_pretrained( + path_1, + ).to("cpu") + + patch_pipe(loaded_pipeline, path_2) + + collapse_lora(loaded_pipeline.unet, alpha_1) + collapse_lora(loaded_pipeline.text_encoder, alpha_1) + + monkeypatch_remove_lora(loaded_pipeline.unet) + monkeypatch_remove_lora(loaded_pipeline.text_encoder) + + loaded_pipeline.save_pretrained(output_path) + + elif mode == "upl-ckpt-v2": + + assert output_path.endswith(".ckpt"), "Only .ckpt files are supported" + name = os.path.basename(output_path)[0:-5] + + print( + f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token." + ) + + loaded_pipeline = StableDiffusionPipeline.from_pretrained( + path_1, + ).to("cpu") + + tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False) + + collapse_lora(loaded_pipeline.unet, alpha_1) + collapse_lora(loaded_pipeline.text_encoder, alpha_1) + + monkeypatch_remove_lora(loaded_pipeline.unet) + monkeypatch_remove_lora(loaded_pipeline.text_encoder) + + _tmp_output = output_path + ".tmp" + + loaded_pipeline.save_pretrained(_tmp_output) + convert_to_ckpt(_tmp_output, output_path, as_half=True) + # remove the tmp_output folder + shutil.rmtree(_tmp_output) + + keys = sorted(tok_dict.keys()) + tok_catted = torch.stack([tok_dict[k] for k in keys]) + ret = { + "string_to_token": {"*": torch.tensor(265)}, + "string_to_param": {"*": tok_catted}, + "name": name, + } + + torch.save(ret, output_path[:-5] + ".pt") + print( + f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, " + ) + elif mode == "ljl": + print("Using Join mode : alpha will not have an effect here.") + assert path_1.endswith(".safetensors") and path_2.endswith( + ".safetensors" + ), "Only .safetensors files are supported" + + safeloras_1 = safe_open(path_1, framework="pt", device="cpu") + safeloras_2 = safe_open(path_2, framework="pt", device="cpu") + + total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2]) + save_file(total_tensor, output_path, total_metadata) + + else: + print("Unknown mode", mode) + raise ValueError(f"Unknown mode {mode}") + + +def main(): + fire.Fire(add) diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py new file mode 100644 index 0000000000000000000000000000000000000000..7de4bae1506d9959708f6af49893217810520f05 --- /dev/null +++ b/lora_diffusion/cli_lora_pti.py @@ -0,0 +1,1040 @@ +# Bootstrapped from: +# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py + +import argparse +import hashlib +import inspect +import itertools +import math +import os +import random +import re +from pathlib import Path +from typing import Optional, List, Literal + +import torch +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.checkpoint +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer +import wandb +import fire + +from lora_diffusion import ( + PivotalTuningDatasetCapation, + extract_lora_ups_down, + inject_trainable_lora, + inject_trainable_lora_extended, + inspect_lora, + save_lora_weight, + save_all, + prepare_clip_model_sets, + evaluate_pipe, + UNET_EXTENDED_TARGET_REPLACE, +) + + +def get_models( + pretrained_model_name_or_path, + pretrained_vae_name_or_path, + revision, + placeholder_tokens: List[str], + initializer_tokens: List[str], + device="cuda:0", +): + + tokenizer = CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + revision=revision, + ) + + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + + placeholder_token_ids = [] + + for token, init_tok in zip(placeholder_tokens, initializer_tokens): + num_added_tokens = tokenizer.add_tokens(token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + placeholder_token_id = tokenizer.convert_tokens_to_ids(token) + + placeholder_token_ids.append(placeholder_token_id) + + # Load models and create wrapper for stable diffusion + + text_encoder.resize_token_embeddings(len(tokenizer)) + token_embeds = text_encoder.get_input_embeddings().weight.data + if init_tok.startswith(", e.g. + sigma_val = float(re.findall(r"", init_tok)[0]) + + token_embeds[placeholder_token_id] = ( + torch.randn_like(token_embeds[0]) * sigma_val + ) + print( + f"Initialized {token} with random noise (sigma={sigma_val}), empirically {token_embeds[placeholder_token_id].mean().item():.3f} +- {token_embeds[placeholder_token_id].std().item():.3f}" + ) + print(f"Norm : {token_embeds[placeholder_token_id].norm():.4f}") + + elif init_tok == "": + token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0]) + else: + token_ids = tokenizer.encode(init_tok, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id = token_ids[0] + token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] + + vae = AutoencoderKL.from_pretrained( + pretrained_vae_name_or_path or pretrained_model_name_or_path, + subfolder=None if pretrained_vae_name_or_path else "vae", + revision=None if pretrained_vae_name_or_path else revision, + ) + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + revision=revision, + ) + + return ( + text_encoder.to(device), + vae.to(device), + unet.to(device), + tokenizer, + placeholder_token_ids, + ) + + +@torch.no_grad() +def text2img_dataloader( + train_dataset, + train_batch_size, + tokenizer, + vae, + text_encoder, + cached_latents: bool = False, +): + + if cached_latents: + cached_latents_dataset = [] + for idx in tqdm(range(len(train_dataset))): + batch = train_dataset[idx] + # rint(batch) + latents = vae.encode( + batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device) + ).latent_dist.sample() + latents = latents * 0.18215 + batch["instance_images"] = latents.squeeze(0) + cached_latents_dataset.append(batch) + + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if examples[0].get("mask", None) is not None: + batch["mask"] = torch.stack([example["mask"] for example in examples]) + + return batch + + if cached_latents: + + train_dataloader = torch.utils.data.DataLoader( + cached_latents_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + + print("PTI : Using cached latent.") + + else: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + + return train_dataloader + + +def inpainting_dataloader( + train_dataset, train_batch_size, tokenizer, vae, text_encoder +): + def collate_fn(examples): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + mask_values = [example["instance_masks"] for example in examples] + masked_image_values = [ + example["instance_masked_images"] for example in examples + ] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if examples[0].get("class_prompt_ids", None) is not None: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + mask_values += [example["class_masks"] for example in examples] + masked_image_values += [ + example["class_masked_images"] for example in examples + ] + + pixel_values = ( + torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() + ) + mask_values = ( + torch.stack(mask_values).to(memory_format=torch.contiguous_format).float() + ) + masked_image_values = ( + torch.stack(masked_image_values) + .to(memory_format=torch.contiguous_format) + .float() + ) + + input_ids = tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + "mask_values": mask_values, + "masked_image_values": masked_image_values, + } + + if examples[0].get("mask", None) is not None: + batch["mask"] = torch.stack([example["mask"] for example in examples]) + + return batch + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=True, + collate_fn=collate_fn, + ) + + return train_dataloader + + +def loss_step( + batch, + unet, + vae, + text_encoder, + scheduler, + train_inpainting=False, + t_mutliplier=1.0, + mixed_precision=False, + mask_temperature=1.0, + cached_latents: bool = False, +): + weight_dtype = torch.float32 + if not cached_latents: + latents = vae.encode( + batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) + ).latent_dist.sample() + latents = latents * 0.18215 + + if train_inpainting: + masked_image_latents = vae.encode( + batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device) + ).latent_dist.sample() + masked_image_latents = masked_image_latents * 0.18215 + mask = F.interpolate( + batch["mask_values"].to(dtype=weight_dtype).to(unet.device), + scale_factor=1 / 8, + ) + else: + latents = batch["pixel_values"] + + if train_inpainting: + masked_image_latents = batch["masked_image_latents"] + mask = batch["mask_values"] + + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + timesteps = torch.randint( + 0, + int(scheduler.config.num_train_timesteps * t_mutliplier), + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + noisy_latents = scheduler.add_noise(latents, noise, timesteps) + + if train_inpainting: + latent_model_input = torch.cat( + [noisy_latents, mask, masked_image_latents], dim=1 + ) + else: + latent_model_input = noisy_latents + + if mixed_precision: + with torch.cuda.amp.autocast(): + + encoder_hidden_states = text_encoder( + batch["input_ids"].to(text_encoder.device) + )[0] + + model_pred = unet( + latent_model_input, timesteps, encoder_hidden_states + ).sample + else: + + encoder_hidden_states = text_encoder( + batch["input_ids"].to(text_encoder.device) + )[0] + + model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample + + if scheduler.config.prediction_type == "epsilon": + target = noise + elif scheduler.config.prediction_type == "v_prediction": + target = scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") + + if batch.get("mask", None) is not None: + + mask = ( + batch["mask"] + .to(model_pred.device) + .reshape( + model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8 + ) + ) + # resize to match model_pred + mask = F.interpolate( + mask.float(), + size=model_pred.shape[-2:], + mode="nearest", + ) + + mask = (mask + 0.01).pow(mask_temperature) + + mask = mask / mask.max() + + model_pred = model_pred * mask + + target = target * mask + + loss = ( + F.mse_loss(model_pred.float(), target.float(), reduction="none") + .mean([1, 2, 3]) + .mean() + ) + + return loss + + +def train_inversion( + unet, + vae, + text_encoder, + dataloader, + num_steps: int, + scheduler, + index_no_updates, + optimizer, + save_steps: int, + placeholder_token_ids, + placeholder_tokens, + save_path: str, + tokenizer, + lr_scheduler, + test_image_path: str, + cached_latents: bool, + accum_iter: int = 1, + log_wandb: bool = False, + wandb_log_prompt_cnt: int = 10, + class_token: str = "person", + train_inpainting: bool = False, + mixed_precision: bool = False, + clip_ti_decay: bool = True, +): + + progress_bar = tqdm(range(num_steps)) + progress_bar.set_description("Steps") + global_step = 0 + + # Original Emb for TI + orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone() + + if log_wandb: + preped_clip = prepare_clip_model_sets() + + index_updates = ~index_no_updates + loss_sum = 0.0 + + for epoch in range(math.ceil(num_steps / len(dataloader))): + unet.eval() + text_encoder.train() + for batch in dataloader: + + lr_scheduler.step() + + with torch.set_grad_enabled(True): + loss = ( + loss_step( + batch, + unet, + vae, + text_encoder, + scheduler, + train_inpainting=train_inpainting, + mixed_precision=mixed_precision, + cached_latents=cached_latents, + ) + / accum_iter + ) + + loss.backward() + loss_sum += loss.detach().item() + + if global_step % accum_iter == 0: + # print gradient of text encoder embedding + print( + text_encoder.get_input_embeddings() + .weight.grad[index_updates, :] + .norm(dim=-1) + .mean() + ) + optimizer.step() + optimizer.zero_grad() + + with torch.no_grad(): + + # normalize embeddings + if clip_ti_decay: + pre_norm = ( + text_encoder.get_input_embeddings() + .weight[index_updates, :] + .norm(dim=-1, keepdim=True) + ) + + lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) + text_encoder.get_input_embeddings().weight[ + index_updates + ] = F.normalize( + text_encoder.get_input_embeddings().weight[ + index_updates, : + ], + dim=-1, + ) * ( + pre_norm + lambda_ * (0.4 - pre_norm) + ) + print(pre_norm) + + current_norm = ( + text_encoder.get_input_embeddings() + .weight[index_updates, :] + .norm(dim=-1) + ) + + text_encoder.get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] + + print(f"Current Norm : {current_norm}") + + global_step += 1 + progress_bar.update(1) + + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + if global_step % save_steps == 0: + save_all( + unet=unet, + text_encoder=text_encoder, + placeholder_token_ids=placeholder_token_ids, + placeholder_tokens=placeholder_tokens, + save_path=os.path.join( + save_path, f"step_inv_{global_step}.safetensors" + ), + save_lora=False, + ) + if log_wandb: + with torch.no_grad(): + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + ) + + # open all images in test_image_path + images = [] + for file in os.listdir(test_image_path): + if ( + file.lower().endswith(".png") + or file.lower().endswith(".jpg") + or file.lower().endswith(".jpeg") + ): + images.append( + Image.open(os.path.join(test_image_path, file)) + ) + + wandb.log({"loss": loss_sum / save_steps}) + loss_sum = 0.0 + wandb.log( + evaluate_pipe( + pipe, + target_images=images, + class_token=class_token, + learnt_token="".join(placeholder_tokens), + n_test=wandb_log_prompt_cnt, + n_step=50, + clip_model_sets=preped_clip, + ) + ) + + if global_step >= num_steps: + return + + +def perform_tuning( + unet, + vae, + text_encoder, + dataloader, + num_steps, + scheduler, + optimizer, + save_steps: int, + placeholder_token_ids, + placeholder_tokens, + save_path, + lr_scheduler_lora, + lora_unet_target_modules, + lora_clip_target_modules, + mask_temperature, + out_name: str, + tokenizer, + test_image_path: str, + cached_latents: bool, + log_wandb: bool = False, + wandb_log_prompt_cnt: int = 10, + class_token: str = "person", + train_inpainting: bool = False, +): + + progress_bar = tqdm(range(num_steps)) + progress_bar.set_description("Steps") + global_step = 0 + + weight_dtype = torch.float16 + + unet.train() + text_encoder.train() + + if log_wandb: + preped_clip = prepare_clip_model_sets() + + loss_sum = 0.0 + + for epoch in range(math.ceil(num_steps / len(dataloader))): + for batch in dataloader: + lr_scheduler_lora.step() + + optimizer.zero_grad() + + loss = loss_step( + batch, + unet, + vae, + text_encoder, + scheduler, + train_inpainting=train_inpainting, + t_mutliplier=0.8, + mixed_precision=True, + mask_temperature=mask_temperature, + cached_latents=cached_latents, + ) + loss_sum += loss.detach().item() + + loss.backward() + torch.nn.utils.clip_grad_norm_( + itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 + ) + optimizer.step() + progress_bar.update(1) + logs = { + "loss": loss.detach().item(), + "lr": lr_scheduler_lora.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + global_step += 1 + + if global_step % save_steps == 0: + save_all( + unet, + text_encoder, + placeholder_token_ids=placeholder_token_ids, + placeholder_tokens=placeholder_tokens, + save_path=os.path.join( + save_path, f"step_{global_step}.safetensors" + ), + target_replace_module_text=lora_clip_target_modules, + target_replace_module_unet=lora_unet_target_modules, + ) + moved = ( + torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) + .mean() + .item() + ) + + print("LORA Unet Moved", moved) + moved = ( + torch.tensor( + list(itertools.chain(*inspect_lora(text_encoder).values())) + ) + .mean() + .item() + ) + + print("LORA CLIP Moved", moved) + + if log_wandb: + with torch.no_grad(): + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + ) + + # open all images in test_image_path + images = [] + for file in os.listdir(test_image_path): + if file.endswith(".png") or file.endswith(".jpg"): + images.append( + Image.open(os.path.join(test_image_path, file)) + ) + + wandb.log({"loss": loss_sum / save_steps}) + loss_sum = 0.0 + wandb.log( + evaluate_pipe( + pipe, + target_images=images, + class_token=class_token, + learnt_token="".join(placeholder_tokens), + n_test=wandb_log_prompt_cnt, + n_step=50, + clip_model_sets=preped_clip, + ) + ) + + if global_step >= num_steps: + break + + save_all( + unet, + text_encoder, + placeholder_token_ids=placeholder_token_ids, + placeholder_tokens=placeholder_tokens, + save_path=os.path.join(save_path, f"{out_name}.safetensors"), + target_replace_module_text=lora_clip_target_modules, + target_replace_module_unet=lora_unet_target_modules, + ) + + +def train( + instance_data_dir: str, + pretrained_model_name_or_path: str, + output_dir: str, + train_text_encoder: bool = True, + pretrained_vae_name_or_path: str = None, + revision: Optional[str] = None, + perform_inversion: bool = True, + use_template: Literal[None, "object", "style"] = None, + train_inpainting: bool = False, + placeholder_tokens: str = "", + placeholder_token_at_data: Optional[str] = None, + initializer_tokens: Optional[str] = None, + seed: int = 42, + resolution: int = 512, + color_jitter: bool = True, + train_batch_size: int = 1, + sample_batch_size: int = 1, + max_train_steps_tuning: int = 1000, + max_train_steps_ti: int = 1000, + save_steps: int = 100, + gradient_accumulation_steps: int = 4, + gradient_checkpointing: bool = False, + lora_rank: int = 4, + lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, + lora_clip_target_modules={"CLIPAttention"}, + lora_dropout_p: float = 0.0, + lora_scale: float = 1.0, + use_extended_lora: bool = False, + clip_ti_decay: bool = True, + learning_rate_unet: float = 1e-4, + learning_rate_text: float = 1e-5, + learning_rate_ti: float = 5e-4, + continue_inversion: bool = False, + continue_inversion_lr: Optional[float] = None, + use_face_segmentation_condition: bool = False, + cached_latents: bool = True, + use_mask_captioned_data: bool = False, + mask_temperature: float = 1.0, + scale_lr: bool = False, + lr_scheduler: str = "linear", + lr_warmup_steps: int = 0, + lr_scheduler_lora: str = "linear", + lr_warmup_steps_lora: int = 0, + weight_decay_ti: float = 0.00, + weight_decay_lora: float = 0.001, + use_8bit_adam: bool = False, + device="cuda:0", + extra_args: Optional[dict] = None, + log_wandb: bool = False, + wandb_log_prompt_cnt: int = 10, + wandb_project_name: str = "new_pti_project", + wandb_entity: str = "new_pti_entity", + proxy_token: str = "person", + enable_xformers_memory_efficient_attention: bool = False, + out_name: str = "final_lora", +): + torch.manual_seed(seed) + + if log_wandb: + wandb.init( + project=wandb_project_name, + entity=wandb_entity, + name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}", + reinit=True, + config={ + **(extra_args if extra_args is not None else {}), + }, + ) + + if output_dir is not None: + os.makedirs(output_dir, exist_ok=True) + # print(placeholder_tokens, initializer_tokens) + if len(placeholder_tokens) == 0: + placeholder_tokens = [] + print("PTI : Placeholder Tokens not given, using null token") + else: + placeholder_tokens = placeholder_tokens.split("|") + + assert ( + sorted(placeholder_tokens) == placeholder_tokens + ), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'" + + if initializer_tokens is None: + print("PTI : Initializer Tokens not given, doing random inits") + initializer_tokens = [""] * len(placeholder_tokens) + else: + initializer_tokens = initializer_tokens.split("|") + + assert len(initializer_tokens) == len( + placeholder_tokens + ), "Unequal Initializer token for Placeholder tokens." + + if proxy_token is not None: + class_token = proxy_token + class_token = "".join(initializer_tokens) + + if placeholder_token_at_data is not None: + tok, pat = placeholder_token_at_data.split("|") + token_map = {tok: pat} + + else: + token_map = {"DUMMY": "".join(placeholder_tokens)} + + print("PTI : Placeholder Tokens", placeholder_tokens) + print("PTI : Initializer Tokens", initializer_tokens) + + # get the models + text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models( + pretrained_model_name_or_path, + pretrained_vae_name_or_path, + revision, + placeholder_tokens, + initializer_tokens, + device=device, + ) + + noise_scheduler = DDPMScheduler.from_config( + pretrained_model_name_or_path, subfolder="scheduler" + ) + + if gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if enable_xformers_memory_efficient_attention: + from diffusers.utils.import_utils import is_xformers_available + + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if scale_lr: + unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size + text_encoder_lr = ( + learning_rate_text * gradient_accumulation_steps * train_batch_size + ) + ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size + else: + unet_lr = learning_rate_unet + text_encoder_lr = learning_rate_text + ti_lr = learning_rate_ti + + train_dataset = PivotalTuningDatasetCapation( + instance_data_root=instance_data_dir, + token_map=token_map, + use_template=use_template, + tokenizer=tokenizer, + size=resolution, + color_jitter=color_jitter, + use_face_segmentation_condition=use_face_segmentation_condition, + use_mask_captioned_data=use_mask_captioned_data, + train_inpainting=train_inpainting, + ) + + train_dataset.blur_amount = 200 + + if train_inpainting: + assert not cached_latents, "Cached latents not supported for inpainting" + + train_dataloader = inpainting_dataloader( + train_dataset, train_batch_size, tokenizer, vae, text_encoder + ) + else: + train_dataloader = text2img_dataloader( + train_dataset, + train_batch_size, + tokenizer, + vae, + text_encoder, + cached_latents=cached_latents, + ) + + index_no_updates = torch.arange(len(tokenizer)) != -1 + + for tok_id in placeholder_token_ids: + index_no_updates[tok_id] = False + + unet.requires_grad_(False) + vae.requires_grad_(False) + + params_to_freeze = itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + ) + for param in params_to_freeze: + param.requires_grad = False + + if cached_latents: + vae = None + # STEP 1 : Perform Inversion + if perform_inversion: + ti_optimizer = optim.AdamW( + text_encoder.get_input_embeddings().parameters(), + lr=ti_lr, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=weight_decay_ti, + ) + + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=ti_optimizer, + num_warmup_steps=lr_warmup_steps, + num_training_steps=max_train_steps_ti, + ) + + train_inversion( + unet, + vae, + text_encoder, + train_dataloader, + max_train_steps_ti, + cached_latents=cached_latents, + accum_iter=gradient_accumulation_steps, + scheduler=noise_scheduler, + index_no_updates=index_no_updates, + optimizer=ti_optimizer, + lr_scheduler=lr_scheduler, + save_steps=save_steps, + placeholder_tokens=placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + save_path=output_dir, + test_image_path=instance_data_dir, + log_wandb=log_wandb, + wandb_log_prompt_cnt=wandb_log_prompt_cnt, + class_token=class_token, + train_inpainting=train_inpainting, + mixed_precision=False, + tokenizer=tokenizer, + clip_ti_decay=clip_ti_decay, + ) + + del ti_optimizer + + # Next perform Tuning with LoRA: + if not use_extended_lora: + unet_lora_params, _ = inject_trainable_lora( + unet, + r=lora_rank, + target_replace_module=lora_unet_target_modules, + dropout_p=lora_dropout_p, + scale=lora_scale, + ) + else: + print("PTI : USING EXTENDED UNET!!!") + lora_unet_target_modules = ( + lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE + ) + print("PTI : Will replace modules: ", lora_unet_target_modules) + + unet_lora_params, _ = inject_trainable_lora_extended( + unet, r=lora_rank, target_replace_module=lora_unet_target_modules + ) + print(f"PTI : has {len(unet_lora_params)} lora") + + print("PTI : Before training:") + inspect_lora(unet) + + params_to_optimize = [ + {"params": itertools.chain(*unet_lora_params), "lr": unet_lr}, + ] + + text_encoder.requires_grad_(False) + + if continue_inversion: + params_to_optimize += [ + { + "params": text_encoder.get_input_embeddings().parameters(), + "lr": continue_inversion_lr + if continue_inversion_lr is not None + else ti_lr, + } + ] + text_encoder.requires_grad_(True) + params_to_freeze = itertools.chain( + text_encoder.text_model.encoder.parameters(), + text_encoder.text_model.final_layer_norm.parameters(), + text_encoder.text_model.embeddings.position_embedding.parameters(), + ) + for param in params_to_freeze: + param.requires_grad = False + else: + text_encoder.requires_grad_(False) + if train_text_encoder: + text_encoder_lora_params, _ = inject_trainable_lora( + text_encoder, + target_replace_module=lora_clip_target_modules, + r=lora_rank, + ) + params_to_optimize += [ + { + "params": itertools.chain(*text_encoder_lora_params), + "lr": text_encoder_lr, + } + ] + inspect_lora(text_encoder) + + lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora) + + unet.train() + if train_text_encoder: + text_encoder.train() + + train_dataset.blur_amount = 70 + + lr_scheduler_lora = get_scheduler( + lr_scheduler_lora, + optimizer=lora_optimizers, + num_warmup_steps=lr_warmup_steps_lora, + num_training_steps=max_train_steps_tuning, + ) + + perform_tuning( + unet, + vae, + text_encoder, + train_dataloader, + max_train_steps_tuning, + cached_latents=cached_latents, + scheduler=noise_scheduler, + optimizer=lora_optimizers, + save_steps=save_steps, + placeholder_tokens=placeholder_tokens, + placeholder_token_ids=placeholder_token_ids, + save_path=output_dir, + lr_scheduler_lora=lr_scheduler_lora, + lora_unet_target_modules=lora_unet_target_modules, + lora_clip_target_modules=lora_clip_target_modules, + mask_temperature=mask_temperature, + tokenizer=tokenizer, + out_name=out_name, + test_image_path=instance_data_dir, + log_wandb=log_wandb, + wandb_log_prompt_cnt=wandb_log_prompt_cnt, + class_token=class_token, + train_inpainting=train_inpainting, + ) + + +def main(): + fire.Fire(train) diff --git a/lora_diffusion/cli_pt_to_safetensors.py b/lora_diffusion/cli_pt_to_safetensors.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4be40d6f950bc24b771db31d9c0e00934a5e71 --- /dev/null +++ b/lora_diffusion/cli_pt_to_safetensors.py @@ -0,0 +1,85 @@ +import os + +import fire +import torch +from lora_diffusion import ( + DEFAULT_TARGET_REPLACE, + TEXT_ENCODER_DEFAULT_TARGET_REPLACE, + UNET_DEFAULT_TARGET_REPLACE, + convert_loras_to_safeloras_with_embeds, + safetensors_available, +) + +_target_by_name = { + "unet": UNET_DEFAULT_TARGET_REPLACE, + "text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE, +} + + +def convert(*paths, outpath, overwrite=False, **settings): + """ + Converts one or more pytorch Lora and/or Textual Embedding pytorch files + into a safetensor file. + + Pass all the input paths as arguments. Whether they are Textual Embedding + or Lora models will be auto-detected. + + For Lora models, their name will be taken from the path, i.e. + "lora_weight.pt" => unet + "lora_weight.text_encoder.pt" => text_encoder + + You can also set target_modules and/or rank by providing an argument prefixed + by the name. + + So a complete example might be something like: + + ``` + python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8 + ``` + """ + modelmap = {} + embeds = {} + + if os.path.exists(outpath) and not overwrite: + raise ValueError( + f"Output path {outpath} already exists, and overwrite is not True" + ) + + for path in paths: + data = torch.load(path) + + if isinstance(data, dict): + print(f"Loading textual inversion embeds {data.keys()} from {path}") + embeds.update(data) + + else: + name_parts = os.path.split(path)[1].split(".") + name = name_parts[-2] if len(name_parts) > 2 else "unet" + + model_settings = { + "target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE), + "rank": 4, + } + + prefix = f"{name}." + + arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) } + model_settings = { **model_settings, **arg_settings } + + print(f"Loading Lora for {name} from {path} with settings {model_settings}") + + modelmap[name] = ( + path, + model_settings["target_modules"], + model_settings["rank"], + ) + + convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath) + + +def main(): + fire.Fire(convert) + + +if __name__ == "__main__": + main() diff --git a/lora_diffusion/cli_svd.py b/lora_diffusion/cli_svd.py new file mode 100644 index 0000000000000000000000000000000000000000..cf52aa0b87314f86ac64f1ef7cd457a34fffabd7 --- /dev/null +++ b/lora_diffusion/cli_svd.py @@ -0,0 +1,146 @@ +import fire +from diffusers import StableDiffusionPipeline +import torch +import torch.nn as nn + +from .lora import ( + save_all, + _find_modules, + LoraInjectedConv2d, + LoraInjectedLinear, + inject_trainable_lora, + inject_trainable_lora_extended, +) + + +def _iter_lora(model): + for module in model.modules(): + if isinstance(module, LoraInjectedConv2d) or isinstance( + module, LoraInjectedLinear + ): + yield module + + +def overwrite_base(base_model, tuned_model, rank, clamp_quantile): + device = base_model.device + dtype = base_model.dtype + + for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)): + + if isinstance(lor_base, LoraInjectedLinear): + residual = lor_tune.linear.weight.data - lor_base.linear.weight.data + # SVD on residual + print("Distill Linear shape ", residual.shape) + residual = residual.float() + U, S, Vh = torch.linalg.svd(residual) + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + assert lor_base.lora_up.weight.shape == U.shape + assert lor_base.lora_down.weight.shape == Vh.shape + + lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype) + lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype) + + if isinstance(lor_base, LoraInjectedConv2d): + residual = lor_tune.conv.weight.data - lor_base.conv.weight.data + print("Distill Conv shape ", residual.shape) + + residual = residual.float() + residual = residual.flatten(start_dim=1) + + # SVD on residual + U, S, Vh = torch.linalg.svd(residual) + U = U[:, :rank] + S = S[:rank] + U = U @ torch.diag(S) + + Vh = Vh[:rank, :] + + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp_quantile) + low_val = -hi_val + + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + + # U is (out_channels, rank) with 1x1 conv. So, + U = U.reshape(U.shape[0], U.shape[1], 1, 1) + # V is (rank, in_channels * kernel_size1 * kernel_size2) + # now reshape: + Vh = Vh.reshape( + Vh.shape[0], + lor_base.conv.in_channels, + lor_base.conv.kernel_size[0], + lor_base.conv.kernel_size[1], + ) + + assert lor_base.lora_up.weight.shape == U.shape + assert lor_base.lora_down.weight.shape == Vh.shape + + lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype) + lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype) + + +def svd_distill( + target_model: str, + base_model: str, + rank: int = 4, + clamp_quantile: float = 0.99, + device: str = "cuda:0", + save_path: str = "svd_distill.safetensors", +): + pipe_base = StableDiffusionPipeline.from_pretrained( + base_model, torch_dtype=torch.float16 + ).to(device) + + pipe_tuned = StableDiffusionPipeline.from_pretrained( + target_model, torch_dtype=torch.float16 + ).to(device) + + # Inject unet + _ = inject_trainable_lora_extended(pipe_base.unet, r=rank) + _ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank) + + overwrite_base( + pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile + ) + + # Inject text encoder + _ = inject_trainable_lora( + pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"} + ) + _ = inject_trainable_lora( + pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"} + ) + + overwrite_base( + pipe_base.text_encoder, + pipe_tuned.text_encoder, + rank=rank, + clamp_quantile=clamp_quantile, + ) + + save_all( + unet=pipe_base.unet, + text_encoder=pipe_base.text_encoder, + placeholder_token_ids=None, + placeholder_tokens=None, + save_path=save_path, + save_lora=True, + save_ti=False, + ) + + +def main(): + fire.Fire(svd_distill) diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c28fd719e12442ecd5100da588370a0e5617bc --- /dev/null +++ b/lora_diffusion/dataset.py @@ -0,0 +1,311 @@ +import random +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +from PIL import Image +from torch import zeros_like +from torch.utils.data import Dataset +from torchvision import transforms +import glob +from .preprocess_files import face_mask_google_mediapipe + +OBJECT_TEMPLATE = [ + "a photo of a {}", + "a rendering of a {}", + "a cropped photo of the {}", + "the photo of a {}", + "a photo of a clean {}", + "a photo of a dirty {}", + "a dark photo of the {}", + "a photo of my {}", + "a photo of the cool {}", + "a close-up photo of a {}", + "a bright photo of the {}", + "a cropped photo of a {}", + "a photo of the {}", + "a good photo of the {}", + "a photo of one {}", + "a close-up photo of the {}", + "a rendition of the {}", + "a photo of the clean {}", + "a rendition of a {}", + "a photo of a nice {}", + "a good photo of a {}", + "a photo of the nice {}", + "a photo of the small {}", + "a photo of the weird {}", + "a photo of the large {}", + "a photo of a cool {}", + "a photo of a small {}", +] + +STYLE_TEMPLATE = [ + "a painting in the style of {}", + "a rendering in the style of {}", + "a cropped painting in the style of {}", + "the painting in the style of {}", + "a clean painting in the style of {}", + "a dirty painting in the style of {}", + "a dark painting in the style of {}", + "a picture in the style of {}", + "a cool painting in the style of {}", + "a close-up painting in the style of {}", + "a bright painting in the style of {}", + "a cropped painting in the style of {}", + "a good painting in the style of {}", + "a close-up painting in the style of {}", + "a rendition in the style of {}", + "a nice painting in the style of {}", + "a small painting in the style of {}", + "a weird painting in the style of {}", + "a large painting in the style of {}", +] + +NULL_TEMPLATE = ["{}"] + +TEMPLATE_MAP = { + "object": OBJECT_TEMPLATE, + "style": STYLE_TEMPLATE, + "null": NULL_TEMPLATE, +} + + +def _randomset(lis): + ret = [] + for i in range(len(lis)): + if random.random() < 0.5: + ret.append(lis[i]) + return ret + + +def _shuffle(lis): + + return random.sample(lis, len(lis)) + + +def _get_cutout_holes( + height, + width, + min_holes=8, + max_holes=32, + min_height=16, + max_height=128, + min_width=16, + max_width=128, +): + holes = [] + for _n in range(random.randint(min_holes, max_holes)): + hole_height = random.randint(min_height, max_height) + hole_width = random.randint(min_width, max_width) + y1 = random.randint(0, height - hole_height) + x1 = random.randint(0, width - hole_width) + y2 = y1 + hole_height + x2 = x1 + hole_width + holes.append((x1, y1, x2, y2)) + return holes + + +def _generate_random_mask(image): + mask = zeros_like(image[:1]) + holes = _get_cutout_holes(mask.shape[1], mask.shape[2]) + for (x1, y1, x2, y2) in holes: + mask[:, y1:y2, x1:x2] = 1.0 + if random.uniform(0, 1) < 0.25: + mask.fill_(1.0) + masked_image = image * (mask < 0.5) + return mask, masked_image + + +class PivotalTuningDatasetCapation(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + tokenizer, + token_map: Optional[dict] = None, + use_template: Optional[str] = None, + size=512, + h_flip=True, + color_jitter=False, + resize=True, + use_mask_captioned_data=False, + use_face_segmentation_condition=False, + train_inpainting=False, + blur_amount: int = 70, + ): + self.size = size + self.tokenizer = tokenizer + self.resize = resize + self.train_inpainting = train_inpainting + + instance_data_root = Path(instance_data_root) + if not instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = [] + self.mask_path = [] + + assert not ( + use_mask_captioned_data and use_template + ), "Can't use both mask caption data and template." + + # Prepare the instance images + if use_mask_captioned_data: + src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg") + for f in src_imgs: + idx = int(str(Path(f).stem).split(".")[0]) + mask_path = f"{instance_data_root}/{idx}.mask.png" + + if Path(mask_path).exists(): + self.instance_images_path.append(f) + self.mask_path.append(mask_path) + else: + print(f"Mask not found for {f}") + + self.captions = open(f"{instance_data_root}/caption.txt").readlines() + + else: + possibily_src_images = ( + glob.glob(str(instance_data_root) + "/*.jpg") + + glob.glob(str(instance_data_root) + "/*.png") + + glob.glob(str(instance_data_root) + "/*.jpeg") + ) + possibily_src_images = ( + set(possibily_src_images) + - set(glob.glob(str(instance_data_root) + "/*mask.png")) + - set([str(instance_data_root) + "/caption.txt"]) + ) + + self.instance_images_path = list(set(possibily_src_images)) + self.captions = [ + x.split("/")[-1].split(".")[0] for x in self.instance_images_path + ] + + assert ( + len(self.instance_images_path) > 0 + ), "No images found in the instance data root." + + self.instance_images_path = sorted(self.instance_images_path) + + self.use_mask = use_face_segmentation_condition or use_mask_captioned_data + self.use_mask_captioned_data = use_mask_captioned_data + + if use_face_segmentation_condition: + + for idx in range(len(self.instance_images_path)): + targ = f"{instance_data_root}/{idx}.mask.png" + # see if the mask exists + if not Path(targ).exists(): + print(f"Mask not found for {targ}") + + print( + "Warning : this will pre-process all the images in the instance data root." + ) + + if len(self.mask_path) > 0: + print( + "Warning : masks already exists, but will be overwritten." + ) + + masks = face_mask_google_mediapipe( + [ + Image.open(f).convert("RGB") + for f in self.instance_images_path + ] + ) + for idx, mask in enumerate(masks): + mask.save(f"{instance_data_root}/{idx}.mask.png") + + break + + for idx in range(len(self.instance_images_path)): + self.mask_path.append(f"{instance_data_root}/{idx}.mask.png") + + self.num_instance_images = len(self.instance_images_path) + self.token_map = token_map + + self.use_template = use_template + if use_template is not None: + self.templates = TEMPLATE_MAP[use_template] + + self._length = self.num_instance_images + + self.h_flip = h_flip + self.image_transforms = transforms.Compose( + [ + transforms.Resize( + size, interpolation=transforms.InterpolationMode.BILINEAR + ) + if resize + else transforms.Lambda(lambda x: x), + transforms.ColorJitter(0.1, 0.1) + if color_jitter + else transforms.Lambda(lambda x: x), + transforms.CenterCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.blur_amount = blur_amount + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open( + self.instance_images_path[index % self.num_instance_images] + ) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.train_inpainting: + ( + example["instance_masks"], + example["instance_masked_images"], + ) = _generate_random_mask(example["instance_images"]) + + if self.use_template: + assert self.token_map is not None + input_tok = list(self.token_map.values())[0] + + text = random.choice(self.templates).format(input_tok) + else: + text = self.captions[index % self.num_instance_images].strip() + + if self.token_map is not None: + for token, value in self.token_map.items(): + text = text.replace(token, value) + + print(text) + + if self.use_mask: + example["mask"] = ( + self.image_transforms( + Image.open(self.mask_path[index % self.num_instance_images]) + ) + * 0.5 + + 1.0 + ) + + if self.h_flip and random.random() > 0.5: + hflip = transforms.RandomHorizontalFlip(p=1) + + example["instance_images"] = hflip(example["instance_images"]) + if self.use_mask: + example["mask"] = hflip(example["mask"]) + + example["instance_prompt_ids"] = self.tokenizer( + text, + padding="do_not_pad", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids + + return example diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py new file mode 100644 index 0000000000000000000000000000000000000000..8753f15f7a49c145053b1b1ea0fca1c0442ef7f5 --- /dev/null +++ b/lora_diffusion/lora.py @@ -0,0 +1,1110 @@ +import json +import math +from itertools import groupby +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import numpy as np +import PIL +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from safetensors.torch import safe_open + from safetensors.torch import save_file as safe_save + + safetensors_available = True +except ImportError: + from .safe_open import safe_open + + def safe_save( + tensors: Dict[str, torch.Tensor], + filename: str, + metadata: Optional[Dict[str, str]] = None, + ) -> None: + raise EnvironmentError( + "Saving safetensors requires the safetensors library. Please install with pip or similar." + ) + + safetensors_available = False + + +class LoraInjectedLinear(nn.Module): + def __init__( + self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 + ): + super().__init__() + + if r > min(in_features, out_features): + raise ValueError( + f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" + ) + self.r = r + self.linear = nn.Linear(in_features, out_features, bias) + self.lora_down = nn.Linear(in_features, r, bias=False) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Linear(r, out_features, bias=False) + self.scale = scale + self.selector = nn.Identity() + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.linear(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Linear(self.r, self.r, bias=False) + self.selector.weight.data = torch.diag(diag) + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +class LoraInjectedConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups: int = 1, + bias: bool = True, + r: int = 4, + dropout_p: float = 0.1, + scale: float = 1.0, + ): + super().__init__() + if r > min(in_channels, out_channels): + raise ValueError( + f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}" + ) + self.r = r + self.conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + + self.lora_down = nn.Conv2d( + in_channels=in_channels, + out_channels=r, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=False, + ) + self.dropout = nn.Dropout(dropout_p) + self.lora_up = nn.Conv2d( + in_channels=r, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector = nn.Identity() + self.scale = scale + + nn.init.normal_(self.lora_down.weight, std=1 / r) + nn.init.zeros_(self.lora_up.weight) + + def forward(self, input): + return ( + self.conv(input) + + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) + * self.scale + ) + + def realize_as_lora(self): + return self.lora_up.weight.data * self.scale, self.lora_down.weight.data + + def set_selector_from_diag(self, diag: torch.Tensor): + # diag is a 1D tensor of size (r,) + assert diag.shape == (self.r,) + self.selector = nn.Conv2d( + in_channels=self.r, + out_channels=self.r, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.selector.weight.data = torch.diag(diag) + + # same device + dtype as lora_up + self.selector.weight.data = self.selector.weight.data.to( + self.lora_up.weight.device + ).to(self.lora_up.weight.dtype) + + +UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} + +UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} + +TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} + +TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} + +DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE + +EMBED_FLAG = "" + + +def _find_children( + model, + search_class: List[Type[nn.Module]] = [nn.Linear], +): + """ + Find all modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for parent in model.modules(): + for name, module in parent.named_children(): + if any([isinstance(module, _class) for _class in search_class]): + yield parent, name, module + + +def _find_modules_v2( + model, + ancestor_class: Optional[Set[str]] = None, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [ + LoraInjectedLinear, + LoraInjectedConv2d, + ], +): + """ + Find all modules of a certain class (or union of classes) that are direct or + indirect descendants of other modules of a certain class (or union of classes). + + Returns all matching modules, along with the parent of those moduless and the + names they are referenced by. + """ + + # Get the targets we should replace all linears under + if ancestor_class is not None: + ancestors = ( + module + for module in model.modules() + if module.__class__.__name__ in ancestor_class + ) + else: + # this, incase you want to naively iterate over all modules. + ancestors = [module for module in model.modules()] + + # For each target find every linear_class module that isn't a child of a LoraInjectedLinear + for ancestor in ancestors: + for fullname, module in ancestor.named_modules(): + if any([isinstance(module, _class) for _class in search_class]): + # Find the direct parent if this is a descendant, not a child, of target + *path, name = fullname.split(".") + parent = ancestor + while path: + parent = parent.get_submodule(path.pop(0)) + # Skip this linear if it's a child of a LoraInjectedLinear + if exclude_children_of and any( + [isinstance(parent, _class) for _class in exclude_children_of] + ): + continue + # Otherwise, yield it + yield parent, name, module + + +def _find_modules_old( + model, + ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], +): + ret = [] + for _module in model.modules(): + if _module.__class__.__name__ in ancestor_class: + + for name, _child_module in _module.named_modules(): + if _child_module.__class__ in search_class: + ret.append((_module, name, _child_module)) + print(ret) + return ret + + +_find_modules = _find_modules_v2 + + +def inject_trainable_lora( + model: nn.Module, + target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt + verbose: bool = False, + dropout_p: float = 0.0, + scale: float = 1.0, +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear] + ): + weight = _child_module.weight + bias = _child_module.bias + if verbose: + print("LoRA Injection : injecting lora into ", name) + print("LoRA Injection : weight shape", weight.shape) + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + dropout_p=dropout_p, + scale=scale, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def inject_trainable_lora_extended( + model: nn.Module, + target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE, + r: int = 4, + loras=None, # path to lora .pt +): + """ + inject lora into model, and returns lora parameter groups. + """ + + require_grad_params = [] + names = [] + + if loras != None: + loras = torch.load(loras) + + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, nn.Conv2d] + ): + if _child_module.__class__ == nn.Linear: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedLinear( + _child_module.in_features, + _child_module.out_features, + _child_module.bias is not None, + r=r, + ) + _tmp.linear.weight = weight + if bias is not None: + _tmp.linear.bias = bias + elif _child_module.__class__ == nn.Conv2d: + weight = _child_module.weight + bias = _child_module.bias + _tmp = LoraInjectedConv2d( + _child_module.in_channels, + _child_module.out_channels, + _child_module.kernel_size, + _child_module.stride, + _child_module.padding, + _child_module.dilation, + _child_module.groups, + _child_module.bias is not None, + r=r, + ) + + _tmp.conv.weight = weight + if bias is not None: + _tmp.conv.bias = bias + + # switch the module + _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) + if bias is not None: + _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype) + + _module._modules[name] = _tmp + + require_grad_params.append(_module._modules[name].lora_up.parameters()) + require_grad_params.append(_module._modules[name].lora_down.parameters()) + + if loras != None: + _module._modules[name].lora_up.weight = loras.pop(0) + _module._modules[name].lora_down.weight = loras.pop(0) + + _module._modules[name].lora_up.weight.requires_grad = True + _module._modules[name].lora_down.weight.requires_grad = True + names.append(name) + + return require_grad_params, names + + +def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): + + loras = [] + + for _m, _n, _child_module in _find_modules( + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d], + ): + loras.append((_child_module.lora_up, _child_module.lora_down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def extract_lora_as_tensor( + model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True +): + + loras = [] + + for _m, _n, _child_module in _find_modules( + model, + target_replace_module, + search_class=[LoraInjectedLinear, LoraInjectedConv2d], + ): + up, down = _child_module.realize_as_lora() + if as_fp16: + up = up.to(torch.float16) + down = down.to(torch.float16) + + loras.append((up, down)) + + if len(loras) == 0: + raise ValueError("No lora injected.") + + return loras + + +def save_lora_weight( + model, + path="./lora.pt", + target_replace_module=DEFAULT_TARGET_REPLACE, +): + weights = [] + for _up, _down in extract_lora_ups_down( + model, target_replace_module=target_replace_module + ): + weights.append(_up.weight.to("cpu").to(torch.float16)) + weights.append(_down.weight.to("cpu").to(torch.float16)) + + torch.save(weights, path) + + +def save_lora_as_json(model, path="./lora.json"): + weights = [] + for _up, _down in extract_lora_ups_down(model): + weights.append(_up.weight.detach().cpu().numpy().tolist()) + weights.append(_down.weight.detach().cpu().numpy().tolist()) + + import json + + with open(path, "w") as f: + json.dump(weights, f) + + +def save_safeloras_with_embeds( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Saves the Lora from multiple modules in a single safetensor file. + + modelmap is a dictionary of { + "module name": (module, target_replace_module) + } + """ + weights = {} + metadata = {} + + for name, (model, target_replace_module) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + for i, (_up, _down) in enumerate( + extract_lora_as_tensor(model, target_replace_module) + ): + rank = _down.shape[0] + + metadata[f"{name}:{i}:rank"] = str(rank) + weights[f"{name}:{i}:up"] = _up + weights[f"{name}:{i}:down"] = _down + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def save_safeloras( + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, + outpath="./lora.safetensors", +): + return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def convert_loras_to_safeloras_with_embeds( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + embeds: Dict[str, torch.Tensor] = {}, + outpath="./lora.safetensors", +): + """ + Converts the Lora from multiple pytorch .pt files into a single safetensor file. + + modelmap is a dictionary of { + "module name": (pytorch_model_path, target_replace_module, rank) + } + """ + + weights = {} + metadata = {} + + for name, (path, target_replace_module, r) in modelmap.items(): + metadata[name] = json.dumps(list(target_replace_module)) + + lora = torch.load(path) + for i, weight in enumerate(lora): + is_up = i % 2 == 0 + i = i // 2 + + if is_up: + metadata[f"{name}:{i}:rank"] = str(r) + weights[f"{name}:{i}:up"] = weight + else: + weights[f"{name}:{i}:down"] = weight + + for token, tensor in embeds.items(): + metadata[token] = EMBED_FLAG + weights[token] = tensor + + print(f"Saving weights to {outpath}") + safe_save(weights, outpath, metadata) + + +def convert_loras_to_safeloras( + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, + outpath="./lora.safetensors", +): + convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath) + + +def parse_safeloras( + safeloras, +) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: + """ + Converts a loaded safetensor file that contains a set of module Loras + into Parameters and other information + + Output is a dictionary of { + "module name": ( + [list of weights], + [list of ranks], + target_replacement_modules + ) + } + """ + loras = {} + metadata = safeloras.metadata() + + get_name = lambda k: k.split(":")[0] + + keys = list(safeloras.keys()) + keys.sort(key=get_name) + + for name, module_keys in groupby(keys, get_name): + info = metadata.get(name) + + if not info: + raise ValueError( + f"Tensor {name} has no metadata - is this a Lora safetensor?" + ) + + # Skip Textual Inversion embeds + if info == EMBED_FLAG: + continue + + # Handle Loras + # Extract the targets + target = json.loads(info) + + # Build the result lists - Python needs us to preallocate lists to insert into them + module_keys = list(module_keys) + ranks = [4] * (len(module_keys) // 2) + weights = [None] * len(module_keys) + + for key in module_keys: + # Split the model name and index out of the key + _, idx, direction = key.split(":") + idx = int(idx) + + # Add the rank + ranks[idx] = int(metadata[f"{name}:{idx}:rank"]) + + # Insert the weight into the list + idx = idx * 2 + (1 if direction == "down" else 0) + weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key)) + + loras[name] = (weights, ranks, target) + + return loras + + +def parse_safeloras_embeds( + safeloras, +) -> Dict[str, torch.Tensor]: + """ + Converts a loaded safetensor file that contains Textual Inversion embeds into + a dictionary of embed_token: Tensor + """ + embeds = {} + metadata = safeloras.metadata() + + for key in safeloras.keys(): + # Only handle Textual Inversion embeds + meta = metadata.get(key) + if not meta or meta != EMBED_FLAG: + continue + + embeds[key] = safeloras.get_tensor(key) + + return embeds + + +def load_safeloras(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras) + + +def load_safeloras_embeds(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras_embeds(safeloras) + + +def load_safeloras_both(path, device="cpu"): + safeloras = safe_open(path, framework="pt", device=device) + return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras) + + +def collapse_lora(model, alpha=1.0): + + for _module, name, _child_module in _find_modules( + model, + UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE, + search_class=[LoraInjectedLinear, LoraInjectedConv2d], + ): + + if isinstance(_child_module, LoraInjectedLinear): + print("Collapsing Lin Lora in", name) + + _child_module.linear.weight = nn.Parameter( + _child_module.linear.weight.data + + alpha + * ( + _child_module.lora_up.weight.data + @ _child_module.lora_down.weight.data + ) + .type(_child_module.linear.weight.dtype) + .to(_child_module.linear.weight.device) + ) + + else: + print("Collapsing Conv Lora in", name) + _child_module.conv.weight = nn.Parameter( + _child_module.conv.weight.data + + alpha + * ( + _child_module.lora_up.weight.data.flatten(start_dim=1) + @ _child_module.lora_down.weight.data.flatten(start_dim=1) + ) + .reshape(_child_module.conv.weight.data.shape) + .type(_child_module.conv.weight.dtype) + .to(_child_module.conv.weight.device) + ) + + +def monkeypatch_or_replace_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] + ): + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_lora_extended( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, +): + for _module, name, _child_module in _find_modules( + model, + target_replace_module, + search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d], + ): + + if (_child_module.__class__ == nn.Linear) or ( + _child_module.__class__ == LoraInjectedLinear + ): + if len(loras[0].shape) != 2: + continue + + _source = ( + _child_module.linear + if isinstance(_child_module, LoraInjectedLinear) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedLinear( + _source.in_features, + _source.out_features, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + _tmp.linear.weight = weight + + if bias is not None: + _tmp.linear.bias = bias + + elif (_child_module.__class__ == nn.Conv2d) or ( + _child_module.__class__ == LoraInjectedConv2d + ): + if len(loras[0].shape) != 4: + continue + _source = ( + _child_module.conv + if isinstance(_child_module, LoraInjectedConv2d) + else _child_module + ) + + weight = _source.weight + bias = _source.bias + _tmp = LoraInjectedConv2d( + _source.in_channels, + _source.out_channels, + _source.kernel_size, + _source.stride, + _source.padding, + _source.dilation, + _source.groups, + _source.bias is not None, + r=r.pop(0) if isinstance(r, list) else r, + ) + + _tmp.conv.weight = weight + + if bias is not None: + _tmp.conv.bias = bias + + # switch the module + _module._modules[name] = _tmp + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype) + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype) + ) + + _module._modules[name].to(weight.device) + + +def monkeypatch_or_replace_safeloras(models, safeloras): + loras = parse_safeloras(safeloras) + + for name, (lora, ranks, target) in loras.items(): + model = getattr(models, name, None) + + if not model: + print(f"No model provided for {name}, contained in Lora") + continue + + monkeypatch_or_replace_lora_extended(model, lora, target, ranks) + + +def monkeypatch_remove_lora(model): + for _module, name, _child_module in _find_modules( + model, search_class=[LoraInjectedLinear, LoraInjectedConv2d] + ): + if isinstance(_child_module, LoraInjectedLinear): + _source = _child_module.linear + weight, bias = _source.weight, _source.bias + + _tmp = nn.Linear( + _source.in_features, _source.out_features, bias is not None + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + else: + _source = _child_module.conv + weight, bias = _source.weight, _source.bias + + _tmp = nn.Conv2d( + in_channels=_source.in_channels, + out_channels=_source.out_channels, + kernel_size=_source.kernel_size, + stride=_source.stride, + padding=_source.padding, + dilation=_source.dilation, + groups=_source.groups, + bias=bias is not None, + ) + + _tmp.weight = weight + if bias is not None: + _tmp.bias = bias + + _module._modules[name] = _tmp + + +def monkeypatch_add_lora( + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + alpha: float = 1.0, + beta: float = 1.0, +): + for _module, name, _child_module in _find_modules( + model, target_replace_module, search_class=[LoraInjectedLinear] + ): + weight = _child_module.linear.weight + + up_weight = loras.pop(0) + down_weight = loras.pop(0) + + _module._modules[name].lora_up.weight = nn.Parameter( + up_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_up.weight.to(weight.device) * beta + ) + _module._modules[name].lora_down.weight = nn.Parameter( + down_weight.type(weight.dtype).to(weight.device) * alpha + + _module._modules[name].lora_down.weight.to(weight.device) * beta + ) + + _module._modules[name].to(weight.device) + + +def tune_lora_scale(model, alpha: float = 1.0): + for _module in model.modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: + _module.scale = alpha + + +def set_lora_diag(model, diag: torch.Tensor): + for _module in model.modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: + _module.set_selector_from_diag(diag) + + +def _text_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"]) + + +def _ti_lora_path(path: str) -> str: + assert path.endswith(".pt"), "Only .pt files are supported" + return ".".join(path.split(".")[:-1] + ["ti", "pt"]) + + +def apply_learned_embed_in_clip( + learned_embeds, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + if isinstance(token, str): + trained_tokens = [token] + elif isinstance(token, list): + assert len(learned_embeds.keys()) == len( + token + ), "The number of tokens and the number of embeds should be the same" + trained_tokens = token + else: + trained_tokens = list(learned_embeds.keys()) + + for token in trained_tokens: + print(token) + embeds = learned_embeds[token] + + # cast to dtype of text_encoder + dtype = text_encoder.get_input_embeddings().weight.dtype + num_added_tokens = tokenizer.add_tokens(token) + + i = 1 + if not idempotent: + while num_added_tokens == 0: + print(f"The tokenizer already contains the token {token}.") + token = f"{token[:-1]}-{i}>" + print(f"Attempting to add the token {token}.") + num_added_tokens = tokenizer.add_tokens(token) + i += 1 + elif num_added_tokens == 0 and idempotent: + print(f"The tokenizer already contains the token {token}.") + print(f"Replacing {token} embedding.") + + # resize the token embeddings + text_encoder.resize_token_embeddings(len(tokenizer)) + + # get the id for the token and assign the embeds + token_id = tokenizer.convert_tokens_to_ids(token) + text_encoder.get_input_embeddings().weight.data[token_id] = embeds + return token + + +def load_learned_embed_in_clip( + learned_embeds_path, + text_encoder, + tokenizer, + token: Optional[Union[str, List[str]]] = None, + idempotent=False, +): + learned_embeds = torch.load(learned_embeds_path) + apply_learned_embed_in_clip( + learned_embeds, text_encoder, tokenizer, token, idempotent + ) + + +def patch_pipe( + pipe, + maybe_unet_path, + token: Optional[str] = None, + r: int = 4, + patch_unet=True, + patch_text=True, + patch_ti=True, + idempotent_token=True, + unet_target_replace_module=DEFAULT_TARGET_REPLACE, + text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, +): + if maybe_unet_path.endswith(".pt"): + # torch format + + if maybe_unet_path.endswith(".ti.pt"): + unet_path = maybe_unet_path[:-6] + ".pt" + elif maybe_unet_path.endswith(".text_encoder.pt"): + unet_path = maybe_unet_path[:-16] + ".pt" + else: + unet_path = maybe_unet_path + + ti_path = _ti_lora_path(unet_path) + text_path = _text_lora_path(unet_path) + + if patch_unet: + print("LoRA : Patching Unet") + monkeypatch_or_replace_lora( + pipe.unet, + torch.load(unet_path), + r=r, + target_replace_module=unet_target_replace_module, + ) + + if patch_text: + print("LoRA : Patching text encoder") + monkeypatch_or_replace_lora( + pipe.text_encoder, + torch.load(text_path), + target_replace_module=text_target_replace_module, + r=r, + ) + if patch_ti: + print("LoRA : Patching token input") + token = load_learned_embed_in_clip( + ti_path, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + + elif maybe_unet_path.endswith(".safetensors"): + safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu") + monkeypatch_or_replace_safeloras(pipe, safeloras) + tok_dict = parse_safeloras_embeds(safeloras) + if patch_ti: + apply_learned_embed_in_clip( + tok_dict, + pipe.text_encoder, + pipe.tokenizer, + token=token, + idempotent=idempotent_token, + ) + return tok_dict + + +@torch.no_grad() +def inspect_lora(model): + moved = {} + + for name, _module in model.named_modules(): + if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]: + ups = _module.lora_up.weight.data.clone() + downs = _module.lora_down.weight.data.clone() + + wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1) + + dist = wght.flatten().abs().mean().item() + if name in moved: + moved[name].append(dist) + else: + moved[name] = [dist] + + return moved + + +def save_all( + unet, + text_encoder, + save_path, + placeholder_token_ids=None, + placeholder_tokens=None, + save_lora=True, + save_ti=True, + target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, + target_replace_module_unet=DEFAULT_TARGET_REPLACE, + safe_form=True, +): + if not safe_form: + # save ti + if save_ti: + ti_path = _ti_lora_path(save_path) + learned_embeds_dict = {} + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + learned_embeds_dict[tok] = learned_embeds.detach().cpu() + + torch.save(learned_embeds_dict, ti_path) + print("Ti saved to ", ti_path) + + # save text encoder + if save_lora: + + save_lora_weight( + unet, save_path, target_replace_module=target_replace_module_unet + ) + print("Unet saved to ", save_path) + + save_lora_weight( + text_encoder, + _text_lora_path(save_path), + target_replace_module=target_replace_module_text, + ) + print("Text Encoder saved to ", _text_lora_path(save_path)) + + else: + assert save_path.endswith( + ".safetensors" + ), f"Save path : {save_path} should end with .safetensors" + + loras = {} + embeds = {} + + if save_lora: + + loras["unet"] = (unet, target_replace_module_unet) + loras["text_encoder"] = (text_encoder, target_replace_module_text) + + if save_ti: + for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): + learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] + print( + f"Current Learned Embeddings for {tok}:, id {tok_id} ", + learned_embeds[:4], + ) + embeds[tok] = learned_embeds.detach().cpu() + + save_safeloras_with_embeds(loras, embeds, save_path) diff --git a/lora_diffusion/lora_manager.py b/lora_diffusion/lora_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8306e43cc6e730f7b14f2ba546584fbf81adec --- /dev/null +++ b/lora_diffusion/lora_manager.py @@ -0,0 +1,144 @@ +from typing import List +import torch +from safetensors import safe_open +from diffusers import StableDiffusionPipeline +from .lora import ( + monkeypatch_or_replace_safeloras, + apply_learned_embed_in_clip, + set_lora_diag, + parse_safeloras_embeds, +) + + +def lora_join(lora_safetenors: list): + metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors] + _total_metadata = {} + total_metadata = {} + total_tensor = {} + total_rank = 0 + ranklist = [] + for _metadata in metadatas: + rankset = [] + for k, v in _metadata.items(): + if k.endswith("rank"): + rankset.append(int(v)) + + assert len(set(rankset)) <= 1, "Rank should be the same per model" + if len(rankset) == 0: + rankset = [0] + + total_rank += rankset[0] + _total_metadata.update(_metadata) + ranklist.append(rankset[0]) + + # remove metadata about tokens + for k, v in _total_metadata.items(): + if v != "": + total_metadata[k] = v + + tensorkeys = set() + for safelora in lora_safetenors: + tensorkeys.update(safelora.keys()) + + for keys in tensorkeys: + if keys.startswith("text_encoder") or keys.startswith("unet"): + tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors] + + is_down = keys.endswith("down") + + if is_down: + _tensor = torch.cat(tensorset, dim=0) + assert _tensor.shape[0] == total_rank + else: + _tensor = torch.cat(tensorset, dim=1) + assert _tensor.shape[1] == total_rank + + total_tensor[keys] = _tensor + keys_rank = ":".join(keys.split(":")[:-1]) + ":rank" + total_metadata[keys_rank] = str(total_rank) + token_size_list = [] + for idx, safelora in enumerate(lora_safetenors): + tokens = [k for k, v in safelora.metadata().items() if v == ""] + for jdx, token in enumerate(sorted(tokens)): + + total_tensor[f""] = safelora.get_tensor(token) + total_metadata[f""] = "" + + print(f"Embedding {token} replaced to ") + + token_size_list.append(len(tokens)) + + return total_tensor, total_metadata, ranklist, token_size_list + + +class DummySafeTensorObject: + def __init__(self, tensor: dict, metadata): + self.tensor = tensor + self._metadata = metadata + + def keys(self): + return self.tensor.keys() + + def metadata(self): + return self._metadata + + def get_tensor(self, key): + return self.tensor[key] + + +class LoRAManager: + def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline): + + self.lora_paths_list = lora_paths_list + self.pipe = pipe + self._setup() + + def _setup(self): + + self._lora_safetenors = [ + safe_open(path, framework="pt", device="cpu") + for path in self.lora_paths_list + ] + + ( + total_tensor, + total_metadata, + self.ranklist, + self.token_size_list, + ) = lora_join(self._lora_safetenors) + + self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata) + + monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora) + tok_dict = parse_safeloras_embeds(self.total_safelora) + + apply_learned_embed_in_clip( + tok_dict, + self.pipe.text_encoder, + self.pipe.tokenizer, + token=None, + idempotent=True, + ) + + def tune(self, scales): + + assert len(scales) == len( + self.ranklist + ), "Scale list should be the same length as ranklist" + + diags = [] + for scale, rank in zip(scales, self.ranklist): + diags = diags + [scale] * rank + + set_lora_diag(self.pipe.unet, torch.tensor(diags)) + + def prompt(self, prompt): + if prompt is not None: + for idx, tok_size in enumerate(self.token_size_list): + prompt = prompt.replace( + f"<{idx + 1}>", + "".join([f"" for jdx in range(tok_size)]), + ) + # TODO : Rescale LoRA + Text inputs based on prompt scale params + + return prompt diff --git a/lora_diffusion/preprocess_files.py b/lora_diffusion/preprocess_files.py new file mode 100644 index 0000000000000000000000000000000000000000..bedb89f54dd8ad2b2a5b8b3f3eb69ffb763d38b4 --- /dev/null +++ b/lora_diffusion/preprocess_files.py @@ -0,0 +1,327 @@ +# Have SwinIR upsample +# Have BLIP auto caption +# Have CLIPSeg auto mask concept + +from typing import List, Literal, Union, Optional, Tuple +import os +from PIL import Image, ImageFilter +import torch +import numpy as np +import fire +from tqdm import tqdm +import glob +from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation + + +@torch.no_grad() +def swin_ir_sr( + images: List[Image.Image], + model_id: Literal[ + "caidas/swin2SR-classical-sr-x2-64", "caidas/swin2SR-classical-sr-x4-48" + ] = "caidas/swin2SR-classical-sr-x2-64", + target_size: Optional[Tuple[int, int]] = None, + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + **kwargs, +) -> List[Image.Image]: + """ + Upscales images using SwinIR. Returns a list of PIL images. + """ + # So this is currently in main branch, so this can be used in the future I guess? + from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor + + model = Swin2SRForImageSuperResolution.from_pretrained( + model_id, + ).to(device) + processor = Swin2SRImageProcessor() + + out_images = [] + + for image in tqdm(images): + + ori_w, ori_h = image.size + if target_size is not None: + if ori_w >= target_size[0] and ori_h >= target_size[1]: + out_images.append(image) + continue + + inputs = processor(image, return_tensors="pt").to(device) + with torch.no_grad(): + outputs = model(**inputs) + + output = ( + outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() + ) + output = np.moveaxis(output, source=0, destination=-1) + output = (output * 255.0).round().astype(np.uint8) + output = Image.fromarray(output) + + out_images.append(output) + + return out_images + + +@torch.no_grad() +def clipseg_mask_generator( + images: List[Image.Image], + target_prompts: Union[List[str], str], + model_id: Literal[ + "CIDAS/clipseg-rd64-refined", "CIDAS/clipseg-rd16" + ] = "CIDAS/clipseg-rd64-refined", + device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), + bias: float = 0.01, + temp: float = 1.0, + **kwargs, +) -> List[Image.Image]: + """ + Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image + """ + + if isinstance(target_prompts, str): + print( + f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images' + ) + + target_prompts = [target_prompts] * len(images) + + processor = CLIPSegProcessor.from_pretrained(model_id) + model = CLIPSegForImageSegmentation.from_pretrained(model_id).to(device) + + masks = [] + + for image, prompt in tqdm(zip(images, target_prompts)): + + original_size = image.size + + inputs = processor( + text=[prompt, ""], + images=[image] * 2, + padding="max_length", + truncation=True, + return_tensors="pt", + ).to(device) + + outputs = model(**inputs) + + logits = outputs.logits + probs = torch.nn.functional.softmax(logits / temp, dim=0)[0] + probs = (probs + bias).clamp_(0, 1) + probs = 255 * probs / probs.max() + + # make mask greyscale + mask = Image.fromarray(probs.cpu().numpy()).convert("L") + + # resize mask to original size + mask = mask.resize(original_size) + + masks.append(mask) + + return masks + + +@torch.no_grad() +def blip_captioning_dataset( + images: List[Image.Image], + text: Optional[str] = None, + model_id: Literal[ + "Salesforce/blip-image-captioning-large", + "Salesforce/blip-image-captioning-base", + ] = "Salesforce/blip-image-captioning-large", + device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + **kwargs, +) -> List[str]: + """ + Returns a list of captions for the given images + """ + + from transformers import BlipProcessor, BlipForConditionalGeneration + + processor = BlipProcessor.from_pretrained(model_id) + model = BlipForConditionalGeneration.from_pretrained(model_id).to(device) + captions = [] + + for image in tqdm(images): + inputs = processor(image, text=text, return_tensors="pt").to("cuda") + out = model.generate( + **inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7 + ) + caption = processor.decode(out[0], skip_special_tokens=True) + + captions.append(caption) + + return captions + + +def face_mask_google_mediapipe( + images: List[Image.Image], blur_amount: float = 80.0, bias: float = 0.05 +) -> List[Image.Image]: + """ + Returns a list of images with mask on the face parts. + """ + import mediapipe as mp + + mp_face_detection = mp.solutions.face_detection + + face_detection = mp_face_detection.FaceDetection( + model_selection=1, min_detection_confidence=0.5 + ) + + masks = [] + for image in tqdm(images): + + image = np.array(image) + + results = face_detection.process(image) + black_image = np.ones((image.shape[0], image.shape[1]), dtype=np.uint8) + + if results.detections: + + for detection in results.detections: + + x_min = int( + detection.location_data.relative_bounding_box.xmin * image.shape[1] + ) + y_min = int( + detection.location_data.relative_bounding_box.ymin * image.shape[0] + ) + width = int( + detection.location_data.relative_bounding_box.width * image.shape[1] + ) + height = int( + detection.location_data.relative_bounding_box.height + * image.shape[0] + ) + + # draw the colored rectangle + black_image[y_min : y_min + height, x_min : x_min + width] = 255 + + black_image = Image.fromarray(black_image) + masks.append(black_image) + + return masks + + +def _crop_to_square( + image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None +): + cx, cy = com + width, height = image.size + if width > height: + left_possible = max(cx - height / 2, 0) + left = min(left_possible, width - height) + right = left + height + top = 0 + bottom = height + else: + left = 0 + right = width + top_possible = max(cy - width / 2, 0) + top = min(top_possible, height - width) + bottom = top + width + + image = image.crop((left, top, right, bottom)) + + if resize_to: + image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS) + + return image + + +def _center_of_mass(mask: Image.Image): + """ + Returns the center of mass of the mask + """ + x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1])) + + x_ = x * np.array(mask) + y_ = y * np.array(mask) + + x = np.sum(x_) / np.sum(mask) + y = np.sum(y_) / np.sum(mask) + + return x, y + + +def load_and_save_masks_and_captions( + files: Union[str, List[str]], + output_dir: str, + caption_text: Optional[str] = None, + target_prompts: Optional[Union[List[str], str]] = None, + target_size: int = 512, + crop_based_on_salience: bool = True, + use_face_detection_instead: bool = False, + temp: float = 1.0, + n_length: int = -1, +): + """ + Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images + to output dir. + """ + os.makedirs(output_dir, exist_ok=True) + + # load images + if isinstance(files, str): + # check if it is a directory + if os.path.isdir(files): + # get all the .png .jpg in the directory + files = glob.glob(os.path.join(files, "*.png")) + glob.glob( + os.path.join(files, "*.jpg") + ) + + if len(files) == 0: + raise Exception( + f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg files." + ) + if n_length == -1: + n_length = len(files) + files = sorted(files)[:n_length] + + images = [Image.open(file) for file in files] + + # captions + print(f"Generating {len(images)} captions...") + captions = blip_captioning_dataset(images, text=caption_text) + + if target_prompts is None: + target_prompts = captions + + print(f"Generating {len(images)} masks...") + if not use_face_detection_instead: + seg_masks = clipseg_mask_generator( + images=images, target_prompts=target_prompts, temp=temp + ) + else: + seg_masks = face_mask_google_mediapipe(images=images) + + # find the center of mass of the mask + if crop_based_on_salience: + coms = [_center_of_mass(mask) for mask in seg_masks] + else: + coms = [(image.size[0] / 2, image.size[1] / 2) for image in images] + # based on the center of mass, crop the image to a square + images = [ + _crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms) + ] + + print(f"Upscaling {len(images)} images...") + # upscale images anyways + images = swin_ir_sr(images, target_size=(target_size, target_size)) + images = [ + image.resize((target_size, target_size), Image.Resampling.LANCZOS) + for image in images + ] + + seg_masks = [ + _crop_to_square(mask, com, resize_to=target_size) + for mask, com in zip(seg_masks, coms) + ] + with open(os.path.join(output_dir, "caption.txt"), "w") as f: + # save images and masks + for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)): + image.save(os.path.join(output_dir, f"{idx}.src.jpg"), quality=99) + mask.save(os.path.join(output_dir, f"{idx}.mask.png")) + + f.write(caption + "\n") + + +def main(): + fire.Fire(load_and_save_masks_and_captions) diff --git a/lora_diffusion/safe_open.py b/lora_diffusion/safe_open.py new file mode 100644 index 0000000000000000000000000000000000000000..77ada821b6dd22e1ea08f4c4eaf6c30a8d43161b --- /dev/null +++ b/lora_diffusion/safe_open.py @@ -0,0 +1,68 @@ +""" +Pure python version of Safetensors safe_open +From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282 +""" + +import json +import mmap +import os + +import torch + + +class SafetensorsWrapper: + def __init__(self, metadata, tensors): + self._metadata = metadata + self._tensors = tensors + + def metadata(self): + return self._metadata + + def keys(self): + return self._tensors.keys() + + def get_tensor(self, k): + return self._tensors[k] + + +DTYPES = { + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, +} + + +def create_tensor(storage, info, offset): + dtype = DTYPES[info["dtype"]] + shape = info["shape"] + start, stop = info["data_offsets"] + return ( + torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8) + .view(dtype=dtype) + .reshape(shape) + ) + + +def safe_open(filename, framework="pt", device="cpu"): + if framework != "pt": + raise ValueError("`framework` must be 'pt'") + + with open(filename, mode="r", encoding="utf8") as file_obj: + with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m: + header = m.read(8) + n = int.from_bytes(header, "little") + metadata_bytes = m.read(n) + metadata = json.loads(metadata_bytes) + + size = os.stat(filename).st_size + storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped() + offset = n + 8 + + return SafetensorsWrapper( + metadata=metadata.get("__metadata__", {}), + tensors={ + name: create_tensor(storage, info, offset).to(device) + for name, info in metadata.items() + if name != "__metadata__" + }, + ) diff --git a/lora_diffusion/to_ckpt_v2.py b/lora_diffusion/to_ckpt_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..15f3947118e3be01f7f68630a858dc95da220f2d --- /dev/null +++ b/lora_diffusion/to_ckpt_v2.py @@ -0,0 +1,232 @@ +# from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05 +# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. +# *Only* converts the UNet, VAE, and Text Encoder. +# Does not convert optimizer state or any other thing. +# Written by jachiam +import argparse +import os.path as osp + +import torch + + +# =================# +# UNet Conversion # +# =================# + +unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), +] + +unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), +] + +unet_conversion_map_layer = [] +# hardcoded number of downblocks and resnets/attentions... +# would need smarter logic for other networks. +for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + +hf_mid_atn_prefix = "mid_block.attentions.0." +sd_mid_atn_prefix = "middle_block.1." +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +def convert_unet_state_dict(unet_state_dict): + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + +vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), +] + +for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + +# this part accounts for mid blocks in both the encoder and the decoder +for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), +] + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + print(f"Reshaping {k} for SD format") + new_state_dict[k] = reshape_weight_for_sd(v) + return new_state_dict + + +# =========================# +# Text Encoder Conversion # +# =========================# +# pretty much a no-op + + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + + +def convert_to_ckpt(model_path, checkpoint_path, as_half): + + assert model_path is not None, "Must provide a model path!" + + assert checkpoint_path is not None, "Must provide a checkpoint path!" + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + + # Convert the UNet model + unet_state_dict = torch.load(unet_path, map_location="cpu") + unet_state_dict = convert_unet_state_dict(unet_state_dict) + unet_state_dict = { + "model.diffusion_model." + k: v for k, v in unet_state_dict.items() + } + + # Convert the VAE model + vae_state_dict = torch.load(vae_path, map_location="cpu") + vae_state_dict = convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Convert the text encoder model + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + text_enc_dict = convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = { + "cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items() + } + + # Put together new checkpoint + state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + if as_half: + state_dict = {k: v.half() for k, v in state_dict.items()} + state_dict = {"state_dict": state_dict} + torch.save(state_dict, checkpoint_path) diff --git a/lora_diffusion/utils.py b/lora_diffusion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a3410df7eb5f929a4e2ae662c0a1f563a3f760 --- /dev/null +++ b/lora_diffusion/utils.py @@ -0,0 +1,214 @@ +from typing import List, Union + +import torch +from PIL import Image +from transformers import ( + CLIPProcessor, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers import StableDiffusionPipeline +from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path +import os +import glob +import math + +EXAMPLE_PROMPTS = [ + " swimming in a pool", + " at a beach with a view of seashore", + " in times square", + " wearing sunglasses", + " in a construction outfit", + " playing with a ball", + " wearing headphones", + " oil painting ghibli inspired", + " working on the laptop", + " with mountains and sunset in background", + "Painting of at a beach by artist claude monet", + " digital painting 3d render geometric style", + "A screaming ", + "A depressed ", + "A sleeping ", + "A sad ", + "A joyous ", + "A frowning ", + "A sculpture of ", + " near a pool", + " at a beach with a view of seashore", + " in a garden", + " in grand canyon", + " floating in ocean", + " and an armchair", + "A maple tree on the side of ", + " and an orange sofa", + " with chocolate cake on it", + " with a vase of rose flowers on it", + "A digital illustration of ", + "Georgia O'Keeffe style painting", + "A watercolor painting of on a beach", +] + + +def image_grid(_imgs, rows=None, cols=None): + + if rows is None and cols is None: + rows = cols = math.ceil(len(_imgs) ** 0.5) + + if rows is None: + rows = math.ceil(len(_imgs) / cols) + if cols is None: + cols = math.ceil(len(_imgs) / rows) + + w, h = _imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(_imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def text_img_alignment(img_embeds, text_embeds, target_img_embeds): + # evaluation inspired from textual inversion paper + # https://arxiv.org/abs/2208.01618 + + # text alignment + assert img_embeds.shape[0] == text_embeds.shape[0] + text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / ( + img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1) + ) + + # image alignment + img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True) + + avg_target_img_embed = ( + (target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True)) + .mean(dim=0) + .unsqueeze(0) + .repeat(img_embeds.shape[0], 1) + ) + + img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1) + + return { + "text_alignment_avg": text_img_sim.mean().item(), + "image_alignment_avg": img_img_sim.mean().item(), + "text_alignment_all": text_img_sim.tolist(), + "image_alignment_all": img_img_sim.tolist(), + } + + +def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"): + text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id) + tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id) + vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id) + processor = CLIPProcessor.from_pretrained(eval_clip_id) + + return text_model, tokenizer, vis_model, processor + + +def evaluate_pipe( + pipe, + target_images: List[Image.Image], + class_token: str = "", + learnt_token: str = "", + guidance_scale: float = 5.0, + seed=0, + clip_model_sets=None, + eval_clip_id: str = "openai/clip-vit-large-patch14", + n_test: int = 10, + n_step: int = 50, +): + + if clip_model_sets is not None: + text_model, tokenizer, vis_model, processor = clip_model_sets + else: + text_model, tokenizer, vis_model, processor = prepare_clip_model_sets( + eval_clip_id + ) + + images = [] + img_embeds = [] + text_embeds = [] + for prompt in EXAMPLE_PROMPTS[:n_test]: + prompt = prompt.replace("", learnt_token) + torch.manual_seed(seed) + with torch.autocast("cuda"): + img = pipe( + prompt, num_inference_steps=n_step, guidance_scale=guidance_scale + ).images[0] + images.append(img) + + # image + inputs = processor(images=img, return_tensors="pt") + img_embed = vis_model(**inputs).image_embeds + img_embeds.append(img_embed) + + prompt = prompt.replace(learnt_token, class_token) + # prompts + inputs = tokenizer([prompt], padding=True, return_tensors="pt") + outputs = text_model(**inputs) + text_embed = outputs.text_embeds + text_embeds.append(text_embed) + + # target images + inputs = processor(images=target_images, return_tensors="pt") + target_img_embeds = vis_model(**inputs).image_embeds + + img_embeds = torch.cat(img_embeds, dim=0) + text_embeds = torch.cat(text_embeds, dim=0) + + return text_img_alignment(img_embeds, text_embeds, target_img_embeds) + + +def visualize_progress( + path_alls: Union[str, List[str]], + prompt: str, + model_id: str = "runwayml/stable-diffusion-v1-5", + device="cuda:0", + patch_unet=True, + patch_text=True, + patch_ti=True, + unet_scale=1.0, + text_sclae=1.0, + num_inference_steps=50, + guidance_scale=5.0, + offset: int = 0, + limit: int = 10, + seed: int = 0, +): + + imgs = [] + if isinstance(path_alls, str): + alls = list(set(glob.glob(path_alls))) + + alls.sort(key=os.path.getmtime) + else: + alls = path_alls + + pipe = StableDiffusionPipeline.from_pretrained( + model_id, torch_dtype=torch.float16 + ).to(device) + + print(f"Found {len(alls)} checkpoints") + for path in alls[offset:limit]: + print(path) + + patch_pipe( + pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti + ) + + tune_lora_scale(pipe.unet, unet_scale) + tune_lora_scale(pipe.text_encoder, text_sclae) + + torch.manual_seed(seed) + image = pipe( + prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + ).images[0] + imgs.append(image) + + return imgs diff --git a/lora_diffusion/xformers_utils.py b/lora_diffusion/xformers_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fdabf665e185147cfd0ef8662db48d05b94ea7f0 --- /dev/null +++ b/lora_diffusion/xformers_utils.py @@ -0,0 +1,70 @@ +import functools + +import torch +from diffusers.models.attention import BasicTransformerBlock +from diffusers.utils.import_utils import is_xformers_available + +from .lora import LoraInjectedLinear + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@functools.cache +def test_xformers_backwards(size): + @torch.enable_grad() + def _grad(size): + q = torch.randn((1, 4, size), device="cuda") + k = torch.randn((1, 4, size), device="cuda") + v = torch.randn((1, 4, size), device="cuda") + + q = q.detach().requires_grad_() + k = k.detach().requires_grad_() + v = v.detach().requires_grad_() + + out = xformers.ops.memory_efficient_attention(q, k, v) + loss = out.sum(2).mean(0).sum() + + return torch.autograd.grad(loss, v) + + try: + _grad(size) + print(size, "pass") + return True + except Exception as e: + print(size, "fail") + return False + + +def set_use_memory_efficient_attention_xformers( + module: torch.nn.Module, valid: bool +) -> None: + def fn_test_dim_head(module: torch.nn.Module): + if isinstance(module, BasicTransformerBlock): + # dim_head isn't stored anywhere, so back-calculate + source = module.attn1.to_v + if isinstance(source, LoraInjectedLinear): + source = source.linear + + dim_head = source.out_features // module.attn1.heads + + result = test_xformers_backwards(dim_head) + + # If dim_head > dim_head_max, turn xformers off + if not result: + module.set_use_memory_efficient_attention_xformers(False) + + for child in module.children(): + fn_test_dim_head(child) + + if not is_xformers_available() and valid: + print("XFormers is not available. Skipping.") + return + + module.set_use_memory_efficient_attention_xformers(valid) + + if valid: + fn_test_dim_head(module) diff --git a/scene/__init__.py b/scene/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1cecab8cc26043248b718303ed74663987849717 --- /dev/null +++ b/scene/__init__.py @@ -0,0 +1,98 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import random +import json +from utils.system_utils import searchForMaxIteration +from scene.dataset_readers import sceneLoadTypeCallbacks,GenerateRandomCameras,GeneratePurnCameras,GenerateCircleCameras +from scene.gaussian_model import GaussianModel +from arguments import ModelParams, GenerateCamParams +from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON, cameraList_from_RcamInfos + +class Scene: + + gaussians : GaussianModel + + def __init__(self, args : ModelParams, pose_args : GenerateCamParams, gaussians : GaussianModel, load_iteration=None, shuffle=False, resolution_scales=[1.0]): + """b + :param path: Path to colmap scene main folder. + """ + self.model_path = args._model_path + self.pretrained_model_path = args.pretrained_model_path + self.loaded_iter = None + self.gaussians = gaussians + self.resolution_scales = resolution_scales + self.pose_args = pose_args + self.args = args + if load_iteration: + if load_iteration == -1: + self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud")) + else: + self.loaded_iter = load_iteration + print("Loading trained model at iteration {}".format(self.loaded_iter)) + + self.test_cameras = {} + scene_info = sceneLoadTypeCallbacks["RandomCam"](self.model_path ,pose_args) + + json_cams = [] + camlist = [] + if scene_info.test_cameras: + camlist.extend(scene_info.test_cameras) + for id, cam in enumerate(camlist): + json_cams.append(camera_to_JSON(id, cam)) + with open(os.path.join(self.model_path, "cameras.json"), 'w') as file: + json.dump(json_cams, file) + + if shuffle: + random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling + self.cameras_extent = pose_args.default_radius # scene_info.nerf_normalization["radius"] + for resolution_scale in resolution_scales: + self.test_cameras[resolution_scale] = cameraList_from_RcamInfos(scene_info.test_cameras, resolution_scale, self.pose_args) + if self.loaded_iter: + self.gaussians.load_ply(os.path.join(self.model_path, + "point_cloud", + "iteration_" + str(self.loaded_iter), + "point_cloud.ply")) + elif self.pretrained_model_path is not None: + self.gaussians.load_ply(self.pretrained_model_path) + else: + self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent) + + def save(self, iteration): + point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration)) + self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply")) + + def getRandTrainCameras(self, scale=1.0): + rand_train_cameras = GenerateRandomCameras(self.pose_args, self.args.batch, SSAA=True) + train_cameras = {} + for resolution_scale in self.resolution_scales: + train_cameras[resolution_scale] = cameraList_from_RcamInfos(rand_train_cameras, resolution_scale, self.pose_args, SSAA=True) + return train_cameras[scale] + + + def getPurnTrainCameras(self, scale=1.0): + rand_train_cameras = GeneratePurnCameras(self.pose_args) + train_cameras = {} + for resolution_scale in self.resolution_scales: + train_cameras[resolution_scale] = cameraList_from_RcamInfos(rand_train_cameras, resolution_scale, self.pose_args) + return train_cameras[scale] + + + def getTestCameras(self, scale=1.0): + return self.test_cameras[scale] + + def getCircleVideoCameras(self, scale=1.0,batch_size=120, render45 = True): + video_circle_cameras = GenerateCircleCameras(self.pose_args,batch_size,render45) + video_cameras = {} + for resolution_scale in self.resolution_scales: + video_cameras[resolution_scale] = cameraList_from_RcamInfos(video_circle_cameras, resolution_scale, self.pose_args) + return video_cameras[scale] \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py new file mode 100644 index 0000000000000000000000000000000000000000..e286fcb4ba89c84e5f19caf9cdc3f3bae7676c50 --- /dev/null +++ b/scene/cameras.py @@ -0,0 +1,138 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +from torch import nn +import numpy as np +from utils.graphics_utils import getWorld2View2, getProjectionMatrix, fov2focal + +def get_rays_torch(focal, c2w, H=64,W=64): + """Computes rays using a General Pinhole Camera Model + Assumes self.h, self.w, self.focal, and self.cam_to_world exist + """ + x, y = torch.meshgrid( + torch.arange(W), # X-Axis (columns) + torch.arange(H), # Y-Axis (rows) + indexing='xy') + camera_directions = torch.stack( + [(x - W * 0.5 + 0.5) / focal, + -(y - H * 0.5 + 0.5) / focal, + -torch.ones_like(x)], + dim=-1).to(c2w) + + # Rotate ray directions from camera frame to the world frame + directions = ((camera_directions[ None,..., None, :] * c2w[None,None, None, :3, :3]).sum(axis=-1)) # Translate camera frame's origin to the world frame + origins = torch.broadcast_to(c2w[ None,None, None, :3, -1], directions.shape) + viewdirs = directions / torch.linalg.norm(directions, axis=-1, keepdims=True) + + return torch.cat((origins,viewdirs),dim=-1) + + +class Camera(nn.Module): + def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, + image_name, uid, + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + ): + super(Camera, self).__init__() + + self.uid = uid + self.colmap_id = colmap_id + self.R = R + self.T = T + self.FoVx = FoVx + self.FoVy = FoVy + self.image_name = image_name + + try: + self.data_device = torch.device(data_device) + except Exception as e: + print(e) + print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) + self.data_device = torch.device("cuda") + + self.original_image = image.clamp(0.0, 1.0).to(self.data_device) + self.image_width = self.original_image.shape[2] + self.image_height = self.original_image.shape[1] + + if gt_alpha_mask is not None: + self.original_image *= gt_alpha_mask.to(self.data_device) + else: + self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + + self.zfar = 100.0 + self.znear = 0.01 + + self.trans = trans + self.scale = scale + + self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + + +class RCamera(nn.Module): + def __init__(self, colmap_id, R, T, FoVx, FoVy, uid, delta_polar, delta_azimuth, delta_radius, opt, + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", SSAA=False + ): + super(RCamera, self).__init__() + + self.uid = uid + self.colmap_id = colmap_id + self.R = R + self.T = T + self.FoVx = FoVx + self.FoVy = FoVy + self.delta_polar = delta_polar + self.delta_azimuth = delta_azimuth + self.delta_radius = delta_radius + try: + self.data_device = torch.device(data_device) + except Exception as e: + print(e) + print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) + self.data_device = torch.device("cuda") + + self.zfar = 100.0 + self.znear = 0.01 + + if SSAA: + ssaa = opt.SSAA + else: + ssaa = 1 + + self.image_width = opt.image_w * ssaa + self.image_height = opt.image_h * ssaa + + self.trans = trans + self.scale = scale + + RT = torch.tensor(getWorld2View2(R, T, trans, scale)) + self.world_view_transform = RT.transpose(0, 1).cuda() + self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() + self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) + self.camera_center = self.world_view_transform.inverse()[3, :3] + # self.rays = get_rays_torch(fov2focal(FoVx, 64), RT).cuda() + self.rays = get_rays_torch(fov2focal(FoVx, self.image_width//8), RT, H=self.image_height//8, W=self.image_width//8).cuda() + +class MiniCam: + def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): + self.image_width = width + self.image_height = height + self.FoVy = fovy + self.FoVx = fovx + self.znear = znear + self.zfar = zfar + self.world_view_transform = world_view_transform + self.full_proj_transform = full_proj_transform + view_inv = torch.inverse(self.world_view_transform) + self.camera_center = view_inv[3][:3] + diff --git a/scene/dataset_readers.py b/scene/dataset_readers.py new file mode 100644 index 0000000000000000000000000000000000000000..0a87fed7f65e9b903d8ecdba130c3eb9ccf28163 --- /dev/null +++ b/scene/dataset_readers.py @@ -0,0 +1,466 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import os +import sys +import torch +import random +import torch.nn.functional as F +from PIL import Image +from typing import NamedTuple +from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal +import numpy as np +import json +from pathlib import Path +from utils.pointe_utils import init_from_pointe +from plyfile import PlyData, PlyElement +from utils.sh_utils import SH2RGB +from utils.general_utils import inverse_sigmoid_np +from scene.gaussian_model import BasicPointCloud + + +class RandCameraInfo(NamedTuple): + uid: int + R: np.array + T: np.array + FovY: np.array + FovX: np.array + width: int + height: int + delta_polar : np.array + delta_azimuth : np.array + delta_radius : np.array + + +class SceneInfo(NamedTuple): + point_cloud: BasicPointCloud + train_cameras: list + test_cameras: list + nerf_normalization: dict + ply_path: str + + +class RSceneInfo(NamedTuple): + point_cloud: BasicPointCloud + test_cameras: list + ply_path: str + +# def getNerfppNorm(cam_info): +# def get_center_and_diag(cam_centers): +# cam_centers = np.hstack(cam_centers) +# avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True) +# center = avg_cam_center +# dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True) +# diagonal = np.max(dist) +# return center.flatten(), diagonal + +# cam_centers = [] + +# for cam in cam_info: +# W2C = getWorld2View2(cam.R, cam.T) +# C2W = np.linalg.inv(W2C) +# cam_centers.append(C2W[:3, 3:4]) + +# center, diagonal = get_center_and_diag(cam_centers) +# radius = diagonal * 1.1 + +# translate = -center + +# return {"translate": translate, "radius": radius} + + + +def fetchPly(path): + plydata = PlyData.read(path) + vertices = plydata['vertex'] + positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T + colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0 + normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T + return BasicPointCloud(points=positions, colors=colors, normals=normals) + +def storePly(path, xyz, rgb): + # Define the dtype for the structured array + dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), + ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'), + ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] + + normals = np.zeros_like(xyz) + + elements = np.empty(xyz.shape[0], dtype=dtype) + attributes = np.concatenate((xyz, normals, rgb), axis=1) + elements[:] = list(map(tuple, attributes)) + + # Create the PlyData object and write to file + vertex_element = PlyElement.describe(elements, 'vertex') + ply_data = PlyData([vertex_element]) + ply_data.write(path) + +#only test_camera +def readCircleCamInfo(path,opt): + print("Reading Test Transforms") + test_cam_infos = GenerateCircleCameras(opt,render45 = opt.render_45) + ply_path = os.path.join(path, "init_points3d.ply") + if not os.path.exists(ply_path): + # Since this data set has no colmap data, we start with random points + num_pts = opt.init_num_pts + if opt.init_shape == 'sphere': + thetas = np.random.rand(num_pts)*np.pi + phis = np.random.rand(num_pts)*2*np.pi + radius = np.random.rand(num_pts)*0.5 + # We create random points inside the bounds of sphere + xyz = np.stack([ + radius * np.sin(thetas) * np.sin(phis), + radius * np.sin(thetas) * np.cos(phis), + radius * np.cos(thetas), + ], axis=-1) # [B, 3] + elif opt.init_shape == 'box': + xyz = np.random.random((num_pts, 3)) * 1.0 - 0.5 + elif opt.init_shape == 'rectangle_x': + xyz = np.random.random((num_pts, 3)) + xyz[:, 0] = xyz[:, 0] * 0.6 - 0.3 + xyz[:, 1] = xyz[:, 1] * 1.2 - 0.6 + xyz[:, 2] = xyz[:, 2] * 0.5 - 0.25 + elif opt.init_shape == 'rectangle_z': + xyz = np.random.random((num_pts, 3)) + xyz[:, 0] = xyz[:, 0] * 0.8 - 0.4 + xyz[:, 1] = xyz[:, 1] * 0.6 - 0.3 + xyz[:, 2] = xyz[:, 2] * 1.2 - 0.6 + elif opt.init_shape == 'pointe': + num_pts = int(num_pts/5000) + xyz,rgb = init_from_pointe(opt.init_prompt) + xyz[:,1] = - xyz[:,1] + xyz[:,2] = xyz[:,2] + 0.15 + thetas = np.random.rand(num_pts)*np.pi + phis = np.random.rand(num_pts)*2*np.pi + radius = np.random.rand(num_pts)*0.05 + # We create random points inside the bounds of sphere + xyz_ball = np.stack([ + radius * np.sin(thetas) * np.sin(phis), + radius * np.sin(thetas) * np.cos(phis), + radius * np.cos(thetas), + ], axis=-1) # [B, 3]expend_dims + rgb_ball = np.random.random((4096, num_pts, 3))*0.0001 + rgb = (np.expand_dims(rgb,axis=1)+rgb_ball).reshape(-1,3) + xyz = (np.expand_dims(xyz,axis=1)+np.expand_dims(xyz_ball,axis=0)).reshape(-1,3) + xyz = xyz * 1. + num_pts = xyz.shape[0] + elif opt.init_shape == 'scene': + thetas = np.random.rand(num_pts)*np.pi + phis = np.random.rand(num_pts)*2*np.pi + radius = np.random.rand(num_pts) + opt.radius_range[-1]*3 + # We create random points inside the bounds of sphere + xyz = np.stack([ + radius * np.sin(thetas) * np.sin(phis), + radius * np.sin(thetas) * np.cos(phis), + radius * np.cos(thetas), + ], axis=-1) # [B, 3] + else: + raise NotImplementedError() + print(f"Generating random point cloud ({num_pts})...") + + shs = np.random.random((num_pts, 3)) / 255.0 + + if opt.init_shape == 'pointe' and opt.use_pointe_rgb: + pcd = BasicPointCloud(points=xyz, colors=rgb, normals=np.zeros((num_pts, 3))) + storePly(ply_path, xyz, rgb * 255) + else: + pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))) + storePly(ply_path, xyz, SH2RGB(shs) * 255) + try: + pcd = fetchPly(ply_path) + except: + pcd = None + + scene_info = RSceneInfo(point_cloud=pcd, + test_cameras=test_cam_infos, + ply_path=ply_path) + return scene_info +#borrow from https://github.com/ashawkey/stable-dreamfusion + +def safe_normalize(x, eps=1e-20): + return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) + +# def circle_poses(radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), angle_overhead=30, angle_front=60): + +# theta = theta / 180 * np.pi +# phi = phi / 180 * np.pi +# angle_overhead = angle_overhead / 180 * np.pi +# angle_front = angle_front / 180 * np.pi + +# centers = torch.stack([ +# radius * torch.sin(theta) * torch.sin(phi), +# radius * torch.cos(theta), +# radius * torch.sin(theta) * torch.cos(phi), +# ], dim=-1) # [B, 3] + +# # lookat +# forward_vector = safe_normalize(centers) +# up_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(len(centers), 1) +# right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) +# up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1)) + +# poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(len(centers), 1, 1) +# poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) +# poses[:, :3, 3] = centers + +# return poses.numpy() + +def circle_poses(radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), angle_overhead=30, angle_front=60): + + theta = theta / 180 * np.pi + phi = phi / 180 * np.pi + angle_overhead = angle_overhead / 180 * np.pi + angle_front = angle_front / 180 * np.pi + + centers = torch.stack([ + radius * torch.sin(theta) * torch.sin(phi), + radius * torch.sin(theta) * torch.cos(phi), + radius * torch.cos(theta), + ], dim=-1) # [B, 3] + + # lookat + forward_vector = safe_normalize(centers) + up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(len(centers), 1) + right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) + up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(len(centers), 1, 1) + poses[:, :3, :3] = torch.stack((-right_vector, up_vector, forward_vector), dim=-1) + poses[:, :3, 3] = centers + + return poses.numpy() + +def gen_random_pos(size, param_range, gamma=1): + lower, higher = param_range[0], param_range[1] + + mid = lower + (higher - lower) * 0.5 + radius = (higher - lower) * 0.5 + + rand_ = torch.rand(size) # 0, 1 + sign = torch.where(torch.rand(size) > 0.5, torch.ones(size) * -1., torch.ones(size)) + rand_ = sign * (rand_ ** gamma) + + return (rand_ * radius) + mid + + +def rand_poses(size, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5, rand_cam_gamma=1): + ''' generate random poses from an orbit camera + Args: + size: batch size of generated poses. + device: where to allocate the output. + radius: camera radius + theta_range: [min, max], should be in [0, pi] + phi_range: [min, max], should be in [0, 2 * pi] + Return: + poses: [size, 4, 4] + ''' + + theta_range = np.array(theta_range) / 180 * np.pi + phi_range = np.array(phi_range) / 180 * np.pi + angle_overhead = angle_overhead / 180 * np.pi + angle_front = angle_front / 180 * np.pi + + # radius = torch.rand(size) * (radius_range[1] - radius_range[0]) + radius_range[0] + radius = gen_random_pos(size, radius_range) + + if random.random() < uniform_sphere_rate: + unit_centers = F.normalize( + torch.stack([ + torch.randn(size), + torch.abs(torch.randn(size)), + torch.randn(size), + ], dim=-1), p=2, dim=1 + ) + thetas = torch.acos(unit_centers[:,1]) + phis = torch.atan2(unit_centers[:,0], unit_centers[:,2]) + phis[phis < 0] += 2 * np.pi + centers = unit_centers * radius.unsqueeze(-1) + else: + # thetas = torch.rand(size) * (theta_range[1] - theta_range[0]) + theta_range[0] + # phis = torch.rand(size) * (phi_range[1] - phi_range[0]) + phi_range[0] + # phis[phis < 0] += 2 * np.pi + + # centers = torch.stack([ + # radius * torch.sin(thetas) * torch.sin(phis), + # radius * torch.cos(thetas), + # radius * torch.sin(thetas) * torch.cos(phis), + # ], dim=-1) # [B, 3] + # thetas = torch.rand(size) * (theta_range[1] - theta_range[0]) + theta_range[0] + # phis = torch.rand(size) * (phi_range[1] - phi_range[0]) + phi_range[0] + thetas = gen_random_pos(size, theta_range, rand_cam_gamma) + phis = gen_random_pos(size, phi_range, rand_cam_gamma) + phis[phis < 0] += 2 * np.pi + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + radius * torch.sin(thetas) * torch.cos(phis), + radius * torch.cos(thetas), + ], dim=-1) # [B, 3] + + targets = 0 + + # jitters + if opt.jitter_pose: + jit_center = opt.jitter_center # 0.015 # was 0.2 + jit_target = opt.jitter_target + centers += torch.rand_like(centers) * jit_center - jit_center/2.0 + targets += torch.randn_like(centers) * jit_target + + # lookat + forward_vector = safe_normalize(centers - targets) + up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1) + #up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1) + right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) + + if opt.jitter_pose: + up_noise = torch.randn_like(up_vector) * opt.jitter_up + else: + up_noise = 0 + + up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise) #forward_vector + + poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(size, 1, 1) + poses[:, :3, :3] = torch.stack((-right_vector, up_vector, forward_vector), dim=-1) #up_vector + poses[:, :3, 3] = centers + + + # back to degree + thetas = thetas / np.pi * 180 + phis = phis / np.pi * 180 + + return poses.numpy(), thetas.numpy(), phis.numpy(), radius.numpy() + +def GenerateCircleCameras(opt, size=8, render45 = False): + # random focal + fov = opt.default_fovy + cam_infos = [] + #generate specific data structure + for idx in range(size): + thetas = torch.FloatTensor([opt.default_polar]) + phis = torch.FloatTensor([(idx / size) * 360]) + radius = torch.FloatTensor([opt.default_radius]) + # random pose on the fly + poses = circle_poses(radius=radius, theta=thetas, phi=phis, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front) + matrix = np.linalg.inv(poses[0]) + R = -np.transpose(matrix[:3,:3]) + R[:,0] = -R[:,0] + T = -matrix[:3, 3] + fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w) + FovY = fovy + FovX = fov + + # delta polar/azimuth/radius to default view + delta_polar = thetas - opt.default_polar + delta_azimuth = phis - opt.default_azimuth + delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] + delta_radius = radius - opt.default_radius + cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w, + height = opt.image_h, delta_polar = delta_polar,delta_azimuth = delta_azimuth, delta_radius = delta_radius)) + if render45: + for idx in range(size): + thetas = torch.FloatTensor([opt.default_polar*2//3]) + phis = torch.FloatTensor([(idx / size) * 360]) + radius = torch.FloatTensor([opt.default_radius]) + # random pose on the fly + poses = circle_poses(radius=radius, theta=thetas, phi=phis, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front) + matrix = np.linalg.inv(poses[0]) + R = -np.transpose(matrix[:3,:3]) + R[:,0] = -R[:,0] + T = -matrix[:3, 3] + fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w) + FovY = fovy + FovX = fov + + # delta polar/azimuth/radius to default view + delta_polar = thetas - opt.default_polar + delta_azimuth = phis - opt.default_azimuth + delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] + delta_radius = radius - opt.default_radius + cam_infos.append(RandCameraInfo(uid=idx+size, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w, + height = opt.image_h, delta_polar = delta_polar,delta_azimuth = delta_azimuth, delta_radius = delta_radius)) + return cam_infos + + +def GenerateRandomCameras(opt, size=2000, SSAA=True): + # random pose on the fly + poses, thetas, phis, radius = rand_poses(size, opt, radius_range=opt.radius_range, theta_range=opt.theta_range, phi_range=opt.phi_range, + angle_overhead=opt.angle_overhead, angle_front=opt.angle_front, uniform_sphere_rate=opt.uniform_sphere_rate, + rand_cam_gamma=opt.rand_cam_gamma) + # delta polar/azimuth/radius to default view + delta_polar = thetas - opt.default_polar + delta_azimuth = phis - opt.default_azimuth + delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] + delta_radius = radius - opt.default_radius + # random focal + fov = random.random() * (opt.fovy_range[1] - opt.fovy_range[0]) + opt.fovy_range[0] + + cam_infos = [] + + if SSAA: + ssaa = opt.SSAA + else: + ssaa = 1 + + image_h = opt.image_h * ssaa + image_w = opt.image_w * ssaa + + #generate specific data structure + for idx in range(size): + matrix = np.linalg.inv(poses[idx]) + R = -np.transpose(matrix[:3,:3]) + R[:,0] = -R[:,0] + T = -matrix[:3, 3] + # matrix = poses[idx] + # R = matrix[:3,:3] + # T = matrix[:3, 3] + fovy = focal2fov(fov2focal(fov, image_h), image_w) + FovY = fovy + FovX = fov + + cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=image_w, + height=image_h, delta_polar = delta_polar[idx], + delta_azimuth = delta_azimuth[idx], delta_radius = delta_radius[idx])) + return cam_infos + +def GeneratePurnCameras(opt, size=300): + # random pose on the fly + poses, thetas, phis, radius = rand_poses(size, opt, radius_range=[opt.default_radius,opt.default_radius+0.1], theta_range=opt.theta_range, phi_range=opt.phi_range, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front, uniform_sphere_rate=opt.uniform_sphere_rate) + # delta polar/azimuth/radius to default view + delta_polar = thetas - opt.default_polar + delta_azimuth = phis - opt.default_azimuth + delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180] + delta_radius = radius - opt.default_radius + # random focal + #fov = random.random() * (opt.fovy_range[1] - opt.fovy_range[0]) + opt.fovy_range[0] + fov = opt.default_fovy + cam_infos = [] + #generate specific data structure + for idx in range(size): + matrix = np.linalg.inv(poses[idx]) + R = -np.transpose(matrix[:3,:3]) + R[:,0] = -R[:,0] + T = -matrix[:3, 3] + # matrix = poses[idx] + # R = matrix[:3,:3] + # T = matrix[:3, 3] + fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w) + FovY = fovy + FovX = fov + + cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w, + height = opt.image_h, delta_polar = delta_polar[idx],delta_azimuth = delta_azimuth[idx], delta_radius = delta_radius[idx])) + return cam_infos + +sceneLoadTypeCallbacks = { + # "Colmap": readColmapSceneInfo, + # "Blender" : readNerfSyntheticInfo, + "RandomCam" : readCircleCamInfo +} \ No newline at end of file diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py new file mode 100644 index 0000000000000000000000000000000000000000..4f5af969c885aef73d3c0fa505a1f98b28b073b3 --- /dev/null +++ b/scene/gaussian_model.py @@ -0,0 +1,458 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import numpy as np +from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation +from torch import nn +import os +from utils.system_utils import mkdir_p +from plyfile import PlyData, PlyElement +from utils.sh_utils import RGB2SH,SH2RGB +from simple_knn._C import distCUDA2 +from utils.graphics_utils import BasicPointCloud +from utils.general_utils import strip_symmetric, build_scaling_rotation +# from .resnet import * + +class GaussianModel: + + def setup_functions(self): + def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): + L = build_scaling_rotation(scaling_modifier * scaling, rotation) + actual_covariance = L @ L.transpose(1, 2) + symm = strip_symmetric(actual_covariance) + return symm + + self.scaling_activation = torch.exp + self.scaling_inverse_activation = torch.log + + self.covariance_activation = build_covariance_from_scaling_rotation + + self.opacity_activation = torch.sigmoid + self.inverse_opacity_activation = inverse_sigmoid + + self.rotation_activation = torch.nn.functional.normalize + + + def __init__(self, sh_degree : int): + self.active_sh_degree = 0 + self.max_sh_degree = sh_degree + self._xyz = torch.empty(0) + self._features_dc = torch.empty(0) + self._features_rest = torch.empty(0) + self._scaling = torch.empty(0) + self._rotation = torch.empty(0) + self._opacity = torch.empty(0) + self._background = torch.empty(0) + self.max_radii2D = torch.empty(0) + self.xyz_gradient_accum = torch.empty(0) + self.denom = torch.empty(0) + self.optimizer = None + self.percent_dense = 0 + self.spatial_lr_scale = 0 + self.setup_functions() + + def capture(self): + return ( + self.active_sh_degree, + self._xyz, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + self.xyz_gradient_accum, + self.denom, + self.optimizer.state_dict(), + self.spatial_lr_scale, + ) + + def restore(self, model_args, training_args): + (self.active_sh_degree, + self._xyz, + self._features_dc, + self._features_rest, + self._scaling, + self._rotation, + self._opacity, + self.max_radii2D, + xyz_gradient_accum, + denom, + opt_dict, + self.spatial_lr_scale) = model_args + self.training_setup(training_args) + self.xyz_gradient_accum = xyz_gradient_accum + self.denom = denom + self.optimizer.load_state_dict(opt_dict) + + @property + def get_scaling(self): + return self.scaling_activation(self._scaling) + + @property + def get_rotation(self): + return self.rotation_activation(self._rotation) + + @property + def get_xyz(self): + return self._xyz + + @property + def get_background(self): + return torch.sigmoid(self._background) + + @property + def get_features(self): + features_dc = self._features_dc + features_rest = self._features_rest + return torch.cat((features_dc, features_rest), dim=1) + + @property + def get_opacity(self): + return self.opacity_activation(self._opacity) + + def get_covariance(self, scaling_modifier = 1): + return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) + + def oneupSHdegree(self): + if self.active_sh_degree < self.max_sh_degree: + self.active_sh_degree += 1 + + def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): + self.spatial_lr_scale = spatial_lr_scale + fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() + fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors))).float().cuda() #RGB2SH( + features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() + features[:, :3, 0 ] = fused_color + features[:, 3:, 1:] = 0.0 + + print("Number of points at initialisation : ", fused_point_cloud.shape[0]) + + dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) + scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) + rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") + rots[:, 0] = 1 + + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + + self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) + self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) + self._scaling = nn.Parameter(scales.requires_grad_(True)) + self._rotation = nn.Parameter(rots.requires_grad_(True)) + self._opacity = nn.Parameter(opacities.requires_grad_(True)) + self._background = nn.Parameter(torch.zeros((3,1,1), device="cuda").requires_grad_(True)) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def training_setup(self, training_args): + self.percent_dense = training_args.percent_dense + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + + l = [ + {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, + {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, + {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, + {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, + {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, + {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}, + {'params': [self._background], 'lr': training_args.feature_lr, "name": "background"}, + ] + + self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) + self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, + lr_final=training_args.position_lr_final*self.spatial_lr_scale, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.iterations) + + + self.rotation_scheduler_args = get_expon_lr_func(lr_init=training_args.rotation_lr, + lr_final=training_args.rotation_lr_final, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.iterations) + + self.scaling_scheduler_args = get_expon_lr_func(lr_init=training_args.scaling_lr, + lr_final=training_args.scaling_lr_final, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.iterations) + + self.feature_scheduler_args = get_expon_lr_func(lr_init=training_args.feature_lr, + lr_final=training_args.feature_lr_final, + lr_delay_mult=training_args.position_lr_delay_mult, + max_steps=training_args.iterations) + def update_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "xyz": + lr = self.xyz_scheduler_args(iteration) + param_group['lr'] = lr + return lr + + def update_feature_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "f_dc": + lr = self.feature_scheduler_args(iteration) + param_group['lr'] = lr + return lr + + def update_rotation_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "rotation": + lr = self.rotation_scheduler_args(iteration) + param_group['lr'] = lr + return lr + + def update_scaling_learning_rate(self, iteration): + ''' Learning rate scheduling per step ''' + for param_group in self.optimizer.param_groups: + if param_group["name"] == "scaling": + lr = self.scaling_scheduler_args(iteration) + param_group['lr'] = lr + return lr + + + def construct_list_of_attributes(self): + l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] + # All channels except the 3 DC + for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): + l.append('f_dc_{}'.format(i)) + for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): + l.append('f_rest_{}'.format(i)) + l.append('opacity') + for i in range(self._scaling.shape[1]): + l.append('scale_{}'.format(i)) + for i in range(self._rotation.shape[1]): + l.append('rot_{}'.format(i)) + return l + + def save_ply(self, path): + mkdir_p(os.path.dirname(path)) + + xyz = self._xyz.detach().cpu().numpy() + normals = np.zeros_like(xyz) + f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() + opacities = self._opacity.detach().cpu().numpy() + scale = self._scaling.detach().cpu().numpy() + rotation = self._rotation.detach().cpu().numpy() + + dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] + + elements = np.empty(xyz.shape[0], dtype=dtype_full) + attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) + elements[:] = list(map(tuple, attributes)) + el = PlyElement.describe(elements, 'vertex') + PlyData([el]).write(path) + np.savetxt(os.path.join(os.path.split(path)[0],"point_cloud_rgb.txt"),np.concatenate((xyz, SH2RGB(f_dc)), axis=1)) + + def reset_opacity(self): + opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01)) + optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity") + self._opacity = optimizable_tensors["opacity"] + + def load_ply(self, path): + plydata = PlyData.read(path) + + xyz = np.stack((np.asarray(plydata.elements[0]["x"]), + np.asarray(plydata.elements[0]["y"]), + np.asarray(plydata.elements[0]["z"])), axis=1) + opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] + + features_dc = np.zeros((xyz.shape[0], 3, 1)) + features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) + features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) + features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) + + extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] + extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1])) + assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 + features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) + for idx, attr_name in enumerate(extra_f_names): + features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) + # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) + features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) + + scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] + scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1])) + scales = np.zeros((xyz.shape[0], len(scale_names))) + for idx, attr_name in enumerate(scale_names): + scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] + rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1])) + rots = np.zeros((xyz.shape[0], len(rot_names))) + for idx, attr_name in enumerate(rot_names): + rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) + + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + self.active_sh_degree = self.max_sh_degree + + def replace_tensor_to_optimizer(self, tensor, name): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if group["name"] not in ['background']: + if group["name"] == name: + stored_state = self.optimizer.state.get(group['params'][0], None) + stored_state["exp_avg"] = torch.zeros_like(tensor) + stored_state["exp_avg_sq"] = torch.zeros_like(tensor) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def _prune_optimizer(self, mask): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + stored_state = self.optimizer.state.get(group['params'][0], None) + if group["name"] not in ['background']: + if stored_state is not None: + stored_state["exp_avg"] = stored_state["exp_avg"][mask] + stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + return optimizable_tensors + + def prune_points(self, mask): + valid_points_mask = ~mask + optimizable_tensors = self._prune_optimizer(valid_points_mask) + + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] + + self.denom = self.denom[valid_points_mask] + self.max_radii2D = self.max_radii2D[valid_points_mask] + + def cat_tensors_to_optimizer(self, tensors_dict): + optimizable_tensors = {} + for group in self.optimizer.param_groups: + if group["name"] not in ['background']: + assert len(group["params"]) == 1 + extension_tensor = tensors_dict[group["name"]] + stored_state = self.optimizer.state.get(group['params'][0], None) + if stored_state is not None: + stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) + stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) + + del self.optimizer.state[group['params'][0]] + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + self.optimizer.state[group['params'][0]] = stored_state + + optimizable_tensors[group["name"]] = group["params"][0] + else: + group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) + optimizable_tensors[group["name"]] = group["params"][0] + + return optimizable_tensors + + def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation): + d = {"xyz": new_xyz, + "f_dc": new_features_dc, + "f_rest": new_features_rest, + "opacity": new_opacities, + "scaling" : new_scaling, + "rotation" : new_rotation} + + optimizable_tensors = self.cat_tensors_to_optimizer(d) + self._xyz = optimizable_tensors["xyz"] + self._features_dc = optimizable_tensors["f_dc"] + self._features_rest = optimizable_tensors["f_rest"] + self._opacity = optimizable_tensors["opacity"] + self._scaling = optimizable_tensors["scaling"] + self._rotation = optimizable_tensors["rotation"] + + self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + + def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): + n_init_points = self.get_xyz.shape[0] + # Extract points that satisfy the gradient condition + padded_grad = torch.zeros((n_init_points), device="cuda") + padded_grad[:grads.shape[0]] = grads.squeeze() + selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) + + stds = self.get_scaling[selected_pts_mask].repeat(N,1) + means =torch.zeros((stds.size(0), 3),device="cuda") + samples = torch.normal(mean=means, std=stds) + rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) + new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) + new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) + new_rotation = self._rotation[selected_pts_mask].repeat(N,1) + new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) + new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) + new_opacity = self._opacity[selected_pts_mask].repeat(N,1) + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation) + + prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) + self.prune_points(prune_filter) + + def densify_and_clone(self, grads, grad_threshold, scene_extent): + # Extract points that satisfy the gradient condition + selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) + selected_pts_mask = torch.logical_and(selected_pts_mask, + torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) + + new_xyz = self._xyz[selected_pts_mask] + new_features_dc = self._features_dc[selected_pts_mask] + new_features_rest = self._features_rest[selected_pts_mask] + new_opacities = self._opacity[selected_pts_mask] + new_scaling = self._scaling[selected_pts_mask] + new_rotation = self._rotation[selected_pts_mask] + + self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation) + + def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size): + grads = self.xyz_gradient_accum / self.denom + grads[grads.isnan()] = 0.0 + + self.densify_and_clone(grads, max_grad, extent) + self.densify_and_split(grads, max_grad, extent) + + prune_mask = (self.get_opacity < min_opacity).squeeze() + if max_screen_size: + big_points_vs = self.max_radii2D > max_screen_size + big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent + prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws) + self.prune_points(prune_mask) + + torch.cuda.empty_cache() + + def add_densification_stats(self, viewspace_point_tensor, update_filter): + self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True) + self.denom[update_filter] += 1 \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d68ab5f19938faeefea80c187c6c9d9555b351 --- /dev/null +++ b/train.py @@ -0,0 +1,553 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import random +import imageio +import os +import torch +import torch.nn as nn +from random import randint +from utils.loss_utils import l1_loss, ssim, tv_loss +from gaussian_renderer import render, network_gui +import sys +from scene import Scene, GaussianModel +from utils.general_utils import safe_state +import uuid +from tqdm import tqdm +from utils.image_utils import psnr +from argparse import ArgumentParser, Namespace +from arguments import ModelParams, PipelineParams, OptimizationParams, GenerateCamParams, GuidanceParams +import math +import yaml +from torchvision.utils import save_image +import torchvision.transforms as T + +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_FOUND = True +except ImportError: + TENSORBOARD_FOUND = False + +sys.path.append('/root/yangxin/codebase/3D_Playground/GSDF') + + +def adjust_text_embeddings(embeddings, azimuth, guidance_opt): + #TODO: add prenerg functions + text_z_list = [] + weights_list = [] + K = 0 + #for b in range(azimuth): + text_z_, weights_ = get_pos_neg_text_embeddings(embeddings, azimuth, guidance_opt) + K = max(K, weights_.shape[0]) + text_z_list.append(text_z_) + weights_list.append(weights_) + + # Interleave text_embeddings from different dirs to form a batch + text_embeddings = [] + for i in range(K): + for text_z in text_z_list: + # if uneven length, pad with the first embedding + text_embeddings.append(text_z[i] if i < len(text_z) else text_z[0]) + text_embeddings = torch.stack(text_embeddings, dim=0) # [B * K, 77, 768] + + # Interleave weights from different dirs to form a batch + weights = [] + for i in range(K): + for weights_ in weights_list: + weights.append(weights_[i] if i < len(weights_) else torch.zeros_like(weights_[0])) + weights = torch.stack(weights, dim=0) # [B * K] + return text_embeddings, weights + +def get_pos_neg_text_embeddings(embeddings, azimuth_val, opt): + if azimuth_val >= -90 and azimuth_val < 90: + if azimuth_val >= 0: + r = 1 - azimuth_val / 90 + else: + r = 1 + azimuth_val / 90 + start_z = embeddings['front'] + end_z = embeddings['side'] + # if random.random() < 0.3: + # r = r + random.gauss(0, 0.08) + pos_z = r * start_z + (1 - r) * end_z + text_z = torch.cat([pos_z, embeddings['front'], embeddings['side']], dim=0) + if r > 0.8: + front_neg_w = 0.0 + else: + front_neg_w = math.exp(-r * opt.front_decay_factor) * opt.negative_w + if r < 0.2: + side_neg_w = 0.0 + else: + side_neg_w = math.exp(-(1-r) * opt.side_decay_factor) * opt.negative_w + + weights = torch.tensor([1.0, front_neg_w, side_neg_w]) + else: + if azimuth_val >= 0: + r = 1 - (azimuth_val - 90) / 90 + else: + r = 1 + (azimuth_val + 90) / 90 + start_z = embeddings['side'] + end_z = embeddings['back'] + # if random.random() < 0.3: + # r = r + random.gauss(0, 0.08) + pos_z = r * start_z + (1 - r) * end_z + text_z = torch.cat([pos_z, embeddings['side'], embeddings['front']], dim=0) + front_neg_w = opt.negative_w + if r > 0.8: + side_neg_w = 0.0 + else: + side_neg_w = math.exp(-r * opt.side_decay_factor) * opt.negative_w / 2 + + weights = torch.tensor([1.0, side_neg_w, front_neg_w]) + return text_z, weights.to(text_z.device) + +def prepare_embeddings(guidance_opt, guidance): + embeddings = {} + # text embeddings (stable-diffusion) and (IF) + embeddings['default'] = guidance.get_text_embeds([guidance_opt.text]) + embeddings['uncond'] = guidance.get_text_embeds([guidance_opt.negative]) + + for d in ['front', 'side', 'back']: + embeddings[d] = guidance.get_text_embeds([f"{guidance_opt.text}, {d} view"]) + embeddings['inverse_text'] = guidance.get_text_embeds(guidance_opt.inverse_text) + return embeddings + +def guidance_setup(guidance_opt): + if guidance_opt.guidance=="SD": + from guidance.sd_utils import StableDiffusion + guidance = StableDiffusion(guidance_opt.g_device, guidance_opt.fp16, guidance_opt.vram_O, + guidance_opt.t_range, guidance_opt.max_t_range, + num_train_timesteps=guidance_opt.num_train_timesteps, + ddim_inv=guidance_opt.ddim_inv, + textual_inversion_path = guidance_opt.textual_inversion_path, + LoRA_path = guidance_opt.LoRA_path, + guidance_opt=guidance_opt) + else: + raise ValueError(f'{guidance_opt.guidance} not supported.') + if guidance is not None: + for p in guidance.parameters(): + p.requires_grad = False + embeddings = prepare_embeddings(guidance_opt, guidance) + return guidance, embeddings + + +def training(dataset, opt, pipe, gcams, guidance_opt, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, save_video): + first_iter = 0 + tb_writer = prepare_output_and_logger(dataset) + gaussians = GaussianModel(dataset.sh_degree) + scene = Scene(dataset, gcams, gaussians) + gaussians.training_setup(opt) + if checkpoint: + (model_params, first_iter) = torch.load(checkpoint) + gaussians.restore(model_params, opt) + + bg_color = [1, 1, 1] if dataset._white_background else [0, 0, 0] + background = torch.tensor(bg_color, dtype=torch.float32, device=dataset.data_device) + iter_start = torch.cuda.Event(enable_timing = True) + iter_end = torch.cuda.Event(enable_timing = True) + + # + save_folder = os.path.join(dataset._model_path,"train_process/") + if not os.path.exists(save_folder): + os.makedirs(save_folder) # makedirs + print('train_process is in :', save_folder) + #controlnet + use_control_net = False + #set up pretrain diffusion models and text_embedings + guidance, embeddings = guidance_setup(guidance_opt) + viewpoint_stack = None + viewpoint_stack_around = None + ema_loss_for_log = 0.0 + progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress") + first_iter += 1 + + if opt.save_process: + save_folder_proc = os.path.join(scene.args._model_path,"process_videos/") + if not os.path.exists(save_folder_proc): + os.makedirs(save_folder_proc) # makedirs + process_view_points = scene.getCircleVideoCameras(batch_size=opt.pro_frames_num,render45=opt.pro_render_45).copy() + save_process_iter = opt.iterations // len(process_view_points) + pro_img_frames = [] + + for iteration in range(first_iter, opt.iterations + 1): + #TODO: DEBUG NETWORK_GUI + if network_gui.conn == None: + network_gui.try_connect() + while network_gui.conn != None: + try: + net_image_bytes = None + custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive() + if custom_cam != None: + net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"] + net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy()) + network_gui.send(net_image_bytes, guidance_opt.text) + if do_training and ((iteration < int(opt.iterations)) or not keep_alive): + break + except Exception as e: + network_gui.conn = None + + iter_start.record() + + gaussians.update_learning_rate(iteration) + gaussians.update_feature_learning_rate(iteration) + gaussians.update_rotation_learning_rate(iteration) + gaussians.update_scaling_learning_rate(iteration) + # Every 500 its we increase the levels of SH up to a maximum degree + if iteration % 500 == 0: + gaussians.oneupSHdegree() + + # progressively relaxing view range + if not opt.use_progressive: + if iteration >= opt.progressive_view_iter and iteration % opt.scale_up_cameras_iter == 0: + scene.pose_args.fovy_range[0] = max(scene.pose_args.max_fovy_range[0], scene.pose_args.fovy_range[0] * opt.fovy_scale_up_factor[0]) + scene.pose_args.fovy_range[1] = min(scene.pose_args.max_fovy_range[1], scene.pose_args.fovy_range[1] * opt.fovy_scale_up_factor[1]) + + scene.pose_args.radius_range[1] = max(scene.pose_args.max_radius_range[1], scene.pose_args.radius_range[1] * opt.scale_up_factor) + scene.pose_args.radius_range[0] = max(scene.pose_args.max_radius_range[0], scene.pose_args.radius_range[0] * opt.scale_up_factor) + + scene.pose_args.theta_range[1] = min(scene.pose_args.max_theta_range[1], scene.pose_args.theta_range[1] * opt.phi_scale_up_factor) + scene.pose_args.theta_range[0] = max(scene.pose_args.max_theta_range[0], scene.pose_args.theta_range[0] * 1/opt.phi_scale_up_factor) + + # opt.reset_resnet_iter = max(500, opt.reset_resnet_iter // 1.25) + scene.pose_args.phi_range[0] = max(scene.pose_args.max_phi_range[0] , scene.pose_args.phi_range[0] * opt.phi_scale_up_factor) + scene.pose_args.phi_range[1] = min(scene.pose_args.max_phi_range[1], scene.pose_args.phi_range[1] * opt.phi_scale_up_factor) + + print('scale up theta_range to:', scene.pose_args.theta_range) + print('scale up radius_range to:', scene.pose_args.radius_range) + print('scale up phi_range to:', scene.pose_args.phi_range) + print('scale up fovy_range to:', scene.pose_args.fovy_range) + + # Pick a random Camera + if not viewpoint_stack: + viewpoint_stack = scene.getRandTrainCameras().copy() + + C_batch_size = guidance_opt.C_batch_size + viewpoint_cams = [] + images = [] + text_z_ = [] + weights_ = [] + depths = [] + alphas = [] + scales = [] + + text_z_inverse =torch.cat([embeddings['uncond'],embeddings['inverse_text']], dim=0) + + for i in range(C_batch_size): + try: + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + except: + viewpoint_stack = scene.getRandTrainCameras().copy() + viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1)) + + #pred text_z + azimuth = viewpoint_cam.delta_azimuth + text_z = [embeddings['uncond']] + + + if guidance_opt.perpneg: + text_z_comp, weights = adjust_text_embeddings(embeddings, azimuth, guidance_opt) + text_z.append(text_z_comp) + weights_.append(weights) + + else: + if azimuth >= -90 and azimuth < 90: + if azimuth >= 0: + r = 1 - azimuth / 90 + else: + r = 1 + azimuth / 90 + start_z = embeddings['front'] + end_z = embeddings['side'] + else: + if azimuth >= 0: + r = 1 - (azimuth - 90) / 90 + else: + r = 1 + (azimuth + 90) / 90 + start_z = embeddings['side'] + end_z = embeddings['back'] + text_z.append(r * start_z + (1 - r) * end_z) + + text_z = torch.cat(text_z, dim=0) + text_z_.append(text_z) + + # Render + if (iteration - 1) == debug_from: + pipe.debug = True + render_pkg = render(viewpoint_cam, gaussians, pipe, background, + sh_deg_aug_ratio = dataset.sh_deg_aug_ratio, + bg_aug_ratio = dataset.bg_aug_ratio, + shs_aug_ratio = dataset.shs_aug_ratio, + scale_aug_ratio = dataset.scale_aug_ratio) + image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"] + depth, alpha = render_pkg["depth"], render_pkg["alpha"] + + scales.append(render_pkg["scales"]) + images.append(image) + depths.append(depth) + alphas.append(alpha) + viewpoint_cams.append(viewpoint_cams) + + images = torch.stack(images, dim=0) + depths = torch.stack(depths, dim=0) + alphas = torch.stack(alphas, dim=0) + + # Loss + warm_up_rate = 1. - min(iteration/opt.warmup_iter,1.) + guidance_scale = guidance_opt.guidance_scale + _aslatent = False + if iteration < opt.geo_iter or random.random()< opt.as_latent_ratio: + _aslatent=True + if iteration > opt.use_control_net_iter and (random.random() < guidance_opt.controlnet_ratio): + use_control_net = True + if guidance_opt.perpneg: + loss = guidance.train_step_perpneg(torch.stack(text_z_, dim=1), images, + pred_depth=depths, pred_alpha=alphas, + grad_scale=guidance_opt.lambda_guidance, + use_control_net = use_control_net ,save_folder = save_folder, iteration = iteration, warm_up_rate=warm_up_rate, + weights = torch.stack(weights_, dim=1), resolution=(gcams.image_h, gcams.image_w), + guidance_opt=guidance_opt,as_latent=_aslatent, embedding_inverse = text_z_inverse) + else: + loss = guidance.train_step(torch.stack(text_z_, dim=1), images, + pred_depth=depths, pred_alpha=alphas, + grad_scale=guidance_opt.lambda_guidance, + use_control_net = use_control_net ,save_folder = save_folder, iteration = iteration, warm_up_rate=warm_up_rate, + resolution=(gcams.image_h, gcams.image_w), + guidance_opt=guidance_opt,as_latent=_aslatent, embedding_inverse = text_z_inverse) + #raise ValueError(f'original version not supported.') + scales = torch.stack(scales, dim=0) + + loss_scale = torch.mean(scales,dim=-1).mean() + loss_tv = tv_loss(images) + tv_loss(depths) + # loss_bin = torch.mean(torch.min(alphas - 0.0001, 1 - alphas)) + + loss = loss + opt.lambda_tv * loss_tv + opt.lambda_scale * loss_scale #opt.lambda_tv * loss_tv + opt.lambda_bin * loss_bin + opt.lambda_scale * loss_scale + + loss.backward() + iter_end.record() + + with torch.no_grad(): + # Progress bar + ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log + if opt.save_process: + if iteration % save_process_iter == 0 and len(process_view_points) > 0: + viewpoint_cam_p = process_view_points.pop(0) + render_p = render(viewpoint_cam_p, gaussians, pipe, background, test=True) + img_p = torch.clamp(render_p["render"], 0.0, 1.0) + img_p = img_p.detach().cpu().permute(1,2,0).numpy() + img_p = (img_p * 255).round().astype('uint8') + pro_img_frames.append(img_p) + + if iteration % 10 == 0: + progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"}) + progress_bar.update(10) + if iteration == opt.iterations: + progress_bar.close() + + # Log and save + training_report(tb_writer, iteration, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background)) + if (iteration in testing_iterations): + if save_video: + video_path = video_inference(iteration, scene, render, (pipe, background)) + + if (iteration in saving_iterations): + print("\n[ITER {}] Saving Gaussians".format(iteration)) + scene.save(iteration) + + # Densification + if iteration < opt.densify_until_iter: + # Keep track of max radii in image-space for pruning + gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) + gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter) + + if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0: + size_threshold = 20 if iteration > opt.opacity_reset_interval else None + gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold) + + if iteration % opt.opacity_reset_interval == 0: #or (dataset._white_background and iteration == opt.densify_from_iter) + gaussians.reset_opacity() + + # Optimizer step + if iteration < opt.iterations: + gaussians.optimizer.step() + gaussians.optimizer.zero_grad(set_to_none = True) + + if (iteration in checkpoint_iterations): + print("\n[ITER {}] Saving Checkpoint".format(iteration)) + torch.save((gaussians.capture(), iteration), scene._model_path + "/chkpnt" + str(iteration) + ".pth") + + if opt.save_process: + imageio.mimwrite(os.path.join(save_folder_proc, "video_rgb.mp4"), pro_img_frames, fps=30, quality=8) + return video_path + + +def prepare_output_and_logger(args): + if not args._model_path: + if os.getenv('OAR_JOB_ID'): + unique_str=os.getenv('OAR_JOB_ID') + else: + unique_str = str(uuid.uuid4()) + args._model_path = os.path.join("./output/", args.workspace) + + # Set up output folder + print("Output folder: {}".format(args._model_path)) + os.makedirs(args._model_path, exist_ok = True) + + # copy configs + if args.opt_path is not None: + os.system(' '.join(['cp', args.opt_path, os.path.join(args._model_path, 'config.yaml')])) + + with open(os.path.join(args._model_path, "cfg_args"), 'w') as cfg_log_f: + cfg_log_f.write(str(Namespace(**vars(args)))) + + # Create Tensorboard writer + tb_writer = None + if TENSORBOARD_FOUND: + tb_writer = SummaryWriter(args._model_path) + else: + print("Tensorboard not available: not logging progress") + return tb_writer + +def training_report(tb_writer, iteration, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs): + if tb_writer: + tb_writer.add_scalar('iter_time', elapsed, iteration) + # Report test and samples of training set + if iteration in testing_iterations: + save_folder = os.path.join(scene.args._model_path,"test_six_views/{}_iteration".format(iteration)) + if not os.path.exists(save_folder): + os.makedirs(save_folder) # makedirs 创建文件时如果路径不存在会创建这个路径 + print('test views is in :', save_folder) + torch.cuda.empty_cache() + config = ({'name': 'test', 'cameras' : scene.getTestCameras()}) + if config['cameras'] and len(config['cameras']) > 0: + for idx, viewpoint in enumerate(config['cameras']): + render_out = renderFunc(viewpoint, scene.gaussians, *renderArgs, test=True) + rgb, depth = render_out["render"],render_out["depth"] + if depth is not None: + depth_norm = depth/depth.max() + save_image(depth_norm,os.path.join(save_folder,"render_depth_{}.png".format(viewpoint.uid))) + + image = torch.clamp(rgb, 0.0, 1.0) + save_image(image,os.path.join(save_folder,"render_view_{}.png".format(viewpoint.uid))) + if tb_writer: + tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.uid), image[None], global_step=iteration) + print("\n[ITER {}] Eval Done!".format(iteration)) + if tb_writer: + tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration) + tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration) + torch.cuda.empty_cache() + +def video_inference(iteration, scene : Scene, renderFunc, renderArgs): + sharp = T.RandomAdjustSharpness(3, p=1.0) + + save_folder = os.path.join(scene.args._model_path,"videos/{}_iteration".format(iteration)) + if not os.path.exists(save_folder): + os.makedirs(save_folder) # makedirs + print('videos is in :', save_folder) + torch.cuda.empty_cache() + config = ({'name': 'test', 'cameras' : scene.getCircleVideoCameras()}) + if config['cameras'] and len(config['cameras']) > 0: + img_frames = [] + depth_frames = [] + print("Generating Video using", len(config['cameras']), "different view points") + for idx, viewpoint in enumerate(config['cameras']): + render_out = renderFunc(viewpoint, scene.gaussians, *renderArgs, test=True) + rgb,depth = render_out["render"],render_out["depth"] + if depth is not None: + depth_norm = depth/depth.max() + depths = torch.clamp(depth_norm, 0.0, 1.0) + depths = depths.detach().cpu().permute(1,2,0).numpy() + depths = (depths * 255).round().astype('uint8') + depth_frames.append(depths) + + image = torch.clamp(rgb, 0.0, 1.0) + image = image.detach().cpu().permute(1,2,0).numpy() + image = (image * 255).round().astype('uint8') + img_frames.append(image) + #save_image(image,os.path.join(save_folder,"lora_view_{}.jpg".format(viewpoint.uid))) + # Img to Numpy + imageio.mimwrite(os.path.join(save_folder, "video_rgb_{}.mp4".format(iteration)), img_frames, fps=30, quality=8) + if len(depth_frames) > 0: + imageio.mimwrite(os.path.join(save_folder, "video_depth_{}.mp4".format(iteration)), depth_frames, fps=30, quality=8) + print("\n[ITER {}] Video Save Done!".format(iteration)) + torch.cuda.empty_cache() + return os.path.join(save_folder, "video_rgb_{}.mp4".format(iteration)) + +def args_parser(default_opt=None): + # Set up command line argument parser + parser = ArgumentParser(description="Training script parameters") + + parser.add_argument('--opt', type=str, default=default_opt) + parser.add_argument('--ip', type=str, default="127.0.0.1") + parser.add_argument('--port', type=int, default=6009) + parser.add_argument('--debug_from', type=int, default=-1) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--detect_anomaly', action='store_true', default=False) + parser.add_argument("--test_ratio", type=int, default=5) # [2500,5000,7500,10000,12000] + parser.add_argument("--save_ratio", type=int, default=2) # [10000,12000] + parser.add_argument("--save_video", type=bool, default=False) + parser.add_argument("--quiet", action="store_true") + parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[]) + parser.add_argument("--start_checkpoint", type=str, default = None) + parser.add_argument("--cuda", type=str, default='0') + + lp = ModelParams(parser) + op = OptimizationParams(parser) + pp = PipelineParams(parser) + gcp = GenerateCamParams(parser) + gp = GuidanceParams(parser) + + args = parser.parse_args(sys.argv[1:]) + + os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda + if args.opt is not None: + with open(args.opt) as f: + opts = yaml.load(f, Loader=yaml.FullLoader) + lp.load_yaml(opts.get('ModelParams', None)) + op.load_yaml(opts.get('OptimizationParams', None)) + pp.load_yaml(opts.get('PipelineParams', None)) + gcp.load_yaml(opts.get('GenerateCamParams', None)) + gp.load_yaml(opts.get('GuidanceParams', None)) + + lp.opt_path = args.opt + args.port = opts['port'] + args.save_video = opts.get('save_video', True) + args.seed = opts.get('seed', 0) + args.device = opts.get('device', 'cuda') + + # override device + gp.g_device = args.device + lp.data_device = args.device + gcp.device = args.device + return args, lp, op, pp, gcp, gp + +def start_training(args, lp, op, pp, gcp, gp): + # save iterations + test_iter = [1] + [k * op.iterations // args.test_ratio for k in range(1, args.test_ratio)] + [op.iterations] + args.test_iterations = test_iter + + save_iter = [k * op.iterations // args.save_ratio for k in range(1, args.save_ratio)] + [op.iterations] + args.save_iterations = save_iter + + print('Test iter:', args.test_iterations) + print('Save iter:', args.save_iterations) + + print("Optimizing " + lp._model_path) + + # Initialize system state (RNG) + safe_state(args.quiet, seed=args.seed) + # Start GUI server, configure and run training + network_gui.init(args.ip, args.port) + torch.autograd.set_detect_anomaly(args.detect_anomaly) + video_path = training(lp, op, pp, gcp, gp, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.save_video) + # All done + print("\nTraining complete.") + return video_path + +if __name__ == "__main__": + args, lp, op, pp, gcp, gp = args_parser() + start_training(args, lp, op, pp, gcp, gp) diff --git a/train.sh b/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..cfa1a8c7a8daf93e77ca27b75d39da87cc18cffd --- /dev/null +++ b/train.sh @@ -0,0 +1 @@ +python train.py --opt 'configs/bagel.yaml' --cuda 4 \ No newline at end of file diff --git a/utils/camera_utils.py b/utils/camera_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5aba9884a3d709089e2b31d6d12f7d14940993b4 --- /dev/null +++ b/utils/camera_utils.py @@ -0,0 +1,98 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from scene.cameras import Camera, RCamera +import numpy as np +from utils.general_utils import PILtoTorch +from utils.graphics_utils import fov2focal + +WARNED = False + +def loadCam(args, id, cam_info, resolution_scale): + orig_w, orig_h = cam_info.image.size + + if args.resolution in [1, 2, 4, 8]: + resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution)) + else: # should be a type that converts to float + if args.resolution == -1: + if orig_w > 1600: + global WARNED + if not WARNED: + print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n " + "If this is not desired, please explicitly specify '--resolution/-r' as 1") + WARNED = True + global_down = orig_w / 1600 + else: + global_down = 1 + else: + global_down = orig_w / args.resolution + + scale = float(global_down) * float(resolution_scale) + resolution = (int(orig_w / scale), int(orig_h / scale)) + + resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + gt_image = resized_image_rgb[:3, ...] + loaded_mask = None + + if resized_image_rgb.shape[1] == 4: + loaded_mask = resized_image_rgb[3:4, ...] + + return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, + image=gt_image, gt_alpha_mask=loaded_mask, + image_name=cam_info.image_name, uid=id, data_device=args.data_device) + + +def loadRandomCam(opt, id, cam_info, resolution_scale, SSAA=False): + return RCamera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, + FoVx=cam_info.FovX, FoVy=cam_info.FovY, delta_polar=cam_info.delta_polar, + delta_azimuth=cam_info.delta_azimuth , delta_radius=cam_info.delta_radius, opt=opt, + uid=id, data_device=opt.device, SSAA=SSAA) + +def cameraList_from_camInfos(cam_infos, resolution_scale, args): + camera_list = [] + + for id, c in enumerate(cam_infos): + camera_list.append(loadCam(args, id, c, resolution_scale)) + + return camera_list + + +def cameraList_from_RcamInfos(cam_infos, resolution_scale, opt, SSAA=False): + camera_list = [] + + for id, c in enumerate(cam_infos): + camera_list.append(loadRandomCam(opt, id, c, resolution_scale, SSAA=SSAA)) + + return camera_list + +def camera_to_JSON(id, camera : Camera): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = camera.R.transpose() + Rt[:3, 3] = camera.T + Rt[3, 3] = 1.0 + + W2C = np.linalg.inv(Rt) + pos = W2C[:3, 3] + rot = W2C[:3, :3] + serializable_array_2d = [x.tolist() for x in rot] + camera_entry = { + 'id' : id, + 'img_name' : id, + 'width' : camera.width, + 'height' : camera.height, + 'position': pos.tolist(), + 'rotation': serializable_array_2d, + 'fy' : fov2focal(camera.FovY, camera.height), + 'fx' : fov2focal(camera.FovX, camera.width) + } + return camera_entry diff --git a/utils/general_utils.py b/utils/general_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b0edde1a02b8014e75d208c4607fb08db880cbe0 --- /dev/null +++ b/utils/general_utils.py @@ -0,0 +1,141 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import sys +from datetime import datetime +import numpy as np +import random + +def inverse_sigmoid(x): + return torch.log(x/(1-x)) + +def inverse_sigmoid_np(x): + return np.log(x/(1-x)) + +def PILtoTorch(pil_image, resolution): + resized_image_PIL = pil_image.resize(resolution) + resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + if len(resized_image.shape) == 3: + return resized_image.permute(2, 0, 1) + else: + return resized_image.unsqueeze(dim=-1).permute(2, 0, 1) + +def get_expon_lr_func( + lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 +): + """ + Copied from Plenoxels + + Continuous learning rate decay function. Adapted from JaxNeRF + The returned rate is lr_init when step=0 and lr_final when step=max_steps, and + is log-linearly interpolated elsewhere (equivalent to exponential decay). + If lr_delay_steps>0 then the learning rate will be scaled by some smooth + function of lr_delay_mult, such that the initial learning rate is + lr_init*lr_delay_mult at the beginning of optimization but will be eased back + to the normal learning rate when steps>lr_delay_steps. + :param conf: config subtree 'lr' or similar + :param max_steps: int, the number of steps during optimization. + :return HoF which takes step as input + """ + + def helper(step): + if step < 0 or (lr_init == 0.0 and lr_final == 0.0): + # Disable this parameter + return 0.0 + if lr_delay_steps > 0: + # A kind of reverse cosine decay. + delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( + 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) + ) + else: + delay_rate = 1.0 + t = np.clip(step / max_steps, 0, 1) + log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) + return delay_rate * log_lerp + + return helper + +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") + + uncertainty[:, 0] = L[:, 0, 0] + uncertainty[:, 1] = L[:, 0, 1] + uncertainty[:, 2] = L[:, 0, 2] + uncertainty[:, 3] = L[:, 1, 1] + uncertainty[:, 4] = L[:, 1, 2] + uncertainty[:, 5] = L[:, 2, 2] + return uncertainty + +def strip_symmetric(sym): + return strip_lowerdiag(sym) + +def build_rotation(r): + norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) + + q = r / norm[:, None] + + R = torch.zeros((q.size(0), 3, 3), device='cuda') + + r = q[:, 0] + x = q[:, 1] + y = q[:, 2] + z = q[:, 3] + + R[:, 0, 0] = 1 - 2 * (y*y + z*z) + R[:, 0, 1] = 2 * (x*y - r*z) + R[:, 0, 2] = 2 * (x*z + r*y) + R[:, 1, 0] = 2 * (x*y + r*z) + R[:, 1, 1] = 1 - 2 * (x*x + z*z) + R[:, 1, 2] = 2 * (y*z - r*x) + R[:, 2, 0] = 2 * (x*z - r*y) + R[:, 2, 1] = 2 * (y*z + r*x) + R[:, 2, 2] = 1 - 2 * (x*x + y*y) + return R + +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") + R = build_rotation(r) + + L[:,0,0] = s[:,0] + L[:,1,1] = s[:,1] + L[:,2,2] = s[:,2] + + L = R @ L + return L + +def safe_state(silent, seed=0): + old_f = sys.stdout + class F: + def __init__(self, silent): + self.silent = silent + + def write(self, x): + if not self.silent: + if x.endswith("\n"): + old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S"))))) + else: + old_f.write(x) + + def flush(self): + old_f.flush() + + sys.stdout = F(silent) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # if seed == 0: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + # torch.cuda.set_device(torch.device("cuda:0")) diff --git a/utils/graphics_utils.py b/utils/graphics_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a77b2ebd09f774bf38648cf7fc6d5a9dfe170c5a --- /dev/null +++ b/utils/graphics_utils.py @@ -0,0 +1,81 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import math +import numpy as np +from typing import NamedTuple +import torch.nn as nn +import torch.nn.functional as F +from torch.functional import norm + + +class BasicPointCloud(NamedTuple): + points : np.array + colors : np.array + normals : np.array + +def geom_transform_points(points, transf_matrix): + P, _ = points.shape + ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) + points_hom = torch.cat([points, ones], dim=1) + points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) + + denom = points_out[..., 3:] + 0.0000001 + return (points_out[..., :3] / denom).squeeze(dim=0) + +def getWorld2View(R, t): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + return np.float32(Rt) + +def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): + Rt = np.zeros((4, 4)) + Rt[:3, :3] = R.transpose() + Rt[:3, 3] = t + Rt[3, 3] = 1.0 + + C2W = np.linalg.inv(Rt) + cam_center = C2W[:3, 3] + cam_center = (cam_center + translate) * scale + C2W[:3, 3] = cam_center + Rt = np.linalg.inv(C2W) + return np.float32(Rt) + +def getProjectionMatrix(znear, zfar, fovX, fovY): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + +def fov2focal(fov, pixels): + return pixels / (2 * math.tan(fov / 2)) + +def focal2fov(focal, pixels): + return 2*math.atan(pixels/(2*focal)) diff --git a/utils/image_utils.py b/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cdeaa1b6d250e549181ab165070f82ccd31b3eb9 --- /dev/null +++ b/utils/image_utils.py @@ -0,0 +1,19 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch + +def mse(img1, img2): + return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + +def psnr(img1, img2): + mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True) + return 20 * torch.log10(1.0 / torch.sqrt(mse)) diff --git a/utils/loss_utils.py b/utils/loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f4571dc5898809d88aa718d1612636528fda07ba --- /dev/null +++ b/utils/loss_utils.py @@ -0,0 +1,79 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp + +def l1_loss(network_output, gt): + return torch.abs((network_output - gt)).mean() + +def l2_loss(network_output, gt): + return ((network_output - gt) ** 2).mean() + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + +def create_window(window_size, channel, sigma=1.5): + _1D_window = gaussian(window_size, sigma).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def ssim(img1, img2, window_size=11, sigma=1.5, size_average=True, reduce=True): + channel = img1.size(-3) + window = create_window(window_size, channel, sigma) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average, reduce) + +def _ssim(img1, img2, window, window_size, channel, size_average=True, reduce=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if reduce: + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + else: + return ssim_map + +def _tensor_size(t): + return t.size()[1]*t.size()[2]*t.size()[3] + +def tv_loss(x): + batch_size = x.size()[0] + h_x = x.size()[2] + w_x = x.size()[3] + count_h = _tensor_size(x[:,:,1:,:]) + count_w = _tensor_size(x[:,:,:,1:]) + h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() + w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() + return 2*(h_tv/count_h+w_tv/count_w)/batch_size \ No newline at end of file diff --git a/utils/pointe_utils.py b/utils/pointe_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db19fb97fc39880a2ca75f9beb2ed06418cb517d --- /dev/null +++ b/utils/pointe_utils.py @@ -0,0 +1,44 @@ +import torch +from tqdm.auto import tqdm + +from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config +from point_e.diffusion.sampler import PointCloudSampler +from point_e.models.download import load_checkpoint +from point_e.models.configs import MODEL_CONFIGS, model_from_config +from point_e.util.plotting import plot_point_cloud +import numpy as np + +def init_from_pointe(prompt): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print('creating base model...') + base_name = 'base40M-textvec' + base_model = model_from_config(MODEL_CONFIGS[base_name], device) + base_model.eval() + base_diffusion = diffusion_from_config(DIFFUSION_CONFIGS[base_name]) + print('creating upsample model...') + upsampler_model = model_from_config(MODEL_CONFIGS['upsample'], device) + upsampler_model.eval() + upsampler_diffusion = diffusion_from_config(DIFFUSION_CONFIGS['upsample']) + print('downloading base checkpoint...') + base_model.load_state_dict(load_checkpoint(base_name, device)) + print('downloading upsampler checkpoint...') + upsampler_model.load_state_dict(load_checkpoint('upsample', device)) + sampler = PointCloudSampler( + device=device, + models=[base_model, upsampler_model], + diffusions=[base_diffusion, upsampler_diffusion], + num_points=[1024, 4096 - 1024], + aux_channels=['R', 'G', 'B'], + guidance_scale=[3.0, 0.0], + model_kwargs_key_filter=('texts', ''), # Do not condition the upsampler at all + ) + # Produce a sample from the model. + samples = None + for x in tqdm(sampler.sample_batch_progressive(batch_size=1, model_kwargs=dict(texts=[prompt]))): + samples = x + + pc = sampler.output_to_point_clouds(samples)[0] + xyz = pc.coords + rgb = np.zeros_like(xyz) + rgb[:,0],rgb[:,1],rgb[:,2] = pc.channels['R'],pc.channels['G'],pc.channels['B'] + return xyz,rgb \ No newline at end of file diff --git a/utils/sh_utils.py b/utils/sh_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bbca7d192aa3a7edf8c5b2d24dee535eac765785 --- /dev/null +++ b/utils/sh_utils.py @@ -0,0 +1,118 @@ +# Copyright 2021 The PlenOctree Authors. +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +import torch + +C0 = 0.28209479177387814 +C1 = 0.4886025119029199 +C2 = [ + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 +] +C3 = [ + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 +] +C4 = [ + 2.5033429417967046, + -1.7701307697799304, + 0.9461746957575601, + -0.6690465435572892, + 0.10578554691520431, + -0.6690465435572892, + 0.47308734787878004, + -1.7701307697799304, + 0.6258357354491761, +] + + +def eval_sh(deg, sh, dirs): + """ + Evaluate spherical harmonics at unit directions + using hardcoded SH polynomials. + Works with torch/np/jnp. + ... Can be 0 or more batch dimensions. + Args: + deg: int SH deg. Currently, 0-3 supported + sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] + dirs: jnp.ndarray unit directions [..., 3] + Returns: + [..., C] + """ + assert deg <= 4 and deg >= 0 + coeff = (deg + 1) ** 2 + assert sh.shape[-1] >= coeff + + result = C0 * sh[..., 0] + if deg > 0: + x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] + result = (result - + C1 * y * sh[..., 1] + + C1 * z * sh[..., 2] - + C1 * x * sh[..., 3]) + + if deg > 1: + xx, yy, zz = x * x, y * y, z * z + xy, yz, xz = x * y, y * z, x * z + result = (result + + C2[0] * xy * sh[..., 4] + + C2[1] * yz * sh[..., 5] + + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + + C2[3] * xz * sh[..., 7] + + C2[4] * (xx - yy) * sh[..., 8]) + + if deg > 2: + result = (result + + C3[0] * y * (3 * xx - yy) * sh[..., 9] + + C3[1] * xy * z * sh[..., 10] + + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + + C3[5] * z * (xx - yy) * sh[..., 14] + + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) + + if deg > 3: + result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) + return result + +def RGB2SH(rgb): + return (rgb - 0.5) / C0 + +def SH2RGB(sh): + return sh * C0 + 0.5 \ No newline at end of file diff --git a/utils/system_utils.py b/utils/system_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..90ca6d7f77610c967affe313398777cd86920e8e --- /dev/null +++ b/utils/system_utils.py @@ -0,0 +1,28 @@ +# +# Copyright (C) 2023, Inria +# GRAPHDECO research group, https://team.inria.fr/graphdeco +# All rights reserved. +# +# This software is free for non-commercial, research and evaluation use +# under the terms of the LICENSE.md file. +# +# For inquiries contact george.drettakis@inria.fr +# + +from errno import EEXIST +from os import makedirs, path +import os + +def mkdir_p(folder_path): + # Creates a directory. equivalent to using mkdir -p on the command line + try: + makedirs(folder_path) + except OSError as exc: # Python >2.5 + if exc.errno == EEXIST and path.isdir(folder_path): + pass + else: + raise + +def searchForMaxIteration(folder): + saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] + return max(saved_iters)