haodongli commited on
Commit
916b126
1 Parent(s): 3e15bf8
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +5 -0
  2. .gitignore +6 -0
  3. .gitmodules +6 -0
  4. GAUSSIAN_SPLATTING_LICENSE.md +83 -0
  5. LICENSE.txt +21 -0
  6. arguments/__init__.py +258 -0
  7. configs/axe.yaml +76 -0
  8. configs/bagel.yaml +74 -0
  9. configs/cat_armor.yaml +74 -0
  10. configs/crown.yaml +74 -0
  11. configs/football_helmet.yaml +75 -0
  12. configs/hamburger.yaml +75 -0
  13. configs/ts_lora.yaml +76 -0
  14. configs/white_hair_ironman.yaml +73 -0
  15. configs/zombie_joker.yaml +75 -0
  16. environment.yml +29 -0
  17. example/Donut.mp4 +3 -0
  18. example/boots.mp4 +3 -0
  19. example/durian.mp4 +3 -0
  20. example/pillow_huskies.mp4 +3 -0
  21. example/wooden_car.mp4 +3 -0
  22. gaussian_renderer/__init__.py +168 -0
  23. gaussian_renderer/network_gui.py +95 -0
  24. gradio_demo.py +62 -0
  25. guidance/perpneg_utils.py +48 -0
  26. guidance/sd_step.py +264 -0
  27. guidance/sd_utils.py +487 -0
  28. lora_diffusion/__init__.py +5 -0
  29. lora_diffusion/cli_lora_add.py +187 -0
  30. lora_diffusion/cli_lora_pti.py +1040 -0
  31. lora_diffusion/cli_pt_to_safetensors.py +85 -0
  32. lora_diffusion/cli_svd.py +146 -0
  33. lora_diffusion/dataset.py +311 -0
  34. lora_diffusion/lora.py +1110 -0
  35. lora_diffusion/lora_manager.py +144 -0
  36. lora_diffusion/preprocess_files.py +327 -0
  37. lora_diffusion/safe_open.py +68 -0
  38. lora_diffusion/to_ckpt_v2.py +232 -0
  39. lora_diffusion/utils.py +214 -0
  40. lora_diffusion/xformers_utils.py +70 -0
  41. scene/__init__.py +98 -0
  42. scene/cameras.py +138 -0
  43. scene/dataset_readers.py +466 -0
  44. scene/gaussian_model.py +458 -0
  45. train.py +553 -0
  46. train.sh +1 -0
  47. utils/camera_utils.py +98 -0
  48. utils/general_utils.py +141 -0
  49. utils/graphics_utils.py +81 -0
  50. utils/image_utils.py +19 -0
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example/boots.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ example/Donut.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ example/durian.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ example/pillow_huskies.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ example/wooden_car.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pyc
2
+ .vscode
3
+ output
4
+ build
5
+ output/
6
+ point_e_model_cache/
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "submodules/diff-gaussian-rasterization"]
2
+ path = submodules/diff-gaussian-rasterization
3
+ url = https://github.com/YixunLiang/diff-gaussian-rasterization.git
4
+ [submodule "submodules/simple-knn"]
5
+ path = submodules/simple-knn
6
+ url = https://github.com/YixunLiang/simple-knn.git
GAUSSIAN_SPLATTING_LICENSE.md ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Gaussian-Splatting License
2
+ ===========================
3
+
4
+ **Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
5
+ The *Software* is in the process of being registered with the Agence pour la Protection des
6
+ Programmes (APP).
7
+
8
+ The *Software* is still being developed by the *Licensor*.
9
+
10
+ *Licensor*'s goal is to allow the research community to use, test and evaluate
11
+ the *Software*.
12
+
13
+ ## 1. Definitions
14
+
15
+ *Licensee* means any person or entity that uses the *Software* and distributes
16
+ its *Work*.
17
+
18
+ *Licensor* means the owners of the *Software*, i.e Inria and MPII
19
+
20
+ *Software* means the original work of authorship made available under this
21
+ License ie gaussian-splatting.
22
+
23
+ *Work* means the *Software* and any additions to or derivative works of the
24
+ *Software* that are made available under this License.
25
+
26
+
27
+ ## 2. Purpose
28
+ This license is intended to define the rights granted to the *Licensee* by
29
+ Licensors under the *Software*.
30
+
31
+ ## 3. Rights granted
32
+
33
+ For the above reasons Licensors have decided to distribute the *Software*.
34
+ Licensors grant non-exclusive rights to use the *Software* for research purposes
35
+ to research users (both academic and industrial), free of charge, without right
36
+ to sublicense.. The *Software* may be used "non-commercially", i.e., for research
37
+ and/or evaluation purposes only.
38
+
39
+ Subject to the terms and conditions of this License, you are granted a
40
+ non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
41
+ publicly display, publicly perform and distribute its *Work* and any resulting
42
+ derivative works in any form.
43
+
44
+ ## 4. Limitations
45
+
46
+ **4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
47
+ so under this License, (b) you include a complete copy of this License with
48
+ your distribution, and (c) you retain without modification any copyright,
49
+ patent, trademark, or attribution notices that are present in the *Work*.
50
+
51
+ **4.2 Derivative Works.** You may specify that additional or different terms apply
52
+ to the use, reproduction, and distribution of your derivative works of the *Work*
53
+ ("Your Terms") only if (a) Your Terms provide that the use limitation in
54
+ Section 2 applies to your derivative works, and (b) you identify the specific
55
+ derivative works that are subject to Your Terms. Notwithstanding Your Terms,
56
+ this License (including the redistribution requirements in Section 3.1) will
57
+ continue to apply to the *Work* itself.
58
+
59
+ **4.3** Any other use without of prior consent of Licensors is prohibited. Research
60
+ users explicitly acknowledge having received from Licensors all information
61
+ allowing to appreciate the adequacy between of the *Software* and their needs and
62
+ to undertake all necessary precautions for its execution and use.
63
+
64
+ **4.4** The *Software* is provided both as a compiled library file and as source
65
+ code. In case of using the *Software* for a publication or other results obtained
66
+ through the use of the *Software*, users are strongly encouraged to cite the
67
+ corresponding publications as explained in the documentation of the *Software*.
68
+
69
+ ## 5. Disclaimer
70
+
71
+ THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
72
+ WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
73
+ UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
74
+ CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
75
+ OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
76
+ USE, PROFESSIONAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
77
+ ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
78
+ AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
79
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
80
+ GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
81
+ HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
82
+ LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
83
+ IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
LICENSE.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 dreamgaussian
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
arguments/__init__.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from argparse import ArgumentParser, Namespace
13
+ import sys
14
+ import os
15
+
16
+ class GroupParams:
17
+ pass
18
+
19
+ class ParamGroup:
20
+ def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
21
+ group = parser.add_argument_group(name)
22
+ for key, value in vars(self).items():
23
+ shorthand = False
24
+ if key.startswith("_"):
25
+ shorthand = True
26
+ key = key[1:]
27
+ t = type(value)
28
+ value = value if not fill_none else None
29
+ if shorthand:
30
+ if t == bool:
31
+ group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
32
+ else:
33
+ group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
34
+ else:
35
+ if t == bool:
36
+ group.add_argument("--" + key, default=value, action="store_true")
37
+ else:
38
+ group.add_argument("--" + key, default=value, type=t)
39
+
40
+ def extract(self, args):
41
+ group = GroupParams()
42
+ for arg in vars(args).items():
43
+ if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
44
+ setattr(group, arg[0], arg[1])
45
+ return group
46
+
47
+ def load_yaml(self, opts=None):
48
+ if opts is None:
49
+ return
50
+ else:
51
+ for key, value in opts.items():
52
+ try:
53
+ setattr(self, key, value)
54
+ except:
55
+ raise Exception(f'Unknown attribute {key}')
56
+
57
+ class GuidanceParams(ParamGroup):
58
+ def __init__(self, parser, opts=None):
59
+ self.guidance = "SD"
60
+ self.g_device = "cuda"
61
+
62
+ self.model_key = None
63
+ self.is_safe_tensor = False
64
+ self.base_model_key = None
65
+
66
+ self.controlnet_model_key = None
67
+
68
+ self.perpneg = True
69
+ self.negative_w = -2.
70
+ self.front_decay_factor = 2.
71
+ self.side_decay_factor = 10.
72
+
73
+ self.vram_O = False
74
+ self.fp16 = True
75
+ self.hf_key = None
76
+ self.t_range = [0.02, 0.5]
77
+ self.max_t_range = 0.98
78
+
79
+ self.scheduler_type = 'DDIM'
80
+ self.num_train_timesteps = None
81
+
82
+ self.sds = False
83
+ self.fix_noise = False
84
+ self.noise_seed = 0
85
+
86
+ self.ddim_inv = False
87
+ self.delta_t = 80
88
+ self.delta_t_start = 100
89
+ self.annealing_intervals = True
90
+ self.text = ''
91
+ self.inverse_text = ''
92
+ self.textual_inversion_path = None
93
+ self.LoRA_path = None
94
+ self.controlnet_ratio = 0.5
95
+ self.negative = ""
96
+ self.guidance_scale = 7.5
97
+ self.denoise_guidance_scale = 1.0
98
+ self.lambda_guidance = 1.
99
+
100
+ self.xs_delta_t = 200
101
+ self.xs_inv_steps = 5
102
+ self.xs_eta = 0.0
103
+
104
+ # multi-batch
105
+ self.C_batch_size = 1
106
+
107
+ self.vis_interval = 100
108
+
109
+ super().__init__(parser, "Guidance Model Parameters")
110
+
111
+
112
+ class ModelParams(ParamGroup):
113
+ def __init__(self, parser, sentinel=False, opts=None):
114
+ self.sh_degree = 0
115
+ self._source_path = ""
116
+ self._model_path = ""
117
+ self.pretrained_model_path = None
118
+ self._images = "images"
119
+ self.workspace = "debug"
120
+ self.batch = 10
121
+ self._resolution = -1
122
+ self._white_background = True
123
+ self.data_device = "cuda"
124
+ self.eval = False
125
+ self.opt_path = None
126
+
127
+ # augmentation
128
+ self.sh_deg_aug_ratio = 0.1
129
+ self.bg_aug_ratio = 0.5
130
+ self.shs_aug_ratio = 0.0
131
+ self.scale_aug_ratio = 1.0
132
+ super().__init__(parser, "Loading Parameters", sentinel)
133
+
134
+ def extract(self, args):
135
+ g = super().extract(args)
136
+ g.source_path = os.path.abspath(g.source_path)
137
+ return g
138
+
139
+
140
+ class PipelineParams(ParamGroup):
141
+ def __init__(self, parser, opts=None):
142
+ self.convert_SHs_python = False
143
+ self.compute_cov3D_python = False
144
+ self.debug = False
145
+ super().__init__(parser, "Pipeline Parameters")
146
+
147
+
148
+ class OptimizationParams(ParamGroup):
149
+ def __init__(self, parser, opts=None):
150
+ self.iterations = 5000# 10_000
151
+ self.position_lr_init = 0.00016
152
+ self.position_lr_final = 0.0000016
153
+ self.position_lr_delay_mult = 0.01
154
+ self.position_lr_max_steps = 30_000
155
+ self.feature_lr = 0.0050
156
+ self.feature_lr_final = 0.0030
157
+
158
+ self.opacity_lr = 0.05
159
+ self.scaling_lr = 0.005
160
+ self.rotation_lr = 0.001
161
+
162
+
163
+ self.geo_iter = 0
164
+ self.as_latent_ratio = 0.2
165
+ # dense
166
+
167
+ self.resnet_lr = 1e-4
168
+ self.resnet_lr_init = 2e-3
169
+ self.resnet_lr_final = 5e-5
170
+
171
+
172
+ self.scaling_lr_final = 0.001
173
+ self.rotation_lr_final = 0.0002
174
+
175
+ self.percent_dense = 0.003
176
+ self.densify_grad_threshold = 0.00075
177
+
178
+ self.lambda_tv = 1.0 # 0.1
179
+ self.lambda_bin = 10.0
180
+ self.lambda_scale = 1.0
181
+ self.lambda_sat = 1.0
182
+ self.lambda_radius = 1.0
183
+ self.densification_interval = 100
184
+ self.opacity_reset_interval = 300
185
+ self.densify_from_iter = 100
186
+ self.densify_until_iter = 30_00
187
+
188
+ self.use_control_net_iter = 10000000
189
+ self.warmup_iter = 1500
190
+
191
+ self.use_progressive = False
192
+ self.save_process = True
193
+ self.pro_frames_num = 600
194
+ self.pro_render_45 = False
195
+ self.progressive_view_iter = 500
196
+ self.progressive_view_init_ratio = 0.2
197
+
198
+ self.scale_up_cameras_iter = 500
199
+ self.scale_up_factor = 0.95
200
+ self.fovy_scale_up_factor = [0.75, 1.1]
201
+ self.phi_scale_up_factor = 1.5
202
+ super().__init__(parser, "Optimization Parameters")
203
+
204
+
205
+ class GenerateCamParams(ParamGroup):
206
+ def __init__(self, parser):
207
+ self.init_shape = 'sphere'
208
+ self.init_prompt = ''
209
+ self.use_pointe_rgb = False
210
+ self.radius_range = [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
211
+ self.max_radius_range = [3.5, 5.0]
212
+ self.default_radius = 3.5
213
+ self.theta_range = [45, 105]
214
+ self.max_theta_range = [45, 105]
215
+ self.phi_range = [-180, 180]
216
+ self.max_phi_range = [-180, 180]
217
+ self.fovy_range = [0.32, 0.60] #[0.3, 1.5] #[0.5, 0.8] #[10, 30]
218
+ self.max_fovy_range = [0.16, 0.60]
219
+ self.rand_cam_gamma = 1.0
220
+ self.angle_overhead = 30
221
+ self.angle_front =60
222
+ self.render_45 = True
223
+ self.uniform_sphere_rate = 0
224
+ self.image_w = 512
225
+ self.image_h = 512 # 512
226
+ self.SSAA = 1
227
+ self.init_num_pts = 100_000
228
+ self.default_polar = 90
229
+ self.default_azimuth = 0
230
+ self.default_fovy = 0.55 #20
231
+ self.jitter_pose = True
232
+ self.jitter_center = 0.05
233
+ self.jitter_target = 0.05
234
+ self.jitter_up = 0.01
235
+ self.device = "cuda"
236
+ super().__init__(parser, "Generate Cameras Parameters")
237
+
238
+ def get_combined_args(parser : ArgumentParser):
239
+ cmdlne_string = sys.argv[1:]
240
+ cfgfile_string = "Namespace()"
241
+ args_cmdline = parser.parse_args(cmdlne_string)
242
+
243
+ try:
244
+ cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
245
+ print("Looking for config file in", cfgfilepath)
246
+ with open(cfgfilepath) as cfg_file:
247
+ print("Config file found: {}".format(cfgfilepath))
248
+ cfgfile_string = cfg_file.read()
249
+ except TypeError:
250
+ print("Config file not found at")
251
+ pass
252
+ args_cfgfile = eval(cfgfile_string)
253
+
254
+ merged_dict = vars(args_cfgfile).copy()
255
+ for k,v in vars(args_cmdline).items():
256
+ if v != None:
257
+ merged_dict[k] = v
258
+ return Namespace(**merged_dict)
configs/axe.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: viking_axe
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'Viking axe, fantasy, weapon, blender, 8k, HDR.'
15
+ negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.'
16
+ inverse_text: ''
17
+ perpneg: false
18
+ C_batch_size: 4
19
+
20
+ t_range: [0.02, 0.5]
21
+ max_t_range: 0.98
22
+ lambda_guidance: 0.1
23
+ guidance_scale: 7.5
24
+ denoise_guidance_scale: 1.0
25
+ noise_seed: 0
26
+
27
+ ddim_inv: true
28
+ accum: false
29
+ annealing_intervals: true
30
+
31
+ xs_delta_t: 200
32
+ xs_inv_steps: 5
33
+ xs_eta: 0.0
34
+
35
+ delta_t: 25
36
+ delta_t_start: 100
37
+
38
+ GenerateCamParams:
39
+ init_shape: 'pointe'
40
+ init_prompt: 'A flag.'
41
+ use_pointe_rgb: false
42
+ init_num_pts: 100_000
43
+ phi_range: [-180, 180]
44
+ max_phi_range: [-180, 180]
45
+ rand_cam_gamma: 1.
46
+
47
+ theta_range: [45, 105]
48
+ max_theta_range: [45, 105]
49
+
50
+ radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
51
+ max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
52
+ default_radius: 3.5
53
+
54
+ default_fovy: 0.55
55
+ fovy_range: [0.32, 0.60]
56
+ max_fovy_range: [0.16, 0.60]
57
+
58
+ OptimizationParams:
59
+ iterations: 5000
60
+ save_process: True
61
+ pro_frames_num: 600
62
+ pro_render_45: False
63
+ warmup_iter: 1500 # 2500
64
+
65
+ as_latent_ratio : 0.2
66
+ geo_iter : 0
67
+ densify_from_iter: 100
68
+ densify_until_iter: 3000
69
+ percent_dense: 0.003
70
+ densify_grad_threshold: 0.00075
71
+ progressive_view_iter: 500 #1500
72
+ opacity_reset_interval: 300 #500
73
+
74
+ scale_up_cameras_iter: 500
75
+ fovy_scale_up_factor: [0.75, 1.1]
76
+ phi_scale_up_factor: 1.5
configs/bagel.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: bagel
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'a DSLR photo of a bagel filled with cream cheese and lox.'
15
+ negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.'
16
+ inverse_text: ''
17
+ perpneg: false
18
+ C_batch_size: 4
19
+ t_range: [0.02, 0.5]
20
+ max_t_range: 0.98
21
+ lambda_guidance: 0.1
22
+ guidance_scale: 7.5
23
+ denoise_guidance_scale: 1.0
24
+ noise_seed: 0
25
+
26
+ ddim_inv: true
27
+ annealing_intervals: true
28
+
29
+ xs_delta_t: 200
30
+ xs_inv_steps: 5
31
+ xs_eta: 0.0
32
+
33
+ delta_t: 80
34
+ delta_t_start: 100
35
+
36
+ GenerateCamParams:
37
+ init_shape: 'pointe'
38
+ init_prompt: 'a bagel.'
39
+ use_pointe_rgb: false
40
+ init_num_pts: 100_000
41
+ phi_range: [-180, 180]
42
+ max_phi_range: [-180, 180]
43
+ rand_cam_gamma: 1.
44
+
45
+ theta_range: [45, 105]
46
+ max_theta_range: [45, 105]
47
+
48
+ radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
49
+ max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
50
+ default_radius: 3.5
51
+
52
+ default_fovy: 0.55
53
+ fovy_range: [0.32, 0.60]
54
+ max_fovy_range: [0.16, 0.60]
55
+
56
+ OptimizationParams:
57
+ iterations: 5000
58
+ save_process: True
59
+ pro_frames_num: 600
60
+ pro_render_45: False
61
+ warmup_iter: 1500 # 2500
62
+
63
+ as_latent_ratio : 0.2
64
+ geo_iter : 0
65
+ densify_from_iter: 100
66
+ densify_until_iter: 3000
67
+ percent_dense: 0.003
68
+ densify_grad_threshold: 0.00075
69
+ progressive_view_iter: 500 #1500
70
+ opacity_reset_interval: 300 #500
71
+
72
+ scale_up_cameras_iter: 500
73
+ fovy_scale_up_factor: [0.75, 1.1]
74
+ phi_scale_up_factor: 1.5
configs/cat_armor.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: cat_armor
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'a DSLR photo of a cat wearing armor.'
15
+ negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.'
16
+ inverse_text: ''
17
+ perpneg: true
18
+ C_batch_size: 4
19
+ t_range: [0.02, 0.5]
20
+ max_t_range: 0.98
21
+ lambda_guidance: 0.1
22
+ guidance_scale: 7.5
23
+ denoise_guidance_scale: 1.0
24
+ noise_seed: 0
25
+
26
+ ddim_inv: true
27
+ annealing_intervals: true
28
+
29
+ xs_delta_t: 200
30
+ xs_inv_steps: 5
31
+ xs_eta: 0.0
32
+
33
+ delta_t: 80
34
+ delta_t_start: 100
35
+
36
+ GenerateCamParams:
37
+ init_shape: 'pointe'
38
+ init_prompt: 'a cat.'
39
+ use_pointe_rgb: false
40
+ init_num_pts: 100_000
41
+ phi_range: [-180, 180]
42
+ max_phi_range: [-180, 180]
43
+ rand_cam_gamma: 1.5
44
+
45
+ theta_range: [60, 90]
46
+ max_theta_range: [60, 90]
47
+
48
+ radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
49
+ max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
50
+ default_radius: 3.5
51
+
52
+ default_fovy: 0.55
53
+ fovy_range: [0.32, 0.60]
54
+ max_fovy_range: [0.16, 0.60]
55
+
56
+ OptimizationParams:
57
+ iterations: 5000
58
+ save_process: True
59
+ pro_frames_num: 600
60
+ pro_render_45: False
61
+ warmup_iter: 1500 # 2500
62
+
63
+ as_latent_ratio : 0.2
64
+ geo_iter : 0
65
+ densify_from_iter: 100
66
+ densify_until_iter: 3000
67
+ percent_dense: 0.003
68
+ densify_grad_threshold: 0.00075
69
+ progressive_view_iter: 500 #1500
70
+ opacity_reset_interval: 300 #500
71
+
72
+ scale_up_cameras_iter: 500
73
+ fovy_scale_up_factor: [0.75, 1.1]
74
+ phi_scale_up_factor: 1.5
configs/crown.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: crown
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'a DSLR photo of the Imperial State Crown of England.'
15
+ negative: 'unrealistic, blurry, low quality.'
16
+ inverse_text: ''
17
+ perpneg: false
18
+ C_batch_size: 4
19
+ t_range: [0.02, 0.5]
20
+ max_t_range: 0.98
21
+ lambda_guidance: 0.1
22
+ guidance_scale: 7.5
23
+ denoise_guidance_scale: 1.0
24
+ noise_seed: 0
25
+
26
+ ddim_inv: true
27
+ annealing_intervals: true
28
+
29
+ xs_delta_t: 200
30
+ xs_inv_steps: 5
31
+ xs_eta: 0.0
32
+
33
+ delta_t: 80
34
+ delta_t_start: 100
35
+
36
+ GenerateCamParams:
37
+ init_shape: 'pointe'
38
+ init_prompt: 'the Imperial State Crown of England.'
39
+ use_pointe_rgb: false
40
+ init_num_pts: 100_000
41
+ phi_range: [-180, 180]
42
+ max_phi_range: [-180, 180]
43
+ rand_cam_gamma: 1.
44
+
45
+ theta_range: [45, 105]
46
+ max_theta_range: [45, 105]
47
+
48
+ radius_range: [5.2, 5.5]
49
+ max_radius_range: [3.5, 5.0]
50
+ default_radius: 3.5
51
+
52
+ default_fovy: 0.55
53
+ fovy_range: [0.32, 0.60]
54
+ max_fovy_range: [0.16, 0.60]
55
+
56
+ OptimizationParams:
57
+ iterations: 5000
58
+ save_process: True
59
+ pro_frames_num: 600
60
+ pro_render_45: False
61
+ warmup_iter: 1500 # 2500
62
+
63
+ as_latent_ratio : 0.2
64
+ geo_iter : 0
65
+ densify_from_iter: 100
66
+ densify_until_iter: 3000
67
+ percent_dense: 0.003
68
+ densify_grad_threshold: 0.00075
69
+ progressive_view_iter: 500
70
+ opacity_reset_interval: 300
71
+
72
+ scale_up_cameras_iter: 500
73
+ fovy_scale_up_factor: [0.75, 1.1]
74
+ phi_scale_up_factor: 1.5
configs/football_helmet.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: football_helmet
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'a DSLR photo of a football helmet.'
15
+ negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution, oversaturation.'
16
+ inverse_text: ''
17
+ perpneg: false
18
+ C_batch_size: 4
19
+ t_range: [0.02, 0.5]
20
+ max_t_range: 0.98
21
+ lambda_guidance: 0.1
22
+ guidance_scale: 7.5
23
+ denoise_guidance_scale: 1.0
24
+
25
+ noise_seed: 0
26
+
27
+ ddim_inv: true
28
+ accum: false
29
+ annealing_intervals: true
30
+
31
+ xs_delta_t: 200
32
+ xs_inv_steps: 5
33
+ xs_eta: 0.0
34
+
35
+ delta_t: 50
36
+ delta_t_start: 100
37
+
38
+ GenerateCamParams:
39
+ init_shape: 'pointe'
40
+ init_prompt: 'a football helmet.'
41
+ use_pointe_rgb: false
42
+ init_num_pts: 100_000
43
+ phi_range: [-180, 180]
44
+ max_phi_range: [-180, 180]
45
+
46
+ theta_range: [45, 90]
47
+ max_theta_range: [45, 90]
48
+
49
+ radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
50
+ max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
51
+ default_radius: 3.5
52
+
53
+ default_fovy: 0.55
54
+ fovy_range: [0.32, 0.60]
55
+ max_fovy_range: [0.16, 0.60]
56
+
57
+ OptimizationParams:
58
+ iterations: 5000
59
+ save_process: True
60
+ pro_frames_num: 600
61
+ pro_render_45: False
62
+ warmup_iter: 1500 # 2500
63
+
64
+ as_latent_ratio : 0.2
65
+ geo_iter : 0
66
+ densify_from_iter: 100
67
+ densify_until_iter: 3000
68
+ percent_dense: 0.003
69
+ densify_grad_threshold: 0.00075
70
+ progressive_view_iter: 500 #1500
71
+ opacity_reset_interval: 300 #500
72
+
73
+ scale_up_cameras_iter: 500
74
+ fovy_scale_up_factor: [0.75, 1.1]
75
+ phi_scale_up_factor: 1.5
configs/hamburger.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: hamburger
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'A delicious hamburger.'
15
+ negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.'
16
+ inverse_text: ''
17
+ perpneg: false
18
+ C_batch_size: 4
19
+ t_range: [0.02, 0.5]
20
+ max_t_range: 0.98
21
+ lambda_guidance: 0.1
22
+ guidance_scale: 7.5
23
+ denoise_guidance_scale: 1.0
24
+
25
+ noise_seed: 0
26
+
27
+ ddim_inv: true
28
+ annealing_intervals: true
29
+
30
+ xs_delta_t: 200
31
+ xs_inv_steps: 5
32
+ xs_eta: 0.0
33
+
34
+ delta_t: 50
35
+ delta_t_start: 100
36
+
37
+ GenerateCamParams:
38
+ init_shape: 'sphere'
39
+ init_prompt: '.'
40
+ use_pointe_rgb: false
41
+ init_num_pts: 100_000
42
+ phi_range: [-180, 180]
43
+ max_phi_range: [-180, 180]
44
+ rand_cam_gamma: 1.
45
+
46
+ theta_range: [45, 105]
47
+ max_theta_range: [45, 105]
48
+
49
+ radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
50
+ max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
51
+ default_radius: 3.5
52
+
53
+ default_fovy: 0.55
54
+ fovy_range: [0.32, 0.60]
55
+ max_fovy_range: [0.16, 0.60]
56
+
57
+ OptimizationParams:
58
+ iterations: 5000
59
+ save_process: True
60
+ pro_frames_num: 600
61
+ pro_render_45: False
62
+ warmup_iter: 1500 # 2500
63
+
64
+ as_latent_ratio : 0.2
65
+ geo_iter : 0
66
+ densify_from_iter: 100
67
+ densify_until_iter: 3000
68
+ percent_dense: 0.003
69
+ densify_grad_threshold: 0.00075
70
+ progressive_view_iter: 500 #1500
71
+ opacity_reset_interval: 300 #500
72
+
73
+ scale_up_cameras_iter: 500
74
+ fovy_scale_up_factor: [0.75, 1.1]
75
+ phi_scale_up_factor: 1.5
configs/ts_lora.yaml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: TS_lora
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'A <Taylor_Swift> wearing sunglasses.'
15
+ LoRA_path: "./custom_example/lora/Taylor_Swift/step_inv_1000.safetensors"
16
+ negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.'
17
+ inverse_text: ''
18
+ perpneg: true
19
+ C_batch_size: 4
20
+ t_range: [0.02, 0.5]
21
+ max_t_range: 0.98
22
+ lambda_guidance: 0.1
23
+ guidance_scale: 7.5
24
+ denoise_guidance_scale: 1.0
25
+
26
+ noise_seed: 0
27
+
28
+ ddim_inv: true
29
+ annealing_intervals: true
30
+
31
+ xs_delta_t: 200
32
+ xs_inv_steps: 5
33
+ xs_eta: 0.0
34
+
35
+ delta_t: 80
36
+ delta_t_start: 100
37
+
38
+ GenerateCamParams:
39
+ init_shape: 'pointe'
40
+ init_prompt: 'a girl head.'
41
+ use_pointe_rgb: false
42
+ init_num_pts: 100_000
43
+ phi_range: [-80, 80]
44
+ max_phi_range: [-180, 180]
45
+ rand_cam_gamma: 1.5
46
+
47
+ theta_range: [60, 120]
48
+ max_theta_range: [60, 120]
49
+
50
+ radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
51
+ max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
52
+ default_radius: 3.5
53
+
54
+ default_fovy: 0.55
55
+ fovy_range: [0.32, 0.60]
56
+ max_fovy_range: [0.16, 0.60]
57
+
58
+ OptimizationParams:
59
+ iterations: 5000
60
+ save_process: True
61
+ pro_frames_num: 600
62
+ pro_render_45: False
63
+ warmup_iter: 1500 # 2500
64
+
65
+ as_latent_ratio : 0.2
66
+ geo_iter : 0
67
+ densify_from_iter: 100
68
+ densify_until_iter: 3000
69
+ percent_dense: 0.003
70
+ densify_grad_threshold: 0.00075
71
+ progressive_view_iter: 500 #1500
72
+ opacity_reset_interval: 300 #500
73
+
74
+ scale_up_cameras_iter: 500
75
+ fovy_scale_up_factor: [0.75, 1.1]
76
+ phi_scale_up_factor: 1.5
configs/white_hair_ironman.yaml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: white_hair_IRONMAN
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'A portrait of IRONMAN, white hair, head, photorealistic, 8K, HDR.'
15
+ negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution.'
16
+ inverse_text: ''
17
+ perpneg: true
18
+ C_batch_size: 4
19
+ max_t_range: 0.98
20
+ lambda_guidance: 0.1
21
+ guidance_scale: 7.5
22
+ denoise_guidance_scale: 1.0
23
+ noise_seed: 0
24
+
25
+ ddim_inv: true
26
+ annealing_intervals: true
27
+
28
+ xs_delta_t: 200
29
+ xs_inv_steps: 5
30
+ xs_eta: 0.0
31
+
32
+ delta_t: 50
33
+ delta_t_start: 100
34
+
35
+ GenerateCamParams:
36
+ init_shape: 'pointe'
37
+ init_prompt: 'a man head.'
38
+ use_pointe_rgb: false
39
+ init_num_pts: 100_000
40
+ phi_range: [-80, 80]
41
+ max_phi_range: [-180, 180]
42
+ rand_cam_gamma: 1.5
43
+
44
+ theta_range: [45, 90]
45
+ max_theta_range: [45, 90]
46
+
47
+ radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
48
+ max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
49
+ default_radius: 3.5
50
+
51
+ default_fovy: 0.55
52
+ fovy_range: [0.32, 0.60]
53
+ max_fovy_range: [0.16, 0.60]
54
+
55
+ OptimizationParams:
56
+ iterations: 5000
57
+ save_process: True
58
+ pro_frames_num: 600
59
+ pro_render_45: False
60
+ warmup_iter: 1500 # 2500
61
+
62
+ as_latent_ratio : 0.2
63
+ geo_iter : 0
64
+ densify_from_iter: 100
65
+ densify_until_iter: 3000
66
+ percent_dense: 0.003
67
+ densify_grad_threshold: 0.00075
68
+ progressive_view_iter: 500 #1500
69
+ opacity_reset_interval: 300 #500
70
+
71
+ scale_up_cameras_iter: 500
72
+ fovy_scale_up_factor: [0.75, 1.1]
73
+ phi_scale_up_factor: 1.5
configs/zombie_joker.yaml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ port: 2355
2
+ save_video: true
3
+ seed: 0
4
+
5
+ PipelineParams:
6
+ convert_SHs_python: False #true = using direct rgb
7
+ ModelParams:
8
+ workspace: zombie_joker
9
+ sh_degree: 0
10
+ bg_aug_ratio: 0.66
11
+
12
+ GuidanceParams:
13
+ model_key: 'stabilityai/stable-diffusion-2-1-base'
14
+ text: 'Zombie JOKER, head, photorealistic, 8K, HDR.'
15
+ negative: 'unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, dark, low-resolution, oversaturation.'
16
+ inverse_text: ''
17
+ perpneg: true
18
+ C_batch_size: 4
19
+
20
+ t_range: [0.02, 0.5]
21
+ max_t_range: 0.98
22
+ lambda_guidance: 0.1
23
+ guidance_scale: 7.5
24
+ denoise_guidance_scale: 1.0
25
+ noise_seed: 0
26
+
27
+ ddim_inv: true
28
+ annealing_intervals: true
29
+
30
+ xs_delta_t: 200
31
+ xs_inv_steps: 5
32
+ xs_eta: 0.0
33
+
34
+ delta_t: 50
35
+ delta_t_start: 100
36
+
37
+ GenerateCamParams:
38
+ init_shape: 'pointe'
39
+ init_prompt: 'a man head.'
40
+ use_pointe_rgb: false
41
+ init_num_pts: 100_000
42
+ phi_range: [-80, 80]
43
+ max_phi_range: [-180, 180]
44
+ rand_cam_gamma: 1.5
45
+
46
+ theta_range: [45, 90]
47
+ max_theta_range: [45, 90]
48
+
49
+ radius_range: [5.2, 5.5] #[3.8, 4.5] #[3.0, 3.5]
50
+ max_radius_range: [3.5, 5.0] #[3.8, 4.5] #[3.0, 3.5]
51
+ default_radius: 3.5
52
+
53
+ default_fovy: 0.55
54
+ fovy_range: [0.32, 0.60]
55
+ max_fovy_range: [0.16, 0.60]
56
+
57
+ OptimizationParams:
58
+ iterations: 5000
59
+ save_process: True
60
+ pro_frames_num: 600
61
+ pro_render_45: False
62
+ warmup_iter: 1500 # 2500
63
+
64
+ as_latent_ratio : 0.2
65
+ geo_iter : 0
66
+ densify_from_iter: 100
67
+ densify_until_iter: 3000
68
+ percent_dense: 0.003
69
+ densify_grad_threshold: 0.00075
70
+ progressive_view_iter: 500 #1500
71
+ opacity_reset_interval: 300 #500
72
+
73
+ scale_up_cameras_iter: 500
74
+ fovy_scale_up_factor: [0.75, 1.1]
75
+ phi_scale_up_factor: 1.5
environment.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: LucidDreamer
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - cudatoolkit=11.6
8
+ - plyfile=0.8.1
9
+ - python=3.9
10
+ - pip=22.3.1
11
+ - pytorch=1.12.1
12
+ - torchaudio=0.12.1
13
+ - torchvision=0.15.2
14
+ - tqdm
15
+ - pip:
16
+ - mediapipe
17
+ - Pillow
18
+ - diffusers==0.18.2
19
+ - xformers==0.0.20
20
+ - transformers==4.30.2
21
+ - fire==0.5.0
22
+ - huggingface_hub==0.16.4
23
+ - imageio==2.31.1
24
+ - imageio-ffmpeg
25
+ - PyYAML
26
+ - safetensors
27
+ - wandb
28
+ - accelerate
29
+ - triton
example/Donut.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4633e31ff1ff161e0bd015c166c507cad140e14aef616eecef95c32da5dd1902
3
+ size 2264633
example/boots.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f117d721a095ae913d17072ee5ed4373c95f1a8851ca6e9e254bf5efeaf56cb
3
+ size 5358683
example/durian.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da35c90e1212627da08180fcb513a0d402dc189c613c00d18a3b3937c992b47d
3
+ size 9316285
example/pillow_huskies.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc53845fdca59e413765833aed51a9a93e2962719a63f4471e9fa7943e217cf6
3
+ size 3586741
example/wooden_car.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4e3b6b31a1d2c9e3791c4c2c7b278d1eb3ae209c94b5d5f3834b4ea5d6d3c16
3
+ size 1660564
gaussian_renderer/__init__.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
15
+ from scene.gaussian_model import GaussianModel
16
+ from utils.sh_utils import eval_sh, SH2RGB
17
+ from utils.graphics_utils import fov2focal
18
+ import random
19
+
20
+
21
+ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, scaling_modifier = 1.0, black_video = False,
22
+ override_color = None, sh_deg_aug_ratio = 0.1, bg_aug_ratio = 0.3, shs_aug_ratio=1.0, scale_aug_ratio=1.0, test = False):
23
+ """
24
+ Render the scene.
25
+
26
+ Background tensor (bg_color) must be on GPU!
27
+ """
28
+
29
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
30
+ screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
31
+ try:
32
+ screenspace_points.retain_grad()
33
+ except:
34
+ pass
35
+
36
+ if black_video:
37
+ bg_color = torch.zeros_like(bg_color)
38
+ #Aug
39
+ if random.random() < sh_deg_aug_ratio and not test:
40
+ act_SH = 0
41
+ else:
42
+ act_SH = pc.active_sh_degree
43
+
44
+ if random.random() < bg_aug_ratio and not test:
45
+ if random.random() < 0.5:
46
+ bg_color = torch.rand_like(bg_color)
47
+ else:
48
+ bg_color = torch.zeros_like(bg_color)
49
+ # bg_color = torch.zeros_like(bg_color)
50
+
51
+ #bg_color = torch.zeros_like(bg_color)
52
+ # Set up rasterization configuration
53
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
54
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
55
+ try:
56
+ raster_settings = GaussianRasterizationSettings(
57
+ image_height=int(viewpoint_camera.image_height),
58
+ image_width=int(viewpoint_camera.image_width),
59
+ tanfovx=tanfovx,
60
+ tanfovy=tanfovy,
61
+ bg=bg_color,
62
+ scale_modifier=scaling_modifier,
63
+ viewmatrix=viewpoint_camera.world_view_transform,
64
+ projmatrix=viewpoint_camera.full_proj_transform,
65
+ sh_degree=act_SH,
66
+ campos=viewpoint_camera.camera_center,
67
+ prefiltered=False
68
+ )
69
+ except TypeError as e:
70
+ raster_settings = GaussianRasterizationSettings(
71
+ image_height=int(viewpoint_camera.image_height),
72
+ image_width=int(viewpoint_camera.image_width),
73
+ tanfovx=tanfovx,
74
+ tanfovy=tanfovy,
75
+ bg=bg_color,
76
+ scale_modifier=scaling_modifier,
77
+ viewmatrix=viewpoint_camera.world_view_transform,
78
+ projmatrix=viewpoint_camera.full_proj_transform,
79
+ sh_degree=act_SH,
80
+ campos=viewpoint_camera.camera_center,
81
+ prefiltered=False,
82
+ debug=False
83
+ )
84
+
85
+
86
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
87
+
88
+ means3D = pc.get_xyz
89
+ means2D = screenspace_points
90
+ opacity = pc.get_opacity
91
+
92
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
93
+ # scaling / rotation by the rasterizer.
94
+ scales = None
95
+ rotations = None
96
+ cov3D_precomp = None
97
+ if pipe.compute_cov3D_python:
98
+ cov3D_precomp = pc.get_covariance(scaling_modifier)
99
+ else:
100
+ scales = pc.get_scaling
101
+ rotations = pc.get_rotation
102
+
103
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
104
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
105
+ shs = None
106
+ colors_precomp = None
107
+ if colors_precomp is None:
108
+ if pipe.convert_SHs_python:
109
+ raw_rgb = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2).squeeze()[:,:3]
110
+ rgb = torch.sigmoid(raw_rgb)
111
+ colors_precomp = rgb
112
+ else:
113
+ shs = pc.get_features
114
+ else:
115
+ colors_precomp = override_color
116
+
117
+ if random.random() < shs_aug_ratio and not test:
118
+ variance = (0.2 ** 0.5) * shs
119
+ shs = shs + (torch.randn_like(shs) * variance)
120
+
121
+ # add noise to scales
122
+ if random.random() < scale_aug_ratio and not test:
123
+ variance = (0.2 ** 0.5) * scales / 4
124
+ scales = torch.clamp(scales + (torch.randn_like(scales) * variance), 0.0)
125
+
126
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
127
+
128
+ rendered_image, radii, depth_alpha = rasterizer(
129
+ means3D = means3D,
130
+ means2D = means2D,
131
+ shs = shs,
132
+ colors_precomp = colors_precomp,
133
+ opacities = opacity,
134
+ scales = scales,
135
+ rotations = rotations,
136
+ cov3D_precomp = cov3D_precomp)
137
+ depth, alpha = torch.chunk(depth_alpha, 2)
138
+ # bg_train = pc.get_background
139
+ # rendered_image = bg_train*alpha.repeat(3,1,1) + rendered_image
140
+ # focal = 1 / (2 * math.tan(viewpoint_camera.FoVx / 2)) #torch.tan(torch.tensor(viewpoint_camera.FoVx) / 2) * (2. / 2
141
+ # disparity = focal / (depth + 1e-9)
142
+ # max_disp = torch.max(disparity)
143
+ # min_disp = torch.min(disparity[disparity > 0])
144
+ # norm_disparity = (disparity - min_disp) / (max_disp - min_disp)
145
+ # # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
146
+ # # They will be excluded from value updates used in the splitting criteria.
147
+ # return {"render": rendered_image,
148
+ # "depth": norm_disparity,
149
+
150
+ focal = 1 / (2 * math.tan(viewpoint_camera.FoVx / 2))
151
+ disp = focal / (depth + (alpha * 10) + 1e-5)
152
+
153
+ try:
154
+ min_d = disp[alpha <= 0.1].min()
155
+ except Exception:
156
+ min_d = disp.min()
157
+
158
+ disp = torch.clamp((disp - min_d) / (disp.max() - min_d), 0.0, 1.0)
159
+
160
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
161
+ # They will be excluded from value updates used in the splitting criteria.
162
+ return {"render": rendered_image,
163
+ "depth": disp,
164
+ "alpha": alpha,
165
+ "viewspace_points": screenspace_points,
166
+ "visibility_filter" : radii > 0,
167
+ "radii": radii,
168
+ "scales": scales}
gaussian_renderer/network_gui.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import traceback
14
+ import socket
15
+ import json
16
+ from scene.cameras import MiniCam
17
+
18
+ host = "127.0.0.1"
19
+ port = 6009
20
+
21
+ conn = None
22
+ addr = None
23
+
24
+ listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
25
+
26
+ def init(wish_host, wish_port):
27
+ global host, port, listener
28
+ host = wish_host
29
+ port = wish_port
30
+ cnt = 0
31
+ while True:
32
+ try:
33
+ listener.bind((host, port))
34
+ break
35
+ except:
36
+ if cnt == 10:
37
+ break
38
+ cnt += 1
39
+ port += 1
40
+ listener.listen()
41
+ listener.settimeout(0)
42
+
43
+ def try_connect():
44
+ global conn, addr, listener
45
+ try:
46
+ conn, addr = listener.accept()
47
+ print(f"\nConnected by {addr}")
48
+ conn.settimeout(None)
49
+ except Exception as inst:
50
+ pass
51
+
52
+ def read():
53
+ global conn
54
+ messageLength = conn.recv(4)
55
+ messageLength = int.from_bytes(messageLength, 'little')
56
+ message = conn.recv(messageLength)
57
+ return json.loads(message.decode("utf-8"))
58
+
59
+ def send(message_bytes, verify):
60
+ global conn
61
+ if message_bytes != None:
62
+ conn.sendall(message_bytes)
63
+ conn.sendall(len(verify).to_bytes(4, 'little'))
64
+ conn.sendall(bytes(verify, 'ascii'))
65
+
66
+ def receive():
67
+ message = read()
68
+
69
+ width = message["resolution_x"]
70
+ height = message["resolution_y"]
71
+
72
+ if width != 0 and height != 0:
73
+ try:
74
+ do_training = bool(message["train"])
75
+ fovy = message["fov_y"]
76
+ fovx = message["fov_x"]
77
+ znear = message["z_near"]
78
+ zfar = message["z_far"]
79
+ do_shs_python = bool(message["shs_python"])
80
+ do_rot_scale_python = bool(message["rot_scale_python"])
81
+ keep_alive = bool(message["keep_alive"])
82
+ scaling_modifier = message["scaling_modifier"]
83
+ world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
84
+ world_view_transform[:,1] = -world_view_transform[:,1]
85
+ world_view_transform[:,2] = -world_view_transform[:,2]
86
+ full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
87
+ full_proj_transform[:,1] = -full_proj_transform[:,1]
88
+ custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
89
+ except Exception as e:
90
+ print("")
91
+ traceback.print_exc()
92
+ raise e
93
+ return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
94
+ else:
95
+ return None, None, None, None, None, None
gradio_demo.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from train import *
4
+
5
+ example_inputs = [[
6
+ "A DSLR photo of a Rugged, vintage-inspired hiking boots with a weathered leather finish, best quality, 4K, HD.",
7
+ "Rugged, vintage-inspired hiking boots with a weathered leather finish."
8
+ ], [
9
+ "a DSLR photo of a Cream Cheese Donut.",
10
+ "a Donut."
11
+ ], [
12
+ "A durian, 8k, HDR.",
13
+ "A durian"
14
+ ], [
15
+ "A pillow with huskies printed on it",
16
+ "A pillow"
17
+ ], [
18
+ "A DSLR photo of a wooden car, super detailed, best quality, 4K, HD.",
19
+ "a wooden car."
20
+ ]]
21
+ example_outputs = [
22
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/boots.mp4'), autoplay=True),
23
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/Donut.mp4'), autoplay=True),
24
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/durian.mp4'), autoplay=True),
25
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/pillow_huskies.mp4'), autoplay=True),
26
+ gr.Video(value=os.path.join(os.path.dirname(__file__), 'example/wooden_car.mp4'), autoplay=True)
27
+ ]
28
+
29
+ def main(prompt, init_prompt, negative_prompt, num_iter, CFG, seed):
30
+ if [prompt, init_prompt] in example_inputs:
31
+ return example_outputs[example_inputs.index([prompt, init_prompt])]
32
+ args, lp, op, pp, gcp, gp = args_parser(default_opt=os.path.join(os.path.dirname(__file__), 'configs/white_hair_ironman.yaml'))
33
+ gp.text = prompt
34
+ gp.negative = negative_prompt
35
+ if len(init_prompt) > 1:
36
+ gcp.init_shape = 'pointe'
37
+ gcp.init_prompt = init_prompt
38
+ else:
39
+ gcp.init_shape = 'sphere'
40
+ gcp.init_prompt = '.'
41
+ op.iterations = num_iter
42
+ gp.guidance_scale = CFG
43
+ gp.noise_seed = int(seed)
44
+ lp.workspace = 'gradio_demo'
45
+ video_path = start_training(args, lp, op, pp, gcp, gp)
46
+ return gr.Video(value=video_path, autoplay=True)
47
+
48
+ with gr.Blocks() as demo:
49
+ gr.Markdown("# <center>LucidDreamer: Towards High-Fidelity Text-to-3D Generation via Interval Score Matching</center>")
50
+ gr.Markdown("<center>Yixun Liang*, Xin Yang*, Jiantao Lin, Haodong Li, Xiaogang Xu, Yingcong Chen**</center>")
51
+ gr.Markdown("<center>*: Equal contribution. **: Corresponding author.</center>")
52
+ gr.Markdown("We present a text-to-3D generation framework, named the *LucidDreamer*, to distill high-fidelity textures and shapes from pretrained 2D diffusion models.")
53
+ gr.Markdown("<details><summary><strong>CLICK for the full abstract</strong></summary>The recent advancements in text-to-3D generation mark a significant milestone in generative models, unlocking new possibilities for creating imaginative 3D assets across various real-world scenarios. While recent advancements in text-to-3D generation have shown promise, they often fall short in rendering detailed and high-quality 3D models. This problem is especially prevalent as many methods base themselves on Score Distillation Sampling (SDS). This paper identifies a notable deficiency in SDS, that it brings inconsistent and low-quality updating direction for the 3D model, causing the over-smoothing effect. To address this, we propose a novel approach called Interval Score Matching (ISM). ISM employs deterministic diffusing trajectories and utilizes interval-based score matching to counteract over-smoothing. Furthermore, we incorporate 3D Gaussian Splatting into our text-to-3D generation pipeline. Extensive experiments show that our model largely outperforms the state-of-the-art in quality and training efficiency.</details>")
54
+ gr.Interface(fn=main, inputs=[gr.Textbox(lines=2, value="A portrait of IRONMAN, white hair, head, photorealistic, 8K, HDR.", label="Your prompt"),
55
+ gr.Textbox(lines=1, value="a man head.", label="Point-E init prompt (optional)"),
56
+ gr.Textbox(lines=2, value="unrealistic, blurry, low quality, out of focus, ugly, low contrast, dull, low-resolution.", label="Negative prompt (optional)"),
57
+ gr.Slider(1000, 5000, value=5000, label="Number of iterations"),
58
+ gr.Slider(7.5, 100, value=7.5, label="CFG"),
59
+ gr.Number(value=0, label="Seed")],
60
+ outputs="playable_video",
61
+ examples=example_inputs)
62
+ demo.launch()
guidance/perpneg_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Please refer to the https://perp-neg.github.io/ for details about the paper and algorithm
4
+ def get_perpendicular_component(x, y):
5
+ assert x.shape == y.shape
6
+ return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y
7
+
8
+
9
+ def batch_get_perpendicular_component(x, y):
10
+ assert x.shape == y.shape
11
+ result = []
12
+ for i in range(x.shape[0]):
13
+ result.append(get_perpendicular_component(x[i], y[i]))
14
+ return torch.stack(result)
15
+
16
+
17
+ def weighted_perpendicular_aggregator(delta_noise_preds, weights, batch_size):
18
+ """
19
+ Notes:
20
+ - weights: an array with the weights for combining the noise predictions
21
+ - delta_noise_preds: [B x K, 4, 64, 64], K = max_prompts_per_dir
22
+ """
23
+ delta_noise_preds = delta_noise_preds.split(batch_size, dim=0) # K x [B, 4, 64, 64]
24
+ weights = weights.split(batch_size, dim=0) # K x [B]
25
+ # print(f"{weights[0].shape = } {weights = }")
26
+
27
+ assert torch.all(weights[0] == 1.0)
28
+
29
+ main_positive = delta_noise_preds[0] # [B, 4, 64, 64]
30
+
31
+ accumulated_output = torch.zeros_like(main_positive)
32
+ for i, complementary_noise_pred in enumerate(delta_noise_preds[1:], start=1):
33
+ # print(f"\n{i = }, {weights[i] = }, {weights[i].shape = }\n")
34
+
35
+ idx_non_zero = torch.abs(weights[i]) > 1e-4
36
+
37
+ # print(f"{idx_non_zero.shape = }, {idx_non_zero = }")
38
+ # print(f"{weights[i][idx_non_zero].shape = }, {weights[i][idx_non_zero] = }")
39
+ # print(f"{complementary_noise_pred.shape = }, {complementary_noise_pred[idx_non_zero].shape = }")
40
+ # print(f"{main_positive.shape = }, {main_positive[idx_non_zero].shape = }")
41
+ if sum(idx_non_zero) == 0:
42
+ continue
43
+ accumulated_output[idx_non_zero] += weights[i][idx_non_zero].reshape(-1, 1, 1, 1) * batch_get_perpendicular_component(complementary_noise_pred[idx_non_zero], main_positive[idx_non_zero])
44
+
45
+ #assert accumulated_output.shape == main_positive.shape,# f"{accumulated_output.shape = }, {main_positive.shape = }"
46
+
47
+
48
+ return accumulated_output + main_positive
guidance/sd_step.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
2
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline, DDPMScheduler, DDIMScheduler, EulerDiscreteScheduler, \
3
+ EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ControlNetModel, \
4
+ DDIMInverseScheduler
5
+ from diffusers.utils import BaseOutput, deprecate
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms as T
12
+
13
+ from typing import List, Optional, Tuple, Union
14
+ from dataclasses import dataclass
15
+
16
+ from diffusers.utils import BaseOutput, randn_tensor
17
+
18
+
19
+ @dataclass
20
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
21
+ class DDIMSchedulerOutput(BaseOutput):
22
+ """
23
+ Output class for the scheduler's `step` function output.
24
+
25
+ Args:
26
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
27
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
28
+ denoising loop.
29
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
30
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
31
+ `pred_original_sample` can be used to preview progress or for guidance.
32
+ """
33
+
34
+ prev_sample: torch.FloatTensor
35
+ pred_original_sample: Optional[torch.FloatTensor] = None
36
+
37
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
38
+ def ddim_add_noise(
39
+ self,
40
+ original_samples: torch.FloatTensor,
41
+ noise: torch.FloatTensor,
42
+ timesteps: torch.IntTensor,
43
+ ) -> torch.FloatTensor:
44
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
45
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
46
+ timesteps = timesteps.to(original_samples.device)
47
+
48
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
49
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
50
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
51
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
52
+
53
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
54
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
55
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
56
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
57
+
58
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
59
+ return noisy_samples
60
+
61
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.step
62
+ def ddim_step(
63
+ self,
64
+ model_output: torch.FloatTensor,
65
+ timestep: int,
66
+ sample: torch.FloatTensor,
67
+ delta_timestep: int = None,
68
+ eta: float = 0.0,
69
+ use_clipped_model_output: bool = False,
70
+ generator=None,
71
+ variance_noise: Optional[torch.FloatTensor] = None,
72
+ return_dict: bool = True,
73
+ **kwargs
74
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
75
+ """
76
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
77
+ process from the learned model outputs (most often the predicted noise).
78
+
79
+ Args:
80
+ model_output (`torch.FloatTensor`):
81
+ The direct output from learned diffusion model.
82
+ timestep (`float`):
83
+ The current discrete timestep in the diffusion chain.
84
+ sample (`torch.FloatTensor`):
85
+ A current instance of a sample created by the diffusion process.
86
+ eta (`float`):
87
+ The weight of noise for added noise in diffusion step.
88
+ use_clipped_model_output (`bool`, defaults to `False`):
89
+ If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
90
+ because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
91
+ clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
92
+ `use_clipped_model_output` has no effect.
93
+ generator (`torch.Generator`, *optional*):
94
+ A random number generator.
95
+ variance_noise (`torch.FloatTensor`):
96
+ Alternative to generating noise with `generator` by directly providing the noise for the variance
97
+ itself. Useful for methods such as [`CycleDiffusion`].
98
+ return_dict (`bool`, *optional*, defaults to `True`):
99
+ Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`.
100
+
101
+ Returns:
102
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
103
+ If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a
104
+ tuple is returned where the first element is the sample tensor.
105
+
106
+ """
107
+ if self.num_inference_steps is None:
108
+ raise ValueError(
109
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
110
+ )
111
+
112
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
113
+ # Ideally, read DDIM paper in-detail understanding
114
+
115
+ # Notation (<variable name> -> <name in paper>
116
+ # - pred_noise_t -> e_theta(x_t, t)
117
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
118
+ # - std_dev_t -> sigma_t
119
+ # - eta -> η
120
+ # - pred_sample_direction -> "direction pointing to x_t"
121
+ # - pred_prev_sample -> "x_t-1"
122
+
123
+
124
+ if delta_timestep is None:
125
+ # 1. get previous step value (=t+1)
126
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
127
+ else:
128
+ prev_timestep = timestep - delta_timestep
129
+
130
+ # 2. compute alphas, betas
131
+ alpha_prod_t = self.alphas_cumprod[timestep]
132
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
133
+
134
+ beta_prod_t = 1 - alpha_prod_t
135
+
136
+ # 3. compute predicted original sample from predicted noise also called
137
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
138
+ if self.config.prediction_type == "epsilon":
139
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
140
+ pred_epsilon = model_output
141
+ elif self.config.prediction_type == "sample":
142
+ pred_original_sample = model_output
143
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
144
+ elif self.config.prediction_type == "v_prediction":
145
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
146
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
147
+ else:
148
+ raise ValueError(
149
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
150
+ " `v_prediction`"
151
+ )
152
+
153
+ # 4. Clip or threshold "predicted x_0"
154
+ if self.config.thresholding:
155
+ pred_original_sample = self._threshold_sample(pred_original_sample)
156
+ elif self.config.clip_sample:
157
+ pred_original_sample = pred_original_sample.clamp(
158
+ -self.config.clip_sample_range, self.config.clip_sample_range
159
+ )
160
+
161
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
162
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
163
+ # if prev_timestep < timestep:
164
+ # else:
165
+ # variance = abs(self._get_variance(prev_timestep, timestep))
166
+
167
+ variance = abs(self._get_variance(timestep, prev_timestep))
168
+
169
+ std_dev_t = eta * variance
170
+ std_dev_t = min((1 - alpha_prod_t_prev) / 2, std_dev_t) ** 0.5
171
+
172
+ if use_clipped_model_output:
173
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
174
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
175
+
176
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
177
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
178
+
179
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
180
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
181
+
182
+ if eta > 0:
183
+ if variance_noise is not None and generator is not None:
184
+ raise ValueError(
185
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
186
+ " `variance_noise` stays `None`."
187
+ )
188
+
189
+ if variance_noise is None:
190
+ variance_noise = randn_tensor(
191
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
192
+ )
193
+ variance = std_dev_t * variance_noise
194
+
195
+ prev_sample = prev_sample + variance
196
+
197
+ prev_sample = torch.nan_to_num(prev_sample)
198
+
199
+ if not return_dict:
200
+ return (prev_sample,)
201
+
202
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
203
+
204
+ def pred_original(
205
+ self,
206
+ model_output: torch.FloatTensor,
207
+ timesteps: int,
208
+ sample: torch.FloatTensor,
209
+ ):
210
+ if isinstance(self, DDPMScheduler) or isinstance(self, DDIMScheduler):
211
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
212
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
213
+ timesteps = timesteps.to(sample.device)
214
+
215
+ # 1. compute alphas, betas
216
+ alpha_prod_t = alphas_cumprod[timesteps]
217
+ while len(alpha_prod_t.shape) < len(sample.shape):
218
+ alpha_prod_t = alpha_prod_t.unsqueeze(-1)
219
+
220
+ beta_prod_t = 1 - alpha_prod_t
221
+
222
+ # 2. compute predicted original sample from predicted noise also called
223
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
224
+ if self.config.prediction_type == "epsilon":
225
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
226
+ elif self.config.prediction_type == "sample":
227
+ pred_original_sample = model_output
228
+ elif self.config.prediction_type == "v_prediction":
229
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
230
+ else:
231
+ raise ValueError(
232
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
233
+ " `v_prediction` for the DDPMScheduler."
234
+ )
235
+
236
+ # 3. Clip or threshold "predicted x_0"
237
+ if self.config.thresholding:
238
+ pred_original_sample = self._threshold_sample(pred_original_sample)
239
+ elif self.config.clip_sample:
240
+ pred_original_sample = pred_original_sample.clamp(
241
+ -self.config.clip_sample_range, self.config.clip_sample_range
242
+ )
243
+ elif isinstance(self, EulerAncestralDiscreteScheduler) or isinstance(self, EulerDiscreteScheduler):
244
+ timestep = timesteps.to(self.timesteps.device)
245
+
246
+ step_index = (self.timesteps == timestep).nonzero().item()
247
+ sigma = self.sigmas[step_index].to(device=sample.device, dtype=sample.dtype)
248
+
249
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
250
+ if self.config.prediction_type == "epsilon":
251
+ pred_original_sample = sample - sigma * model_output
252
+ elif self.config.prediction_type == "v_prediction":
253
+ # * c_out + input * c_skip
254
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
255
+ elif self.config.prediction_type == "sample":
256
+ raise NotImplementedError("prediction_type not implemented yet: sample")
257
+ else:
258
+ raise ValueError(
259
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
260
+ )
261
+ else:
262
+ raise NotImplementedError
263
+
264
+ return pred_original_sample
guidance/sd_utils.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audioop import mul
2
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
3
+ from diffusers import StableDiffusionPipeline, DiffusionPipeline, DDPMScheduler, DDIMScheduler, EulerDiscreteScheduler, \
4
+ EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, ControlNetModel, \
5
+ DDIMInverseScheduler, UNet2DConditionModel
6
+ from diffusers.utils.import_utils import is_xformers_available
7
+ from os.path import isfile
8
+ from pathlib import Path
9
+ import os
10
+ import random
11
+
12
+ import torchvision.transforms as T
13
+ # suppress partial model loading warning
14
+ logging.set_verbosity_error()
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torchvision.transforms as T
21
+ from torchvision.utils import save_image
22
+ from torch.cuda.amp import custom_bwd, custom_fwd
23
+ from .perpneg_utils import weighted_perpendicular_aggregator
24
+
25
+ from .sd_step import *
26
+
27
+ def rgb2sat(img, T=None):
28
+ max_ = torch.max(img, dim=1, keepdim=True).values + 1e-5
29
+ min_ = torch.min(img, dim=1, keepdim=True).values
30
+ sat = (max_ - min_) / max_
31
+ if T is not None:
32
+ sat = (1 - T) * sat
33
+ return sat
34
+
35
+ class SpecifyGradient(torch.autograd.Function):
36
+ @staticmethod
37
+ @custom_fwd
38
+ def forward(ctx, input_tensor, gt_grad):
39
+ ctx.save_for_backward(gt_grad)
40
+ # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
41
+ return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
42
+
43
+ @staticmethod
44
+ @custom_bwd
45
+ def backward(ctx, grad_scale):
46
+ gt_grad, = ctx.saved_tensors
47
+ gt_grad = gt_grad * grad_scale
48
+ return gt_grad, None
49
+
50
+ def seed_everything(seed):
51
+ torch.manual_seed(seed)
52
+ torch.cuda.manual_seed(seed)
53
+ #torch.backends.cudnn.deterministic = True
54
+ #torch.backends.cudnn.benchmark = True
55
+
56
+ class StableDiffusion(nn.Module):
57
+ def __init__(self, device, fp16, vram_O, t_range=[0.02, 0.98], max_t_range=0.98, num_train_timesteps=None,
58
+ ddim_inv=False, use_control_net=False, textual_inversion_path = None,
59
+ LoRA_path = None, guidance_opt=None):
60
+ super().__init__()
61
+
62
+ self.device = device
63
+ self.precision_t = torch.float16 if fp16 else torch.float32
64
+
65
+ print(f'[INFO] loading stable diffusion...')
66
+
67
+ model_key = guidance_opt.model_key
68
+ assert model_key is not None
69
+
70
+ is_safe_tensor = guidance_opt.is_safe_tensor
71
+ base_model_key = "stabilityai/stable-diffusion-v1-5" if guidance_opt.base_model_key is None else guidance_opt.base_model_key # for finetuned model only
72
+
73
+ if is_safe_tensor:
74
+ pipe = StableDiffusionPipeline.from_single_file(model_key, use_safetensors=True, torch_dtype=self.precision_t, load_safety_checker=False)
75
+ else:
76
+ pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t)
77
+
78
+ self.ism = not guidance_opt.sds
79
+ self.scheduler = DDIMScheduler.from_pretrained(model_key if not is_safe_tensor else base_model_key, subfolder="scheduler", torch_dtype=self.precision_t)
80
+ self.sche_func = ddim_step
81
+
82
+ if use_control_net:
83
+ controlnet_model_key = guidance_opt.controlnet_model_key
84
+ self.controlnet_depth = ControlNetModel.from_pretrained(controlnet_model_key,torch_dtype=self.precision_t).to(device)
85
+
86
+ if vram_O:
87
+ pipe.enable_sequential_cpu_offload()
88
+ pipe.enable_vae_slicing()
89
+ pipe.unet.to(memory_format=torch.channels_last)
90
+ pipe.enable_attention_slicing(1)
91
+ pipe.enable_model_cpu_offload()
92
+
93
+ pipe.enable_xformers_memory_efficient_attention()
94
+
95
+ pipe = pipe.to(self.device)
96
+ if textual_inversion_path is not None:
97
+ pipe.load_textual_inversion(textual_inversion_path)
98
+ print("load textual inversion in:.{}".format(textual_inversion_path))
99
+
100
+ if LoRA_path is not None:
101
+ from lora_diffusion import tune_lora_scale, patch_pipe
102
+ print("load lora in:.{}".format(LoRA_path))
103
+ patch_pipe(
104
+ pipe,
105
+ LoRA_path,
106
+ patch_text=True,
107
+ patch_ti=True,
108
+ patch_unet=True,
109
+ )
110
+ tune_lora_scale(pipe.unet, 1.00)
111
+ tune_lora_scale(pipe.text_encoder, 1.00)
112
+
113
+ self.pipe = pipe
114
+ self.vae = pipe.vae
115
+ self.tokenizer = pipe.tokenizer
116
+ self.text_encoder = pipe.text_encoder
117
+ self.unet = pipe.unet
118
+
119
+ self.num_train_timesteps = num_train_timesteps if num_train_timesteps is not None else self.scheduler.config.num_train_timesteps
120
+ self.scheduler.set_timesteps(self.num_train_timesteps, device=device)
121
+
122
+ self.timesteps = torch.flip(self.scheduler.timesteps, dims=(0, ))
123
+ self.min_step = int(self.num_train_timesteps * t_range[0])
124
+ self.max_step = int(self.num_train_timesteps * t_range[1])
125
+ self.warmup_step = int(self.num_train_timesteps*(max_t_range-t_range[1]))
126
+
127
+ self.noise_temp = None
128
+ self.noise_gen = torch.Generator(self.device)
129
+ self.noise_gen.manual_seed(guidance_opt.noise_seed)
130
+
131
+ self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
132
+ self.rgb_latent_factors = torch.tensor([
133
+ # R G B
134
+ [ 0.298, 0.207, 0.208],
135
+ [ 0.187, 0.286, 0.173],
136
+ [-0.158, 0.189, 0.264],
137
+ [-0.184, -0.271, -0.473]
138
+ ], device=self.device)
139
+
140
+
141
+ print(f'[INFO] loaded stable diffusion!')
142
+
143
+ def augmentation(self, *tensors):
144
+ augs = T.Compose([
145
+ T.RandomHorizontalFlip(p=0.5),
146
+ ])
147
+
148
+ channels = [ten.shape[1] for ten in tensors]
149
+ tensors_concat = torch.concat(tensors, dim=1)
150
+ tensors_concat = augs(tensors_concat)
151
+
152
+ results = []
153
+ cur_c = 0
154
+ for i in range(len(channels)):
155
+ results.append(tensors_concat[:, cur_c:cur_c + channels[i], ...])
156
+ cur_c += channels[i]
157
+ return (ten for ten in results)
158
+
159
+ def add_noise_with_cfg(self, latents, noise,
160
+ ind_t, ind_prev_t,
161
+ text_embeddings=None, cfg=1.0,
162
+ delta_t=1, inv_steps=1,
163
+ is_noisy_latent=False,
164
+ eta=0.0):
165
+
166
+ text_embeddings = text_embeddings.to(self.precision_t)
167
+ if cfg <= 1.0:
168
+ uncond_text_embedding = text_embeddings.reshape(2, -1, text_embeddings.shape[-2], text_embeddings.shape[-1])[1]
169
+
170
+ unet = self.unet
171
+
172
+ if is_noisy_latent:
173
+ prev_noisy_lat = latents
174
+ else:
175
+ prev_noisy_lat = self.scheduler.add_noise(latents, noise, self.timesteps[ind_prev_t])
176
+
177
+ cur_ind_t = ind_prev_t
178
+ cur_noisy_lat = prev_noisy_lat
179
+
180
+ pred_scores = []
181
+
182
+ for i in range(inv_steps):
183
+ # pred noise
184
+ cur_noisy_lat_ = self.scheduler.scale_model_input(cur_noisy_lat, self.timesteps[cur_ind_t]).to(self.precision_t)
185
+
186
+ if cfg > 1.0:
187
+ latent_model_input = torch.cat([cur_noisy_lat_, cur_noisy_lat_])
188
+ timestep_model_input = self.timesteps[cur_ind_t].reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1)
189
+ unet_output = unet(latent_model_input, timestep_model_input,
190
+ encoder_hidden_states=text_embeddings).sample
191
+
192
+ uncond, cond = torch.chunk(unet_output, chunks=2)
193
+
194
+ unet_output = cond + cfg * (uncond - cond) # reverse cfg to enhance the distillation
195
+ else:
196
+ timestep_model_input = self.timesteps[cur_ind_t].reshape(1, 1).repeat(cur_noisy_lat_.shape[0], 1).reshape(-1)
197
+ unet_output = unet(cur_noisy_lat_, timestep_model_input,
198
+ encoder_hidden_states=uncond_text_embedding).sample
199
+
200
+ pred_scores.append((cur_ind_t, unet_output))
201
+
202
+ next_ind_t = min(cur_ind_t + delta_t, ind_t)
203
+ cur_t, next_t = self.timesteps[cur_ind_t], self.timesteps[next_ind_t]
204
+ delta_t_ = next_t-cur_t if isinstance(self.scheduler, DDIMScheduler) else next_ind_t-cur_ind_t
205
+
206
+ cur_noisy_lat = self.sche_func(self.scheduler, unet_output, cur_t, cur_noisy_lat, -delta_t_, eta).prev_sample
207
+ cur_ind_t = next_ind_t
208
+
209
+ del unet_output
210
+ torch.cuda.empty_cache()
211
+
212
+ if cur_ind_t == ind_t:
213
+ break
214
+
215
+ return prev_noisy_lat, cur_noisy_lat, pred_scores[::-1]
216
+
217
+
218
+ @torch.no_grad()
219
+ def get_text_embeds(self, prompt, resolution=(512, 512)):
220
+ inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, truncation=True, return_tensors='pt')
221
+ embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]
222
+ return embeddings
223
+
224
+ def train_step_perpneg(self, text_embeddings, pred_rgb, pred_depth=None, pred_alpha=None,
225
+ grad_scale=1,use_control_net=False,
226
+ save_folder:Path=None, iteration=0, warm_up_rate = 0, weights = 0,
227
+ resolution=(512, 512), guidance_opt=None,as_latent=False, embedding_inverse = None):
228
+
229
+
230
+ # flip aug
231
+ pred_rgb, pred_depth, pred_alpha = self.augmentation(pred_rgb, pred_depth, pred_alpha)
232
+
233
+ B = pred_rgb.shape[0]
234
+ K = text_embeddings.shape[0] - 1
235
+
236
+ if as_latent:
237
+ latents,_ = self.encode_imgs(pred_depth.repeat(1,3,1,1).to(self.precision_t))
238
+ else:
239
+ latents,_ = self.encode_imgs(pred_rgb.to(self.precision_t))
240
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
241
+
242
+ weights = weights.reshape(-1)
243
+ noise = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)
244
+
245
+ inverse_text_embeddings = embedding_inverse.unsqueeze(1).repeat(1, B, 1, 1).reshape(-1, embedding_inverse.shape[-2], embedding_inverse.shape[-1])
246
+
247
+ text_embeddings = text_embeddings.reshape(-1, text_embeddings.shape[-2], text_embeddings.shape[-1]) # make it k+1, c * t, ...
248
+
249
+ if guidance_opt.annealing_intervals:
250
+ current_delta_t = int(guidance_opt.delta_t + np.ceil((warm_up_rate)*(guidance_opt.delta_t_start - guidance_opt.delta_t)))
251
+ else:
252
+ current_delta_t = guidance_opt.delta_t
253
+
254
+ ind_t = torch.randint(self.min_step, self.max_step + int(self.warmup_step*warm_up_rate), (1, ), dtype=torch.long, generator=self.noise_gen, device=self.device)[0]
255
+ ind_prev_t = max(ind_t - current_delta_t, torch.ones_like(ind_t) * 0)
256
+
257
+ t = self.timesteps[ind_t]
258
+ prev_t = self.timesteps[ind_prev_t]
259
+
260
+ with torch.no_grad():
261
+ # step unroll via ddim inversion
262
+ if not self.ism:
263
+ prev_latents_noisy = self.scheduler.add_noise(latents, noise, prev_t)
264
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
265
+ target = noise
266
+ else:
267
+ # Step 1: sample x_s with larger steps
268
+ xs_delta_t = guidance_opt.xs_delta_t if guidance_opt.xs_delta_t is not None else current_delta_t
269
+ xs_inv_steps = guidance_opt.xs_inv_steps if guidance_opt.xs_inv_steps is not None else int(np.ceil(ind_prev_t / xs_delta_t))
270
+ starting_ind = max(ind_prev_t - xs_delta_t * xs_inv_steps, torch.ones_like(ind_t) * 0)
271
+
272
+ _, prev_latents_noisy, pred_scores_xs = self.add_noise_with_cfg(latents, noise, ind_prev_t, starting_ind, inverse_text_embeddings,
273
+ guidance_opt.denoise_guidance_scale, xs_delta_t, xs_inv_steps, eta=guidance_opt.xs_eta)
274
+ # Step 2: sample x_t
275
+ _, latents_noisy, pred_scores_xt = self.add_noise_with_cfg(prev_latents_noisy, noise, ind_t, ind_prev_t, inverse_text_embeddings,
276
+ guidance_opt.denoise_guidance_scale, current_delta_t, 1, is_noisy_latent=True)
277
+
278
+ pred_scores = pred_scores_xt + pred_scores_xs
279
+ target = pred_scores[0][1]
280
+
281
+
282
+ with torch.no_grad():
283
+ latent_model_input = latents_noisy[None, :, ...].repeat(1 + K, 1, 1, 1, 1).reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
284
+ tt = t.reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1)
285
+
286
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, tt[0])
287
+ if use_control_net:
288
+ pred_depth_input = pred_depth_input[None, :, ...].repeat(1 + K, 1, 3, 1, 1).reshape(-1, 3, 512, 512).half()
289
+ down_block_res_samples, mid_block_res_sample = self.controlnet_depth(
290
+ latent_model_input,
291
+ tt,
292
+ encoder_hidden_states=text_embeddings,
293
+ controlnet_cond=pred_depth_input,
294
+ return_dict=False,
295
+ )
296
+ unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings,
297
+ down_block_additional_residuals=down_block_res_samples,
298
+ mid_block_additional_residual=mid_block_res_sample).sample
299
+ else:
300
+ unet_output = self.unet(latent_model_input.to(self.precision_t), tt.to(self.precision_t), encoder_hidden_states=text_embeddings.to(self.precision_t)).sample
301
+
302
+ unet_output = unet_output.reshape(1 + K, -1, 4, resolution[0] // 8, resolution[1] // 8, )
303
+ noise_pred_uncond, noise_pred_text = unet_output[:1].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ), unet_output[1:].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
304
+ delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
305
+ delta_DSD = weighted_perpendicular_aggregator(delta_noise_preds,\
306
+ weights,\
307
+ B)
308
+
309
+ pred_noise = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD
310
+ w = lambda alphas: (((1 - alphas) / alphas) ** 0.5)
311
+
312
+ grad = w(self.alphas[t]) * (pred_noise - target)
313
+
314
+ grad = torch.nan_to_num(grad_scale * grad)
315
+ loss = SpecifyGradient.apply(latents, grad)
316
+
317
+ if iteration % guidance_opt.vis_interval == 0:
318
+ noise_pred_post = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD
319
+ lat2rgb = lambda x: torch.clip((x.permute(0,2,3,1) @ self.rgb_latent_factors.to(x.dtype)).permute(0,3,1,2), 0., 1.)
320
+ save_path_iter = os.path.join(save_folder,"iter_{}_step_{}.jpg".format(iteration,prev_t.item()))
321
+ with torch.no_grad():
322
+ pred_x0_latent_sp = pred_original(self.scheduler, noise_pred_uncond, prev_t, prev_latents_noisy)
323
+ pred_x0_latent_pos = pred_original(self.scheduler, noise_pred_post, prev_t, prev_latents_noisy)
324
+ pred_x0_pos = self.decode_latents(pred_x0_latent_pos.type(self.precision_t))
325
+ pred_x0_sp = self.decode_latents(pred_x0_latent_sp.type(self.precision_t))
326
+
327
+ grad_abs = torch.abs(grad.detach())
328
+ norm_grad = F.interpolate((grad_abs / grad_abs.max()).mean(dim=1,keepdim=True), (resolution[0], resolution[1]), mode='bilinear', align_corners=False).repeat(1,3,1,1)
329
+
330
+ latents_rgb = F.interpolate(lat2rgb(latents), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
331
+ latents_sp_rgb = F.interpolate(lat2rgb(pred_x0_latent_sp), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
332
+
333
+ viz_images = torch.cat([pred_rgb,
334
+ pred_depth.repeat(1, 3, 1, 1),
335
+ pred_alpha.repeat(1, 3, 1, 1),
336
+ rgb2sat(pred_rgb, pred_alpha).repeat(1, 3, 1, 1),
337
+ latents_rgb, latents_sp_rgb,
338
+ norm_grad,
339
+ pred_x0_sp, pred_x0_pos],dim=0)
340
+ save_image(viz_images, save_path_iter)
341
+
342
+
343
+ return loss
344
+
345
+
346
+ def train_step(self, text_embeddings, pred_rgb, pred_depth=None, pred_alpha=None,
347
+ grad_scale=1,use_control_net=False,
348
+ save_folder:Path=None, iteration=0, warm_up_rate = 0,
349
+ resolution=(512, 512), guidance_opt=None,as_latent=False, embedding_inverse = None):
350
+
351
+ pred_rgb, pred_depth, pred_alpha = self.augmentation(pred_rgb, pred_depth, pred_alpha)
352
+
353
+ B = pred_rgb.shape[0]
354
+ K = text_embeddings.shape[0] - 1
355
+
356
+ if as_latent:
357
+ latents,_ = self.encode_imgs(pred_depth.repeat(1,3,1,1).to(self.precision_t))
358
+ else:
359
+ latents,_ = self.encode_imgs(pred_rgb.to(self.precision_t))
360
+ # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
361
+
362
+ if self.noise_temp is None:
363
+ self.noise_temp = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)
364
+
365
+ if guidance_opt.fix_noise:
366
+ noise = self.noise_temp
367
+ else:
368
+ noise = torch.randn((latents.shape[0], 4, resolution[0] // 8, resolution[1] // 8, ), dtype=latents.dtype, device=latents.device, generator=self.noise_gen) + 0.1 * torch.randn((1, 4, 1, 1), device=latents.device).repeat(latents.shape[0], 1, 1, 1)
369
+
370
+ text_embeddings = text_embeddings[:, :, ...]
371
+ text_embeddings = text_embeddings.reshape(-1, text_embeddings.shape[-2], text_embeddings.shape[-1]) # make it k+1, c * t, ...
372
+
373
+ inverse_text_embeddings = embedding_inverse.unsqueeze(1).repeat(1, B, 1, 1).reshape(-1, embedding_inverse.shape[-2], embedding_inverse.shape[-1])
374
+
375
+ if guidance_opt.annealing_intervals:
376
+ current_delta_t = int(guidance_opt.delta_t + (warm_up_rate)*(guidance_opt.delta_t_start - guidance_opt.delta_t))
377
+ else:
378
+ current_delta_t = guidance_opt.delta_t
379
+
380
+ ind_t = torch.randint(self.min_step, self.max_step + int(self.warmup_step*warm_up_rate), (1, ), dtype=torch.long, generator=self.noise_gen, device=self.device)[0]
381
+ ind_prev_t = max(ind_t - current_delta_t, torch.ones_like(ind_t) * 0)
382
+
383
+ t = self.timesteps[ind_t]
384
+ prev_t = self.timesteps[ind_prev_t]
385
+
386
+ with torch.no_grad():
387
+ # step unroll via ddim inversion
388
+ if not self.ism:
389
+ prev_latents_noisy = self.scheduler.add_noise(latents, noise, prev_t)
390
+ latents_noisy = self.scheduler.add_noise(latents, noise, t)
391
+ target = noise
392
+ else:
393
+ # Step 1: sample x_s with larger steps
394
+ xs_delta_t = guidance_opt.xs_delta_t if guidance_opt.xs_delta_t is not None else current_delta_t
395
+ xs_inv_steps = guidance_opt.xs_inv_steps if guidance_opt.xs_inv_steps is not None else int(np.ceil(ind_prev_t / xs_delta_t))
396
+ starting_ind = max(ind_prev_t - xs_delta_t * xs_inv_steps, torch.ones_like(ind_t) * 0)
397
+
398
+ _, prev_latents_noisy, pred_scores_xs = self.add_noise_with_cfg(latents, noise, ind_prev_t, starting_ind, inverse_text_embeddings,
399
+ guidance_opt.denoise_guidance_scale, xs_delta_t, xs_inv_steps, eta=guidance_opt.xs_eta)
400
+ # Step 2: sample x_t
401
+ _, latents_noisy, pred_scores_xt = self.add_noise_with_cfg(prev_latents_noisy, noise, ind_t, ind_prev_t, inverse_text_embeddings,
402
+ guidance_opt.denoise_guidance_scale, current_delta_t, 1, is_noisy_latent=True)
403
+
404
+ pred_scores = pred_scores_xt + pred_scores_xs
405
+ target = pred_scores[0][1]
406
+
407
+
408
+ with torch.no_grad():
409
+ latent_model_input = latents_noisy[None, :, ...].repeat(2, 1, 1, 1, 1).reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
410
+ tt = t.reshape(1, 1).repeat(latent_model_input.shape[0], 1).reshape(-1)
411
+
412
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, tt[0])
413
+ if use_control_net:
414
+ pred_depth_input = pred_depth_input[None, :, ...].repeat(1 + K, 1, 3, 1, 1).reshape(-1, 3, 512, 512).half()
415
+ down_block_res_samples, mid_block_res_sample = self.controlnet_depth(
416
+ latent_model_input,
417
+ tt,
418
+ encoder_hidden_states=text_embeddings,
419
+ controlnet_cond=pred_depth_input,
420
+ return_dict=False,
421
+ )
422
+ unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings,
423
+ down_block_additional_residuals=down_block_res_samples,
424
+ mid_block_additional_residual=mid_block_res_sample).sample
425
+ else:
426
+ unet_output = self.unet(latent_model_input.to(self.precision_t), tt.to(self.precision_t), encoder_hidden_states=text_embeddings.to(self.precision_t)).sample
427
+
428
+ unet_output = unet_output.reshape(2, -1, 4, resolution[0] // 8, resolution[1] // 8, )
429
+ noise_pred_uncond, noise_pred_text = unet_output[:1].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, ), unet_output[1:].reshape(-1, 4, resolution[0] // 8, resolution[1] // 8, )
430
+ delta_DSD = noise_pred_text - noise_pred_uncond
431
+
432
+ pred_noise = noise_pred_uncond + guidance_opt.guidance_scale * delta_DSD
433
+
434
+ w = lambda alphas: (((1 - alphas) / alphas) ** 0.5)
435
+
436
+ grad = w(self.alphas[t]) * (pred_noise - target)
437
+
438
+ grad = torch.nan_to_num(grad_scale * grad)
439
+ loss = SpecifyGradient.apply(latents, grad)
440
+
441
+ if iteration % guidance_opt.vis_interval == 0:
442
+ noise_pred_post = noise_pred_uncond + 7.5* delta_DSD
443
+ lat2rgb = lambda x: torch.clip((x.permute(0,2,3,1) @ self.rgb_latent_factors.to(x.dtype)).permute(0,3,1,2), 0., 1.)
444
+ save_path_iter = os.path.join(save_folder,"iter_{}_step_{}.jpg".format(iteration,prev_t.item()))
445
+ with torch.no_grad():
446
+ pred_x0_latent_sp = pred_original(self.scheduler, noise_pred_uncond, prev_t, prev_latents_noisy)
447
+ pred_x0_latent_pos = pred_original(self.scheduler, noise_pred_post, prev_t, prev_latents_noisy)
448
+ pred_x0_pos = self.decode_latents(pred_x0_latent_pos.type(self.precision_t))
449
+ pred_x0_sp = self.decode_latents(pred_x0_latent_sp.type(self.precision_t))
450
+ # pred_x0_uncond = pred_x0_sp[:1, ...]
451
+
452
+ grad_abs = torch.abs(grad.detach())
453
+ norm_grad = F.interpolate((grad_abs / grad_abs.max()).mean(dim=1,keepdim=True), (resolution[0], resolution[1]), mode='bilinear', align_corners=False).repeat(1,3,1,1)
454
+
455
+ latents_rgb = F.interpolate(lat2rgb(latents), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
456
+ latents_sp_rgb = F.interpolate(lat2rgb(pred_x0_latent_sp), (resolution[0], resolution[1]), mode='bilinear', align_corners=False)
457
+
458
+ viz_images = torch.cat([pred_rgb,
459
+ pred_depth.repeat(1, 3, 1, 1),
460
+ pred_alpha.repeat(1, 3, 1, 1),
461
+ rgb2sat(pred_rgb, pred_alpha).repeat(1, 3, 1, 1),
462
+ latents_rgb, latents_sp_rgb, norm_grad,
463
+ pred_x0_sp, pred_x0_pos],dim=0)
464
+ save_image(viz_images, save_path_iter)
465
+
466
+ return loss
467
+
468
+ def decode_latents(self, latents):
469
+ target_dtype = latents.dtype
470
+ latents = latents / self.vae.config.scaling_factor
471
+
472
+ imgs = self.vae.decode(latents.to(self.vae.dtype)).sample
473
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
474
+
475
+ return imgs.to(target_dtype)
476
+
477
+ def encode_imgs(self, imgs):
478
+ target_dtype = imgs.dtype
479
+ # imgs: [B, 3, H, W]
480
+ imgs = 2 * imgs - 1
481
+
482
+ posterior = self.vae.encode(imgs.to(self.vae.dtype)).latent_dist
483
+ kl_divergence = posterior.kl()
484
+
485
+ latents = posterior.sample() * self.vae.config.scaling_factor
486
+
487
+ return latents.to(target_dtype), kl_divergence
lora_diffusion/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .lora import *
2
+ from .dataset import *
3
+ from .utils import *
4
+ from .preprocess_files import *
5
+ from .lora_manager import *
lora_diffusion/cli_lora_add.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Dict
2
+ import os
3
+ import shutil
4
+ import fire
5
+ from diffusers import StableDiffusionPipeline
6
+ from safetensors.torch import safe_open, save_file
7
+
8
+ import torch
9
+ from .lora import (
10
+ tune_lora_scale,
11
+ patch_pipe,
12
+ collapse_lora,
13
+ monkeypatch_remove_lora,
14
+ )
15
+ from .lora_manager import lora_join
16
+ from .to_ckpt_v2 import convert_to_ckpt
17
+
18
+
19
+ def _text_lora_path(path: str) -> str:
20
+ assert path.endswith(".pt"), "Only .pt files are supported"
21
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
22
+
23
+
24
+ def add(
25
+ path_1: str,
26
+ path_2: str,
27
+ output_path: str,
28
+ alpha_1: float = 0.5,
29
+ alpha_2: float = 0.5,
30
+ mode: Literal[
31
+ "lpl",
32
+ "upl",
33
+ "upl-ckpt-v2",
34
+ ] = "lpl",
35
+ with_text_lora: bool = False,
36
+ ):
37
+ print("Lora Add, mode " + mode)
38
+ if mode == "lpl":
39
+ if path_1.endswith(".pt") and path_2.endswith(".pt"):
40
+ for _path_1, _path_2, opt in [(path_1, path_2, "unet")] + (
41
+ [(_text_lora_path(path_1), _text_lora_path(path_2), "text_encoder")]
42
+ if with_text_lora
43
+ else []
44
+ ):
45
+ print("Loading", _path_1, _path_2)
46
+ out_list = []
47
+ if opt == "text_encoder":
48
+ if not os.path.exists(_path_1):
49
+ print(f"No text encoder found in {_path_1}, skipping...")
50
+ continue
51
+ if not os.path.exists(_path_2):
52
+ print(f"No text encoder found in {_path_1}, skipping...")
53
+ continue
54
+
55
+ l1 = torch.load(_path_1)
56
+ l2 = torch.load(_path_2)
57
+
58
+ l1pairs = zip(l1[::2], l1[1::2])
59
+ l2pairs = zip(l2[::2], l2[1::2])
60
+
61
+ for (x1, y1), (x2, y2) in zip(l1pairs, l2pairs):
62
+ # print("Merging", x1.shape, y1.shape, x2.shape, y2.shape)
63
+ x1.data = alpha_1 * x1.data + alpha_2 * x2.data
64
+ y1.data = alpha_1 * y1.data + alpha_2 * y2.data
65
+
66
+ out_list.append(x1)
67
+ out_list.append(y1)
68
+
69
+ if opt == "unet":
70
+
71
+ print("Saving merged UNET to", output_path)
72
+ torch.save(out_list, output_path)
73
+
74
+ elif opt == "text_encoder":
75
+ print("Saving merged text encoder to", _text_lora_path(output_path))
76
+ torch.save(
77
+ out_list,
78
+ _text_lora_path(output_path),
79
+ )
80
+
81
+ elif path_1.endswith(".safetensors") and path_2.endswith(".safetensors"):
82
+ safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
83
+ safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
84
+
85
+ metadata = dict(safeloras_1.metadata())
86
+ metadata.update(dict(safeloras_2.metadata()))
87
+
88
+ ret_tensor = {}
89
+
90
+ for keys in set(list(safeloras_1.keys()) + list(safeloras_2.keys())):
91
+ if keys.startswith("text_encoder") or keys.startswith("unet"):
92
+
93
+ tens1 = safeloras_1.get_tensor(keys)
94
+ tens2 = safeloras_2.get_tensor(keys)
95
+
96
+ tens = alpha_1 * tens1 + alpha_2 * tens2
97
+ ret_tensor[keys] = tens
98
+ else:
99
+ if keys in safeloras_1.keys():
100
+
101
+ tens1 = safeloras_1.get_tensor(keys)
102
+ else:
103
+ tens1 = safeloras_2.get_tensor(keys)
104
+
105
+ ret_tensor[keys] = tens1
106
+
107
+ save_file(ret_tensor, output_path, metadata)
108
+
109
+ elif mode == "upl":
110
+
111
+ print(
112
+ f"Merging UNET/CLIP from {path_1} with LoRA from {path_2} to {output_path}. Merging ratio : {alpha_1}."
113
+ )
114
+
115
+ loaded_pipeline = StableDiffusionPipeline.from_pretrained(
116
+ path_1,
117
+ ).to("cpu")
118
+
119
+ patch_pipe(loaded_pipeline, path_2)
120
+
121
+ collapse_lora(loaded_pipeline.unet, alpha_1)
122
+ collapse_lora(loaded_pipeline.text_encoder, alpha_1)
123
+
124
+ monkeypatch_remove_lora(loaded_pipeline.unet)
125
+ monkeypatch_remove_lora(loaded_pipeline.text_encoder)
126
+
127
+ loaded_pipeline.save_pretrained(output_path)
128
+
129
+ elif mode == "upl-ckpt-v2":
130
+
131
+ assert output_path.endswith(".ckpt"), "Only .ckpt files are supported"
132
+ name = os.path.basename(output_path)[0:-5]
133
+
134
+ print(
135
+ f"You will be using {name} as the token in A1111 webui. Make sure {name} is unique enough token."
136
+ )
137
+
138
+ loaded_pipeline = StableDiffusionPipeline.from_pretrained(
139
+ path_1,
140
+ ).to("cpu")
141
+
142
+ tok_dict = patch_pipe(loaded_pipeline, path_2, patch_ti=False)
143
+
144
+ collapse_lora(loaded_pipeline.unet, alpha_1)
145
+ collapse_lora(loaded_pipeline.text_encoder, alpha_1)
146
+
147
+ monkeypatch_remove_lora(loaded_pipeline.unet)
148
+ monkeypatch_remove_lora(loaded_pipeline.text_encoder)
149
+
150
+ _tmp_output = output_path + ".tmp"
151
+
152
+ loaded_pipeline.save_pretrained(_tmp_output)
153
+ convert_to_ckpt(_tmp_output, output_path, as_half=True)
154
+ # remove the tmp_output folder
155
+ shutil.rmtree(_tmp_output)
156
+
157
+ keys = sorted(tok_dict.keys())
158
+ tok_catted = torch.stack([tok_dict[k] for k in keys])
159
+ ret = {
160
+ "string_to_token": {"*": torch.tensor(265)},
161
+ "string_to_param": {"*": tok_catted},
162
+ "name": name,
163
+ }
164
+
165
+ torch.save(ret, output_path[:-5] + ".pt")
166
+ print(
167
+ f"Textual embedding saved as {output_path[:-5]}.pt, put it in the embedding folder and use it as {name} in A1111 repo, "
168
+ )
169
+ elif mode == "ljl":
170
+ print("Using Join mode : alpha will not have an effect here.")
171
+ assert path_1.endswith(".safetensors") and path_2.endswith(
172
+ ".safetensors"
173
+ ), "Only .safetensors files are supported"
174
+
175
+ safeloras_1 = safe_open(path_1, framework="pt", device="cpu")
176
+ safeloras_2 = safe_open(path_2, framework="pt", device="cpu")
177
+
178
+ total_tensor, total_metadata, _, _ = lora_join([safeloras_1, safeloras_2])
179
+ save_file(total_tensor, output_path, total_metadata)
180
+
181
+ else:
182
+ print("Unknown mode", mode)
183
+ raise ValueError(f"Unknown mode {mode}")
184
+
185
+
186
+ def main():
187
+ fire.Fire(add)
lora_diffusion/cli_lora_pti.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bootstrapped from:
2
+ # https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
3
+
4
+ import argparse
5
+ import hashlib
6
+ import inspect
7
+ import itertools
8
+ import math
9
+ import os
10
+ import random
11
+ import re
12
+ from pathlib import Path
13
+ from typing import Optional, List, Literal
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ import torch.optim as optim
18
+ import torch.utils.checkpoint
19
+ from diffusers import (
20
+ AutoencoderKL,
21
+ DDPMScheduler,
22
+ StableDiffusionPipeline,
23
+ UNet2DConditionModel,
24
+ )
25
+ from diffusers.optimization import get_scheduler
26
+ from huggingface_hub import HfFolder, Repository, whoami
27
+ from PIL import Image
28
+ from torch.utils.data import Dataset
29
+ from torchvision import transforms
30
+ from tqdm.auto import tqdm
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+ import wandb
33
+ import fire
34
+
35
+ from lora_diffusion import (
36
+ PivotalTuningDatasetCapation,
37
+ extract_lora_ups_down,
38
+ inject_trainable_lora,
39
+ inject_trainable_lora_extended,
40
+ inspect_lora,
41
+ save_lora_weight,
42
+ save_all,
43
+ prepare_clip_model_sets,
44
+ evaluate_pipe,
45
+ UNET_EXTENDED_TARGET_REPLACE,
46
+ )
47
+
48
+
49
+ def get_models(
50
+ pretrained_model_name_or_path,
51
+ pretrained_vae_name_or_path,
52
+ revision,
53
+ placeholder_tokens: List[str],
54
+ initializer_tokens: List[str],
55
+ device="cuda:0",
56
+ ):
57
+
58
+ tokenizer = CLIPTokenizer.from_pretrained(
59
+ pretrained_model_name_or_path,
60
+ subfolder="tokenizer",
61
+ revision=revision,
62
+ )
63
+
64
+ text_encoder = CLIPTextModel.from_pretrained(
65
+ pretrained_model_name_or_path,
66
+ subfolder="text_encoder",
67
+ revision=revision,
68
+ )
69
+
70
+ placeholder_token_ids = []
71
+
72
+ for token, init_tok in zip(placeholder_tokens, initializer_tokens):
73
+ num_added_tokens = tokenizer.add_tokens(token)
74
+ if num_added_tokens == 0:
75
+ raise ValueError(
76
+ f"The tokenizer already contains the token {token}. Please pass a different"
77
+ " `placeholder_token` that is not already in the tokenizer."
78
+ )
79
+
80
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(token)
81
+
82
+ placeholder_token_ids.append(placeholder_token_id)
83
+
84
+ # Load models and create wrapper for stable diffusion
85
+
86
+ text_encoder.resize_token_embeddings(len(tokenizer))
87
+ token_embeds = text_encoder.get_input_embeddings().weight.data
88
+ if init_tok.startswith("<rand"):
89
+ # <rand-"sigma">, e.g. <rand-0.5>
90
+ sigma_val = float(re.findall(r"<rand-(.*)>", init_tok)[0])
91
+
92
+ token_embeds[placeholder_token_id] = (
93
+ torch.randn_like(token_embeds[0]) * sigma_val
94
+ )
95
+ print(
96
+ f"Initialized {token} with random noise (sigma={sigma_val}), empirically {token_embeds[placeholder_token_id].mean().item():.3f} +- {token_embeds[placeholder_token_id].std().item():.3f}"
97
+ )
98
+ print(f"Norm : {token_embeds[placeholder_token_id].norm():.4f}")
99
+
100
+ elif init_tok == "<zero>":
101
+ token_embeds[placeholder_token_id] = torch.zeros_like(token_embeds[0])
102
+ else:
103
+ token_ids = tokenizer.encode(init_tok, add_special_tokens=False)
104
+ # Check if initializer_token is a single token or a sequence of tokens
105
+ if len(token_ids) > 1:
106
+ raise ValueError("The initializer token must be a single token.")
107
+
108
+ initializer_token_id = token_ids[0]
109
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
110
+
111
+ vae = AutoencoderKL.from_pretrained(
112
+ pretrained_vae_name_or_path or pretrained_model_name_or_path,
113
+ subfolder=None if pretrained_vae_name_or_path else "vae",
114
+ revision=None if pretrained_vae_name_or_path else revision,
115
+ )
116
+ unet = UNet2DConditionModel.from_pretrained(
117
+ pretrained_model_name_or_path,
118
+ subfolder="unet",
119
+ revision=revision,
120
+ )
121
+
122
+ return (
123
+ text_encoder.to(device),
124
+ vae.to(device),
125
+ unet.to(device),
126
+ tokenizer,
127
+ placeholder_token_ids,
128
+ )
129
+
130
+
131
+ @torch.no_grad()
132
+ def text2img_dataloader(
133
+ train_dataset,
134
+ train_batch_size,
135
+ tokenizer,
136
+ vae,
137
+ text_encoder,
138
+ cached_latents: bool = False,
139
+ ):
140
+
141
+ if cached_latents:
142
+ cached_latents_dataset = []
143
+ for idx in tqdm(range(len(train_dataset))):
144
+ batch = train_dataset[idx]
145
+ # rint(batch)
146
+ latents = vae.encode(
147
+ batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
148
+ ).latent_dist.sample()
149
+ latents = latents * 0.18215
150
+ batch["instance_images"] = latents.squeeze(0)
151
+ cached_latents_dataset.append(batch)
152
+
153
+ def collate_fn(examples):
154
+ input_ids = [example["instance_prompt_ids"] for example in examples]
155
+ pixel_values = [example["instance_images"] for example in examples]
156
+ pixel_values = torch.stack(pixel_values)
157
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
158
+
159
+ input_ids = tokenizer.pad(
160
+ {"input_ids": input_ids},
161
+ padding="max_length",
162
+ max_length=tokenizer.model_max_length,
163
+ return_tensors="pt",
164
+ ).input_ids
165
+
166
+ batch = {
167
+ "input_ids": input_ids,
168
+ "pixel_values": pixel_values,
169
+ }
170
+
171
+ if examples[0].get("mask", None) is not None:
172
+ batch["mask"] = torch.stack([example["mask"] for example in examples])
173
+
174
+ return batch
175
+
176
+ if cached_latents:
177
+
178
+ train_dataloader = torch.utils.data.DataLoader(
179
+ cached_latents_dataset,
180
+ batch_size=train_batch_size,
181
+ shuffle=True,
182
+ collate_fn=collate_fn,
183
+ )
184
+
185
+ print("PTI : Using cached latent.")
186
+
187
+ else:
188
+ train_dataloader = torch.utils.data.DataLoader(
189
+ train_dataset,
190
+ batch_size=train_batch_size,
191
+ shuffle=True,
192
+ collate_fn=collate_fn,
193
+ )
194
+
195
+ return train_dataloader
196
+
197
+
198
+ def inpainting_dataloader(
199
+ train_dataset, train_batch_size, tokenizer, vae, text_encoder
200
+ ):
201
+ def collate_fn(examples):
202
+ input_ids = [example["instance_prompt_ids"] for example in examples]
203
+ pixel_values = [example["instance_images"] for example in examples]
204
+ mask_values = [example["instance_masks"] for example in examples]
205
+ masked_image_values = [
206
+ example["instance_masked_images"] for example in examples
207
+ ]
208
+
209
+ # Concat class and instance examples for prior preservation.
210
+ # We do this to avoid doing two forward passes.
211
+ if examples[0].get("class_prompt_ids", None) is not None:
212
+ input_ids += [example["class_prompt_ids"] for example in examples]
213
+ pixel_values += [example["class_images"] for example in examples]
214
+ mask_values += [example["class_masks"] for example in examples]
215
+ masked_image_values += [
216
+ example["class_masked_images"] for example in examples
217
+ ]
218
+
219
+ pixel_values = (
220
+ torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
221
+ )
222
+ mask_values = (
223
+ torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
224
+ )
225
+ masked_image_values = (
226
+ torch.stack(masked_image_values)
227
+ .to(memory_format=torch.contiguous_format)
228
+ .float()
229
+ )
230
+
231
+ input_ids = tokenizer.pad(
232
+ {"input_ids": input_ids},
233
+ padding="max_length",
234
+ max_length=tokenizer.model_max_length,
235
+ return_tensors="pt",
236
+ ).input_ids
237
+
238
+ batch = {
239
+ "input_ids": input_ids,
240
+ "pixel_values": pixel_values,
241
+ "mask_values": mask_values,
242
+ "masked_image_values": masked_image_values,
243
+ }
244
+
245
+ if examples[0].get("mask", None) is not None:
246
+ batch["mask"] = torch.stack([example["mask"] for example in examples])
247
+
248
+ return batch
249
+
250
+ train_dataloader = torch.utils.data.DataLoader(
251
+ train_dataset,
252
+ batch_size=train_batch_size,
253
+ shuffle=True,
254
+ collate_fn=collate_fn,
255
+ )
256
+
257
+ return train_dataloader
258
+
259
+
260
+ def loss_step(
261
+ batch,
262
+ unet,
263
+ vae,
264
+ text_encoder,
265
+ scheduler,
266
+ train_inpainting=False,
267
+ t_mutliplier=1.0,
268
+ mixed_precision=False,
269
+ mask_temperature=1.0,
270
+ cached_latents: bool = False,
271
+ ):
272
+ weight_dtype = torch.float32
273
+ if not cached_latents:
274
+ latents = vae.encode(
275
+ batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
276
+ ).latent_dist.sample()
277
+ latents = latents * 0.18215
278
+
279
+ if train_inpainting:
280
+ masked_image_latents = vae.encode(
281
+ batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
282
+ ).latent_dist.sample()
283
+ masked_image_latents = masked_image_latents * 0.18215
284
+ mask = F.interpolate(
285
+ batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
286
+ scale_factor=1 / 8,
287
+ )
288
+ else:
289
+ latents = batch["pixel_values"]
290
+
291
+ if train_inpainting:
292
+ masked_image_latents = batch["masked_image_latents"]
293
+ mask = batch["mask_values"]
294
+
295
+ noise = torch.randn_like(latents)
296
+ bsz = latents.shape[0]
297
+
298
+ timesteps = torch.randint(
299
+ 0,
300
+ int(scheduler.config.num_train_timesteps * t_mutliplier),
301
+ (bsz,),
302
+ device=latents.device,
303
+ )
304
+ timesteps = timesteps.long()
305
+
306
+ noisy_latents = scheduler.add_noise(latents, noise, timesteps)
307
+
308
+ if train_inpainting:
309
+ latent_model_input = torch.cat(
310
+ [noisy_latents, mask, masked_image_latents], dim=1
311
+ )
312
+ else:
313
+ latent_model_input = noisy_latents
314
+
315
+ if mixed_precision:
316
+ with torch.cuda.amp.autocast():
317
+
318
+ encoder_hidden_states = text_encoder(
319
+ batch["input_ids"].to(text_encoder.device)
320
+ )[0]
321
+
322
+ model_pred = unet(
323
+ latent_model_input, timesteps, encoder_hidden_states
324
+ ).sample
325
+ else:
326
+
327
+ encoder_hidden_states = text_encoder(
328
+ batch["input_ids"].to(text_encoder.device)
329
+ )[0]
330
+
331
+ model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
332
+
333
+ if scheduler.config.prediction_type == "epsilon":
334
+ target = noise
335
+ elif scheduler.config.prediction_type == "v_prediction":
336
+ target = scheduler.get_velocity(latents, noise, timesteps)
337
+ else:
338
+ raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
339
+
340
+ if batch.get("mask", None) is not None:
341
+
342
+ mask = (
343
+ batch["mask"]
344
+ .to(model_pred.device)
345
+ .reshape(
346
+ model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
347
+ )
348
+ )
349
+ # resize to match model_pred
350
+ mask = F.interpolate(
351
+ mask.float(),
352
+ size=model_pred.shape[-2:],
353
+ mode="nearest",
354
+ )
355
+
356
+ mask = (mask + 0.01).pow(mask_temperature)
357
+
358
+ mask = mask / mask.max()
359
+
360
+ model_pred = model_pred * mask
361
+
362
+ target = target * mask
363
+
364
+ loss = (
365
+ F.mse_loss(model_pred.float(), target.float(), reduction="none")
366
+ .mean([1, 2, 3])
367
+ .mean()
368
+ )
369
+
370
+ return loss
371
+
372
+
373
+ def train_inversion(
374
+ unet,
375
+ vae,
376
+ text_encoder,
377
+ dataloader,
378
+ num_steps: int,
379
+ scheduler,
380
+ index_no_updates,
381
+ optimizer,
382
+ save_steps: int,
383
+ placeholder_token_ids,
384
+ placeholder_tokens,
385
+ save_path: str,
386
+ tokenizer,
387
+ lr_scheduler,
388
+ test_image_path: str,
389
+ cached_latents: bool,
390
+ accum_iter: int = 1,
391
+ log_wandb: bool = False,
392
+ wandb_log_prompt_cnt: int = 10,
393
+ class_token: str = "person",
394
+ train_inpainting: bool = False,
395
+ mixed_precision: bool = False,
396
+ clip_ti_decay: bool = True,
397
+ ):
398
+
399
+ progress_bar = tqdm(range(num_steps))
400
+ progress_bar.set_description("Steps")
401
+ global_step = 0
402
+
403
+ # Original Emb for TI
404
+ orig_embeds_params = text_encoder.get_input_embeddings().weight.data.clone()
405
+
406
+ if log_wandb:
407
+ preped_clip = prepare_clip_model_sets()
408
+
409
+ index_updates = ~index_no_updates
410
+ loss_sum = 0.0
411
+
412
+ for epoch in range(math.ceil(num_steps / len(dataloader))):
413
+ unet.eval()
414
+ text_encoder.train()
415
+ for batch in dataloader:
416
+
417
+ lr_scheduler.step()
418
+
419
+ with torch.set_grad_enabled(True):
420
+ loss = (
421
+ loss_step(
422
+ batch,
423
+ unet,
424
+ vae,
425
+ text_encoder,
426
+ scheduler,
427
+ train_inpainting=train_inpainting,
428
+ mixed_precision=mixed_precision,
429
+ cached_latents=cached_latents,
430
+ )
431
+ / accum_iter
432
+ )
433
+
434
+ loss.backward()
435
+ loss_sum += loss.detach().item()
436
+
437
+ if global_step % accum_iter == 0:
438
+ # print gradient of text encoder embedding
439
+ print(
440
+ text_encoder.get_input_embeddings()
441
+ .weight.grad[index_updates, :]
442
+ .norm(dim=-1)
443
+ .mean()
444
+ )
445
+ optimizer.step()
446
+ optimizer.zero_grad()
447
+
448
+ with torch.no_grad():
449
+
450
+ # normalize embeddings
451
+ if clip_ti_decay:
452
+ pre_norm = (
453
+ text_encoder.get_input_embeddings()
454
+ .weight[index_updates, :]
455
+ .norm(dim=-1, keepdim=True)
456
+ )
457
+
458
+ lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
459
+ text_encoder.get_input_embeddings().weight[
460
+ index_updates
461
+ ] = F.normalize(
462
+ text_encoder.get_input_embeddings().weight[
463
+ index_updates, :
464
+ ],
465
+ dim=-1,
466
+ ) * (
467
+ pre_norm + lambda_ * (0.4 - pre_norm)
468
+ )
469
+ print(pre_norm)
470
+
471
+ current_norm = (
472
+ text_encoder.get_input_embeddings()
473
+ .weight[index_updates, :]
474
+ .norm(dim=-1)
475
+ )
476
+
477
+ text_encoder.get_input_embeddings().weight[
478
+ index_no_updates
479
+ ] = orig_embeds_params[index_no_updates]
480
+
481
+ print(f"Current Norm : {current_norm}")
482
+
483
+ global_step += 1
484
+ progress_bar.update(1)
485
+
486
+ logs = {
487
+ "loss": loss.detach().item(),
488
+ "lr": lr_scheduler.get_last_lr()[0],
489
+ }
490
+ progress_bar.set_postfix(**logs)
491
+
492
+ if global_step % save_steps == 0:
493
+ save_all(
494
+ unet=unet,
495
+ text_encoder=text_encoder,
496
+ placeholder_token_ids=placeholder_token_ids,
497
+ placeholder_tokens=placeholder_tokens,
498
+ save_path=os.path.join(
499
+ save_path, f"step_inv_{global_step}.safetensors"
500
+ ),
501
+ save_lora=False,
502
+ )
503
+ if log_wandb:
504
+ with torch.no_grad():
505
+ pipe = StableDiffusionPipeline(
506
+ vae=vae,
507
+ text_encoder=text_encoder,
508
+ tokenizer=tokenizer,
509
+ unet=unet,
510
+ scheduler=scheduler,
511
+ safety_checker=None,
512
+ feature_extractor=None,
513
+ )
514
+
515
+ # open all images in test_image_path
516
+ images = []
517
+ for file in os.listdir(test_image_path):
518
+ if (
519
+ file.lower().endswith(".png")
520
+ or file.lower().endswith(".jpg")
521
+ or file.lower().endswith(".jpeg")
522
+ ):
523
+ images.append(
524
+ Image.open(os.path.join(test_image_path, file))
525
+ )
526
+
527
+ wandb.log({"loss": loss_sum / save_steps})
528
+ loss_sum = 0.0
529
+ wandb.log(
530
+ evaluate_pipe(
531
+ pipe,
532
+ target_images=images,
533
+ class_token=class_token,
534
+ learnt_token="".join(placeholder_tokens),
535
+ n_test=wandb_log_prompt_cnt,
536
+ n_step=50,
537
+ clip_model_sets=preped_clip,
538
+ )
539
+ )
540
+
541
+ if global_step >= num_steps:
542
+ return
543
+
544
+
545
+ def perform_tuning(
546
+ unet,
547
+ vae,
548
+ text_encoder,
549
+ dataloader,
550
+ num_steps,
551
+ scheduler,
552
+ optimizer,
553
+ save_steps: int,
554
+ placeholder_token_ids,
555
+ placeholder_tokens,
556
+ save_path,
557
+ lr_scheduler_lora,
558
+ lora_unet_target_modules,
559
+ lora_clip_target_modules,
560
+ mask_temperature,
561
+ out_name: str,
562
+ tokenizer,
563
+ test_image_path: str,
564
+ cached_latents: bool,
565
+ log_wandb: bool = False,
566
+ wandb_log_prompt_cnt: int = 10,
567
+ class_token: str = "person",
568
+ train_inpainting: bool = False,
569
+ ):
570
+
571
+ progress_bar = tqdm(range(num_steps))
572
+ progress_bar.set_description("Steps")
573
+ global_step = 0
574
+
575
+ weight_dtype = torch.float16
576
+
577
+ unet.train()
578
+ text_encoder.train()
579
+
580
+ if log_wandb:
581
+ preped_clip = prepare_clip_model_sets()
582
+
583
+ loss_sum = 0.0
584
+
585
+ for epoch in range(math.ceil(num_steps / len(dataloader))):
586
+ for batch in dataloader:
587
+ lr_scheduler_lora.step()
588
+
589
+ optimizer.zero_grad()
590
+
591
+ loss = loss_step(
592
+ batch,
593
+ unet,
594
+ vae,
595
+ text_encoder,
596
+ scheduler,
597
+ train_inpainting=train_inpainting,
598
+ t_mutliplier=0.8,
599
+ mixed_precision=True,
600
+ mask_temperature=mask_temperature,
601
+ cached_latents=cached_latents,
602
+ )
603
+ loss_sum += loss.detach().item()
604
+
605
+ loss.backward()
606
+ torch.nn.utils.clip_grad_norm_(
607
+ itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
608
+ )
609
+ optimizer.step()
610
+ progress_bar.update(1)
611
+ logs = {
612
+ "loss": loss.detach().item(),
613
+ "lr": lr_scheduler_lora.get_last_lr()[0],
614
+ }
615
+ progress_bar.set_postfix(**logs)
616
+
617
+ global_step += 1
618
+
619
+ if global_step % save_steps == 0:
620
+ save_all(
621
+ unet,
622
+ text_encoder,
623
+ placeholder_token_ids=placeholder_token_ids,
624
+ placeholder_tokens=placeholder_tokens,
625
+ save_path=os.path.join(
626
+ save_path, f"step_{global_step}.safetensors"
627
+ ),
628
+ target_replace_module_text=lora_clip_target_modules,
629
+ target_replace_module_unet=lora_unet_target_modules,
630
+ )
631
+ moved = (
632
+ torch.tensor(list(itertools.chain(*inspect_lora(unet).values())))
633
+ .mean()
634
+ .item()
635
+ )
636
+
637
+ print("LORA Unet Moved", moved)
638
+ moved = (
639
+ torch.tensor(
640
+ list(itertools.chain(*inspect_lora(text_encoder).values()))
641
+ )
642
+ .mean()
643
+ .item()
644
+ )
645
+
646
+ print("LORA CLIP Moved", moved)
647
+
648
+ if log_wandb:
649
+ with torch.no_grad():
650
+ pipe = StableDiffusionPipeline(
651
+ vae=vae,
652
+ text_encoder=text_encoder,
653
+ tokenizer=tokenizer,
654
+ unet=unet,
655
+ scheduler=scheduler,
656
+ safety_checker=None,
657
+ feature_extractor=None,
658
+ )
659
+
660
+ # open all images in test_image_path
661
+ images = []
662
+ for file in os.listdir(test_image_path):
663
+ if file.endswith(".png") or file.endswith(".jpg"):
664
+ images.append(
665
+ Image.open(os.path.join(test_image_path, file))
666
+ )
667
+
668
+ wandb.log({"loss": loss_sum / save_steps})
669
+ loss_sum = 0.0
670
+ wandb.log(
671
+ evaluate_pipe(
672
+ pipe,
673
+ target_images=images,
674
+ class_token=class_token,
675
+ learnt_token="".join(placeholder_tokens),
676
+ n_test=wandb_log_prompt_cnt,
677
+ n_step=50,
678
+ clip_model_sets=preped_clip,
679
+ )
680
+ )
681
+
682
+ if global_step >= num_steps:
683
+ break
684
+
685
+ save_all(
686
+ unet,
687
+ text_encoder,
688
+ placeholder_token_ids=placeholder_token_ids,
689
+ placeholder_tokens=placeholder_tokens,
690
+ save_path=os.path.join(save_path, f"{out_name}.safetensors"),
691
+ target_replace_module_text=lora_clip_target_modules,
692
+ target_replace_module_unet=lora_unet_target_modules,
693
+ )
694
+
695
+
696
+ def train(
697
+ instance_data_dir: str,
698
+ pretrained_model_name_or_path: str,
699
+ output_dir: str,
700
+ train_text_encoder: bool = True,
701
+ pretrained_vae_name_or_path: str = None,
702
+ revision: Optional[str] = None,
703
+ perform_inversion: bool = True,
704
+ use_template: Literal[None, "object", "style"] = None,
705
+ train_inpainting: bool = False,
706
+ placeholder_tokens: str = "",
707
+ placeholder_token_at_data: Optional[str] = None,
708
+ initializer_tokens: Optional[str] = None,
709
+ seed: int = 42,
710
+ resolution: int = 512,
711
+ color_jitter: bool = True,
712
+ train_batch_size: int = 1,
713
+ sample_batch_size: int = 1,
714
+ max_train_steps_tuning: int = 1000,
715
+ max_train_steps_ti: int = 1000,
716
+ save_steps: int = 100,
717
+ gradient_accumulation_steps: int = 4,
718
+ gradient_checkpointing: bool = False,
719
+ lora_rank: int = 4,
720
+ lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
721
+ lora_clip_target_modules={"CLIPAttention"},
722
+ lora_dropout_p: float = 0.0,
723
+ lora_scale: float = 1.0,
724
+ use_extended_lora: bool = False,
725
+ clip_ti_decay: bool = True,
726
+ learning_rate_unet: float = 1e-4,
727
+ learning_rate_text: float = 1e-5,
728
+ learning_rate_ti: float = 5e-4,
729
+ continue_inversion: bool = False,
730
+ continue_inversion_lr: Optional[float] = None,
731
+ use_face_segmentation_condition: bool = False,
732
+ cached_latents: bool = True,
733
+ use_mask_captioned_data: bool = False,
734
+ mask_temperature: float = 1.0,
735
+ scale_lr: bool = False,
736
+ lr_scheduler: str = "linear",
737
+ lr_warmup_steps: int = 0,
738
+ lr_scheduler_lora: str = "linear",
739
+ lr_warmup_steps_lora: int = 0,
740
+ weight_decay_ti: float = 0.00,
741
+ weight_decay_lora: float = 0.001,
742
+ use_8bit_adam: bool = False,
743
+ device="cuda:0",
744
+ extra_args: Optional[dict] = None,
745
+ log_wandb: bool = False,
746
+ wandb_log_prompt_cnt: int = 10,
747
+ wandb_project_name: str = "new_pti_project",
748
+ wandb_entity: str = "new_pti_entity",
749
+ proxy_token: str = "person",
750
+ enable_xformers_memory_efficient_attention: bool = False,
751
+ out_name: str = "final_lora",
752
+ ):
753
+ torch.manual_seed(seed)
754
+
755
+ if log_wandb:
756
+ wandb.init(
757
+ project=wandb_project_name,
758
+ entity=wandb_entity,
759
+ name=f"steps_{max_train_steps_ti}_lr_{learning_rate_ti}_{instance_data_dir.split('/')[-1]}",
760
+ reinit=True,
761
+ config={
762
+ **(extra_args if extra_args is not None else {}),
763
+ },
764
+ )
765
+
766
+ if output_dir is not None:
767
+ os.makedirs(output_dir, exist_ok=True)
768
+ # print(placeholder_tokens, initializer_tokens)
769
+ if len(placeholder_tokens) == 0:
770
+ placeholder_tokens = []
771
+ print("PTI : Placeholder Tokens not given, using null token")
772
+ else:
773
+ placeholder_tokens = placeholder_tokens.split("|")
774
+
775
+ assert (
776
+ sorted(placeholder_tokens) == placeholder_tokens
777
+ ), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'"
778
+
779
+ if initializer_tokens is None:
780
+ print("PTI : Initializer Tokens not given, doing random inits")
781
+ initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens)
782
+ else:
783
+ initializer_tokens = initializer_tokens.split("|")
784
+
785
+ assert len(initializer_tokens) == len(
786
+ placeholder_tokens
787
+ ), "Unequal Initializer token for Placeholder tokens."
788
+
789
+ if proxy_token is not None:
790
+ class_token = proxy_token
791
+ class_token = "".join(initializer_tokens)
792
+
793
+ if placeholder_token_at_data is not None:
794
+ tok, pat = placeholder_token_at_data.split("|")
795
+ token_map = {tok: pat}
796
+
797
+ else:
798
+ token_map = {"DUMMY": "".join(placeholder_tokens)}
799
+
800
+ print("PTI : Placeholder Tokens", placeholder_tokens)
801
+ print("PTI : Initializer Tokens", initializer_tokens)
802
+
803
+ # get the models
804
+ text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models(
805
+ pretrained_model_name_or_path,
806
+ pretrained_vae_name_or_path,
807
+ revision,
808
+ placeholder_tokens,
809
+ initializer_tokens,
810
+ device=device,
811
+ )
812
+
813
+ noise_scheduler = DDPMScheduler.from_config(
814
+ pretrained_model_name_or_path, subfolder="scheduler"
815
+ )
816
+
817
+ if gradient_checkpointing:
818
+ unet.enable_gradient_checkpointing()
819
+
820
+ if enable_xformers_memory_efficient_attention:
821
+ from diffusers.utils.import_utils import is_xformers_available
822
+
823
+ if is_xformers_available():
824
+ unet.enable_xformers_memory_efficient_attention()
825
+ else:
826
+ raise ValueError(
827
+ "xformers is not available. Make sure it is installed correctly"
828
+ )
829
+
830
+ if scale_lr:
831
+ unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size
832
+ text_encoder_lr = (
833
+ learning_rate_text * gradient_accumulation_steps * train_batch_size
834
+ )
835
+ ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size
836
+ else:
837
+ unet_lr = learning_rate_unet
838
+ text_encoder_lr = learning_rate_text
839
+ ti_lr = learning_rate_ti
840
+
841
+ train_dataset = PivotalTuningDatasetCapation(
842
+ instance_data_root=instance_data_dir,
843
+ token_map=token_map,
844
+ use_template=use_template,
845
+ tokenizer=tokenizer,
846
+ size=resolution,
847
+ color_jitter=color_jitter,
848
+ use_face_segmentation_condition=use_face_segmentation_condition,
849
+ use_mask_captioned_data=use_mask_captioned_data,
850
+ train_inpainting=train_inpainting,
851
+ )
852
+
853
+ train_dataset.blur_amount = 200
854
+
855
+ if train_inpainting:
856
+ assert not cached_latents, "Cached latents not supported for inpainting"
857
+
858
+ train_dataloader = inpainting_dataloader(
859
+ train_dataset, train_batch_size, tokenizer, vae, text_encoder
860
+ )
861
+ else:
862
+ train_dataloader = text2img_dataloader(
863
+ train_dataset,
864
+ train_batch_size,
865
+ tokenizer,
866
+ vae,
867
+ text_encoder,
868
+ cached_latents=cached_latents,
869
+ )
870
+
871
+ index_no_updates = torch.arange(len(tokenizer)) != -1
872
+
873
+ for tok_id in placeholder_token_ids:
874
+ index_no_updates[tok_id] = False
875
+
876
+ unet.requires_grad_(False)
877
+ vae.requires_grad_(False)
878
+
879
+ params_to_freeze = itertools.chain(
880
+ text_encoder.text_model.encoder.parameters(),
881
+ text_encoder.text_model.final_layer_norm.parameters(),
882
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
883
+ )
884
+ for param in params_to_freeze:
885
+ param.requires_grad = False
886
+
887
+ if cached_latents:
888
+ vae = None
889
+ # STEP 1 : Perform Inversion
890
+ if perform_inversion:
891
+ ti_optimizer = optim.AdamW(
892
+ text_encoder.get_input_embeddings().parameters(),
893
+ lr=ti_lr,
894
+ betas=(0.9, 0.999),
895
+ eps=1e-08,
896
+ weight_decay=weight_decay_ti,
897
+ )
898
+
899
+ lr_scheduler = get_scheduler(
900
+ lr_scheduler,
901
+ optimizer=ti_optimizer,
902
+ num_warmup_steps=lr_warmup_steps,
903
+ num_training_steps=max_train_steps_ti,
904
+ )
905
+
906
+ train_inversion(
907
+ unet,
908
+ vae,
909
+ text_encoder,
910
+ train_dataloader,
911
+ max_train_steps_ti,
912
+ cached_latents=cached_latents,
913
+ accum_iter=gradient_accumulation_steps,
914
+ scheduler=noise_scheduler,
915
+ index_no_updates=index_no_updates,
916
+ optimizer=ti_optimizer,
917
+ lr_scheduler=lr_scheduler,
918
+ save_steps=save_steps,
919
+ placeholder_tokens=placeholder_tokens,
920
+ placeholder_token_ids=placeholder_token_ids,
921
+ save_path=output_dir,
922
+ test_image_path=instance_data_dir,
923
+ log_wandb=log_wandb,
924
+ wandb_log_prompt_cnt=wandb_log_prompt_cnt,
925
+ class_token=class_token,
926
+ train_inpainting=train_inpainting,
927
+ mixed_precision=False,
928
+ tokenizer=tokenizer,
929
+ clip_ti_decay=clip_ti_decay,
930
+ )
931
+
932
+ del ti_optimizer
933
+
934
+ # Next perform Tuning with LoRA:
935
+ if not use_extended_lora:
936
+ unet_lora_params, _ = inject_trainable_lora(
937
+ unet,
938
+ r=lora_rank,
939
+ target_replace_module=lora_unet_target_modules,
940
+ dropout_p=lora_dropout_p,
941
+ scale=lora_scale,
942
+ )
943
+ else:
944
+ print("PTI : USING EXTENDED UNET!!!")
945
+ lora_unet_target_modules = (
946
+ lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE
947
+ )
948
+ print("PTI : Will replace modules: ", lora_unet_target_modules)
949
+
950
+ unet_lora_params, _ = inject_trainable_lora_extended(
951
+ unet, r=lora_rank, target_replace_module=lora_unet_target_modules
952
+ )
953
+ print(f"PTI : has {len(unet_lora_params)} lora")
954
+
955
+ print("PTI : Before training:")
956
+ inspect_lora(unet)
957
+
958
+ params_to_optimize = [
959
+ {"params": itertools.chain(*unet_lora_params), "lr": unet_lr},
960
+ ]
961
+
962
+ text_encoder.requires_grad_(False)
963
+
964
+ if continue_inversion:
965
+ params_to_optimize += [
966
+ {
967
+ "params": text_encoder.get_input_embeddings().parameters(),
968
+ "lr": continue_inversion_lr
969
+ if continue_inversion_lr is not None
970
+ else ti_lr,
971
+ }
972
+ ]
973
+ text_encoder.requires_grad_(True)
974
+ params_to_freeze = itertools.chain(
975
+ text_encoder.text_model.encoder.parameters(),
976
+ text_encoder.text_model.final_layer_norm.parameters(),
977
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
978
+ )
979
+ for param in params_to_freeze:
980
+ param.requires_grad = False
981
+ else:
982
+ text_encoder.requires_grad_(False)
983
+ if train_text_encoder:
984
+ text_encoder_lora_params, _ = inject_trainable_lora(
985
+ text_encoder,
986
+ target_replace_module=lora_clip_target_modules,
987
+ r=lora_rank,
988
+ )
989
+ params_to_optimize += [
990
+ {
991
+ "params": itertools.chain(*text_encoder_lora_params),
992
+ "lr": text_encoder_lr,
993
+ }
994
+ ]
995
+ inspect_lora(text_encoder)
996
+
997
+ lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora)
998
+
999
+ unet.train()
1000
+ if train_text_encoder:
1001
+ text_encoder.train()
1002
+
1003
+ train_dataset.blur_amount = 70
1004
+
1005
+ lr_scheduler_lora = get_scheduler(
1006
+ lr_scheduler_lora,
1007
+ optimizer=lora_optimizers,
1008
+ num_warmup_steps=lr_warmup_steps_lora,
1009
+ num_training_steps=max_train_steps_tuning,
1010
+ )
1011
+
1012
+ perform_tuning(
1013
+ unet,
1014
+ vae,
1015
+ text_encoder,
1016
+ train_dataloader,
1017
+ max_train_steps_tuning,
1018
+ cached_latents=cached_latents,
1019
+ scheduler=noise_scheduler,
1020
+ optimizer=lora_optimizers,
1021
+ save_steps=save_steps,
1022
+ placeholder_tokens=placeholder_tokens,
1023
+ placeholder_token_ids=placeholder_token_ids,
1024
+ save_path=output_dir,
1025
+ lr_scheduler_lora=lr_scheduler_lora,
1026
+ lora_unet_target_modules=lora_unet_target_modules,
1027
+ lora_clip_target_modules=lora_clip_target_modules,
1028
+ mask_temperature=mask_temperature,
1029
+ tokenizer=tokenizer,
1030
+ out_name=out_name,
1031
+ test_image_path=instance_data_dir,
1032
+ log_wandb=log_wandb,
1033
+ wandb_log_prompt_cnt=wandb_log_prompt_cnt,
1034
+ class_token=class_token,
1035
+ train_inpainting=train_inpainting,
1036
+ )
1037
+
1038
+
1039
+ def main():
1040
+ fire.Fire(train)
lora_diffusion/cli_pt_to_safetensors.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import fire
4
+ import torch
5
+ from lora_diffusion import (
6
+ DEFAULT_TARGET_REPLACE,
7
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
8
+ UNET_DEFAULT_TARGET_REPLACE,
9
+ convert_loras_to_safeloras_with_embeds,
10
+ safetensors_available,
11
+ )
12
+
13
+ _target_by_name = {
14
+ "unet": UNET_DEFAULT_TARGET_REPLACE,
15
+ "text_encoder": TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
16
+ }
17
+
18
+
19
+ def convert(*paths, outpath, overwrite=False, **settings):
20
+ """
21
+ Converts one or more pytorch Lora and/or Textual Embedding pytorch files
22
+ into a safetensor file.
23
+
24
+ Pass all the input paths as arguments. Whether they are Textual Embedding
25
+ or Lora models will be auto-detected.
26
+
27
+ For Lora models, their name will be taken from the path, i.e.
28
+ "lora_weight.pt" => unet
29
+ "lora_weight.text_encoder.pt" => text_encoder
30
+
31
+ You can also set target_modules and/or rank by providing an argument prefixed
32
+ by the name.
33
+
34
+ So a complete example might be something like:
35
+
36
+ ```
37
+ python -m lora_diffusion.cli_pt_to_safetensors lora_weight.* --outpath lora_weight.safetensor --unet.rank 8
38
+ ```
39
+ """
40
+ modelmap = {}
41
+ embeds = {}
42
+
43
+ if os.path.exists(outpath) and not overwrite:
44
+ raise ValueError(
45
+ f"Output path {outpath} already exists, and overwrite is not True"
46
+ )
47
+
48
+ for path in paths:
49
+ data = torch.load(path)
50
+
51
+ if isinstance(data, dict):
52
+ print(f"Loading textual inversion embeds {data.keys()} from {path}")
53
+ embeds.update(data)
54
+
55
+ else:
56
+ name_parts = os.path.split(path)[1].split(".")
57
+ name = name_parts[-2] if len(name_parts) > 2 else "unet"
58
+
59
+ model_settings = {
60
+ "target_modules": _target_by_name.get(name, DEFAULT_TARGET_REPLACE),
61
+ "rank": 4,
62
+ }
63
+
64
+ prefix = f"{name}."
65
+
66
+ arg_settings = { k[len(prefix) :]: v for k, v in settings.items() if k.startswith(prefix) }
67
+ model_settings = { **model_settings, **arg_settings }
68
+
69
+ print(f"Loading Lora for {name} from {path} with settings {model_settings}")
70
+
71
+ modelmap[name] = (
72
+ path,
73
+ model_settings["target_modules"],
74
+ model_settings["rank"],
75
+ )
76
+
77
+ convert_loras_to_safeloras_with_embeds(modelmap, embeds, outpath)
78
+
79
+
80
+ def main():
81
+ fire.Fire(convert)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
lora_diffusion/cli_svd.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fire
2
+ from diffusers import StableDiffusionPipeline
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .lora import (
7
+ save_all,
8
+ _find_modules,
9
+ LoraInjectedConv2d,
10
+ LoraInjectedLinear,
11
+ inject_trainable_lora,
12
+ inject_trainable_lora_extended,
13
+ )
14
+
15
+
16
+ def _iter_lora(model):
17
+ for module in model.modules():
18
+ if isinstance(module, LoraInjectedConv2d) or isinstance(
19
+ module, LoraInjectedLinear
20
+ ):
21
+ yield module
22
+
23
+
24
+ def overwrite_base(base_model, tuned_model, rank, clamp_quantile):
25
+ device = base_model.device
26
+ dtype = base_model.dtype
27
+
28
+ for lor_base, lor_tune in zip(_iter_lora(base_model), _iter_lora(tuned_model)):
29
+
30
+ if isinstance(lor_base, LoraInjectedLinear):
31
+ residual = lor_tune.linear.weight.data - lor_base.linear.weight.data
32
+ # SVD on residual
33
+ print("Distill Linear shape ", residual.shape)
34
+ residual = residual.float()
35
+ U, S, Vh = torch.linalg.svd(residual)
36
+ U = U[:, :rank]
37
+ S = S[:rank]
38
+ U = U @ torch.diag(S)
39
+
40
+ Vh = Vh[:rank, :]
41
+
42
+ dist = torch.cat([U.flatten(), Vh.flatten()])
43
+ hi_val = torch.quantile(dist, clamp_quantile)
44
+ low_val = -hi_val
45
+
46
+ U = U.clamp(low_val, hi_val)
47
+ Vh = Vh.clamp(low_val, hi_val)
48
+
49
+ assert lor_base.lora_up.weight.shape == U.shape
50
+ assert lor_base.lora_down.weight.shape == Vh.shape
51
+
52
+ lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
53
+ lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
54
+
55
+ if isinstance(lor_base, LoraInjectedConv2d):
56
+ residual = lor_tune.conv.weight.data - lor_base.conv.weight.data
57
+ print("Distill Conv shape ", residual.shape)
58
+
59
+ residual = residual.float()
60
+ residual = residual.flatten(start_dim=1)
61
+
62
+ # SVD on residual
63
+ U, S, Vh = torch.linalg.svd(residual)
64
+ U = U[:, :rank]
65
+ S = S[:rank]
66
+ U = U @ torch.diag(S)
67
+
68
+ Vh = Vh[:rank, :]
69
+
70
+ dist = torch.cat([U.flatten(), Vh.flatten()])
71
+ hi_val = torch.quantile(dist, clamp_quantile)
72
+ low_val = -hi_val
73
+
74
+ U = U.clamp(low_val, hi_val)
75
+ Vh = Vh.clamp(low_val, hi_val)
76
+
77
+ # U is (out_channels, rank) with 1x1 conv. So,
78
+ U = U.reshape(U.shape[0], U.shape[1], 1, 1)
79
+ # V is (rank, in_channels * kernel_size1 * kernel_size2)
80
+ # now reshape:
81
+ Vh = Vh.reshape(
82
+ Vh.shape[0],
83
+ lor_base.conv.in_channels,
84
+ lor_base.conv.kernel_size[0],
85
+ lor_base.conv.kernel_size[1],
86
+ )
87
+
88
+ assert lor_base.lora_up.weight.shape == U.shape
89
+ assert lor_base.lora_down.weight.shape == Vh.shape
90
+
91
+ lor_base.lora_up.weight.data = U.to(device=device, dtype=dtype)
92
+ lor_base.lora_down.weight.data = Vh.to(device=device, dtype=dtype)
93
+
94
+
95
+ def svd_distill(
96
+ target_model: str,
97
+ base_model: str,
98
+ rank: int = 4,
99
+ clamp_quantile: float = 0.99,
100
+ device: str = "cuda:0",
101
+ save_path: str = "svd_distill.safetensors",
102
+ ):
103
+ pipe_base = StableDiffusionPipeline.from_pretrained(
104
+ base_model, torch_dtype=torch.float16
105
+ ).to(device)
106
+
107
+ pipe_tuned = StableDiffusionPipeline.from_pretrained(
108
+ target_model, torch_dtype=torch.float16
109
+ ).to(device)
110
+
111
+ # Inject unet
112
+ _ = inject_trainable_lora_extended(pipe_base.unet, r=rank)
113
+ _ = inject_trainable_lora_extended(pipe_tuned.unet, r=rank)
114
+
115
+ overwrite_base(
116
+ pipe_base.unet, pipe_tuned.unet, rank=rank, clamp_quantile=clamp_quantile
117
+ )
118
+
119
+ # Inject text encoder
120
+ _ = inject_trainable_lora(
121
+ pipe_base.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
122
+ )
123
+ _ = inject_trainable_lora(
124
+ pipe_tuned.text_encoder, r=rank, target_replace_module={"CLIPAttention"}
125
+ )
126
+
127
+ overwrite_base(
128
+ pipe_base.text_encoder,
129
+ pipe_tuned.text_encoder,
130
+ rank=rank,
131
+ clamp_quantile=clamp_quantile,
132
+ )
133
+
134
+ save_all(
135
+ unet=pipe_base.unet,
136
+ text_encoder=pipe_base.text_encoder,
137
+ placeholder_token_ids=None,
138
+ placeholder_tokens=None,
139
+ save_path=save_path,
140
+ save_lora=True,
141
+ save_ti=False,
142
+ )
143
+
144
+
145
+ def main():
146
+ fire.Fire(svd_distill)
lora_diffusion/dataset.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ from PIL import Image
6
+ from torch import zeros_like
7
+ from torch.utils.data import Dataset
8
+ from torchvision import transforms
9
+ import glob
10
+ from .preprocess_files import face_mask_google_mediapipe
11
+
12
+ OBJECT_TEMPLATE = [
13
+ "a photo of a {}",
14
+ "a rendering of a {}",
15
+ "a cropped photo of the {}",
16
+ "the photo of a {}",
17
+ "a photo of a clean {}",
18
+ "a photo of a dirty {}",
19
+ "a dark photo of the {}",
20
+ "a photo of my {}",
21
+ "a photo of the cool {}",
22
+ "a close-up photo of a {}",
23
+ "a bright photo of the {}",
24
+ "a cropped photo of a {}",
25
+ "a photo of the {}",
26
+ "a good photo of the {}",
27
+ "a photo of one {}",
28
+ "a close-up photo of the {}",
29
+ "a rendition of the {}",
30
+ "a photo of the clean {}",
31
+ "a rendition of a {}",
32
+ "a photo of a nice {}",
33
+ "a good photo of a {}",
34
+ "a photo of the nice {}",
35
+ "a photo of the small {}",
36
+ "a photo of the weird {}",
37
+ "a photo of the large {}",
38
+ "a photo of a cool {}",
39
+ "a photo of a small {}",
40
+ ]
41
+
42
+ STYLE_TEMPLATE = [
43
+ "a painting in the style of {}",
44
+ "a rendering in the style of {}",
45
+ "a cropped painting in the style of {}",
46
+ "the painting in the style of {}",
47
+ "a clean painting in the style of {}",
48
+ "a dirty painting in the style of {}",
49
+ "a dark painting in the style of {}",
50
+ "a picture in the style of {}",
51
+ "a cool painting in the style of {}",
52
+ "a close-up painting in the style of {}",
53
+ "a bright painting in the style of {}",
54
+ "a cropped painting in the style of {}",
55
+ "a good painting in the style of {}",
56
+ "a close-up painting in the style of {}",
57
+ "a rendition in the style of {}",
58
+ "a nice painting in the style of {}",
59
+ "a small painting in the style of {}",
60
+ "a weird painting in the style of {}",
61
+ "a large painting in the style of {}",
62
+ ]
63
+
64
+ NULL_TEMPLATE = ["{}"]
65
+
66
+ TEMPLATE_MAP = {
67
+ "object": OBJECT_TEMPLATE,
68
+ "style": STYLE_TEMPLATE,
69
+ "null": NULL_TEMPLATE,
70
+ }
71
+
72
+
73
+ def _randomset(lis):
74
+ ret = []
75
+ for i in range(len(lis)):
76
+ if random.random() < 0.5:
77
+ ret.append(lis[i])
78
+ return ret
79
+
80
+
81
+ def _shuffle(lis):
82
+
83
+ return random.sample(lis, len(lis))
84
+
85
+
86
+ def _get_cutout_holes(
87
+ height,
88
+ width,
89
+ min_holes=8,
90
+ max_holes=32,
91
+ min_height=16,
92
+ max_height=128,
93
+ min_width=16,
94
+ max_width=128,
95
+ ):
96
+ holes = []
97
+ for _n in range(random.randint(min_holes, max_holes)):
98
+ hole_height = random.randint(min_height, max_height)
99
+ hole_width = random.randint(min_width, max_width)
100
+ y1 = random.randint(0, height - hole_height)
101
+ x1 = random.randint(0, width - hole_width)
102
+ y2 = y1 + hole_height
103
+ x2 = x1 + hole_width
104
+ holes.append((x1, y1, x2, y2))
105
+ return holes
106
+
107
+
108
+ def _generate_random_mask(image):
109
+ mask = zeros_like(image[:1])
110
+ holes = _get_cutout_holes(mask.shape[1], mask.shape[2])
111
+ for (x1, y1, x2, y2) in holes:
112
+ mask[:, y1:y2, x1:x2] = 1.0
113
+ if random.uniform(0, 1) < 0.25:
114
+ mask.fill_(1.0)
115
+ masked_image = image * (mask < 0.5)
116
+ return mask, masked_image
117
+
118
+
119
+ class PivotalTuningDatasetCapation(Dataset):
120
+ """
121
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
122
+ It pre-processes the images and the tokenizes prompts.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ instance_data_root,
128
+ tokenizer,
129
+ token_map: Optional[dict] = None,
130
+ use_template: Optional[str] = None,
131
+ size=512,
132
+ h_flip=True,
133
+ color_jitter=False,
134
+ resize=True,
135
+ use_mask_captioned_data=False,
136
+ use_face_segmentation_condition=False,
137
+ train_inpainting=False,
138
+ blur_amount: int = 70,
139
+ ):
140
+ self.size = size
141
+ self.tokenizer = tokenizer
142
+ self.resize = resize
143
+ self.train_inpainting = train_inpainting
144
+
145
+ instance_data_root = Path(instance_data_root)
146
+ if not instance_data_root.exists():
147
+ raise ValueError("Instance images root doesn't exists.")
148
+
149
+ self.instance_images_path = []
150
+ self.mask_path = []
151
+
152
+ assert not (
153
+ use_mask_captioned_data and use_template
154
+ ), "Can't use both mask caption data and template."
155
+
156
+ # Prepare the instance images
157
+ if use_mask_captioned_data:
158
+ src_imgs = glob.glob(str(instance_data_root) + "/*src.jpg")
159
+ for f in src_imgs:
160
+ idx = int(str(Path(f).stem).split(".")[0])
161
+ mask_path = f"{instance_data_root}/{idx}.mask.png"
162
+
163
+ if Path(mask_path).exists():
164
+ self.instance_images_path.append(f)
165
+ self.mask_path.append(mask_path)
166
+ else:
167
+ print(f"Mask not found for {f}")
168
+
169
+ self.captions = open(f"{instance_data_root}/caption.txt").readlines()
170
+
171
+ else:
172
+ possibily_src_images = (
173
+ glob.glob(str(instance_data_root) + "/*.jpg")
174
+ + glob.glob(str(instance_data_root) + "/*.png")
175
+ + glob.glob(str(instance_data_root) + "/*.jpeg")
176
+ )
177
+ possibily_src_images = (
178
+ set(possibily_src_images)
179
+ - set(glob.glob(str(instance_data_root) + "/*mask.png"))
180
+ - set([str(instance_data_root) + "/caption.txt"])
181
+ )
182
+
183
+ self.instance_images_path = list(set(possibily_src_images))
184
+ self.captions = [
185
+ x.split("/")[-1].split(".")[0] for x in self.instance_images_path
186
+ ]
187
+
188
+ assert (
189
+ len(self.instance_images_path) > 0
190
+ ), "No images found in the instance data root."
191
+
192
+ self.instance_images_path = sorted(self.instance_images_path)
193
+
194
+ self.use_mask = use_face_segmentation_condition or use_mask_captioned_data
195
+ self.use_mask_captioned_data = use_mask_captioned_data
196
+
197
+ if use_face_segmentation_condition:
198
+
199
+ for idx in range(len(self.instance_images_path)):
200
+ targ = f"{instance_data_root}/{idx}.mask.png"
201
+ # see if the mask exists
202
+ if not Path(targ).exists():
203
+ print(f"Mask not found for {targ}")
204
+
205
+ print(
206
+ "Warning : this will pre-process all the images in the instance data root."
207
+ )
208
+
209
+ if len(self.mask_path) > 0:
210
+ print(
211
+ "Warning : masks already exists, but will be overwritten."
212
+ )
213
+
214
+ masks = face_mask_google_mediapipe(
215
+ [
216
+ Image.open(f).convert("RGB")
217
+ for f in self.instance_images_path
218
+ ]
219
+ )
220
+ for idx, mask in enumerate(masks):
221
+ mask.save(f"{instance_data_root}/{idx}.mask.png")
222
+
223
+ break
224
+
225
+ for idx in range(len(self.instance_images_path)):
226
+ self.mask_path.append(f"{instance_data_root}/{idx}.mask.png")
227
+
228
+ self.num_instance_images = len(self.instance_images_path)
229
+ self.token_map = token_map
230
+
231
+ self.use_template = use_template
232
+ if use_template is not None:
233
+ self.templates = TEMPLATE_MAP[use_template]
234
+
235
+ self._length = self.num_instance_images
236
+
237
+ self.h_flip = h_flip
238
+ self.image_transforms = transforms.Compose(
239
+ [
240
+ transforms.Resize(
241
+ size, interpolation=transforms.InterpolationMode.BILINEAR
242
+ )
243
+ if resize
244
+ else transforms.Lambda(lambda x: x),
245
+ transforms.ColorJitter(0.1, 0.1)
246
+ if color_jitter
247
+ else transforms.Lambda(lambda x: x),
248
+ transforms.CenterCrop(size),
249
+ transforms.ToTensor(),
250
+ transforms.Normalize([0.5], [0.5]),
251
+ ]
252
+ )
253
+
254
+ self.blur_amount = blur_amount
255
+
256
+ def __len__(self):
257
+ return self._length
258
+
259
+ def __getitem__(self, index):
260
+ example = {}
261
+ instance_image = Image.open(
262
+ self.instance_images_path[index % self.num_instance_images]
263
+ )
264
+ if not instance_image.mode == "RGB":
265
+ instance_image = instance_image.convert("RGB")
266
+ example["instance_images"] = self.image_transforms(instance_image)
267
+
268
+ if self.train_inpainting:
269
+ (
270
+ example["instance_masks"],
271
+ example["instance_masked_images"],
272
+ ) = _generate_random_mask(example["instance_images"])
273
+
274
+ if self.use_template:
275
+ assert self.token_map is not None
276
+ input_tok = list(self.token_map.values())[0]
277
+
278
+ text = random.choice(self.templates).format(input_tok)
279
+ else:
280
+ text = self.captions[index % self.num_instance_images].strip()
281
+
282
+ if self.token_map is not None:
283
+ for token, value in self.token_map.items():
284
+ text = text.replace(token, value)
285
+
286
+ print(text)
287
+
288
+ if self.use_mask:
289
+ example["mask"] = (
290
+ self.image_transforms(
291
+ Image.open(self.mask_path[index % self.num_instance_images])
292
+ )
293
+ * 0.5
294
+ + 1.0
295
+ )
296
+
297
+ if self.h_flip and random.random() > 0.5:
298
+ hflip = transforms.RandomHorizontalFlip(p=1)
299
+
300
+ example["instance_images"] = hflip(example["instance_images"])
301
+ if self.use_mask:
302
+ example["mask"] = hflip(example["mask"])
303
+
304
+ example["instance_prompt_ids"] = self.tokenizer(
305
+ text,
306
+ padding="do_not_pad",
307
+ truncation=True,
308
+ max_length=self.tokenizer.model_max_length,
309
+ ).input_ids
310
+
311
+ return example
lora_diffusion/lora.py ADDED
@@ -0,0 +1,1110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ from itertools import groupby
4
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
5
+
6
+ import numpy as np
7
+ import PIL
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ try:
13
+ from safetensors.torch import safe_open
14
+ from safetensors.torch import save_file as safe_save
15
+
16
+ safetensors_available = True
17
+ except ImportError:
18
+ from .safe_open import safe_open
19
+
20
+ def safe_save(
21
+ tensors: Dict[str, torch.Tensor],
22
+ filename: str,
23
+ metadata: Optional[Dict[str, str]] = None,
24
+ ) -> None:
25
+ raise EnvironmentError(
26
+ "Saving safetensors requires the safetensors library. Please install with pip or similar."
27
+ )
28
+
29
+ safetensors_available = False
30
+
31
+
32
+ class LoraInjectedLinear(nn.Module):
33
+ def __init__(
34
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
35
+ ):
36
+ super().__init__()
37
+
38
+ if r > min(in_features, out_features):
39
+ raise ValueError(
40
+ f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
41
+ )
42
+ self.r = r
43
+ self.linear = nn.Linear(in_features, out_features, bias)
44
+ self.lora_down = nn.Linear(in_features, r, bias=False)
45
+ self.dropout = nn.Dropout(dropout_p)
46
+ self.lora_up = nn.Linear(r, out_features, bias=False)
47
+ self.scale = scale
48
+ self.selector = nn.Identity()
49
+
50
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
51
+ nn.init.zeros_(self.lora_up.weight)
52
+
53
+ def forward(self, input):
54
+ return (
55
+ self.linear(input)
56
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
57
+ * self.scale
58
+ )
59
+
60
+ def realize_as_lora(self):
61
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
62
+
63
+ def set_selector_from_diag(self, diag: torch.Tensor):
64
+ # diag is a 1D tensor of size (r,)
65
+ assert diag.shape == (self.r,)
66
+ self.selector = nn.Linear(self.r, self.r, bias=False)
67
+ self.selector.weight.data = torch.diag(diag)
68
+ self.selector.weight.data = self.selector.weight.data.to(
69
+ self.lora_up.weight.device
70
+ ).to(self.lora_up.weight.dtype)
71
+
72
+
73
+ class LoraInjectedConv2d(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_channels: int,
77
+ out_channels: int,
78
+ kernel_size,
79
+ stride=1,
80
+ padding=0,
81
+ dilation=1,
82
+ groups: int = 1,
83
+ bias: bool = True,
84
+ r: int = 4,
85
+ dropout_p: float = 0.1,
86
+ scale: float = 1.0,
87
+ ):
88
+ super().__init__()
89
+ if r > min(in_channels, out_channels):
90
+ raise ValueError(
91
+ f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
92
+ )
93
+ self.r = r
94
+ self.conv = nn.Conv2d(
95
+ in_channels=in_channels,
96
+ out_channels=out_channels,
97
+ kernel_size=kernel_size,
98
+ stride=stride,
99
+ padding=padding,
100
+ dilation=dilation,
101
+ groups=groups,
102
+ bias=bias,
103
+ )
104
+
105
+ self.lora_down = nn.Conv2d(
106
+ in_channels=in_channels,
107
+ out_channels=r,
108
+ kernel_size=kernel_size,
109
+ stride=stride,
110
+ padding=padding,
111
+ dilation=dilation,
112
+ groups=groups,
113
+ bias=False,
114
+ )
115
+ self.dropout = nn.Dropout(dropout_p)
116
+ self.lora_up = nn.Conv2d(
117
+ in_channels=r,
118
+ out_channels=out_channels,
119
+ kernel_size=1,
120
+ stride=1,
121
+ padding=0,
122
+ bias=False,
123
+ )
124
+ self.selector = nn.Identity()
125
+ self.scale = scale
126
+
127
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
128
+ nn.init.zeros_(self.lora_up.weight)
129
+
130
+ def forward(self, input):
131
+ return (
132
+ self.conv(input)
133
+ + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
134
+ * self.scale
135
+ )
136
+
137
+ def realize_as_lora(self):
138
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
139
+
140
+ def set_selector_from_diag(self, diag: torch.Tensor):
141
+ # diag is a 1D tensor of size (r,)
142
+ assert diag.shape == (self.r,)
143
+ self.selector = nn.Conv2d(
144
+ in_channels=self.r,
145
+ out_channels=self.r,
146
+ kernel_size=1,
147
+ stride=1,
148
+ padding=0,
149
+ bias=False,
150
+ )
151
+ self.selector.weight.data = torch.diag(diag)
152
+
153
+ # same device + dtype as lora_up
154
+ self.selector.weight.data = self.selector.weight.data.to(
155
+ self.lora_up.weight.device
156
+ ).to(self.lora_up.weight.dtype)
157
+
158
+
159
+ UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
160
+
161
+ UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"}
162
+
163
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
164
+
165
+ TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
166
+
167
+ DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
168
+
169
+ EMBED_FLAG = "<embed>"
170
+
171
+
172
+ def _find_children(
173
+ model,
174
+ search_class: List[Type[nn.Module]] = [nn.Linear],
175
+ ):
176
+ """
177
+ Find all modules of a certain class (or union of classes).
178
+
179
+ Returns all matching modules, along with the parent of those moduless and the
180
+ names they are referenced by.
181
+ """
182
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
183
+ for parent in model.modules():
184
+ for name, module in parent.named_children():
185
+ if any([isinstance(module, _class) for _class in search_class]):
186
+ yield parent, name, module
187
+
188
+
189
+ def _find_modules_v2(
190
+ model,
191
+ ancestor_class: Optional[Set[str]] = None,
192
+ search_class: List[Type[nn.Module]] = [nn.Linear],
193
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [
194
+ LoraInjectedLinear,
195
+ LoraInjectedConv2d,
196
+ ],
197
+ ):
198
+ """
199
+ Find all modules of a certain class (or union of classes) that are direct or
200
+ indirect descendants of other modules of a certain class (or union of classes).
201
+
202
+ Returns all matching modules, along with the parent of those moduless and the
203
+ names they are referenced by.
204
+ """
205
+
206
+ # Get the targets we should replace all linears under
207
+ if ancestor_class is not None:
208
+ ancestors = (
209
+ module
210
+ for module in model.modules()
211
+ if module.__class__.__name__ in ancestor_class
212
+ )
213
+ else:
214
+ # this, incase you want to naively iterate over all modules.
215
+ ancestors = [module for module in model.modules()]
216
+
217
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
218
+ for ancestor in ancestors:
219
+ for fullname, module in ancestor.named_modules():
220
+ if any([isinstance(module, _class) for _class in search_class]):
221
+ # Find the direct parent if this is a descendant, not a child, of target
222
+ *path, name = fullname.split(".")
223
+ parent = ancestor
224
+ while path:
225
+ parent = parent.get_submodule(path.pop(0))
226
+ # Skip this linear if it's a child of a LoraInjectedLinear
227
+ if exclude_children_of and any(
228
+ [isinstance(parent, _class) for _class in exclude_children_of]
229
+ ):
230
+ continue
231
+ # Otherwise, yield it
232
+ yield parent, name, module
233
+
234
+
235
+ def _find_modules_old(
236
+ model,
237
+ ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
238
+ search_class: List[Type[nn.Module]] = [nn.Linear],
239
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
240
+ ):
241
+ ret = []
242
+ for _module in model.modules():
243
+ if _module.__class__.__name__ in ancestor_class:
244
+
245
+ for name, _child_module in _module.named_modules():
246
+ if _child_module.__class__ in search_class:
247
+ ret.append((_module, name, _child_module))
248
+ print(ret)
249
+ return ret
250
+
251
+
252
+ _find_modules = _find_modules_v2
253
+
254
+
255
+ def inject_trainable_lora(
256
+ model: nn.Module,
257
+ target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
258
+ r: int = 4,
259
+ loras=None, # path to lora .pt
260
+ verbose: bool = False,
261
+ dropout_p: float = 0.0,
262
+ scale: float = 1.0,
263
+ ):
264
+ """
265
+ inject lora into model, and returns lora parameter groups.
266
+ """
267
+
268
+ require_grad_params = []
269
+ names = []
270
+
271
+ if loras != None:
272
+ loras = torch.load(loras)
273
+
274
+ for _module, name, _child_module in _find_modules(
275
+ model, target_replace_module, search_class=[nn.Linear]
276
+ ):
277
+ weight = _child_module.weight
278
+ bias = _child_module.bias
279
+ if verbose:
280
+ print("LoRA Injection : injecting lora into ", name)
281
+ print("LoRA Injection : weight shape", weight.shape)
282
+ _tmp = LoraInjectedLinear(
283
+ _child_module.in_features,
284
+ _child_module.out_features,
285
+ _child_module.bias is not None,
286
+ r=r,
287
+ dropout_p=dropout_p,
288
+ scale=scale,
289
+ )
290
+ _tmp.linear.weight = weight
291
+ if bias is not None:
292
+ _tmp.linear.bias = bias
293
+
294
+ # switch the module
295
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
296
+ _module._modules[name] = _tmp
297
+
298
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
299
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
300
+
301
+ if loras != None:
302
+ _module._modules[name].lora_up.weight = loras.pop(0)
303
+ _module._modules[name].lora_down.weight = loras.pop(0)
304
+
305
+ _module._modules[name].lora_up.weight.requires_grad = True
306
+ _module._modules[name].lora_down.weight.requires_grad = True
307
+ names.append(name)
308
+
309
+ return require_grad_params, names
310
+
311
+
312
+ def inject_trainable_lora_extended(
313
+ model: nn.Module,
314
+ target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
315
+ r: int = 4,
316
+ loras=None, # path to lora .pt
317
+ ):
318
+ """
319
+ inject lora into model, and returns lora parameter groups.
320
+ """
321
+
322
+ require_grad_params = []
323
+ names = []
324
+
325
+ if loras != None:
326
+ loras = torch.load(loras)
327
+
328
+ for _module, name, _child_module in _find_modules(
329
+ model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
330
+ ):
331
+ if _child_module.__class__ == nn.Linear:
332
+ weight = _child_module.weight
333
+ bias = _child_module.bias
334
+ _tmp = LoraInjectedLinear(
335
+ _child_module.in_features,
336
+ _child_module.out_features,
337
+ _child_module.bias is not None,
338
+ r=r,
339
+ )
340
+ _tmp.linear.weight = weight
341
+ if bias is not None:
342
+ _tmp.linear.bias = bias
343
+ elif _child_module.__class__ == nn.Conv2d:
344
+ weight = _child_module.weight
345
+ bias = _child_module.bias
346
+ _tmp = LoraInjectedConv2d(
347
+ _child_module.in_channels,
348
+ _child_module.out_channels,
349
+ _child_module.kernel_size,
350
+ _child_module.stride,
351
+ _child_module.padding,
352
+ _child_module.dilation,
353
+ _child_module.groups,
354
+ _child_module.bias is not None,
355
+ r=r,
356
+ )
357
+
358
+ _tmp.conv.weight = weight
359
+ if bias is not None:
360
+ _tmp.conv.bias = bias
361
+
362
+ # switch the module
363
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
364
+ if bias is not None:
365
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
366
+
367
+ _module._modules[name] = _tmp
368
+
369
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
370
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
371
+
372
+ if loras != None:
373
+ _module._modules[name].lora_up.weight = loras.pop(0)
374
+ _module._modules[name].lora_down.weight = loras.pop(0)
375
+
376
+ _module._modules[name].lora_up.weight.requires_grad = True
377
+ _module._modules[name].lora_down.weight.requires_grad = True
378
+ names.append(name)
379
+
380
+ return require_grad_params, names
381
+
382
+
383
+ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
384
+
385
+ loras = []
386
+
387
+ for _m, _n, _child_module in _find_modules(
388
+ model,
389
+ target_replace_module,
390
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
391
+ ):
392
+ loras.append((_child_module.lora_up, _child_module.lora_down))
393
+
394
+ if len(loras) == 0:
395
+ raise ValueError("No lora injected.")
396
+
397
+ return loras
398
+
399
+
400
+ def extract_lora_as_tensor(
401
+ model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
402
+ ):
403
+
404
+ loras = []
405
+
406
+ for _m, _n, _child_module in _find_modules(
407
+ model,
408
+ target_replace_module,
409
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
410
+ ):
411
+ up, down = _child_module.realize_as_lora()
412
+ if as_fp16:
413
+ up = up.to(torch.float16)
414
+ down = down.to(torch.float16)
415
+
416
+ loras.append((up, down))
417
+
418
+ if len(loras) == 0:
419
+ raise ValueError("No lora injected.")
420
+
421
+ return loras
422
+
423
+
424
+ def save_lora_weight(
425
+ model,
426
+ path="./lora.pt",
427
+ target_replace_module=DEFAULT_TARGET_REPLACE,
428
+ ):
429
+ weights = []
430
+ for _up, _down in extract_lora_ups_down(
431
+ model, target_replace_module=target_replace_module
432
+ ):
433
+ weights.append(_up.weight.to("cpu").to(torch.float16))
434
+ weights.append(_down.weight.to("cpu").to(torch.float16))
435
+
436
+ torch.save(weights, path)
437
+
438
+
439
+ def save_lora_as_json(model, path="./lora.json"):
440
+ weights = []
441
+ for _up, _down in extract_lora_ups_down(model):
442
+ weights.append(_up.weight.detach().cpu().numpy().tolist())
443
+ weights.append(_down.weight.detach().cpu().numpy().tolist())
444
+
445
+ import json
446
+
447
+ with open(path, "w") as f:
448
+ json.dump(weights, f)
449
+
450
+
451
+ def save_safeloras_with_embeds(
452
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
453
+ embeds: Dict[str, torch.Tensor] = {},
454
+ outpath="./lora.safetensors",
455
+ ):
456
+ """
457
+ Saves the Lora from multiple modules in a single safetensor file.
458
+
459
+ modelmap is a dictionary of {
460
+ "module name": (module, target_replace_module)
461
+ }
462
+ """
463
+ weights = {}
464
+ metadata = {}
465
+
466
+ for name, (model, target_replace_module) in modelmap.items():
467
+ metadata[name] = json.dumps(list(target_replace_module))
468
+
469
+ for i, (_up, _down) in enumerate(
470
+ extract_lora_as_tensor(model, target_replace_module)
471
+ ):
472
+ rank = _down.shape[0]
473
+
474
+ metadata[f"{name}:{i}:rank"] = str(rank)
475
+ weights[f"{name}:{i}:up"] = _up
476
+ weights[f"{name}:{i}:down"] = _down
477
+
478
+ for token, tensor in embeds.items():
479
+ metadata[token] = EMBED_FLAG
480
+ weights[token] = tensor
481
+
482
+ print(f"Saving weights to {outpath}")
483
+ safe_save(weights, outpath, metadata)
484
+
485
+
486
+ def save_safeloras(
487
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
488
+ outpath="./lora.safetensors",
489
+ ):
490
+ return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
491
+
492
+
493
+ def convert_loras_to_safeloras_with_embeds(
494
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
495
+ embeds: Dict[str, torch.Tensor] = {},
496
+ outpath="./lora.safetensors",
497
+ ):
498
+ """
499
+ Converts the Lora from multiple pytorch .pt files into a single safetensor file.
500
+
501
+ modelmap is a dictionary of {
502
+ "module name": (pytorch_model_path, target_replace_module, rank)
503
+ }
504
+ """
505
+
506
+ weights = {}
507
+ metadata = {}
508
+
509
+ for name, (path, target_replace_module, r) in modelmap.items():
510
+ metadata[name] = json.dumps(list(target_replace_module))
511
+
512
+ lora = torch.load(path)
513
+ for i, weight in enumerate(lora):
514
+ is_up = i % 2 == 0
515
+ i = i // 2
516
+
517
+ if is_up:
518
+ metadata[f"{name}:{i}:rank"] = str(r)
519
+ weights[f"{name}:{i}:up"] = weight
520
+ else:
521
+ weights[f"{name}:{i}:down"] = weight
522
+
523
+ for token, tensor in embeds.items():
524
+ metadata[token] = EMBED_FLAG
525
+ weights[token] = tensor
526
+
527
+ print(f"Saving weights to {outpath}")
528
+ safe_save(weights, outpath, metadata)
529
+
530
+
531
+ def convert_loras_to_safeloras(
532
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
533
+ outpath="./lora.safetensors",
534
+ ):
535
+ convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
536
+
537
+
538
+ def parse_safeloras(
539
+ safeloras,
540
+ ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
541
+ """
542
+ Converts a loaded safetensor file that contains a set of module Loras
543
+ into Parameters and other information
544
+
545
+ Output is a dictionary of {
546
+ "module name": (
547
+ [list of weights],
548
+ [list of ranks],
549
+ target_replacement_modules
550
+ )
551
+ }
552
+ """
553
+ loras = {}
554
+ metadata = safeloras.metadata()
555
+
556
+ get_name = lambda k: k.split(":")[0]
557
+
558
+ keys = list(safeloras.keys())
559
+ keys.sort(key=get_name)
560
+
561
+ for name, module_keys in groupby(keys, get_name):
562
+ info = metadata.get(name)
563
+
564
+ if not info:
565
+ raise ValueError(
566
+ f"Tensor {name} has no metadata - is this a Lora safetensor?"
567
+ )
568
+
569
+ # Skip Textual Inversion embeds
570
+ if info == EMBED_FLAG:
571
+ continue
572
+
573
+ # Handle Loras
574
+ # Extract the targets
575
+ target = json.loads(info)
576
+
577
+ # Build the result lists - Python needs us to preallocate lists to insert into them
578
+ module_keys = list(module_keys)
579
+ ranks = [4] * (len(module_keys) // 2)
580
+ weights = [None] * len(module_keys)
581
+
582
+ for key in module_keys:
583
+ # Split the model name and index out of the key
584
+ _, idx, direction = key.split(":")
585
+ idx = int(idx)
586
+
587
+ # Add the rank
588
+ ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
589
+
590
+ # Insert the weight into the list
591
+ idx = idx * 2 + (1 if direction == "down" else 0)
592
+ weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
593
+
594
+ loras[name] = (weights, ranks, target)
595
+
596
+ return loras
597
+
598
+
599
+ def parse_safeloras_embeds(
600
+ safeloras,
601
+ ) -> Dict[str, torch.Tensor]:
602
+ """
603
+ Converts a loaded safetensor file that contains Textual Inversion embeds into
604
+ a dictionary of embed_token: Tensor
605
+ """
606
+ embeds = {}
607
+ metadata = safeloras.metadata()
608
+
609
+ for key in safeloras.keys():
610
+ # Only handle Textual Inversion embeds
611
+ meta = metadata.get(key)
612
+ if not meta or meta != EMBED_FLAG:
613
+ continue
614
+
615
+ embeds[key] = safeloras.get_tensor(key)
616
+
617
+ return embeds
618
+
619
+
620
+ def load_safeloras(path, device="cpu"):
621
+ safeloras = safe_open(path, framework="pt", device=device)
622
+ return parse_safeloras(safeloras)
623
+
624
+
625
+ def load_safeloras_embeds(path, device="cpu"):
626
+ safeloras = safe_open(path, framework="pt", device=device)
627
+ return parse_safeloras_embeds(safeloras)
628
+
629
+
630
+ def load_safeloras_both(path, device="cpu"):
631
+ safeloras = safe_open(path, framework="pt", device=device)
632
+ return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
633
+
634
+
635
+ def collapse_lora(model, alpha=1.0):
636
+
637
+ for _module, name, _child_module in _find_modules(
638
+ model,
639
+ UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
640
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
641
+ ):
642
+
643
+ if isinstance(_child_module, LoraInjectedLinear):
644
+ print("Collapsing Lin Lora in", name)
645
+
646
+ _child_module.linear.weight = nn.Parameter(
647
+ _child_module.linear.weight.data
648
+ + alpha
649
+ * (
650
+ _child_module.lora_up.weight.data
651
+ @ _child_module.lora_down.weight.data
652
+ )
653
+ .type(_child_module.linear.weight.dtype)
654
+ .to(_child_module.linear.weight.device)
655
+ )
656
+
657
+ else:
658
+ print("Collapsing Conv Lora in", name)
659
+ _child_module.conv.weight = nn.Parameter(
660
+ _child_module.conv.weight.data
661
+ + alpha
662
+ * (
663
+ _child_module.lora_up.weight.data.flatten(start_dim=1)
664
+ @ _child_module.lora_down.weight.data.flatten(start_dim=1)
665
+ )
666
+ .reshape(_child_module.conv.weight.data.shape)
667
+ .type(_child_module.conv.weight.dtype)
668
+ .to(_child_module.conv.weight.device)
669
+ )
670
+
671
+
672
+ def monkeypatch_or_replace_lora(
673
+ model,
674
+ loras,
675
+ target_replace_module=DEFAULT_TARGET_REPLACE,
676
+ r: Union[int, List[int]] = 4,
677
+ ):
678
+ for _module, name, _child_module in _find_modules(
679
+ model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
680
+ ):
681
+ _source = (
682
+ _child_module.linear
683
+ if isinstance(_child_module, LoraInjectedLinear)
684
+ else _child_module
685
+ )
686
+
687
+ weight = _source.weight
688
+ bias = _source.bias
689
+ _tmp = LoraInjectedLinear(
690
+ _source.in_features,
691
+ _source.out_features,
692
+ _source.bias is not None,
693
+ r=r.pop(0) if isinstance(r, list) else r,
694
+ )
695
+ _tmp.linear.weight = weight
696
+
697
+ if bias is not None:
698
+ _tmp.linear.bias = bias
699
+
700
+ # switch the module
701
+ _module._modules[name] = _tmp
702
+
703
+ up_weight = loras.pop(0)
704
+ down_weight = loras.pop(0)
705
+
706
+ _module._modules[name].lora_up.weight = nn.Parameter(
707
+ up_weight.type(weight.dtype)
708
+ )
709
+ _module._modules[name].lora_down.weight = nn.Parameter(
710
+ down_weight.type(weight.dtype)
711
+ )
712
+
713
+ _module._modules[name].to(weight.device)
714
+
715
+
716
+ def monkeypatch_or_replace_lora_extended(
717
+ model,
718
+ loras,
719
+ target_replace_module=DEFAULT_TARGET_REPLACE,
720
+ r: Union[int, List[int]] = 4,
721
+ ):
722
+ for _module, name, _child_module in _find_modules(
723
+ model,
724
+ target_replace_module,
725
+ search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
726
+ ):
727
+
728
+ if (_child_module.__class__ == nn.Linear) or (
729
+ _child_module.__class__ == LoraInjectedLinear
730
+ ):
731
+ if len(loras[0].shape) != 2:
732
+ continue
733
+
734
+ _source = (
735
+ _child_module.linear
736
+ if isinstance(_child_module, LoraInjectedLinear)
737
+ else _child_module
738
+ )
739
+
740
+ weight = _source.weight
741
+ bias = _source.bias
742
+ _tmp = LoraInjectedLinear(
743
+ _source.in_features,
744
+ _source.out_features,
745
+ _source.bias is not None,
746
+ r=r.pop(0) if isinstance(r, list) else r,
747
+ )
748
+ _tmp.linear.weight = weight
749
+
750
+ if bias is not None:
751
+ _tmp.linear.bias = bias
752
+
753
+ elif (_child_module.__class__ == nn.Conv2d) or (
754
+ _child_module.__class__ == LoraInjectedConv2d
755
+ ):
756
+ if len(loras[0].shape) != 4:
757
+ continue
758
+ _source = (
759
+ _child_module.conv
760
+ if isinstance(_child_module, LoraInjectedConv2d)
761
+ else _child_module
762
+ )
763
+
764
+ weight = _source.weight
765
+ bias = _source.bias
766
+ _tmp = LoraInjectedConv2d(
767
+ _source.in_channels,
768
+ _source.out_channels,
769
+ _source.kernel_size,
770
+ _source.stride,
771
+ _source.padding,
772
+ _source.dilation,
773
+ _source.groups,
774
+ _source.bias is not None,
775
+ r=r.pop(0) if isinstance(r, list) else r,
776
+ )
777
+
778
+ _tmp.conv.weight = weight
779
+
780
+ if bias is not None:
781
+ _tmp.conv.bias = bias
782
+
783
+ # switch the module
784
+ _module._modules[name] = _tmp
785
+
786
+ up_weight = loras.pop(0)
787
+ down_weight = loras.pop(0)
788
+
789
+ _module._modules[name].lora_up.weight = nn.Parameter(
790
+ up_weight.type(weight.dtype)
791
+ )
792
+ _module._modules[name].lora_down.weight = nn.Parameter(
793
+ down_weight.type(weight.dtype)
794
+ )
795
+
796
+ _module._modules[name].to(weight.device)
797
+
798
+
799
+ def monkeypatch_or_replace_safeloras(models, safeloras):
800
+ loras = parse_safeloras(safeloras)
801
+
802
+ for name, (lora, ranks, target) in loras.items():
803
+ model = getattr(models, name, None)
804
+
805
+ if not model:
806
+ print(f"No model provided for {name}, contained in Lora")
807
+ continue
808
+
809
+ monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
810
+
811
+
812
+ def monkeypatch_remove_lora(model):
813
+ for _module, name, _child_module in _find_modules(
814
+ model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
815
+ ):
816
+ if isinstance(_child_module, LoraInjectedLinear):
817
+ _source = _child_module.linear
818
+ weight, bias = _source.weight, _source.bias
819
+
820
+ _tmp = nn.Linear(
821
+ _source.in_features, _source.out_features, bias is not None
822
+ )
823
+
824
+ _tmp.weight = weight
825
+ if bias is not None:
826
+ _tmp.bias = bias
827
+
828
+ else:
829
+ _source = _child_module.conv
830
+ weight, bias = _source.weight, _source.bias
831
+
832
+ _tmp = nn.Conv2d(
833
+ in_channels=_source.in_channels,
834
+ out_channels=_source.out_channels,
835
+ kernel_size=_source.kernel_size,
836
+ stride=_source.stride,
837
+ padding=_source.padding,
838
+ dilation=_source.dilation,
839
+ groups=_source.groups,
840
+ bias=bias is not None,
841
+ )
842
+
843
+ _tmp.weight = weight
844
+ if bias is not None:
845
+ _tmp.bias = bias
846
+
847
+ _module._modules[name] = _tmp
848
+
849
+
850
+ def monkeypatch_add_lora(
851
+ model,
852
+ loras,
853
+ target_replace_module=DEFAULT_TARGET_REPLACE,
854
+ alpha: float = 1.0,
855
+ beta: float = 1.0,
856
+ ):
857
+ for _module, name, _child_module in _find_modules(
858
+ model, target_replace_module, search_class=[LoraInjectedLinear]
859
+ ):
860
+ weight = _child_module.linear.weight
861
+
862
+ up_weight = loras.pop(0)
863
+ down_weight = loras.pop(0)
864
+
865
+ _module._modules[name].lora_up.weight = nn.Parameter(
866
+ up_weight.type(weight.dtype).to(weight.device) * alpha
867
+ + _module._modules[name].lora_up.weight.to(weight.device) * beta
868
+ )
869
+ _module._modules[name].lora_down.weight = nn.Parameter(
870
+ down_weight.type(weight.dtype).to(weight.device) * alpha
871
+ + _module._modules[name].lora_down.weight.to(weight.device) * beta
872
+ )
873
+
874
+ _module._modules[name].to(weight.device)
875
+
876
+
877
+ def tune_lora_scale(model, alpha: float = 1.0):
878
+ for _module in model.modules():
879
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
880
+ _module.scale = alpha
881
+
882
+
883
+ def set_lora_diag(model, diag: torch.Tensor):
884
+ for _module in model.modules():
885
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
886
+ _module.set_selector_from_diag(diag)
887
+
888
+
889
+ def _text_lora_path(path: str) -> str:
890
+ assert path.endswith(".pt"), "Only .pt files are supported"
891
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
892
+
893
+
894
+ def _ti_lora_path(path: str) -> str:
895
+ assert path.endswith(".pt"), "Only .pt files are supported"
896
+ return ".".join(path.split(".")[:-1] + ["ti", "pt"])
897
+
898
+
899
+ def apply_learned_embed_in_clip(
900
+ learned_embeds,
901
+ text_encoder,
902
+ tokenizer,
903
+ token: Optional[Union[str, List[str]]] = None,
904
+ idempotent=False,
905
+ ):
906
+ if isinstance(token, str):
907
+ trained_tokens = [token]
908
+ elif isinstance(token, list):
909
+ assert len(learned_embeds.keys()) == len(
910
+ token
911
+ ), "The number of tokens and the number of embeds should be the same"
912
+ trained_tokens = token
913
+ else:
914
+ trained_tokens = list(learned_embeds.keys())
915
+
916
+ for token in trained_tokens:
917
+ print(token)
918
+ embeds = learned_embeds[token]
919
+
920
+ # cast to dtype of text_encoder
921
+ dtype = text_encoder.get_input_embeddings().weight.dtype
922
+ num_added_tokens = tokenizer.add_tokens(token)
923
+
924
+ i = 1
925
+ if not idempotent:
926
+ while num_added_tokens == 0:
927
+ print(f"The tokenizer already contains the token {token}.")
928
+ token = f"{token[:-1]}-{i}>"
929
+ print(f"Attempting to add the token {token}.")
930
+ num_added_tokens = tokenizer.add_tokens(token)
931
+ i += 1
932
+ elif num_added_tokens == 0 and idempotent:
933
+ print(f"The tokenizer already contains the token {token}.")
934
+ print(f"Replacing {token} embedding.")
935
+
936
+ # resize the token embeddings
937
+ text_encoder.resize_token_embeddings(len(tokenizer))
938
+
939
+ # get the id for the token and assign the embeds
940
+ token_id = tokenizer.convert_tokens_to_ids(token)
941
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
942
+ return token
943
+
944
+
945
+ def load_learned_embed_in_clip(
946
+ learned_embeds_path,
947
+ text_encoder,
948
+ tokenizer,
949
+ token: Optional[Union[str, List[str]]] = None,
950
+ idempotent=False,
951
+ ):
952
+ learned_embeds = torch.load(learned_embeds_path)
953
+ apply_learned_embed_in_clip(
954
+ learned_embeds, text_encoder, tokenizer, token, idempotent
955
+ )
956
+
957
+
958
+ def patch_pipe(
959
+ pipe,
960
+ maybe_unet_path,
961
+ token: Optional[str] = None,
962
+ r: int = 4,
963
+ patch_unet=True,
964
+ patch_text=True,
965
+ patch_ti=True,
966
+ idempotent_token=True,
967
+ unet_target_replace_module=DEFAULT_TARGET_REPLACE,
968
+ text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
969
+ ):
970
+ if maybe_unet_path.endswith(".pt"):
971
+ # torch format
972
+
973
+ if maybe_unet_path.endswith(".ti.pt"):
974
+ unet_path = maybe_unet_path[:-6] + ".pt"
975
+ elif maybe_unet_path.endswith(".text_encoder.pt"):
976
+ unet_path = maybe_unet_path[:-16] + ".pt"
977
+ else:
978
+ unet_path = maybe_unet_path
979
+
980
+ ti_path = _ti_lora_path(unet_path)
981
+ text_path = _text_lora_path(unet_path)
982
+
983
+ if patch_unet:
984
+ print("LoRA : Patching Unet")
985
+ monkeypatch_or_replace_lora(
986
+ pipe.unet,
987
+ torch.load(unet_path),
988
+ r=r,
989
+ target_replace_module=unet_target_replace_module,
990
+ )
991
+
992
+ if patch_text:
993
+ print("LoRA : Patching text encoder")
994
+ monkeypatch_or_replace_lora(
995
+ pipe.text_encoder,
996
+ torch.load(text_path),
997
+ target_replace_module=text_target_replace_module,
998
+ r=r,
999
+ )
1000
+ if patch_ti:
1001
+ print("LoRA : Patching token input")
1002
+ token = load_learned_embed_in_clip(
1003
+ ti_path,
1004
+ pipe.text_encoder,
1005
+ pipe.tokenizer,
1006
+ token=token,
1007
+ idempotent=idempotent_token,
1008
+ )
1009
+
1010
+ elif maybe_unet_path.endswith(".safetensors"):
1011
+ safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
1012
+ monkeypatch_or_replace_safeloras(pipe, safeloras)
1013
+ tok_dict = parse_safeloras_embeds(safeloras)
1014
+ if patch_ti:
1015
+ apply_learned_embed_in_clip(
1016
+ tok_dict,
1017
+ pipe.text_encoder,
1018
+ pipe.tokenizer,
1019
+ token=token,
1020
+ idempotent=idempotent_token,
1021
+ )
1022
+ return tok_dict
1023
+
1024
+
1025
+ @torch.no_grad()
1026
+ def inspect_lora(model):
1027
+ moved = {}
1028
+
1029
+ for name, _module in model.named_modules():
1030
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
1031
+ ups = _module.lora_up.weight.data.clone()
1032
+ downs = _module.lora_down.weight.data.clone()
1033
+
1034
+ wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
1035
+
1036
+ dist = wght.flatten().abs().mean().item()
1037
+ if name in moved:
1038
+ moved[name].append(dist)
1039
+ else:
1040
+ moved[name] = [dist]
1041
+
1042
+ return moved
1043
+
1044
+
1045
+ def save_all(
1046
+ unet,
1047
+ text_encoder,
1048
+ save_path,
1049
+ placeholder_token_ids=None,
1050
+ placeholder_tokens=None,
1051
+ save_lora=True,
1052
+ save_ti=True,
1053
+ target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1054
+ target_replace_module_unet=DEFAULT_TARGET_REPLACE,
1055
+ safe_form=True,
1056
+ ):
1057
+ if not safe_form:
1058
+ # save ti
1059
+ if save_ti:
1060
+ ti_path = _ti_lora_path(save_path)
1061
+ learned_embeds_dict = {}
1062
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1063
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1064
+ print(
1065
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1066
+ learned_embeds[:4],
1067
+ )
1068
+ learned_embeds_dict[tok] = learned_embeds.detach().cpu()
1069
+
1070
+ torch.save(learned_embeds_dict, ti_path)
1071
+ print("Ti saved to ", ti_path)
1072
+
1073
+ # save text encoder
1074
+ if save_lora:
1075
+
1076
+ save_lora_weight(
1077
+ unet, save_path, target_replace_module=target_replace_module_unet
1078
+ )
1079
+ print("Unet saved to ", save_path)
1080
+
1081
+ save_lora_weight(
1082
+ text_encoder,
1083
+ _text_lora_path(save_path),
1084
+ target_replace_module=target_replace_module_text,
1085
+ )
1086
+ print("Text Encoder saved to ", _text_lora_path(save_path))
1087
+
1088
+ else:
1089
+ assert save_path.endswith(
1090
+ ".safetensors"
1091
+ ), f"Save path : {save_path} should end with .safetensors"
1092
+
1093
+ loras = {}
1094
+ embeds = {}
1095
+
1096
+ if save_lora:
1097
+
1098
+ loras["unet"] = (unet, target_replace_module_unet)
1099
+ loras["text_encoder"] = (text_encoder, target_replace_module_text)
1100
+
1101
+ if save_ti:
1102
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1103
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1104
+ print(
1105
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1106
+ learned_embeds[:4],
1107
+ )
1108
+ embeds[tok] = learned_embeds.detach().cpu()
1109
+
1110
+ save_safeloras_with_embeds(loras, embeds, save_path)
lora_diffusion/lora_manager.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+ from safetensors import safe_open
4
+ from diffusers import StableDiffusionPipeline
5
+ from .lora import (
6
+ monkeypatch_or_replace_safeloras,
7
+ apply_learned_embed_in_clip,
8
+ set_lora_diag,
9
+ parse_safeloras_embeds,
10
+ )
11
+
12
+
13
+ def lora_join(lora_safetenors: list):
14
+ metadatas = [dict(safelora.metadata()) for safelora in lora_safetenors]
15
+ _total_metadata = {}
16
+ total_metadata = {}
17
+ total_tensor = {}
18
+ total_rank = 0
19
+ ranklist = []
20
+ for _metadata in metadatas:
21
+ rankset = []
22
+ for k, v in _metadata.items():
23
+ if k.endswith("rank"):
24
+ rankset.append(int(v))
25
+
26
+ assert len(set(rankset)) <= 1, "Rank should be the same per model"
27
+ if len(rankset) == 0:
28
+ rankset = [0]
29
+
30
+ total_rank += rankset[0]
31
+ _total_metadata.update(_metadata)
32
+ ranklist.append(rankset[0])
33
+
34
+ # remove metadata about tokens
35
+ for k, v in _total_metadata.items():
36
+ if v != "<embed>":
37
+ total_metadata[k] = v
38
+
39
+ tensorkeys = set()
40
+ for safelora in lora_safetenors:
41
+ tensorkeys.update(safelora.keys())
42
+
43
+ for keys in tensorkeys:
44
+ if keys.startswith("text_encoder") or keys.startswith("unet"):
45
+ tensorset = [safelora.get_tensor(keys) for safelora in lora_safetenors]
46
+
47
+ is_down = keys.endswith("down")
48
+
49
+ if is_down:
50
+ _tensor = torch.cat(tensorset, dim=0)
51
+ assert _tensor.shape[0] == total_rank
52
+ else:
53
+ _tensor = torch.cat(tensorset, dim=1)
54
+ assert _tensor.shape[1] == total_rank
55
+
56
+ total_tensor[keys] = _tensor
57
+ keys_rank = ":".join(keys.split(":")[:-1]) + ":rank"
58
+ total_metadata[keys_rank] = str(total_rank)
59
+ token_size_list = []
60
+ for idx, safelora in enumerate(lora_safetenors):
61
+ tokens = [k for k, v in safelora.metadata().items() if v == "<embed>"]
62
+ for jdx, token in enumerate(sorted(tokens)):
63
+
64
+ total_tensor[f"<s{idx}-{jdx}>"] = safelora.get_tensor(token)
65
+ total_metadata[f"<s{idx}-{jdx}>"] = "<embed>"
66
+
67
+ print(f"Embedding {token} replaced to <s{idx}-{jdx}>")
68
+
69
+ token_size_list.append(len(tokens))
70
+
71
+ return total_tensor, total_metadata, ranklist, token_size_list
72
+
73
+
74
+ class DummySafeTensorObject:
75
+ def __init__(self, tensor: dict, metadata):
76
+ self.tensor = tensor
77
+ self._metadata = metadata
78
+
79
+ def keys(self):
80
+ return self.tensor.keys()
81
+
82
+ def metadata(self):
83
+ return self._metadata
84
+
85
+ def get_tensor(self, key):
86
+ return self.tensor[key]
87
+
88
+
89
+ class LoRAManager:
90
+ def __init__(self, lora_paths_list: List[str], pipe: StableDiffusionPipeline):
91
+
92
+ self.lora_paths_list = lora_paths_list
93
+ self.pipe = pipe
94
+ self._setup()
95
+
96
+ def _setup(self):
97
+
98
+ self._lora_safetenors = [
99
+ safe_open(path, framework="pt", device="cpu")
100
+ for path in self.lora_paths_list
101
+ ]
102
+
103
+ (
104
+ total_tensor,
105
+ total_metadata,
106
+ self.ranklist,
107
+ self.token_size_list,
108
+ ) = lora_join(self._lora_safetenors)
109
+
110
+ self.total_safelora = DummySafeTensorObject(total_tensor, total_metadata)
111
+
112
+ monkeypatch_or_replace_safeloras(self.pipe, self.total_safelora)
113
+ tok_dict = parse_safeloras_embeds(self.total_safelora)
114
+
115
+ apply_learned_embed_in_clip(
116
+ tok_dict,
117
+ self.pipe.text_encoder,
118
+ self.pipe.tokenizer,
119
+ token=None,
120
+ idempotent=True,
121
+ )
122
+
123
+ def tune(self, scales):
124
+
125
+ assert len(scales) == len(
126
+ self.ranklist
127
+ ), "Scale list should be the same length as ranklist"
128
+
129
+ diags = []
130
+ for scale, rank in zip(scales, self.ranklist):
131
+ diags = diags + [scale] * rank
132
+
133
+ set_lora_diag(self.pipe.unet, torch.tensor(diags))
134
+
135
+ def prompt(self, prompt):
136
+ if prompt is not None:
137
+ for idx, tok_size in enumerate(self.token_size_list):
138
+ prompt = prompt.replace(
139
+ f"<{idx + 1}>",
140
+ "".join([f"<s{idx}-{jdx}>" for jdx in range(tok_size)]),
141
+ )
142
+ # TODO : Rescale LoRA + Text inputs based on prompt scale params
143
+
144
+ return prompt
lora_diffusion/preprocess_files.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Have SwinIR upsample
2
+ # Have BLIP auto caption
3
+ # Have CLIPSeg auto mask concept
4
+
5
+ from typing import List, Literal, Union, Optional, Tuple
6
+ import os
7
+ from PIL import Image, ImageFilter
8
+ import torch
9
+ import numpy as np
10
+ import fire
11
+ from tqdm import tqdm
12
+ import glob
13
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
14
+
15
+
16
+ @torch.no_grad()
17
+ def swin_ir_sr(
18
+ images: List[Image.Image],
19
+ model_id: Literal[
20
+ "caidas/swin2SR-classical-sr-x2-64", "caidas/swin2SR-classical-sr-x4-48"
21
+ ] = "caidas/swin2SR-classical-sr-x2-64",
22
+ target_size: Optional[Tuple[int, int]] = None,
23
+ device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
24
+ **kwargs,
25
+ ) -> List[Image.Image]:
26
+ """
27
+ Upscales images using SwinIR. Returns a list of PIL images.
28
+ """
29
+ # So this is currently in main branch, so this can be used in the future I guess?
30
+ from transformers import Swin2SRForImageSuperResolution, Swin2SRImageProcessor
31
+
32
+ model = Swin2SRForImageSuperResolution.from_pretrained(
33
+ model_id,
34
+ ).to(device)
35
+ processor = Swin2SRImageProcessor()
36
+
37
+ out_images = []
38
+
39
+ for image in tqdm(images):
40
+
41
+ ori_w, ori_h = image.size
42
+ if target_size is not None:
43
+ if ori_w >= target_size[0] and ori_h >= target_size[1]:
44
+ out_images.append(image)
45
+ continue
46
+
47
+ inputs = processor(image, return_tensors="pt").to(device)
48
+ with torch.no_grad():
49
+ outputs = model(**inputs)
50
+
51
+ output = (
52
+ outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
53
+ )
54
+ output = np.moveaxis(output, source=0, destination=-1)
55
+ output = (output * 255.0).round().astype(np.uint8)
56
+ output = Image.fromarray(output)
57
+
58
+ out_images.append(output)
59
+
60
+ return out_images
61
+
62
+
63
+ @torch.no_grad()
64
+ def clipseg_mask_generator(
65
+ images: List[Image.Image],
66
+ target_prompts: Union[List[str], str],
67
+ model_id: Literal[
68
+ "CIDAS/clipseg-rd64-refined", "CIDAS/clipseg-rd16"
69
+ ] = "CIDAS/clipseg-rd64-refined",
70
+ device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
71
+ bias: float = 0.01,
72
+ temp: float = 1.0,
73
+ **kwargs,
74
+ ) -> List[Image.Image]:
75
+ """
76
+ Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image
77
+ """
78
+
79
+ if isinstance(target_prompts, str):
80
+ print(
81
+ f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images'
82
+ )
83
+
84
+ target_prompts = [target_prompts] * len(images)
85
+
86
+ processor = CLIPSegProcessor.from_pretrained(model_id)
87
+ model = CLIPSegForImageSegmentation.from_pretrained(model_id).to(device)
88
+
89
+ masks = []
90
+
91
+ for image, prompt in tqdm(zip(images, target_prompts)):
92
+
93
+ original_size = image.size
94
+
95
+ inputs = processor(
96
+ text=[prompt, ""],
97
+ images=[image] * 2,
98
+ padding="max_length",
99
+ truncation=True,
100
+ return_tensors="pt",
101
+ ).to(device)
102
+
103
+ outputs = model(**inputs)
104
+
105
+ logits = outputs.logits
106
+ probs = torch.nn.functional.softmax(logits / temp, dim=0)[0]
107
+ probs = (probs + bias).clamp_(0, 1)
108
+ probs = 255 * probs / probs.max()
109
+
110
+ # make mask greyscale
111
+ mask = Image.fromarray(probs.cpu().numpy()).convert("L")
112
+
113
+ # resize mask to original size
114
+ mask = mask.resize(original_size)
115
+
116
+ masks.append(mask)
117
+
118
+ return masks
119
+
120
+
121
+ @torch.no_grad()
122
+ def blip_captioning_dataset(
123
+ images: List[Image.Image],
124
+ text: Optional[str] = None,
125
+ model_id: Literal[
126
+ "Salesforce/blip-image-captioning-large",
127
+ "Salesforce/blip-image-captioning-base",
128
+ ] = "Salesforce/blip-image-captioning-large",
129
+ device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
130
+ **kwargs,
131
+ ) -> List[str]:
132
+ """
133
+ Returns a list of captions for the given images
134
+ """
135
+
136
+ from transformers import BlipProcessor, BlipForConditionalGeneration
137
+
138
+ processor = BlipProcessor.from_pretrained(model_id)
139
+ model = BlipForConditionalGeneration.from_pretrained(model_id).to(device)
140
+ captions = []
141
+
142
+ for image in tqdm(images):
143
+ inputs = processor(image, text=text, return_tensors="pt").to("cuda")
144
+ out = model.generate(
145
+ **inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7
146
+ )
147
+ caption = processor.decode(out[0], skip_special_tokens=True)
148
+
149
+ captions.append(caption)
150
+
151
+ return captions
152
+
153
+
154
+ def face_mask_google_mediapipe(
155
+ images: List[Image.Image], blur_amount: float = 80.0, bias: float = 0.05
156
+ ) -> List[Image.Image]:
157
+ """
158
+ Returns a list of images with mask on the face parts.
159
+ """
160
+ import mediapipe as mp
161
+
162
+ mp_face_detection = mp.solutions.face_detection
163
+
164
+ face_detection = mp_face_detection.FaceDetection(
165
+ model_selection=1, min_detection_confidence=0.5
166
+ )
167
+
168
+ masks = []
169
+ for image in tqdm(images):
170
+
171
+ image = np.array(image)
172
+
173
+ results = face_detection.process(image)
174
+ black_image = np.ones((image.shape[0], image.shape[1]), dtype=np.uint8)
175
+
176
+ if results.detections:
177
+
178
+ for detection in results.detections:
179
+
180
+ x_min = int(
181
+ detection.location_data.relative_bounding_box.xmin * image.shape[1]
182
+ )
183
+ y_min = int(
184
+ detection.location_data.relative_bounding_box.ymin * image.shape[0]
185
+ )
186
+ width = int(
187
+ detection.location_data.relative_bounding_box.width * image.shape[1]
188
+ )
189
+ height = int(
190
+ detection.location_data.relative_bounding_box.height
191
+ * image.shape[0]
192
+ )
193
+
194
+ # draw the colored rectangle
195
+ black_image[y_min : y_min + height, x_min : x_min + width] = 255
196
+
197
+ black_image = Image.fromarray(black_image)
198
+ masks.append(black_image)
199
+
200
+ return masks
201
+
202
+
203
+ def _crop_to_square(
204
+ image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None
205
+ ):
206
+ cx, cy = com
207
+ width, height = image.size
208
+ if width > height:
209
+ left_possible = max(cx - height / 2, 0)
210
+ left = min(left_possible, width - height)
211
+ right = left + height
212
+ top = 0
213
+ bottom = height
214
+ else:
215
+ left = 0
216
+ right = width
217
+ top_possible = max(cy - width / 2, 0)
218
+ top = min(top_possible, height - width)
219
+ bottom = top + width
220
+
221
+ image = image.crop((left, top, right, bottom))
222
+
223
+ if resize_to:
224
+ image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS)
225
+
226
+ return image
227
+
228
+
229
+ def _center_of_mass(mask: Image.Image):
230
+ """
231
+ Returns the center of mass of the mask
232
+ """
233
+ x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1]))
234
+
235
+ x_ = x * np.array(mask)
236
+ y_ = y * np.array(mask)
237
+
238
+ x = np.sum(x_) / np.sum(mask)
239
+ y = np.sum(y_) / np.sum(mask)
240
+
241
+ return x, y
242
+
243
+
244
+ def load_and_save_masks_and_captions(
245
+ files: Union[str, List[str]],
246
+ output_dir: str,
247
+ caption_text: Optional[str] = None,
248
+ target_prompts: Optional[Union[List[str], str]] = None,
249
+ target_size: int = 512,
250
+ crop_based_on_salience: bool = True,
251
+ use_face_detection_instead: bool = False,
252
+ temp: float = 1.0,
253
+ n_length: int = -1,
254
+ ):
255
+ """
256
+ Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images
257
+ to output dir.
258
+ """
259
+ os.makedirs(output_dir, exist_ok=True)
260
+
261
+ # load images
262
+ if isinstance(files, str):
263
+ # check if it is a directory
264
+ if os.path.isdir(files):
265
+ # get all the .png .jpg in the directory
266
+ files = glob.glob(os.path.join(files, "*.png")) + glob.glob(
267
+ os.path.join(files, "*.jpg")
268
+ )
269
+
270
+ if len(files) == 0:
271
+ raise Exception(
272
+ f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg files."
273
+ )
274
+ if n_length == -1:
275
+ n_length = len(files)
276
+ files = sorted(files)[:n_length]
277
+
278
+ images = [Image.open(file) for file in files]
279
+
280
+ # captions
281
+ print(f"Generating {len(images)} captions...")
282
+ captions = blip_captioning_dataset(images, text=caption_text)
283
+
284
+ if target_prompts is None:
285
+ target_prompts = captions
286
+
287
+ print(f"Generating {len(images)} masks...")
288
+ if not use_face_detection_instead:
289
+ seg_masks = clipseg_mask_generator(
290
+ images=images, target_prompts=target_prompts, temp=temp
291
+ )
292
+ else:
293
+ seg_masks = face_mask_google_mediapipe(images=images)
294
+
295
+ # find the center of mass of the mask
296
+ if crop_based_on_salience:
297
+ coms = [_center_of_mass(mask) for mask in seg_masks]
298
+ else:
299
+ coms = [(image.size[0] / 2, image.size[1] / 2) for image in images]
300
+ # based on the center of mass, crop the image to a square
301
+ images = [
302
+ _crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms)
303
+ ]
304
+
305
+ print(f"Upscaling {len(images)} images...")
306
+ # upscale images anyways
307
+ images = swin_ir_sr(images, target_size=(target_size, target_size))
308
+ images = [
309
+ image.resize((target_size, target_size), Image.Resampling.LANCZOS)
310
+ for image in images
311
+ ]
312
+
313
+ seg_masks = [
314
+ _crop_to_square(mask, com, resize_to=target_size)
315
+ for mask, com in zip(seg_masks, coms)
316
+ ]
317
+ with open(os.path.join(output_dir, "caption.txt"), "w") as f:
318
+ # save images and masks
319
+ for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)):
320
+ image.save(os.path.join(output_dir, f"{idx}.src.jpg"), quality=99)
321
+ mask.save(os.path.join(output_dir, f"{idx}.mask.png"))
322
+
323
+ f.write(caption + "\n")
324
+
325
+
326
+ def main():
327
+ fire.Fire(load_and_save_masks_and_captions)
lora_diffusion/safe_open.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pure python version of Safetensors safe_open
3
+ From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282
4
+ """
5
+
6
+ import json
7
+ import mmap
8
+ import os
9
+
10
+ import torch
11
+
12
+
13
+ class SafetensorsWrapper:
14
+ def __init__(self, metadata, tensors):
15
+ self._metadata = metadata
16
+ self._tensors = tensors
17
+
18
+ def metadata(self):
19
+ return self._metadata
20
+
21
+ def keys(self):
22
+ return self._tensors.keys()
23
+
24
+ def get_tensor(self, k):
25
+ return self._tensors[k]
26
+
27
+
28
+ DTYPES = {
29
+ "F32": torch.float32,
30
+ "F16": torch.float16,
31
+ "BF16": torch.bfloat16,
32
+ }
33
+
34
+
35
+ def create_tensor(storage, info, offset):
36
+ dtype = DTYPES[info["dtype"]]
37
+ shape = info["shape"]
38
+ start, stop = info["data_offsets"]
39
+ return (
40
+ torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8)
41
+ .view(dtype=dtype)
42
+ .reshape(shape)
43
+ )
44
+
45
+
46
+ def safe_open(filename, framework="pt", device="cpu"):
47
+ if framework != "pt":
48
+ raise ValueError("`framework` must be 'pt'")
49
+
50
+ with open(filename, mode="r", encoding="utf8") as file_obj:
51
+ with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m:
52
+ header = m.read(8)
53
+ n = int.from_bytes(header, "little")
54
+ metadata_bytes = m.read(n)
55
+ metadata = json.loads(metadata_bytes)
56
+
57
+ size = os.stat(filename).st_size
58
+ storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped()
59
+ offset = n + 8
60
+
61
+ return SafetensorsWrapper(
62
+ metadata=metadata.get("__metadata__", {}),
63
+ tensors={
64
+ name: create_tensor(storage, info, offset).to(device)
65
+ for name, info in metadata.items()
66
+ if name != "__metadata__"
67
+ },
68
+ )
lora_diffusion/to_ckpt_v2.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from https://gist.github.com/jachiam/8a5c0b607e38fcc585168b90c686eb05
2
+ # Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
3
+ # *Only* converts the UNet, VAE, and Text Encoder.
4
+ # Does not convert optimizer state or any other thing.
5
+ # Written by jachiam
6
+ import argparse
7
+ import os.path as osp
8
+
9
+ import torch
10
+
11
+
12
+ # =================#
13
+ # UNet Conversion #
14
+ # =================#
15
+
16
+ unet_conversion_map = [
17
+ # (stable-diffusion, HF Diffusers)
18
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
19
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
20
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
21
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
22
+ ("input_blocks.0.0.weight", "conv_in.weight"),
23
+ ("input_blocks.0.0.bias", "conv_in.bias"),
24
+ ("out.0.weight", "conv_norm_out.weight"),
25
+ ("out.0.bias", "conv_norm_out.bias"),
26
+ ("out.2.weight", "conv_out.weight"),
27
+ ("out.2.bias", "conv_out.bias"),
28
+ ]
29
+
30
+ unet_conversion_map_resnet = [
31
+ # (stable-diffusion, HF Diffusers)
32
+ ("in_layers.0", "norm1"),
33
+ ("in_layers.2", "conv1"),
34
+ ("out_layers.0", "norm2"),
35
+ ("out_layers.3", "conv2"),
36
+ ("emb_layers.1", "time_emb_proj"),
37
+ ("skip_connection", "conv_shortcut"),
38
+ ]
39
+
40
+ unet_conversion_map_layer = []
41
+ # hardcoded number of downblocks and resnets/attentions...
42
+ # would need smarter logic for other networks.
43
+ for i in range(4):
44
+ # loop over downblocks/upblocks
45
+
46
+ for j in range(2):
47
+ # loop over resnets/attentions for downblocks
48
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
49
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
50
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
51
+
52
+ if i < 3:
53
+ # no attention layers in down_blocks.3
54
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
55
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
56
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
57
+
58
+ for j in range(3):
59
+ # loop over resnets/attentions for upblocks
60
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
61
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
62
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
63
+
64
+ if i > 0:
65
+ # no attention layers in up_blocks.0
66
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
67
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
68
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
69
+
70
+ if i < 3:
71
+ # no downsample in down_blocks.3
72
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
73
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
74
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
75
+
76
+ # no upsample in up_blocks.3
77
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
78
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
79
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
80
+
81
+ hf_mid_atn_prefix = "mid_block.attentions.0."
82
+ sd_mid_atn_prefix = "middle_block.1."
83
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
84
+
85
+ for j in range(2):
86
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
87
+ sd_mid_res_prefix = f"middle_block.{2*j}."
88
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
89
+
90
+
91
+ def convert_unet_state_dict(unet_state_dict):
92
+ # buyer beware: this is a *brittle* function,
93
+ # and correct output requires that all of these pieces interact in
94
+ # the exact order in which I have arranged them.
95
+ mapping = {k: k for k in unet_state_dict.keys()}
96
+ for sd_name, hf_name in unet_conversion_map:
97
+ mapping[hf_name] = sd_name
98
+ for k, v in mapping.items():
99
+ if "resnets" in k:
100
+ for sd_part, hf_part in unet_conversion_map_resnet:
101
+ v = v.replace(hf_part, sd_part)
102
+ mapping[k] = v
103
+ for k, v in mapping.items():
104
+ for sd_part, hf_part in unet_conversion_map_layer:
105
+ v = v.replace(hf_part, sd_part)
106
+ mapping[k] = v
107
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
108
+ return new_state_dict
109
+
110
+
111
+ # ================#
112
+ # VAE Conversion #
113
+ # ================#
114
+
115
+ vae_conversion_map = [
116
+ # (stable-diffusion, HF Diffusers)
117
+ ("nin_shortcut", "conv_shortcut"),
118
+ ("norm_out", "conv_norm_out"),
119
+ ("mid.attn_1.", "mid_block.attentions.0."),
120
+ ]
121
+
122
+ for i in range(4):
123
+ # down_blocks have two resnets
124
+ for j in range(2):
125
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
126
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
127
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
128
+
129
+ if i < 3:
130
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
131
+ sd_downsample_prefix = f"down.{i}.downsample."
132
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
133
+
134
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
135
+ sd_upsample_prefix = f"up.{3-i}.upsample."
136
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
137
+
138
+ # up_blocks have three resnets
139
+ # also, up blocks in hf are numbered in reverse from sd
140
+ for j in range(3):
141
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
142
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
143
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
144
+
145
+ # this part accounts for mid blocks in both the encoder and the decoder
146
+ for i in range(2):
147
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
148
+ sd_mid_res_prefix = f"mid.block_{i+1}."
149
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
150
+
151
+
152
+ vae_conversion_map_attn = [
153
+ # (stable-diffusion, HF Diffusers)
154
+ ("norm.", "group_norm."),
155
+ ("q.", "query."),
156
+ ("k.", "key."),
157
+ ("v.", "value."),
158
+ ("proj_out.", "proj_attn."),
159
+ ]
160
+
161
+
162
+ def reshape_weight_for_sd(w):
163
+ # convert HF linear weights to SD conv2d weights
164
+ return w.reshape(*w.shape, 1, 1)
165
+
166
+
167
+ def convert_vae_state_dict(vae_state_dict):
168
+ mapping = {k: k for k in vae_state_dict.keys()}
169
+ for k, v in mapping.items():
170
+ for sd_part, hf_part in vae_conversion_map:
171
+ v = v.replace(hf_part, sd_part)
172
+ mapping[k] = v
173
+ for k, v in mapping.items():
174
+ if "attentions" in k:
175
+ for sd_part, hf_part in vae_conversion_map_attn:
176
+ v = v.replace(hf_part, sd_part)
177
+ mapping[k] = v
178
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
179
+ weights_to_convert = ["q", "k", "v", "proj_out"]
180
+ for k, v in new_state_dict.items():
181
+ for weight_name in weights_to_convert:
182
+ if f"mid.attn_1.{weight_name}.weight" in k:
183
+ print(f"Reshaping {k} for SD format")
184
+ new_state_dict[k] = reshape_weight_for_sd(v)
185
+ return new_state_dict
186
+
187
+
188
+ # =========================#
189
+ # Text Encoder Conversion #
190
+ # =========================#
191
+ # pretty much a no-op
192
+
193
+
194
+ def convert_text_enc_state_dict(text_enc_dict):
195
+ return text_enc_dict
196
+
197
+
198
+ def convert_to_ckpt(model_path, checkpoint_path, as_half):
199
+
200
+ assert model_path is not None, "Must provide a model path!"
201
+
202
+ assert checkpoint_path is not None, "Must provide a checkpoint path!"
203
+
204
+ unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin")
205
+ vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin")
206
+ text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
207
+
208
+ # Convert the UNet model
209
+ unet_state_dict = torch.load(unet_path, map_location="cpu")
210
+ unet_state_dict = convert_unet_state_dict(unet_state_dict)
211
+ unet_state_dict = {
212
+ "model.diffusion_model." + k: v for k, v in unet_state_dict.items()
213
+ }
214
+
215
+ # Convert the VAE model
216
+ vae_state_dict = torch.load(vae_path, map_location="cpu")
217
+ vae_state_dict = convert_vae_state_dict(vae_state_dict)
218
+ vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
219
+
220
+ # Convert the text encoder model
221
+ text_enc_dict = torch.load(text_enc_path, map_location="cpu")
222
+ text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
223
+ text_enc_dict = {
224
+ "cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()
225
+ }
226
+
227
+ # Put together new checkpoint
228
+ state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
229
+ if as_half:
230
+ state_dict = {k: v.half() for k, v in state_dict.items()}
231
+ state_dict = {"state_dict": state_dict}
232
+ torch.save(state_dict, checkpoint_path)
lora_diffusion/utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import (
6
+ CLIPProcessor,
7
+ CLIPTextModelWithProjection,
8
+ CLIPTokenizer,
9
+ CLIPVisionModelWithProjection,
10
+ )
11
+
12
+ from diffusers import StableDiffusionPipeline
13
+ from .lora import patch_pipe, tune_lora_scale, _text_lora_path, _ti_lora_path
14
+ import os
15
+ import glob
16
+ import math
17
+
18
+ EXAMPLE_PROMPTS = [
19
+ "<obj> swimming in a pool",
20
+ "<obj> at a beach with a view of seashore",
21
+ "<obj> in times square",
22
+ "<obj> wearing sunglasses",
23
+ "<obj> in a construction outfit",
24
+ "<obj> playing with a ball",
25
+ "<obj> wearing headphones",
26
+ "<obj> oil painting ghibli inspired",
27
+ "<obj> working on the laptop",
28
+ "<obj> with mountains and sunset in background",
29
+ "Painting of <obj> at a beach by artist claude monet",
30
+ "<obj> digital painting 3d render geometric style",
31
+ "A screaming <obj>",
32
+ "A depressed <obj>",
33
+ "A sleeping <obj>",
34
+ "A sad <obj>",
35
+ "A joyous <obj>",
36
+ "A frowning <obj>",
37
+ "A sculpture of <obj>",
38
+ "<obj> near a pool",
39
+ "<obj> at a beach with a view of seashore",
40
+ "<obj> in a garden",
41
+ "<obj> in grand canyon",
42
+ "<obj> floating in ocean",
43
+ "<obj> and an armchair",
44
+ "A maple tree on the side of <obj>",
45
+ "<obj> and an orange sofa",
46
+ "<obj> with chocolate cake on it",
47
+ "<obj> with a vase of rose flowers on it",
48
+ "A digital illustration of <obj>",
49
+ "Georgia O'Keeffe style <obj> painting",
50
+ "A watercolor painting of <obj> on a beach",
51
+ ]
52
+
53
+
54
+ def image_grid(_imgs, rows=None, cols=None):
55
+
56
+ if rows is None and cols is None:
57
+ rows = cols = math.ceil(len(_imgs) ** 0.5)
58
+
59
+ if rows is None:
60
+ rows = math.ceil(len(_imgs) / cols)
61
+ if cols is None:
62
+ cols = math.ceil(len(_imgs) / rows)
63
+
64
+ w, h = _imgs[0].size
65
+ grid = Image.new("RGB", size=(cols * w, rows * h))
66
+ grid_w, grid_h = grid.size
67
+
68
+ for i, img in enumerate(_imgs):
69
+ grid.paste(img, box=(i % cols * w, i // cols * h))
70
+ return grid
71
+
72
+
73
+ def text_img_alignment(img_embeds, text_embeds, target_img_embeds):
74
+ # evaluation inspired from textual inversion paper
75
+ # https://arxiv.org/abs/2208.01618
76
+
77
+ # text alignment
78
+ assert img_embeds.shape[0] == text_embeds.shape[0]
79
+ text_img_sim = (img_embeds * text_embeds).sum(dim=-1) / (
80
+ img_embeds.norm(dim=-1) * text_embeds.norm(dim=-1)
81
+ )
82
+
83
+ # image alignment
84
+ img_embed_normalized = img_embeds / img_embeds.norm(dim=-1, keepdim=True)
85
+
86
+ avg_target_img_embed = (
87
+ (target_img_embeds / target_img_embeds.norm(dim=-1, keepdim=True))
88
+ .mean(dim=0)
89
+ .unsqueeze(0)
90
+ .repeat(img_embeds.shape[0], 1)
91
+ )
92
+
93
+ img_img_sim = (img_embed_normalized * avg_target_img_embed).sum(dim=-1)
94
+
95
+ return {
96
+ "text_alignment_avg": text_img_sim.mean().item(),
97
+ "image_alignment_avg": img_img_sim.mean().item(),
98
+ "text_alignment_all": text_img_sim.tolist(),
99
+ "image_alignment_all": img_img_sim.tolist(),
100
+ }
101
+
102
+
103
+ def prepare_clip_model_sets(eval_clip_id: str = "openai/clip-vit-large-patch14"):
104
+ text_model = CLIPTextModelWithProjection.from_pretrained(eval_clip_id)
105
+ tokenizer = CLIPTokenizer.from_pretrained(eval_clip_id)
106
+ vis_model = CLIPVisionModelWithProjection.from_pretrained(eval_clip_id)
107
+ processor = CLIPProcessor.from_pretrained(eval_clip_id)
108
+
109
+ return text_model, tokenizer, vis_model, processor
110
+
111
+
112
+ def evaluate_pipe(
113
+ pipe,
114
+ target_images: List[Image.Image],
115
+ class_token: str = "",
116
+ learnt_token: str = "",
117
+ guidance_scale: float = 5.0,
118
+ seed=0,
119
+ clip_model_sets=None,
120
+ eval_clip_id: str = "openai/clip-vit-large-patch14",
121
+ n_test: int = 10,
122
+ n_step: int = 50,
123
+ ):
124
+
125
+ if clip_model_sets is not None:
126
+ text_model, tokenizer, vis_model, processor = clip_model_sets
127
+ else:
128
+ text_model, tokenizer, vis_model, processor = prepare_clip_model_sets(
129
+ eval_clip_id
130
+ )
131
+
132
+ images = []
133
+ img_embeds = []
134
+ text_embeds = []
135
+ for prompt in EXAMPLE_PROMPTS[:n_test]:
136
+ prompt = prompt.replace("<obj>", learnt_token)
137
+ torch.manual_seed(seed)
138
+ with torch.autocast("cuda"):
139
+ img = pipe(
140
+ prompt, num_inference_steps=n_step, guidance_scale=guidance_scale
141
+ ).images[0]
142
+ images.append(img)
143
+
144
+ # image
145
+ inputs = processor(images=img, return_tensors="pt")
146
+ img_embed = vis_model(**inputs).image_embeds
147
+ img_embeds.append(img_embed)
148
+
149
+ prompt = prompt.replace(learnt_token, class_token)
150
+ # prompts
151
+ inputs = tokenizer([prompt], padding=True, return_tensors="pt")
152
+ outputs = text_model(**inputs)
153
+ text_embed = outputs.text_embeds
154
+ text_embeds.append(text_embed)
155
+
156
+ # target images
157
+ inputs = processor(images=target_images, return_tensors="pt")
158
+ target_img_embeds = vis_model(**inputs).image_embeds
159
+
160
+ img_embeds = torch.cat(img_embeds, dim=0)
161
+ text_embeds = torch.cat(text_embeds, dim=0)
162
+
163
+ return text_img_alignment(img_embeds, text_embeds, target_img_embeds)
164
+
165
+
166
+ def visualize_progress(
167
+ path_alls: Union[str, List[str]],
168
+ prompt: str,
169
+ model_id: str = "runwayml/stable-diffusion-v1-5",
170
+ device="cuda:0",
171
+ patch_unet=True,
172
+ patch_text=True,
173
+ patch_ti=True,
174
+ unet_scale=1.0,
175
+ text_sclae=1.0,
176
+ num_inference_steps=50,
177
+ guidance_scale=5.0,
178
+ offset: int = 0,
179
+ limit: int = 10,
180
+ seed: int = 0,
181
+ ):
182
+
183
+ imgs = []
184
+ if isinstance(path_alls, str):
185
+ alls = list(set(glob.glob(path_alls)))
186
+
187
+ alls.sort(key=os.path.getmtime)
188
+ else:
189
+ alls = path_alls
190
+
191
+ pipe = StableDiffusionPipeline.from_pretrained(
192
+ model_id, torch_dtype=torch.float16
193
+ ).to(device)
194
+
195
+ print(f"Found {len(alls)} checkpoints")
196
+ for path in alls[offset:limit]:
197
+ print(path)
198
+
199
+ patch_pipe(
200
+ pipe, path, patch_unet=patch_unet, patch_text=patch_text, patch_ti=patch_ti
201
+ )
202
+
203
+ tune_lora_scale(pipe.unet, unet_scale)
204
+ tune_lora_scale(pipe.text_encoder, text_sclae)
205
+
206
+ torch.manual_seed(seed)
207
+ image = pipe(
208
+ prompt,
209
+ num_inference_steps=num_inference_steps,
210
+ guidance_scale=guidance_scale,
211
+ ).images[0]
212
+ imgs.append(image)
213
+
214
+ return imgs
lora_diffusion/xformers_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ from diffusers.models.attention import BasicTransformerBlock
5
+ from diffusers.utils.import_utils import is_xformers_available
6
+
7
+ from .lora import LoraInjectedLinear
8
+
9
+ if is_xformers_available():
10
+ import xformers
11
+ import xformers.ops
12
+ else:
13
+ xformers = None
14
+
15
+
16
+ @functools.cache
17
+ def test_xformers_backwards(size):
18
+ @torch.enable_grad()
19
+ def _grad(size):
20
+ q = torch.randn((1, 4, size), device="cuda")
21
+ k = torch.randn((1, 4, size), device="cuda")
22
+ v = torch.randn((1, 4, size), device="cuda")
23
+
24
+ q = q.detach().requires_grad_()
25
+ k = k.detach().requires_grad_()
26
+ v = v.detach().requires_grad_()
27
+
28
+ out = xformers.ops.memory_efficient_attention(q, k, v)
29
+ loss = out.sum(2).mean(0).sum()
30
+
31
+ return torch.autograd.grad(loss, v)
32
+
33
+ try:
34
+ _grad(size)
35
+ print(size, "pass")
36
+ return True
37
+ except Exception as e:
38
+ print(size, "fail")
39
+ return False
40
+
41
+
42
+ def set_use_memory_efficient_attention_xformers(
43
+ module: torch.nn.Module, valid: bool
44
+ ) -> None:
45
+ def fn_test_dim_head(module: torch.nn.Module):
46
+ if isinstance(module, BasicTransformerBlock):
47
+ # dim_head isn't stored anywhere, so back-calculate
48
+ source = module.attn1.to_v
49
+ if isinstance(source, LoraInjectedLinear):
50
+ source = source.linear
51
+
52
+ dim_head = source.out_features // module.attn1.heads
53
+
54
+ result = test_xformers_backwards(dim_head)
55
+
56
+ # If dim_head > dim_head_max, turn xformers off
57
+ if not result:
58
+ module.set_use_memory_efficient_attention_xformers(False)
59
+
60
+ for child in module.children():
61
+ fn_test_dim_head(child)
62
+
63
+ if not is_xformers_available() and valid:
64
+ print("XFormers is not available. Skipping.")
65
+ return
66
+
67
+ module.set_use_memory_efficient_attention_xformers(valid)
68
+
69
+ if valid:
70
+ fn_test_dim_head(module)
scene/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import os
13
+ import random
14
+ import json
15
+ from utils.system_utils import searchForMaxIteration
16
+ from scene.dataset_readers import sceneLoadTypeCallbacks,GenerateRandomCameras,GeneratePurnCameras,GenerateCircleCameras
17
+ from scene.gaussian_model import GaussianModel
18
+ from arguments import ModelParams, GenerateCamParams
19
+ from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON, cameraList_from_RcamInfos
20
+
21
+ class Scene:
22
+
23
+ gaussians : GaussianModel
24
+
25
+ def __init__(self, args : ModelParams, pose_args : GenerateCamParams, gaussians : GaussianModel, load_iteration=None, shuffle=False, resolution_scales=[1.0]):
26
+ """b
27
+ :param path: Path to colmap scene main folder.
28
+ """
29
+ self.model_path = args._model_path
30
+ self.pretrained_model_path = args.pretrained_model_path
31
+ self.loaded_iter = None
32
+ self.gaussians = gaussians
33
+ self.resolution_scales = resolution_scales
34
+ self.pose_args = pose_args
35
+ self.args = args
36
+ if load_iteration:
37
+ if load_iteration == -1:
38
+ self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
39
+ else:
40
+ self.loaded_iter = load_iteration
41
+ print("Loading trained model at iteration {}".format(self.loaded_iter))
42
+
43
+ self.test_cameras = {}
44
+ scene_info = sceneLoadTypeCallbacks["RandomCam"](self.model_path ,pose_args)
45
+
46
+ json_cams = []
47
+ camlist = []
48
+ if scene_info.test_cameras:
49
+ camlist.extend(scene_info.test_cameras)
50
+ for id, cam in enumerate(camlist):
51
+ json_cams.append(camera_to_JSON(id, cam))
52
+ with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
53
+ json.dump(json_cams, file)
54
+
55
+ if shuffle:
56
+ random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
57
+ self.cameras_extent = pose_args.default_radius # scene_info.nerf_normalization["radius"]
58
+ for resolution_scale in resolution_scales:
59
+ self.test_cameras[resolution_scale] = cameraList_from_RcamInfos(scene_info.test_cameras, resolution_scale, self.pose_args)
60
+ if self.loaded_iter:
61
+ self.gaussians.load_ply(os.path.join(self.model_path,
62
+ "point_cloud",
63
+ "iteration_" + str(self.loaded_iter),
64
+ "point_cloud.ply"))
65
+ elif self.pretrained_model_path is not None:
66
+ self.gaussians.load_ply(self.pretrained_model_path)
67
+ else:
68
+ self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
69
+
70
+ def save(self, iteration):
71
+ point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
72
+ self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
73
+
74
+ def getRandTrainCameras(self, scale=1.0):
75
+ rand_train_cameras = GenerateRandomCameras(self.pose_args, self.args.batch, SSAA=True)
76
+ train_cameras = {}
77
+ for resolution_scale in self.resolution_scales:
78
+ train_cameras[resolution_scale] = cameraList_from_RcamInfos(rand_train_cameras, resolution_scale, self.pose_args, SSAA=True)
79
+ return train_cameras[scale]
80
+
81
+
82
+ def getPurnTrainCameras(self, scale=1.0):
83
+ rand_train_cameras = GeneratePurnCameras(self.pose_args)
84
+ train_cameras = {}
85
+ for resolution_scale in self.resolution_scales:
86
+ train_cameras[resolution_scale] = cameraList_from_RcamInfos(rand_train_cameras, resolution_scale, self.pose_args)
87
+ return train_cameras[scale]
88
+
89
+
90
+ def getTestCameras(self, scale=1.0):
91
+ return self.test_cameras[scale]
92
+
93
+ def getCircleVideoCameras(self, scale=1.0,batch_size=120, render45 = True):
94
+ video_circle_cameras = GenerateCircleCameras(self.pose_args,batch_size,render45)
95
+ video_cameras = {}
96
+ for resolution_scale in self.resolution_scales:
97
+ video_cameras[resolution_scale] = cameraList_from_RcamInfos(video_circle_cameras, resolution_scale, self.pose_args)
98
+ return video_cameras[scale]
scene/cameras.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ from torch import nn
14
+ import numpy as np
15
+ from utils.graphics_utils import getWorld2View2, getProjectionMatrix, fov2focal
16
+
17
+ def get_rays_torch(focal, c2w, H=64,W=64):
18
+ """Computes rays using a General Pinhole Camera Model
19
+ Assumes self.h, self.w, self.focal, and self.cam_to_world exist
20
+ """
21
+ x, y = torch.meshgrid(
22
+ torch.arange(W), # X-Axis (columns)
23
+ torch.arange(H), # Y-Axis (rows)
24
+ indexing='xy')
25
+ camera_directions = torch.stack(
26
+ [(x - W * 0.5 + 0.5) / focal,
27
+ -(y - H * 0.5 + 0.5) / focal,
28
+ -torch.ones_like(x)],
29
+ dim=-1).to(c2w)
30
+
31
+ # Rotate ray directions from camera frame to the world frame
32
+ directions = ((camera_directions[ None,..., None, :] * c2w[None,None, None, :3, :3]).sum(axis=-1)) # Translate camera frame's origin to the world frame
33
+ origins = torch.broadcast_to(c2w[ None,None, None, :3, -1], directions.shape)
34
+ viewdirs = directions / torch.linalg.norm(directions, axis=-1, keepdims=True)
35
+
36
+ return torch.cat((origins,viewdirs),dim=-1)
37
+
38
+
39
+ class Camera(nn.Module):
40
+ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
41
+ image_name, uid,
42
+ trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
43
+ ):
44
+ super(Camera, self).__init__()
45
+
46
+ self.uid = uid
47
+ self.colmap_id = colmap_id
48
+ self.R = R
49
+ self.T = T
50
+ self.FoVx = FoVx
51
+ self.FoVy = FoVy
52
+ self.image_name = image_name
53
+
54
+ try:
55
+ self.data_device = torch.device(data_device)
56
+ except Exception as e:
57
+ print(e)
58
+ print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
59
+ self.data_device = torch.device("cuda")
60
+
61
+ self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
62
+ self.image_width = self.original_image.shape[2]
63
+ self.image_height = self.original_image.shape[1]
64
+
65
+ if gt_alpha_mask is not None:
66
+ self.original_image *= gt_alpha_mask.to(self.data_device)
67
+ else:
68
+ self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
69
+
70
+ self.zfar = 100.0
71
+ self.znear = 0.01
72
+
73
+ self.trans = trans
74
+ self.scale = scale
75
+
76
+ self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
77
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
78
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
79
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
80
+
81
+
82
+ class RCamera(nn.Module):
83
+ def __init__(self, colmap_id, R, T, FoVx, FoVy, uid, delta_polar, delta_azimuth, delta_radius, opt,
84
+ trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", SSAA=False
85
+ ):
86
+ super(RCamera, self).__init__()
87
+
88
+ self.uid = uid
89
+ self.colmap_id = colmap_id
90
+ self.R = R
91
+ self.T = T
92
+ self.FoVx = FoVx
93
+ self.FoVy = FoVy
94
+ self.delta_polar = delta_polar
95
+ self.delta_azimuth = delta_azimuth
96
+ self.delta_radius = delta_radius
97
+ try:
98
+ self.data_device = torch.device(data_device)
99
+ except Exception as e:
100
+ print(e)
101
+ print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
102
+ self.data_device = torch.device("cuda")
103
+
104
+ self.zfar = 100.0
105
+ self.znear = 0.01
106
+
107
+ if SSAA:
108
+ ssaa = opt.SSAA
109
+ else:
110
+ ssaa = 1
111
+
112
+ self.image_width = opt.image_w * ssaa
113
+ self.image_height = opt.image_h * ssaa
114
+
115
+ self.trans = trans
116
+ self.scale = scale
117
+
118
+ RT = torch.tensor(getWorld2View2(R, T, trans, scale))
119
+ self.world_view_transform = RT.transpose(0, 1).cuda()
120
+ self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
121
+ self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
122
+ self.camera_center = self.world_view_transform.inverse()[3, :3]
123
+ # self.rays = get_rays_torch(fov2focal(FoVx, 64), RT).cuda()
124
+ self.rays = get_rays_torch(fov2focal(FoVx, self.image_width//8), RT, H=self.image_height//8, W=self.image_width//8).cuda()
125
+
126
+ class MiniCam:
127
+ def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
128
+ self.image_width = width
129
+ self.image_height = height
130
+ self.FoVy = fovy
131
+ self.FoVx = fovx
132
+ self.znear = znear
133
+ self.zfar = zfar
134
+ self.world_view_transform = world_view_transform
135
+ self.full_proj_transform = full_proj_transform
136
+ view_inv = torch.inverse(self.world_view_transform)
137
+ self.camera_center = view_inv[3][:3]
138
+
scene/dataset_readers.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import os
13
+ import sys
14
+ import torch
15
+ import random
16
+ import torch.nn.functional as F
17
+ from PIL import Image
18
+ from typing import NamedTuple
19
+ from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
20
+ import numpy as np
21
+ import json
22
+ from pathlib import Path
23
+ from utils.pointe_utils import init_from_pointe
24
+ from plyfile import PlyData, PlyElement
25
+ from utils.sh_utils import SH2RGB
26
+ from utils.general_utils import inverse_sigmoid_np
27
+ from scene.gaussian_model import BasicPointCloud
28
+
29
+
30
+ class RandCameraInfo(NamedTuple):
31
+ uid: int
32
+ R: np.array
33
+ T: np.array
34
+ FovY: np.array
35
+ FovX: np.array
36
+ width: int
37
+ height: int
38
+ delta_polar : np.array
39
+ delta_azimuth : np.array
40
+ delta_radius : np.array
41
+
42
+
43
+ class SceneInfo(NamedTuple):
44
+ point_cloud: BasicPointCloud
45
+ train_cameras: list
46
+ test_cameras: list
47
+ nerf_normalization: dict
48
+ ply_path: str
49
+
50
+
51
+ class RSceneInfo(NamedTuple):
52
+ point_cloud: BasicPointCloud
53
+ test_cameras: list
54
+ ply_path: str
55
+
56
+ # def getNerfppNorm(cam_info):
57
+ # def get_center_and_diag(cam_centers):
58
+ # cam_centers = np.hstack(cam_centers)
59
+ # avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
60
+ # center = avg_cam_center
61
+ # dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
62
+ # diagonal = np.max(dist)
63
+ # return center.flatten(), diagonal
64
+
65
+ # cam_centers = []
66
+
67
+ # for cam in cam_info:
68
+ # W2C = getWorld2View2(cam.R, cam.T)
69
+ # C2W = np.linalg.inv(W2C)
70
+ # cam_centers.append(C2W[:3, 3:4])
71
+
72
+ # center, diagonal = get_center_and_diag(cam_centers)
73
+ # radius = diagonal * 1.1
74
+
75
+ # translate = -center
76
+
77
+ # return {"translate": translate, "radius": radius}
78
+
79
+
80
+
81
+ def fetchPly(path):
82
+ plydata = PlyData.read(path)
83
+ vertices = plydata['vertex']
84
+ positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
85
+ colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
86
+ normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
87
+ return BasicPointCloud(points=positions, colors=colors, normals=normals)
88
+
89
+ def storePly(path, xyz, rgb):
90
+ # Define the dtype for the structured array
91
+ dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
92
+ ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
93
+ ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
94
+
95
+ normals = np.zeros_like(xyz)
96
+
97
+ elements = np.empty(xyz.shape[0], dtype=dtype)
98
+ attributes = np.concatenate((xyz, normals, rgb), axis=1)
99
+ elements[:] = list(map(tuple, attributes))
100
+
101
+ # Create the PlyData object and write to file
102
+ vertex_element = PlyElement.describe(elements, 'vertex')
103
+ ply_data = PlyData([vertex_element])
104
+ ply_data.write(path)
105
+
106
+ #only test_camera
107
+ def readCircleCamInfo(path,opt):
108
+ print("Reading Test Transforms")
109
+ test_cam_infos = GenerateCircleCameras(opt,render45 = opt.render_45)
110
+ ply_path = os.path.join(path, "init_points3d.ply")
111
+ if not os.path.exists(ply_path):
112
+ # Since this data set has no colmap data, we start with random points
113
+ num_pts = opt.init_num_pts
114
+ if opt.init_shape == 'sphere':
115
+ thetas = np.random.rand(num_pts)*np.pi
116
+ phis = np.random.rand(num_pts)*2*np.pi
117
+ radius = np.random.rand(num_pts)*0.5
118
+ # We create random points inside the bounds of sphere
119
+ xyz = np.stack([
120
+ radius * np.sin(thetas) * np.sin(phis),
121
+ radius * np.sin(thetas) * np.cos(phis),
122
+ radius * np.cos(thetas),
123
+ ], axis=-1) # [B, 3]
124
+ elif opt.init_shape == 'box':
125
+ xyz = np.random.random((num_pts, 3)) * 1.0 - 0.5
126
+ elif opt.init_shape == 'rectangle_x':
127
+ xyz = np.random.random((num_pts, 3))
128
+ xyz[:, 0] = xyz[:, 0] * 0.6 - 0.3
129
+ xyz[:, 1] = xyz[:, 1] * 1.2 - 0.6
130
+ xyz[:, 2] = xyz[:, 2] * 0.5 - 0.25
131
+ elif opt.init_shape == 'rectangle_z':
132
+ xyz = np.random.random((num_pts, 3))
133
+ xyz[:, 0] = xyz[:, 0] * 0.8 - 0.4
134
+ xyz[:, 1] = xyz[:, 1] * 0.6 - 0.3
135
+ xyz[:, 2] = xyz[:, 2] * 1.2 - 0.6
136
+ elif opt.init_shape == 'pointe':
137
+ num_pts = int(num_pts/5000)
138
+ xyz,rgb = init_from_pointe(opt.init_prompt)
139
+ xyz[:,1] = - xyz[:,1]
140
+ xyz[:,2] = xyz[:,2] + 0.15
141
+ thetas = np.random.rand(num_pts)*np.pi
142
+ phis = np.random.rand(num_pts)*2*np.pi
143
+ radius = np.random.rand(num_pts)*0.05
144
+ # We create random points inside the bounds of sphere
145
+ xyz_ball = np.stack([
146
+ radius * np.sin(thetas) * np.sin(phis),
147
+ radius * np.sin(thetas) * np.cos(phis),
148
+ radius * np.cos(thetas),
149
+ ], axis=-1) # [B, 3]expend_dims
150
+ rgb_ball = np.random.random((4096, num_pts, 3))*0.0001
151
+ rgb = (np.expand_dims(rgb,axis=1)+rgb_ball).reshape(-1,3)
152
+ xyz = (np.expand_dims(xyz,axis=1)+np.expand_dims(xyz_ball,axis=0)).reshape(-1,3)
153
+ xyz = xyz * 1.
154
+ num_pts = xyz.shape[0]
155
+ elif opt.init_shape == 'scene':
156
+ thetas = np.random.rand(num_pts)*np.pi
157
+ phis = np.random.rand(num_pts)*2*np.pi
158
+ radius = np.random.rand(num_pts) + opt.radius_range[-1]*3
159
+ # We create random points inside the bounds of sphere
160
+ xyz = np.stack([
161
+ radius * np.sin(thetas) * np.sin(phis),
162
+ radius * np.sin(thetas) * np.cos(phis),
163
+ radius * np.cos(thetas),
164
+ ], axis=-1) # [B, 3]
165
+ else:
166
+ raise NotImplementedError()
167
+ print(f"Generating random point cloud ({num_pts})...")
168
+
169
+ shs = np.random.random((num_pts, 3)) / 255.0
170
+
171
+ if opt.init_shape == 'pointe' and opt.use_pointe_rgb:
172
+ pcd = BasicPointCloud(points=xyz, colors=rgb, normals=np.zeros((num_pts, 3)))
173
+ storePly(ply_path, xyz, rgb * 255)
174
+ else:
175
+ pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
176
+ storePly(ply_path, xyz, SH2RGB(shs) * 255)
177
+ try:
178
+ pcd = fetchPly(ply_path)
179
+ except:
180
+ pcd = None
181
+
182
+ scene_info = RSceneInfo(point_cloud=pcd,
183
+ test_cameras=test_cam_infos,
184
+ ply_path=ply_path)
185
+ return scene_info
186
+ #borrow from https://github.com/ashawkey/stable-dreamfusion
187
+
188
+ def safe_normalize(x, eps=1e-20):
189
+ return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps))
190
+
191
+ # def circle_poses(radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), angle_overhead=30, angle_front=60):
192
+
193
+ # theta = theta / 180 * np.pi
194
+ # phi = phi / 180 * np.pi
195
+ # angle_overhead = angle_overhead / 180 * np.pi
196
+ # angle_front = angle_front / 180 * np.pi
197
+
198
+ # centers = torch.stack([
199
+ # radius * torch.sin(theta) * torch.sin(phi),
200
+ # radius * torch.cos(theta),
201
+ # radius * torch.sin(theta) * torch.cos(phi),
202
+ # ], dim=-1) # [B, 3]
203
+
204
+ # # lookat
205
+ # forward_vector = safe_normalize(centers)
206
+ # up_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(len(centers), 1)
207
+ # right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
208
+ # up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
209
+
210
+ # poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(len(centers), 1, 1)
211
+ # poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1)
212
+ # poses[:, :3, 3] = centers
213
+
214
+ # return poses.numpy()
215
+
216
+ def circle_poses(radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0]), angle_overhead=30, angle_front=60):
217
+
218
+ theta = theta / 180 * np.pi
219
+ phi = phi / 180 * np.pi
220
+ angle_overhead = angle_overhead / 180 * np.pi
221
+ angle_front = angle_front / 180 * np.pi
222
+
223
+ centers = torch.stack([
224
+ radius * torch.sin(theta) * torch.sin(phi),
225
+ radius * torch.sin(theta) * torch.cos(phi),
226
+ radius * torch.cos(theta),
227
+ ], dim=-1) # [B, 3]
228
+
229
+ # lookat
230
+ forward_vector = safe_normalize(centers)
231
+ up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(len(centers), 1)
232
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
233
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1))
234
+
235
+ poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(len(centers), 1, 1)
236
+ poses[:, :3, :3] = torch.stack((-right_vector, up_vector, forward_vector), dim=-1)
237
+ poses[:, :3, 3] = centers
238
+
239
+ return poses.numpy()
240
+
241
+ def gen_random_pos(size, param_range, gamma=1):
242
+ lower, higher = param_range[0], param_range[1]
243
+
244
+ mid = lower + (higher - lower) * 0.5
245
+ radius = (higher - lower) * 0.5
246
+
247
+ rand_ = torch.rand(size) # 0, 1
248
+ sign = torch.where(torch.rand(size) > 0.5, torch.ones(size) * -1., torch.ones(size))
249
+ rand_ = sign * (rand_ ** gamma)
250
+
251
+ return (rand_ * radius) + mid
252
+
253
+
254
+ def rand_poses(size, opt, radius_range=[1, 1.5], theta_range=[0, 120], phi_range=[0, 360], angle_overhead=30, angle_front=60, uniform_sphere_rate=0.5, rand_cam_gamma=1):
255
+ ''' generate random poses from an orbit camera
256
+ Args:
257
+ size: batch size of generated poses.
258
+ device: where to allocate the output.
259
+ radius: camera radius
260
+ theta_range: [min, max], should be in [0, pi]
261
+ phi_range: [min, max], should be in [0, 2 * pi]
262
+ Return:
263
+ poses: [size, 4, 4]
264
+ '''
265
+
266
+ theta_range = np.array(theta_range) / 180 * np.pi
267
+ phi_range = np.array(phi_range) / 180 * np.pi
268
+ angle_overhead = angle_overhead / 180 * np.pi
269
+ angle_front = angle_front / 180 * np.pi
270
+
271
+ # radius = torch.rand(size) * (radius_range[1] - radius_range[0]) + radius_range[0]
272
+ radius = gen_random_pos(size, radius_range)
273
+
274
+ if random.random() < uniform_sphere_rate:
275
+ unit_centers = F.normalize(
276
+ torch.stack([
277
+ torch.randn(size),
278
+ torch.abs(torch.randn(size)),
279
+ torch.randn(size),
280
+ ], dim=-1), p=2, dim=1
281
+ )
282
+ thetas = torch.acos(unit_centers[:,1])
283
+ phis = torch.atan2(unit_centers[:,0], unit_centers[:,2])
284
+ phis[phis < 0] += 2 * np.pi
285
+ centers = unit_centers * radius.unsqueeze(-1)
286
+ else:
287
+ # thetas = torch.rand(size) * (theta_range[1] - theta_range[0]) + theta_range[0]
288
+ # phis = torch.rand(size) * (phi_range[1] - phi_range[0]) + phi_range[0]
289
+ # phis[phis < 0] += 2 * np.pi
290
+
291
+ # centers = torch.stack([
292
+ # radius * torch.sin(thetas) * torch.sin(phis),
293
+ # radius * torch.cos(thetas),
294
+ # radius * torch.sin(thetas) * torch.cos(phis),
295
+ # ], dim=-1) # [B, 3]
296
+ # thetas = torch.rand(size) * (theta_range[1] - theta_range[0]) + theta_range[0]
297
+ # phis = torch.rand(size) * (phi_range[1] - phi_range[0]) + phi_range[0]
298
+ thetas = gen_random_pos(size, theta_range, rand_cam_gamma)
299
+ phis = gen_random_pos(size, phi_range, rand_cam_gamma)
300
+ phis[phis < 0] += 2 * np.pi
301
+
302
+ centers = torch.stack([
303
+ radius * torch.sin(thetas) * torch.sin(phis),
304
+ radius * torch.sin(thetas) * torch.cos(phis),
305
+ radius * torch.cos(thetas),
306
+ ], dim=-1) # [B, 3]
307
+
308
+ targets = 0
309
+
310
+ # jitters
311
+ if opt.jitter_pose:
312
+ jit_center = opt.jitter_center # 0.015 # was 0.2
313
+ jit_target = opt.jitter_target
314
+ centers += torch.rand_like(centers) * jit_center - jit_center/2.0
315
+ targets += torch.randn_like(centers) * jit_target
316
+
317
+ # lookat
318
+ forward_vector = safe_normalize(centers - targets)
319
+ up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1)
320
+ #up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1)
321
+ right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1))
322
+
323
+ if opt.jitter_pose:
324
+ up_noise = torch.randn_like(up_vector) * opt.jitter_up
325
+ else:
326
+ up_noise = 0
327
+
328
+ up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1) + up_noise) #forward_vector
329
+
330
+ poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(size, 1, 1)
331
+ poses[:, :3, :3] = torch.stack((-right_vector, up_vector, forward_vector), dim=-1) #up_vector
332
+ poses[:, :3, 3] = centers
333
+
334
+
335
+ # back to degree
336
+ thetas = thetas / np.pi * 180
337
+ phis = phis / np.pi * 180
338
+
339
+ return poses.numpy(), thetas.numpy(), phis.numpy(), radius.numpy()
340
+
341
+ def GenerateCircleCameras(opt, size=8, render45 = False):
342
+ # random focal
343
+ fov = opt.default_fovy
344
+ cam_infos = []
345
+ #generate specific data structure
346
+ for idx in range(size):
347
+ thetas = torch.FloatTensor([opt.default_polar])
348
+ phis = torch.FloatTensor([(idx / size) * 360])
349
+ radius = torch.FloatTensor([opt.default_radius])
350
+ # random pose on the fly
351
+ poses = circle_poses(radius=radius, theta=thetas, phi=phis, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front)
352
+ matrix = np.linalg.inv(poses[0])
353
+ R = -np.transpose(matrix[:3,:3])
354
+ R[:,0] = -R[:,0]
355
+ T = -matrix[:3, 3]
356
+ fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w)
357
+ FovY = fovy
358
+ FovX = fov
359
+
360
+ # delta polar/azimuth/radius to default view
361
+ delta_polar = thetas - opt.default_polar
362
+ delta_azimuth = phis - opt.default_azimuth
363
+ delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
364
+ delta_radius = radius - opt.default_radius
365
+ cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w,
366
+ height = opt.image_h, delta_polar = delta_polar,delta_azimuth = delta_azimuth, delta_radius = delta_radius))
367
+ if render45:
368
+ for idx in range(size):
369
+ thetas = torch.FloatTensor([opt.default_polar*2//3])
370
+ phis = torch.FloatTensor([(idx / size) * 360])
371
+ radius = torch.FloatTensor([opt.default_radius])
372
+ # random pose on the fly
373
+ poses = circle_poses(radius=radius, theta=thetas, phi=phis, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front)
374
+ matrix = np.linalg.inv(poses[0])
375
+ R = -np.transpose(matrix[:3,:3])
376
+ R[:,0] = -R[:,0]
377
+ T = -matrix[:3, 3]
378
+ fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w)
379
+ FovY = fovy
380
+ FovX = fov
381
+
382
+ # delta polar/azimuth/radius to default view
383
+ delta_polar = thetas - opt.default_polar
384
+ delta_azimuth = phis - opt.default_azimuth
385
+ delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
386
+ delta_radius = radius - opt.default_radius
387
+ cam_infos.append(RandCameraInfo(uid=idx+size, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w,
388
+ height = opt.image_h, delta_polar = delta_polar,delta_azimuth = delta_azimuth, delta_radius = delta_radius))
389
+ return cam_infos
390
+
391
+
392
+ def GenerateRandomCameras(opt, size=2000, SSAA=True):
393
+ # random pose on the fly
394
+ poses, thetas, phis, radius = rand_poses(size, opt, radius_range=opt.radius_range, theta_range=opt.theta_range, phi_range=opt.phi_range,
395
+ angle_overhead=opt.angle_overhead, angle_front=opt.angle_front, uniform_sphere_rate=opt.uniform_sphere_rate,
396
+ rand_cam_gamma=opt.rand_cam_gamma)
397
+ # delta polar/azimuth/radius to default view
398
+ delta_polar = thetas - opt.default_polar
399
+ delta_azimuth = phis - opt.default_azimuth
400
+ delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
401
+ delta_radius = radius - opt.default_radius
402
+ # random focal
403
+ fov = random.random() * (opt.fovy_range[1] - opt.fovy_range[0]) + opt.fovy_range[0]
404
+
405
+ cam_infos = []
406
+
407
+ if SSAA:
408
+ ssaa = opt.SSAA
409
+ else:
410
+ ssaa = 1
411
+
412
+ image_h = opt.image_h * ssaa
413
+ image_w = opt.image_w * ssaa
414
+
415
+ #generate specific data structure
416
+ for idx in range(size):
417
+ matrix = np.linalg.inv(poses[idx])
418
+ R = -np.transpose(matrix[:3,:3])
419
+ R[:,0] = -R[:,0]
420
+ T = -matrix[:3, 3]
421
+ # matrix = poses[idx]
422
+ # R = matrix[:3,:3]
423
+ # T = matrix[:3, 3]
424
+ fovy = focal2fov(fov2focal(fov, image_h), image_w)
425
+ FovY = fovy
426
+ FovX = fov
427
+
428
+ cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=image_w,
429
+ height=image_h, delta_polar = delta_polar[idx],
430
+ delta_azimuth = delta_azimuth[idx], delta_radius = delta_radius[idx]))
431
+ return cam_infos
432
+
433
+ def GeneratePurnCameras(opt, size=300):
434
+ # random pose on the fly
435
+ poses, thetas, phis, radius = rand_poses(size, opt, radius_range=[opt.default_radius,opt.default_radius+0.1], theta_range=opt.theta_range, phi_range=opt.phi_range, angle_overhead=opt.angle_overhead, angle_front=opt.angle_front, uniform_sphere_rate=opt.uniform_sphere_rate)
436
+ # delta polar/azimuth/radius to default view
437
+ delta_polar = thetas - opt.default_polar
438
+ delta_azimuth = phis - opt.default_azimuth
439
+ delta_azimuth[delta_azimuth > 180] -= 360 # range in [-180, 180]
440
+ delta_radius = radius - opt.default_radius
441
+ # random focal
442
+ #fov = random.random() * (opt.fovy_range[1] - opt.fovy_range[0]) + opt.fovy_range[0]
443
+ fov = opt.default_fovy
444
+ cam_infos = []
445
+ #generate specific data structure
446
+ for idx in range(size):
447
+ matrix = np.linalg.inv(poses[idx])
448
+ R = -np.transpose(matrix[:3,:3])
449
+ R[:,0] = -R[:,0]
450
+ T = -matrix[:3, 3]
451
+ # matrix = poses[idx]
452
+ # R = matrix[:3,:3]
453
+ # T = matrix[:3, 3]
454
+ fovy = focal2fov(fov2focal(fov, opt.image_h), opt.image_w)
455
+ FovY = fovy
456
+ FovX = fov
457
+
458
+ cam_infos.append(RandCameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX,width=opt.image_w,
459
+ height = opt.image_h, delta_polar = delta_polar[idx],delta_azimuth = delta_azimuth[idx], delta_radius = delta_radius[idx]))
460
+ return cam_infos
461
+
462
+ sceneLoadTypeCallbacks = {
463
+ # "Colmap": readColmapSceneInfo,
464
+ # "Blender" : readNerfSyntheticInfo,
465
+ "RandomCam" : readCircleCamInfo
466
+ }
scene/gaussian_model.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import numpy as np
14
+ from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
15
+ from torch import nn
16
+ import os
17
+ from utils.system_utils import mkdir_p
18
+ from plyfile import PlyData, PlyElement
19
+ from utils.sh_utils import RGB2SH,SH2RGB
20
+ from simple_knn._C import distCUDA2
21
+ from utils.graphics_utils import BasicPointCloud
22
+ from utils.general_utils import strip_symmetric, build_scaling_rotation
23
+ # from .resnet import *
24
+
25
+ class GaussianModel:
26
+
27
+ def setup_functions(self):
28
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
29
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
30
+ actual_covariance = L @ L.transpose(1, 2)
31
+ symm = strip_symmetric(actual_covariance)
32
+ return symm
33
+
34
+ self.scaling_activation = torch.exp
35
+ self.scaling_inverse_activation = torch.log
36
+
37
+ self.covariance_activation = build_covariance_from_scaling_rotation
38
+
39
+ self.opacity_activation = torch.sigmoid
40
+ self.inverse_opacity_activation = inverse_sigmoid
41
+
42
+ self.rotation_activation = torch.nn.functional.normalize
43
+
44
+
45
+ def __init__(self, sh_degree : int):
46
+ self.active_sh_degree = 0
47
+ self.max_sh_degree = sh_degree
48
+ self._xyz = torch.empty(0)
49
+ self._features_dc = torch.empty(0)
50
+ self._features_rest = torch.empty(0)
51
+ self._scaling = torch.empty(0)
52
+ self._rotation = torch.empty(0)
53
+ self._opacity = torch.empty(0)
54
+ self._background = torch.empty(0)
55
+ self.max_radii2D = torch.empty(0)
56
+ self.xyz_gradient_accum = torch.empty(0)
57
+ self.denom = torch.empty(0)
58
+ self.optimizer = None
59
+ self.percent_dense = 0
60
+ self.spatial_lr_scale = 0
61
+ self.setup_functions()
62
+
63
+ def capture(self):
64
+ return (
65
+ self.active_sh_degree,
66
+ self._xyz,
67
+ self._features_dc,
68
+ self._features_rest,
69
+ self._scaling,
70
+ self._rotation,
71
+ self._opacity,
72
+ self.max_radii2D,
73
+ self.xyz_gradient_accum,
74
+ self.denom,
75
+ self.optimizer.state_dict(),
76
+ self.spatial_lr_scale,
77
+ )
78
+
79
+ def restore(self, model_args, training_args):
80
+ (self.active_sh_degree,
81
+ self._xyz,
82
+ self._features_dc,
83
+ self._features_rest,
84
+ self._scaling,
85
+ self._rotation,
86
+ self._opacity,
87
+ self.max_radii2D,
88
+ xyz_gradient_accum,
89
+ denom,
90
+ opt_dict,
91
+ self.spatial_lr_scale) = model_args
92
+ self.training_setup(training_args)
93
+ self.xyz_gradient_accum = xyz_gradient_accum
94
+ self.denom = denom
95
+ self.optimizer.load_state_dict(opt_dict)
96
+
97
+ @property
98
+ def get_scaling(self):
99
+ return self.scaling_activation(self._scaling)
100
+
101
+ @property
102
+ def get_rotation(self):
103
+ return self.rotation_activation(self._rotation)
104
+
105
+ @property
106
+ def get_xyz(self):
107
+ return self._xyz
108
+
109
+ @property
110
+ def get_background(self):
111
+ return torch.sigmoid(self._background)
112
+
113
+ @property
114
+ def get_features(self):
115
+ features_dc = self._features_dc
116
+ features_rest = self._features_rest
117
+ return torch.cat((features_dc, features_rest), dim=1)
118
+
119
+ @property
120
+ def get_opacity(self):
121
+ return self.opacity_activation(self._opacity)
122
+
123
+ def get_covariance(self, scaling_modifier = 1):
124
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
125
+
126
+ def oneupSHdegree(self):
127
+ if self.active_sh_degree < self.max_sh_degree:
128
+ self.active_sh_degree += 1
129
+
130
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
131
+ self.spatial_lr_scale = spatial_lr_scale
132
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
133
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors))).float().cuda() #RGB2SH(
134
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
135
+ features[:, :3, 0 ] = fused_color
136
+ features[:, 3:, 1:] = 0.0
137
+
138
+ print("Number of points at initialisation : ", fused_point_cloud.shape[0])
139
+
140
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
141
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
142
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
143
+ rots[:, 0] = 1
144
+
145
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
146
+
147
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
148
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
149
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
150
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
151
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
152
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
153
+ self._background = nn.Parameter(torch.zeros((3,1,1), device="cuda").requires_grad_(True))
154
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
155
+
156
+ def training_setup(self, training_args):
157
+ self.percent_dense = training_args.percent_dense
158
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
159
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
160
+
161
+ l = [
162
+ {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
163
+ {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
164
+ {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
165
+ {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
166
+ {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
167
+ {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
168
+ {'params': [self._background], 'lr': training_args.feature_lr, "name": "background"},
169
+ ]
170
+
171
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
172
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
173
+ lr_final=training_args.position_lr_final*self.spatial_lr_scale,
174
+ lr_delay_mult=training_args.position_lr_delay_mult,
175
+ max_steps=training_args.iterations)
176
+
177
+
178
+ self.rotation_scheduler_args = get_expon_lr_func(lr_init=training_args.rotation_lr,
179
+ lr_final=training_args.rotation_lr_final,
180
+ lr_delay_mult=training_args.position_lr_delay_mult,
181
+ max_steps=training_args.iterations)
182
+
183
+ self.scaling_scheduler_args = get_expon_lr_func(lr_init=training_args.scaling_lr,
184
+ lr_final=training_args.scaling_lr_final,
185
+ lr_delay_mult=training_args.position_lr_delay_mult,
186
+ max_steps=training_args.iterations)
187
+
188
+ self.feature_scheduler_args = get_expon_lr_func(lr_init=training_args.feature_lr,
189
+ lr_final=training_args.feature_lr_final,
190
+ lr_delay_mult=training_args.position_lr_delay_mult,
191
+ max_steps=training_args.iterations)
192
+ def update_learning_rate(self, iteration):
193
+ ''' Learning rate scheduling per step '''
194
+ for param_group in self.optimizer.param_groups:
195
+ if param_group["name"] == "xyz":
196
+ lr = self.xyz_scheduler_args(iteration)
197
+ param_group['lr'] = lr
198
+ return lr
199
+
200
+ def update_feature_learning_rate(self, iteration):
201
+ ''' Learning rate scheduling per step '''
202
+ for param_group in self.optimizer.param_groups:
203
+ if param_group["name"] == "f_dc":
204
+ lr = self.feature_scheduler_args(iteration)
205
+ param_group['lr'] = lr
206
+ return lr
207
+
208
+ def update_rotation_learning_rate(self, iteration):
209
+ ''' Learning rate scheduling per step '''
210
+ for param_group in self.optimizer.param_groups:
211
+ if param_group["name"] == "rotation":
212
+ lr = self.rotation_scheduler_args(iteration)
213
+ param_group['lr'] = lr
214
+ return lr
215
+
216
+ def update_scaling_learning_rate(self, iteration):
217
+ ''' Learning rate scheduling per step '''
218
+ for param_group in self.optimizer.param_groups:
219
+ if param_group["name"] == "scaling":
220
+ lr = self.scaling_scheduler_args(iteration)
221
+ param_group['lr'] = lr
222
+ return lr
223
+
224
+
225
+ def construct_list_of_attributes(self):
226
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
227
+ # All channels except the 3 DC
228
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
229
+ l.append('f_dc_{}'.format(i))
230
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
231
+ l.append('f_rest_{}'.format(i))
232
+ l.append('opacity')
233
+ for i in range(self._scaling.shape[1]):
234
+ l.append('scale_{}'.format(i))
235
+ for i in range(self._rotation.shape[1]):
236
+ l.append('rot_{}'.format(i))
237
+ return l
238
+
239
+ def save_ply(self, path):
240
+ mkdir_p(os.path.dirname(path))
241
+
242
+ xyz = self._xyz.detach().cpu().numpy()
243
+ normals = np.zeros_like(xyz)
244
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
245
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
246
+ opacities = self._opacity.detach().cpu().numpy()
247
+ scale = self._scaling.detach().cpu().numpy()
248
+ rotation = self._rotation.detach().cpu().numpy()
249
+
250
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
251
+
252
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
253
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
254
+ elements[:] = list(map(tuple, attributes))
255
+ el = PlyElement.describe(elements, 'vertex')
256
+ PlyData([el]).write(path)
257
+ np.savetxt(os.path.join(os.path.split(path)[0],"point_cloud_rgb.txt"),np.concatenate((xyz, SH2RGB(f_dc)), axis=1))
258
+
259
+ def reset_opacity(self):
260
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
261
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
262
+ self._opacity = optimizable_tensors["opacity"]
263
+
264
+ def load_ply(self, path):
265
+ plydata = PlyData.read(path)
266
+
267
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
268
+ np.asarray(plydata.elements[0]["y"]),
269
+ np.asarray(plydata.elements[0]["z"])), axis=1)
270
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
271
+
272
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
273
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
274
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
275
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
276
+
277
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
278
+ extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
279
+ assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
280
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
281
+ for idx, attr_name in enumerate(extra_f_names):
282
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
283
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
284
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
285
+
286
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
287
+ scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
288
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
289
+ for idx, attr_name in enumerate(scale_names):
290
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
291
+
292
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
293
+ rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
294
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
295
+ for idx, attr_name in enumerate(rot_names):
296
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
297
+
298
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
299
+ self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
300
+ self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
301
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
302
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
303
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
304
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
305
+ self.active_sh_degree = self.max_sh_degree
306
+
307
+ def replace_tensor_to_optimizer(self, tensor, name):
308
+ optimizable_tensors = {}
309
+ for group in self.optimizer.param_groups:
310
+ if group["name"] not in ['background']:
311
+ if group["name"] == name:
312
+ stored_state = self.optimizer.state.get(group['params'][0], None)
313
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
314
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
315
+
316
+ del self.optimizer.state[group['params'][0]]
317
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
318
+ self.optimizer.state[group['params'][0]] = stored_state
319
+
320
+ optimizable_tensors[group["name"]] = group["params"][0]
321
+ return optimizable_tensors
322
+
323
+ def _prune_optimizer(self, mask):
324
+ optimizable_tensors = {}
325
+ for group in self.optimizer.param_groups:
326
+ stored_state = self.optimizer.state.get(group['params'][0], None)
327
+ if group["name"] not in ['background']:
328
+ if stored_state is not None:
329
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
330
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
331
+
332
+ del self.optimizer.state[group['params'][0]]
333
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
334
+ self.optimizer.state[group['params'][0]] = stored_state
335
+
336
+ optimizable_tensors[group["name"]] = group["params"][0]
337
+ else:
338
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
339
+ optimizable_tensors[group["name"]] = group["params"][0]
340
+ return optimizable_tensors
341
+
342
+ def prune_points(self, mask):
343
+ valid_points_mask = ~mask
344
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
345
+
346
+ self._xyz = optimizable_tensors["xyz"]
347
+ self._features_dc = optimizable_tensors["f_dc"]
348
+ self._features_rest = optimizable_tensors["f_rest"]
349
+ self._opacity = optimizable_tensors["opacity"]
350
+ self._scaling = optimizable_tensors["scaling"]
351
+ self._rotation = optimizable_tensors["rotation"]
352
+
353
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
354
+
355
+ self.denom = self.denom[valid_points_mask]
356
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
357
+
358
+ def cat_tensors_to_optimizer(self, tensors_dict):
359
+ optimizable_tensors = {}
360
+ for group in self.optimizer.param_groups:
361
+ if group["name"] not in ['background']:
362
+ assert len(group["params"]) == 1
363
+ extension_tensor = tensors_dict[group["name"]]
364
+ stored_state = self.optimizer.state.get(group['params'][0], None)
365
+ if stored_state is not None:
366
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
367
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
368
+
369
+ del self.optimizer.state[group['params'][0]]
370
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
371
+ self.optimizer.state[group['params'][0]] = stored_state
372
+
373
+ optimizable_tensors[group["name"]] = group["params"][0]
374
+ else:
375
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
376
+ optimizable_tensors[group["name"]] = group["params"][0]
377
+
378
+ return optimizable_tensors
379
+
380
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
381
+ d = {"xyz": new_xyz,
382
+ "f_dc": new_features_dc,
383
+ "f_rest": new_features_rest,
384
+ "opacity": new_opacities,
385
+ "scaling" : new_scaling,
386
+ "rotation" : new_rotation}
387
+
388
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
389
+ self._xyz = optimizable_tensors["xyz"]
390
+ self._features_dc = optimizable_tensors["f_dc"]
391
+ self._features_rest = optimizable_tensors["f_rest"]
392
+ self._opacity = optimizable_tensors["opacity"]
393
+ self._scaling = optimizable_tensors["scaling"]
394
+ self._rotation = optimizable_tensors["rotation"]
395
+
396
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
397
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
398
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
399
+
400
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
401
+ n_init_points = self.get_xyz.shape[0]
402
+ # Extract points that satisfy the gradient condition
403
+ padded_grad = torch.zeros((n_init_points), device="cuda")
404
+ padded_grad[:grads.shape[0]] = grads.squeeze()
405
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
406
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
407
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
408
+
409
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
410
+ means =torch.zeros((stds.size(0), 3),device="cuda")
411
+ samples = torch.normal(mean=means, std=stds)
412
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
413
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
414
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
415
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
416
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
417
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
418
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
419
+
420
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
421
+
422
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
423
+ self.prune_points(prune_filter)
424
+
425
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
426
+ # Extract points that satisfy the gradient condition
427
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
428
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
429
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
430
+
431
+ new_xyz = self._xyz[selected_pts_mask]
432
+ new_features_dc = self._features_dc[selected_pts_mask]
433
+ new_features_rest = self._features_rest[selected_pts_mask]
434
+ new_opacities = self._opacity[selected_pts_mask]
435
+ new_scaling = self._scaling[selected_pts_mask]
436
+ new_rotation = self._rotation[selected_pts_mask]
437
+
438
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
439
+
440
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
441
+ grads = self.xyz_gradient_accum / self.denom
442
+ grads[grads.isnan()] = 0.0
443
+
444
+ self.densify_and_clone(grads, max_grad, extent)
445
+ self.densify_and_split(grads, max_grad, extent)
446
+
447
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
448
+ if max_screen_size:
449
+ big_points_vs = self.max_radii2D > max_screen_size
450
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
451
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
452
+ self.prune_points(prune_mask)
453
+
454
+ torch.cuda.empty_cache()
455
+
456
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
457
+ self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
458
+ self.denom[update_filter] += 1
train.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import random
13
+ import imageio
14
+ import os
15
+ import torch
16
+ import torch.nn as nn
17
+ from random import randint
18
+ from utils.loss_utils import l1_loss, ssim, tv_loss
19
+ from gaussian_renderer import render, network_gui
20
+ import sys
21
+ from scene import Scene, GaussianModel
22
+ from utils.general_utils import safe_state
23
+ import uuid
24
+ from tqdm import tqdm
25
+ from utils.image_utils import psnr
26
+ from argparse import ArgumentParser, Namespace
27
+ from arguments import ModelParams, PipelineParams, OptimizationParams, GenerateCamParams, GuidanceParams
28
+ import math
29
+ import yaml
30
+ from torchvision.utils import save_image
31
+ import torchvision.transforms as T
32
+
33
+ try:
34
+ from torch.utils.tensorboard import SummaryWriter
35
+ TENSORBOARD_FOUND = True
36
+ except ImportError:
37
+ TENSORBOARD_FOUND = False
38
+
39
+ sys.path.append('/root/yangxin/codebase/3D_Playground/GSDF')
40
+
41
+
42
+ def adjust_text_embeddings(embeddings, azimuth, guidance_opt):
43
+ #TODO: add prenerg functions
44
+ text_z_list = []
45
+ weights_list = []
46
+ K = 0
47
+ #for b in range(azimuth):
48
+ text_z_, weights_ = get_pos_neg_text_embeddings(embeddings, azimuth, guidance_opt)
49
+ K = max(K, weights_.shape[0])
50
+ text_z_list.append(text_z_)
51
+ weights_list.append(weights_)
52
+
53
+ # Interleave text_embeddings from different dirs to form a batch
54
+ text_embeddings = []
55
+ for i in range(K):
56
+ for text_z in text_z_list:
57
+ # if uneven length, pad with the first embedding
58
+ text_embeddings.append(text_z[i] if i < len(text_z) else text_z[0])
59
+ text_embeddings = torch.stack(text_embeddings, dim=0) # [B * K, 77, 768]
60
+
61
+ # Interleave weights from different dirs to form a batch
62
+ weights = []
63
+ for i in range(K):
64
+ for weights_ in weights_list:
65
+ weights.append(weights_[i] if i < len(weights_) else torch.zeros_like(weights_[0]))
66
+ weights = torch.stack(weights, dim=0) # [B * K]
67
+ return text_embeddings, weights
68
+
69
+ def get_pos_neg_text_embeddings(embeddings, azimuth_val, opt):
70
+ if azimuth_val >= -90 and azimuth_val < 90:
71
+ if azimuth_val >= 0:
72
+ r = 1 - azimuth_val / 90
73
+ else:
74
+ r = 1 + azimuth_val / 90
75
+ start_z = embeddings['front']
76
+ end_z = embeddings['side']
77
+ # if random.random() < 0.3:
78
+ # r = r + random.gauss(0, 0.08)
79
+ pos_z = r * start_z + (1 - r) * end_z
80
+ text_z = torch.cat([pos_z, embeddings['front'], embeddings['side']], dim=0)
81
+ if r > 0.8:
82
+ front_neg_w = 0.0
83
+ else:
84
+ front_neg_w = math.exp(-r * opt.front_decay_factor) * opt.negative_w
85
+ if r < 0.2:
86
+ side_neg_w = 0.0
87
+ else:
88
+ side_neg_w = math.exp(-(1-r) * opt.side_decay_factor) * opt.negative_w
89
+
90
+ weights = torch.tensor([1.0, front_neg_w, side_neg_w])
91
+ else:
92
+ if azimuth_val >= 0:
93
+ r = 1 - (azimuth_val - 90) / 90
94
+ else:
95
+ r = 1 + (azimuth_val + 90) / 90
96
+ start_z = embeddings['side']
97
+ end_z = embeddings['back']
98
+ # if random.random() < 0.3:
99
+ # r = r + random.gauss(0, 0.08)
100
+ pos_z = r * start_z + (1 - r) * end_z
101
+ text_z = torch.cat([pos_z, embeddings['side'], embeddings['front']], dim=0)
102
+ front_neg_w = opt.negative_w
103
+ if r > 0.8:
104
+ side_neg_w = 0.0
105
+ else:
106
+ side_neg_w = math.exp(-r * opt.side_decay_factor) * opt.negative_w / 2
107
+
108
+ weights = torch.tensor([1.0, side_neg_w, front_neg_w])
109
+ return text_z, weights.to(text_z.device)
110
+
111
+ def prepare_embeddings(guidance_opt, guidance):
112
+ embeddings = {}
113
+ # text embeddings (stable-diffusion) and (IF)
114
+ embeddings['default'] = guidance.get_text_embeds([guidance_opt.text])
115
+ embeddings['uncond'] = guidance.get_text_embeds([guidance_opt.negative])
116
+
117
+ for d in ['front', 'side', 'back']:
118
+ embeddings[d] = guidance.get_text_embeds([f"{guidance_opt.text}, {d} view"])
119
+ embeddings['inverse_text'] = guidance.get_text_embeds(guidance_opt.inverse_text)
120
+ return embeddings
121
+
122
+ def guidance_setup(guidance_opt):
123
+ if guidance_opt.guidance=="SD":
124
+ from guidance.sd_utils import StableDiffusion
125
+ guidance = StableDiffusion(guidance_opt.g_device, guidance_opt.fp16, guidance_opt.vram_O,
126
+ guidance_opt.t_range, guidance_opt.max_t_range,
127
+ num_train_timesteps=guidance_opt.num_train_timesteps,
128
+ ddim_inv=guidance_opt.ddim_inv,
129
+ textual_inversion_path = guidance_opt.textual_inversion_path,
130
+ LoRA_path = guidance_opt.LoRA_path,
131
+ guidance_opt=guidance_opt)
132
+ else:
133
+ raise ValueError(f'{guidance_opt.guidance} not supported.')
134
+ if guidance is not None:
135
+ for p in guidance.parameters():
136
+ p.requires_grad = False
137
+ embeddings = prepare_embeddings(guidance_opt, guidance)
138
+ return guidance, embeddings
139
+
140
+
141
+ def training(dataset, opt, pipe, gcams, guidance_opt, testing_iterations, saving_iterations, checkpoint_iterations, checkpoint, debug_from, save_video):
142
+ first_iter = 0
143
+ tb_writer = prepare_output_and_logger(dataset)
144
+ gaussians = GaussianModel(dataset.sh_degree)
145
+ scene = Scene(dataset, gcams, gaussians)
146
+ gaussians.training_setup(opt)
147
+ if checkpoint:
148
+ (model_params, first_iter) = torch.load(checkpoint)
149
+ gaussians.restore(model_params, opt)
150
+
151
+ bg_color = [1, 1, 1] if dataset._white_background else [0, 0, 0]
152
+ background = torch.tensor(bg_color, dtype=torch.float32, device=dataset.data_device)
153
+ iter_start = torch.cuda.Event(enable_timing = True)
154
+ iter_end = torch.cuda.Event(enable_timing = True)
155
+
156
+ #
157
+ save_folder = os.path.join(dataset._model_path,"train_process/")
158
+ if not os.path.exists(save_folder):
159
+ os.makedirs(save_folder) # makedirs
160
+ print('train_process is in :', save_folder)
161
+ #controlnet
162
+ use_control_net = False
163
+ #set up pretrain diffusion models and text_embedings
164
+ guidance, embeddings = guidance_setup(guidance_opt)
165
+ viewpoint_stack = None
166
+ viewpoint_stack_around = None
167
+ ema_loss_for_log = 0.0
168
+ progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
169
+ first_iter += 1
170
+
171
+ if opt.save_process:
172
+ save_folder_proc = os.path.join(scene.args._model_path,"process_videos/")
173
+ if not os.path.exists(save_folder_proc):
174
+ os.makedirs(save_folder_proc) # makedirs
175
+ process_view_points = scene.getCircleVideoCameras(batch_size=opt.pro_frames_num,render45=opt.pro_render_45).copy()
176
+ save_process_iter = opt.iterations // len(process_view_points)
177
+ pro_img_frames = []
178
+
179
+ for iteration in range(first_iter, opt.iterations + 1):
180
+ #TODO: DEBUG NETWORK_GUI
181
+ if network_gui.conn == None:
182
+ network_gui.try_connect()
183
+ while network_gui.conn != None:
184
+ try:
185
+ net_image_bytes = None
186
+ custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
187
+ if custom_cam != None:
188
+ net_image = render(custom_cam, gaussians, pipe, background, scaling_modifer)["render"]
189
+ net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
190
+ network_gui.send(net_image_bytes, guidance_opt.text)
191
+ if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
192
+ break
193
+ except Exception as e:
194
+ network_gui.conn = None
195
+
196
+ iter_start.record()
197
+
198
+ gaussians.update_learning_rate(iteration)
199
+ gaussians.update_feature_learning_rate(iteration)
200
+ gaussians.update_rotation_learning_rate(iteration)
201
+ gaussians.update_scaling_learning_rate(iteration)
202
+ # Every 500 its we increase the levels of SH up to a maximum degree
203
+ if iteration % 500 == 0:
204
+ gaussians.oneupSHdegree()
205
+
206
+ # progressively relaxing view range
207
+ if not opt.use_progressive:
208
+ if iteration >= opt.progressive_view_iter and iteration % opt.scale_up_cameras_iter == 0:
209
+ scene.pose_args.fovy_range[0] = max(scene.pose_args.max_fovy_range[0], scene.pose_args.fovy_range[0] * opt.fovy_scale_up_factor[0])
210
+ scene.pose_args.fovy_range[1] = min(scene.pose_args.max_fovy_range[1], scene.pose_args.fovy_range[1] * opt.fovy_scale_up_factor[1])
211
+
212
+ scene.pose_args.radius_range[1] = max(scene.pose_args.max_radius_range[1], scene.pose_args.radius_range[1] * opt.scale_up_factor)
213
+ scene.pose_args.radius_range[0] = max(scene.pose_args.max_radius_range[0], scene.pose_args.radius_range[0] * opt.scale_up_factor)
214
+
215
+ scene.pose_args.theta_range[1] = min(scene.pose_args.max_theta_range[1], scene.pose_args.theta_range[1] * opt.phi_scale_up_factor)
216
+ scene.pose_args.theta_range[0] = max(scene.pose_args.max_theta_range[0], scene.pose_args.theta_range[0] * 1/opt.phi_scale_up_factor)
217
+
218
+ # opt.reset_resnet_iter = max(500, opt.reset_resnet_iter // 1.25)
219
+ scene.pose_args.phi_range[0] = max(scene.pose_args.max_phi_range[0] , scene.pose_args.phi_range[0] * opt.phi_scale_up_factor)
220
+ scene.pose_args.phi_range[1] = min(scene.pose_args.max_phi_range[1], scene.pose_args.phi_range[1] * opt.phi_scale_up_factor)
221
+
222
+ print('scale up theta_range to:', scene.pose_args.theta_range)
223
+ print('scale up radius_range to:', scene.pose_args.radius_range)
224
+ print('scale up phi_range to:', scene.pose_args.phi_range)
225
+ print('scale up fovy_range to:', scene.pose_args.fovy_range)
226
+
227
+ # Pick a random Camera
228
+ if not viewpoint_stack:
229
+ viewpoint_stack = scene.getRandTrainCameras().copy()
230
+
231
+ C_batch_size = guidance_opt.C_batch_size
232
+ viewpoint_cams = []
233
+ images = []
234
+ text_z_ = []
235
+ weights_ = []
236
+ depths = []
237
+ alphas = []
238
+ scales = []
239
+
240
+ text_z_inverse =torch.cat([embeddings['uncond'],embeddings['inverse_text']], dim=0)
241
+
242
+ for i in range(C_batch_size):
243
+ try:
244
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
245
+ except:
246
+ viewpoint_stack = scene.getRandTrainCameras().copy()
247
+ viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
248
+
249
+ #pred text_z
250
+ azimuth = viewpoint_cam.delta_azimuth
251
+ text_z = [embeddings['uncond']]
252
+
253
+
254
+ if guidance_opt.perpneg:
255
+ text_z_comp, weights = adjust_text_embeddings(embeddings, azimuth, guidance_opt)
256
+ text_z.append(text_z_comp)
257
+ weights_.append(weights)
258
+
259
+ else:
260
+ if azimuth >= -90 and azimuth < 90:
261
+ if azimuth >= 0:
262
+ r = 1 - azimuth / 90
263
+ else:
264
+ r = 1 + azimuth / 90
265
+ start_z = embeddings['front']
266
+ end_z = embeddings['side']
267
+ else:
268
+ if azimuth >= 0:
269
+ r = 1 - (azimuth - 90) / 90
270
+ else:
271
+ r = 1 + (azimuth + 90) / 90
272
+ start_z = embeddings['side']
273
+ end_z = embeddings['back']
274
+ text_z.append(r * start_z + (1 - r) * end_z)
275
+
276
+ text_z = torch.cat(text_z, dim=0)
277
+ text_z_.append(text_z)
278
+
279
+ # Render
280
+ if (iteration - 1) == debug_from:
281
+ pipe.debug = True
282
+ render_pkg = render(viewpoint_cam, gaussians, pipe, background,
283
+ sh_deg_aug_ratio = dataset.sh_deg_aug_ratio,
284
+ bg_aug_ratio = dataset.bg_aug_ratio,
285
+ shs_aug_ratio = dataset.shs_aug_ratio,
286
+ scale_aug_ratio = dataset.scale_aug_ratio)
287
+ image, viewspace_point_tensor, visibility_filter, radii = render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
288
+ depth, alpha = render_pkg["depth"], render_pkg["alpha"]
289
+
290
+ scales.append(render_pkg["scales"])
291
+ images.append(image)
292
+ depths.append(depth)
293
+ alphas.append(alpha)
294
+ viewpoint_cams.append(viewpoint_cams)
295
+
296
+ images = torch.stack(images, dim=0)
297
+ depths = torch.stack(depths, dim=0)
298
+ alphas = torch.stack(alphas, dim=0)
299
+
300
+ # Loss
301
+ warm_up_rate = 1. - min(iteration/opt.warmup_iter,1.)
302
+ guidance_scale = guidance_opt.guidance_scale
303
+ _aslatent = False
304
+ if iteration < opt.geo_iter or random.random()< opt.as_latent_ratio:
305
+ _aslatent=True
306
+ if iteration > opt.use_control_net_iter and (random.random() < guidance_opt.controlnet_ratio):
307
+ use_control_net = True
308
+ if guidance_opt.perpneg:
309
+ loss = guidance.train_step_perpneg(torch.stack(text_z_, dim=1), images,
310
+ pred_depth=depths, pred_alpha=alphas,
311
+ grad_scale=guidance_opt.lambda_guidance,
312
+ use_control_net = use_control_net ,save_folder = save_folder, iteration = iteration, warm_up_rate=warm_up_rate,
313
+ weights = torch.stack(weights_, dim=1), resolution=(gcams.image_h, gcams.image_w),
314
+ guidance_opt=guidance_opt,as_latent=_aslatent, embedding_inverse = text_z_inverse)
315
+ else:
316
+ loss = guidance.train_step(torch.stack(text_z_, dim=1), images,
317
+ pred_depth=depths, pred_alpha=alphas,
318
+ grad_scale=guidance_opt.lambda_guidance,
319
+ use_control_net = use_control_net ,save_folder = save_folder, iteration = iteration, warm_up_rate=warm_up_rate,
320
+ resolution=(gcams.image_h, gcams.image_w),
321
+ guidance_opt=guidance_opt,as_latent=_aslatent, embedding_inverse = text_z_inverse)
322
+ #raise ValueError(f'original version not supported.')
323
+ scales = torch.stack(scales, dim=0)
324
+
325
+ loss_scale = torch.mean(scales,dim=-1).mean()
326
+ loss_tv = tv_loss(images) + tv_loss(depths)
327
+ # loss_bin = torch.mean(torch.min(alphas - 0.0001, 1 - alphas))
328
+
329
+ loss = loss + opt.lambda_tv * loss_tv + opt.lambda_scale * loss_scale #opt.lambda_tv * loss_tv + opt.lambda_bin * loss_bin + opt.lambda_scale * loss_scale +
330
+ loss.backward()
331
+ iter_end.record()
332
+
333
+ with torch.no_grad():
334
+ # Progress bar
335
+ ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
336
+ if opt.save_process:
337
+ if iteration % save_process_iter == 0 and len(process_view_points) > 0:
338
+ viewpoint_cam_p = process_view_points.pop(0)
339
+ render_p = render(viewpoint_cam_p, gaussians, pipe, background, test=True)
340
+ img_p = torch.clamp(render_p["render"], 0.0, 1.0)
341
+ img_p = img_p.detach().cpu().permute(1,2,0).numpy()
342
+ img_p = (img_p * 255).round().astype('uint8')
343
+ pro_img_frames.append(img_p)
344
+
345
+ if iteration % 10 == 0:
346
+ progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
347
+ progress_bar.update(10)
348
+ if iteration == opt.iterations:
349
+ progress_bar.close()
350
+
351
+ # Log and save
352
+ training_report(tb_writer, iteration, iter_start.elapsed_time(iter_end), testing_iterations, scene, render, (pipe, background))
353
+ if (iteration in testing_iterations):
354
+ if save_video:
355
+ video_path = video_inference(iteration, scene, render, (pipe, background))
356
+
357
+ if (iteration in saving_iterations):
358
+ print("\n[ITER {}] Saving Gaussians".format(iteration))
359
+ scene.save(iteration)
360
+
361
+ # Densification
362
+ if iteration < opt.densify_until_iter:
363
+ # Keep track of max radii in image-space for pruning
364
+ gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
365
+ gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
366
+
367
+ if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
368
+ size_threshold = 20 if iteration > opt.opacity_reset_interval else None
369
+ gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
370
+
371
+ if iteration % opt.opacity_reset_interval == 0: #or (dataset._white_background and iteration == opt.densify_from_iter)
372
+ gaussians.reset_opacity()
373
+
374
+ # Optimizer step
375
+ if iteration < opt.iterations:
376
+ gaussians.optimizer.step()
377
+ gaussians.optimizer.zero_grad(set_to_none = True)
378
+
379
+ if (iteration in checkpoint_iterations):
380
+ print("\n[ITER {}] Saving Checkpoint".format(iteration))
381
+ torch.save((gaussians.capture(), iteration), scene._model_path + "/chkpnt" + str(iteration) + ".pth")
382
+
383
+ if opt.save_process:
384
+ imageio.mimwrite(os.path.join(save_folder_proc, "video_rgb.mp4"), pro_img_frames, fps=30, quality=8)
385
+ return video_path
386
+
387
+
388
+ def prepare_output_and_logger(args):
389
+ if not args._model_path:
390
+ if os.getenv('OAR_JOB_ID'):
391
+ unique_str=os.getenv('OAR_JOB_ID')
392
+ else:
393
+ unique_str = str(uuid.uuid4())
394
+ args._model_path = os.path.join("./output/", args.workspace)
395
+
396
+ # Set up output folder
397
+ print("Output folder: {}".format(args._model_path))
398
+ os.makedirs(args._model_path, exist_ok = True)
399
+
400
+ # copy configs
401
+ if args.opt_path is not None:
402
+ os.system(' '.join(['cp', args.opt_path, os.path.join(args._model_path, 'config.yaml')]))
403
+
404
+ with open(os.path.join(args._model_path, "cfg_args"), 'w') as cfg_log_f:
405
+ cfg_log_f.write(str(Namespace(**vars(args))))
406
+
407
+ # Create Tensorboard writer
408
+ tb_writer = None
409
+ if TENSORBOARD_FOUND:
410
+ tb_writer = SummaryWriter(args._model_path)
411
+ else:
412
+ print("Tensorboard not available: not logging progress")
413
+ return tb_writer
414
+
415
+ def training_report(tb_writer, iteration, elapsed, testing_iterations, scene : Scene, renderFunc, renderArgs):
416
+ if tb_writer:
417
+ tb_writer.add_scalar('iter_time', elapsed, iteration)
418
+ # Report test and samples of training set
419
+ if iteration in testing_iterations:
420
+ save_folder = os.path.join(scene.args._model_path,"test_six_views/{}_iteration".format(iteration))
421
+ if not os.path.exists(save_folder):
422
+ os.makedirs(save_folder) # makedirs 创建文件时如果路径不存在会创建这个路径
423
+ print('test views is in :', save_folder)
424
+ torch.cuda.empty_cache()
425
+ config = ({'name': 'test', 'cameras' : scene.getTestCameras()})
426
+ if config['cameras'] and len(config['cameras']) > 0:
427
+ for idx, viewpoint in enumerate(config['cameras']):
428
+ render_out = renderFunc(viewpoint, scene.gaussians, *renderArgs, test=True)
429
+ rgb, depth = render_out["render"],render_out["depth"]
430
+ if depth is not None:
431
+ depth_norm = depth/depth.max()
432
+ save_image(depth_norm,os.path.join(save_folder,"render_depth_{}.png".format(viewpoint.uid)))
433
+
434
+ image = torch.clamp(rgb, 0.0, 1.0)
435
+ save_image(image,os.path.join(save_folder,"render_view_{}.png".format(viewpoint.uid)))
436
+ if tb_writer:
437
+ tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.uid), image[None], global_step=iteration)
438
+ print("\n[ITER {}] Eval Done!".format(iteration))
439
+ if tb_writer:
440
+ tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
441
+ tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
442
+ torch.cuda.empty_cache()
443
+
444
+ def video_inference(iteration, scene : Scene, renderFunc, renderArgs):
445
+ sharp = T.RandomAdjustSharpness(3, p=1.0)
446
+
447
+ save_folder = os.path.join(scene.args._model_path,"videos/{}_iteration".format(iteration))
448
+ if not os.path.exists(save_folder):
449
+ os.makedirs(save_folder) # makedirs
450
+ print('videos is in :', save_folder)
451
+ torch.cuda.empty_cache()
452
+ config = ({'name': 'test', 'cameras' : scene.getCircleVideoCameras()})
453
+ if config['cameras'] and len(config['cameras']) > 0:
454
+ img_frames = []
455
+ depth_frames = []
456
+ print("Generating Video using", len(config['cameras']), "different view points")
457
+ for idx, viewpoint in enumerate(config['cameras']):
458
+ render_out = renderFunc(viewpoint, scene.gaussians, *renderArgs, test=True)
459
+ rgb,depth = render_out["render"],render_out["depth"]
460
+ if depth is not None:
461
+ depth_norm = depth/depth.max()
462
+ depths = torch.clamp(depth_norm, 0.0, 1.0)
463
+ depths = depths.detach().cpu().permute(1,2,0).numpy()
464
+ depths = (depths * 255).round().astype('uint8')
465
+ depth_frames.append(depths)
466
+
467
+ image = torch.clamp(rgb, 0.0, 1.0)
468
+ image = image.detach().cpu().permute(1,2,0).numpy()
469
+ image = (image * 255).round().astype('uint8')
470
+ img_frames.append(image)
471
+ #save_image(image,os.path.join(save_folder,"lora_view_{}.jpg".format(viewpoint.uid)))
472
+ # Img to Numpy
473
+ imageio.mimwrite(os.path.join(save_folder, "video_rgb_{}.mp4".format(iteration)), img_frames, fps=30, quality=8)
474
+ if len(depth_frames) > 0:
475
+ imageio.mimwrite(os.path.join(save_folder, "video_depth_{}.mp4".format(iteration)), depth_frames, fps=30, quality=8)
476
+ print("\n[ITER {}] Video Save Done!".format(iteration))
477
+ torch.cuda.empty_cache()
478
+ return os.path.join(save_folder, "video_rgb_{}.mp4".format(iteration))
479
+
480
+ def args_parser(default_opt=None):
481
+ # Set up command line argument parser
482
+ parser = ArgumentParser(description="Training script parameters")
483
+
484
+ parser.add_argument('--opt', type=str, default=default_opt)
485
+ parser.add_argument('--ip', type=str, default="127.0.0.1")
486
+ parser.add_argument('--port', type=int, default=6009)
487
+ parser.add_argument('--debug_from', type=int, default=-1)
488
+ parser.add_argument('--seed', type=int, default=0)
489
+ parser.add_argument('--detect_anomaly', action='store_true', default=False)
490
+ parser.add_argument("--test_ratio", type=int, default=5) # [2500,5000,7500,10000,12000]
491
+ parser.add_argument("--save_ratio", type=int, default=2) # [10000,12000]
492
+ parser.add_argument("--save_video", type=bool, default=False)
493
+ parser.add_argument("--quiet", action="store_true")
494
+ parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
495
+ parser.add_argument("--start_checkpoint", type=str, default = None)
496
+ parser.add_argument("--cuda", type=str, default='0')
497
+
498
+ lp = ModelParams(parser)
499
+ op = OptimizationParams(parser)
500
+ pp = PipelineParams(parser)
501
+ gcp = GenerateCamParams(parser)
502
+ gp = GuidanceParams(parser)
503
+
504
+ args = parser.parse_args(sys.argv[1:])
505
+
506
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
507
+ if args.opt is not None:
508
+ with open(args.opt) as f:
509
+ opts = yaml.load(f, Loader=yaml.FullLoader)
510
+ lp.load_yaml(opts.get('ModelParams', None))
511
+ op.load_yaml(opts.get('OptimizationParams', None))
512
+ pp.load_yaml(opts.get('PipelineParams', None))
513
+ gcp.load_yaml(opts.get('GenerateCamParams', None))
514
+ gp.load_yaml(opts.get('GuidanceParams', None))
515
+
516
+ lp.opt_path = args.opt
517
+ args.port = opts['port']
518
+ args.save_video = opts.get('save_video', True)
519
+ args.seed = opts.get('seed', 0)
520
+ args.device = opts.get('device', 'cuda')
521
+
522
+ # override device
523
+ gp.g_device = args.device
524
+ lp.data_device = args.device
525
+ gcp.device = args.device
526
+ return args, lp, op, pp, gcp, gp
527
+
528
+ def start_training(args, lp, op, pp, gcp, gp):
529
+ # save iterations
530
+ test_iter = [1] + [k * op.iterations // args.test_ratio for k in range(1, args.test_ratio)] + [op.iterations]
531
+ args.test_iterations = test_iter
532
+
533
+ save_iter = [k * op.iterations // args.save_ratio for k in range(1, args.save_ratio)] + [op.iterations]
534
+ args.save_iterations = save_iter
535
+
536
+ print('Test iter:', args.test_iterations)
537
+ print('Save iter:', args.save_iterations)
538
+
539
+ print("Optimizing " + lp._model_path)
540
+
541
+ # Initialize system state (RNG)
542
+ safe_state(args.quiet, seed=args.seed)
543
+ # Start GUI server, configure and run training
544
+ network_gui.init(args.ip, args.port)
545
+ torch.autograd.set_detect_anomaly(args.detect_anomaly)
546
+ video_path = training(lp, op, pp, gcp, gp, args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from, args.save_video)
547
+ # All done
548
+ print("\nTraining complete.")
549
+ return video_path
550
+
551
+ if __name__ == "__main__":
552
+ args, lp, op, pp, gcp, gp = args_parser()
553
+ start_training(args, lp, op, pp, gcp, gp)
train.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python train.py --opt 'configs/bagel.yaml' --cuda 4
utils/camera_utils.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ from scene.cameras import Camera, RCamera
13
+ import numpy as np
14
+ from utils.general_utils import PILtoTorch
15
+ from utils.graphics_utils import fov2focal
16
+
17
+ WARNED = False
18
+
19
+ def loadCam(args, id, cam_info, resolution_scale):
20
+ orig_w, orig_h = cam_info.image.size
21
+
22
+ if args.resolution in [1, 2, 4, 8]:
23
+ resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
24
+ else: # should be a type that converts to float
25
+ if args.resolution == -1:
26
+ if orig_w > 1600:
27
+ global WARNED
28
+ if not WARNED:
29
+ print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
30
+ "If this is not desired, please explicitly specify '--resolution/-r' as 1")
31
+ WARNED = True
32
+ global_down = orig_w / 1600
33
+ else:
34
+ global_down = 1
35
+ else:
36
+ global_down = orig_w / args.resolution
37
+
38
+ scale = float(global_down) * float(resolution_scale)
39
+ resolution = (int(orig_w / scale), int(orig_h / scale))
40
+
41
+ resized_image_rgb = PILtoTorch(cam_info.image, resolution)
42
+
43
+ gt_image = resized_image_rgb[:3, ...]
44
+ loaded_mask = None
45
+
46
+ if resized_image_rgb.shape[1] == 4:
47
+ loaded_mask = resized_image_rgb[3:4, ...]
48
+
49
+ return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
50
+ FoVx=cam_info.FovX, FoVy=cam_info.FovY,
51
+ image=gt_image, gt_alpha_mask=loaded_mask,
52
+ image_name=cam_info.image_name, uid=id, data_device=args.data_device)
53
+
54
+
55
+ def loadRandomCam(opt, id, cam_info, resolution_scale, SSAA=False):
56
+ return RCamera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
57
+ FoVx=cam_info.FovX, FoVy=cam_info.FovY, delta_polar=cam_info.delta_polar,
58
+ delta_azimuth=cam_info.delta_azimuth , delta_radius=cam_info.delta_radius, opt=opt,
59
+ uid=id, data_device=opt.device, SSAA=SSAA)
60
+
61
+ def cameraList_from_camInfos(cam_infos, resolution_scale, args):
62
+ camera_list = []
63
+
64
+ for id, c in enumerate(cam_infos):
65
+ camera_list.append(loadCam(args, id, c, resolution_scale))
66
+
67
+ return camera_list
68
+
69
+
70
+ def cameraList_from_RcamInfos(cam_infos, resolution_scale, opt, SSAA=False):
71
+ camera_list = []
72
+
73
+ for id, c in enumerate(cam_infos):
74
+ camera_list.append(loadRandomCam(opt, id, c, resolution_scale, SSAA=SSAA))
75
+
76
+ return camera_list
77
+
78
+ def camera_to_JSON(id, camera : Camera):
79
+ Rt = np.zeros((4, 4))
80
+ Rt[:3, :3] = camera.R.transpose()
81
+ Rt[:3, 3] = camera.T
82
+ Rt[3, 3] = 1.0
83
+
84
+ W2C = np.linalg.inv(Rt)
85
+ pos = W2C[:3, 3]
86
+ rot = W2C[:3, :3]
87
+ serializable_array_2d = [x.tolist() for x in rot]
88
+ camera_entry = {
89
+ 'id' : id,
90
+ 'img_name' : id,
91
+ 'width' : camera.width,
92
+ 'height' : camera.height,
93
+ 'position': pos.tolist(),
94
+ 'rotation': serializable_array_2d,
95
+ 'fy' : fov2focal(camera.FovY, camera.height),
96
+ 'fx' : fov2focal(camera.FovX, camera.width)
97
+ }
98
+ return camera_entry
utils/general_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import sys
14
+ from datetime import datetime
15
+ import numpy as np
16
+ import random
17
+
18
+ def inverse_sigmoid(x):
19
+ return torch.log(x/(1-x))
20
+
21
+ def inverse_sigmoid_np(x):
22
+ return np.log(x/(1-x))
23
+
24
+ def PILtoTorch(pil_image, resolution):
25
+ resized_image_PIL = pil_image.resize(resolution)
26
+ resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
27
+ if len(resized_image.shape) == 3:
28
+ return resized_image.permute(2, 0, 1)
29
+ else:
30
+ return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
31
+
32
+ def get_expon_lr_func(
33
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
34
+ ):
35
+ """
36
+ Copied from Plenoxels
37
+
38
+ Continuous learning rate decay function. Adapted from JaxNeRF
39
+ The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
40
+ is log-linearly interpolated elsewhere (equivalent to exponential decay).
41
+ If lr_delay_steps>0 then the learning rate will be scaled by some smooth
42
+ function of lr_delay_mult, such that the initial learning rate is
43
+ lr_init*lr_delay_mult at the beginning of optimization but will be eased back
44
+ to the normal learning rate when steps>lr_delay_steps.
45
+ :param conf: config subtree 'lr' or similar
46
+ :param max_steps: int, the number of steps during optimization.
47
+ :return HoF which takes step as input
48
+ """
49
+
50
+ def helper(step):
51
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
52
+ # Disable this parameter
53
+ return 0.0
54
+ if lr_delay_steps > 0:
55
+ # A kind of reverse cosine decay.
56
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
57
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
58
+ )
59
+ else:
60
+ delay_rate = 1.0
61
+ t = np.clip(step / max_steps, 0, 1)
62
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
63
+ return delay_rate * log_lerp
64
+
65
+ return helper
66
+
67
+ def strip_lowerdiag(L):
68
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
69
+
70
+ uncertainty[:, 0] = L[:, 0, 0]
71
+ uncertainty[:, 1] = L[:, 0, 1]
72
+ uncertainty[:, 2] = L[:, 0, 2]
73
+ uncertainty[:, 3] = L[:, 1, 1]
74
+ uncertainty[:, 4] = L[:, 1, 2]
75
+ uncertainty[:, 5] = L[:, 2, 2]
76
+ return uncertainty
77
+
78
+ def strip_symmetric(sym):
79
+ return strip_lowerdiag(sym)
80
+
81
+ def build_rotation(r):
82
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
83
+
84
+ q = r / norm[:, None]
85
+
86
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
87
+
88
+ r = q[:, 0]
89
+ x = q[:, 1]
90
+ y = q[:, 2]
91
+ z = q[:, 3]
92
+
93
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
94
+ R[:, 0, 1] = 2 * (x*y - r*z)
95
+ R[:, 0, 2] = 2 * (x*z + r*y)
96
+ R[:, 1, 0] = 2 * (x*y + r*z)
97
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
98
+ R[:, 1, 2] = 2 * (y*z - r*x)
99
+ R[:, 2, 0] = 2 * (x*z - r*y)
100
+ R[:, 2, 1] = 2 * (y*z + r*x)
101
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
102
+ return R
103
+
104
+ def build_scaling_rotation(s, r):
105
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
106
+ R = build_rotation(r)
107
+
108
+ L[:,0,0] = s[:,0]
109
+ L[:,1,1] = s[:,1]
110
+ L[:,2,2] = s[:,2]
111
+
112
+ L = R @ L
113
+ return L
114
+
115
+ def safe_state(silent, seed=0):
116
+ old_f = sys.stdout
117
+ class F:
118
+ def __init__(self, silent):
119
+ self.silent = silent
120
+
121
+ def write(self, x):
122
+ if not self.silent:
123
+ if x.endswith("\n"):
124
+ old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
125
+ else:
126
+ old_f.write(x)
127
+
128
+ def flush(self):
129
+ old_f.flush()
130
+
131
+ sys.stdout = F(silent)
132
+ random.seed(seed)
133
+ np.random.seed(seed)
134
+ torch.manual_seed(seed)
135
+ torch.cuda.manual_seed_all(seed)
136
+
137
+ # if seed == 0:
138
+ torch.backends.cudnn.deterministic = True
139
+ torch.backends.cudnn.benchmark = False
140
+
141
+ # torch.cuda.set_device(torch.device("cuda:0"))
utils/graphics_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+ import math
14
+ import numpy as np
15
+ from typing import NamedTuple
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch.functional import norm
19
+
20
+
21
+ class BasicPointCloud(NamedTuple):
22
+ points : np.array
23
+ colors : np.array
24
+ normals : np.array
25
+
26
+ def geom_transform_points(points, transf_matrix):
27
+ P, _ = points.shape
28
+ ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
29
+ points_hom = torch.cat([points, ones], dim=1)
30
+ points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
31
+
32
+ denom = points_out[..., 3:] + 0.0000001
33
+ return (points_out[..., :3] / denom).squeeze(dim=0)
34
+
35
+ def getWorld2View(R, t):
36
+ Rt = np.zeros((4, 4))
37
+ Rt[:3, :3] = R.transpose()
38
+ Rt[:3, 3] = t
39
+ Rt[3, 3] = 1.0
40
+ return np.float32(Rt)
41
+
42
+ def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
43
+ Rt = np.zeros((4, 4))
44
+ Rt[:3, :3] = R.transpose()
45
+ Rt[:3, 3] = t
46
+ Rt[3, 3] = 1.0
47
+
48
+ C2W = np.linalg.inv(Rt)
49
+ cam_center = C2W[:3, 3]
50
+ cam_center = (cam_center + translate) * scale
51
+ C2W[:3, 3] = cam_center
52
+ Rt = np.linalg.inv(C2W)
53
+ return np.float32(Rt)
54
+
55
+ def getProjectionMatrix(znear, zfar, fovX, fovY):
56
+ tanHalfFovY = math.tan((fovY / 2))
57
+ tanHalfFovX = math.tan((fovX / 2))
58
+
59
+ top = tanHalfFovY * znear
60
+ bottom = -top
61
+ right = tanHalfFovX * znear
62
+ left = -right
63
+
64
+ P = torch.zeros(4, 4)
65
+
66
+ z_sign = 1.0
67
+
68
+ P[0, 0] = 2.0 * znear / (right - left)
69
+ P[1, 1] = 2.0 * znear / (top - bottom)
70
+ P[0, 2] = (right + left) / (right - left)
71
+ P[1, 2] = (top + bottom) / (top - bottom)
72
+ P[3, 2] = z_sign
73
+ P[2, 2] = z_sign * zfar / (zfar - znear)
74
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
75
+ return P
76
+
77
+ def fov2focal(fov, pixels):
78
+ return pixels / (2 * math.tan(fov / 2))
79
+
80
+ def focal2fov(focal, pixels):
81
+ return 2*math.atan(pixels/(2*focal))
utils/image_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # Copyright (C) 2023, Inria
3
+ # GRAPHDECO research group, https://team.inria.fr/graphdeco
4
+ # All rights reserved.
5
+ #
6
+ # This software is free for non-commercial, research and evaluation use
7
+ # under the terms of the LICENSE.md file.
8
+ #
9
+ # For inquiries contact george.drettakis@inria.fr
10
+ #
11
+
12
+ import torch
13
+
14
+ def mse(img1, img2):
15
+ return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
16
+
17
+ def psnr(img1, img2):
18
+ mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
19
+ return 20 * torch.log10(1.0 / torch.sqrt(mse))