Johannes Stelzer commited on
Commit
940cc9a
1 Parent(s): 7dbcdfe

new latent blending with diffusers, xl, ...

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +51 -1
  2. LICENSE +28 -0
  3. __pycache__/utils.cpython-311.pyc +0 -0
  4. animation.gif +0 -0
  5. configs/v1-inference.yaml +0 -70
  6. configs/v2-inference-v.yaml +0 -68
  7. configs/v2-inference.yaml +0 -67
  8. configs/v2-inpainting-inference.yaml +0 -158
  9. configs/v2-midas-inference.yaml +0 -74
  10. configs/x4-upscaling.yaml +0 -76
  11. example1.jpg +0 -0
  12. example_multi_trans.py +62 -0
  13. example_multi_trans_json.py +75 -0
  14. example_single_trans.py +23 -0
  15. gradio_ui.py +0 -500
  16. latentblending/__init__.py +3 -0
  17. latentblending/__pycache__/diffusers_holder.cpython-311.pyc +0 -0
  18. latent_blending.py → latentblending/blending_engine.py +273 -320
  19. latentblending/diffusers_holder.py +474 -0
  20. latentblending/gradio_ui.py +153 -0
  21. utils.py → latentblending/utils.py +3 -1
  22. ldm/__pycache__/util.cpython-310.pyc +0 -0
  23. ldm/__pycache__/util.cpython-38.pyc +0 -0
  24. ldm/__pycache__/util.cpython-39.pyc +0 -0
  25. ldm/data/__init__.py +0 -0
  26. ldm/data/util.py +0 -24
  27. ldm/ldm +0 -1
  28. ldm/models/__pycache__/autoencoder.cpython-310.pyc +0 -0
  29. ldm/models/__pycache__/autoencoder.cpython-38.pyc +0 -0
  30. ldm/models/__pycache__/autoencoder.cpython-39.pyc +0 -0
  31. ldm/models/autoencoder.py +0 -219
  32. ldm/models/diffusion/__init__.py +0 -0
  33. ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
  34. ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc +0 -0
  35. ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  36. ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc +0 -0
  37. ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc +0 -0
  38. ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc +0 -0
  39. ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc +0 -0
  40. ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc +0 -0
  41. ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc +0 -0
  42. ldm/models/diffusion/__pycache__/plms.cpython-39.pyc +0 -0
  43. ldm/models/diffusion/__pycache__/sampling_util.cpython-39.pyc +0 -0
  44. ldm/models/diffusion/ddim.py +0 -336
  45. ldm/models/diffusion/ddpm.py +0 -1795
  46. ldm/models/diffusion/dpm_solver/__init__.py +0 -1
  47. ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-39.pyc +0 -0
  48. ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-39.pyc +0 -0
  49. ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-39.pyc +0 -0
  50. ldm/models/diffusion/dpm_solver/dpm_solver.py +0 -1154
Dockerfile CHANGED
@@ -1 +1,51 @@
1
- echo "test"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
2
+
3
+ # Configure environment
4
+ ENV DEBIAN_FRONTEND=noninteractive \
5
+ PIP_PREFER_BINARY=1 \
6
+ CUDA_HOME=/usr/local/cuda-12.1 \
7
+ TORCH_CUDA_ARCH_LIST="8.6"
8
+
9
+ # Redirect shell
10
+ RUN rm /bin/sh && ln -s /bin/bash /bin/sh
11
+
12
+ # Install prereqs
13
+ RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ curl \
15
+ git-lfs \
16
+ ffmpeg \
17
+ libgl1-mesa-dev \
18
+ libglib2.0-0 \
19
+ git \
20
+ python3-dev \
21
+ python3-pip \
22
+ # Lunar Tools prereqs
23
+ libasound2-dev \
24
+ libportaudio2 \
25
+ && apt clean && rm -rf /var/lib/apt/lists/* \
26
+ && ln -s /usr/bin/python3 /usr/bin/python
27
+
28
+ # Set symbolic links
29
+ RUN echo "export PATH=/usr/local/cuda/bin:$PATH" >> /etc/bash.bashrc \
30
+ && echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH" >> /etc/bash. bashrc \
31
+ && echo "export CUDA_HOME=/usr/local/cuda-12.1" >> /etc/bash.bashrc
32
+
33
+ # Install Python packages: Basic, then CUDA-compatible, then custom
34
+ RUN pip3 install \
35
+ wheel \
36
+ ninja && \
37
+ pip3 install \
38
+ torch==2.1.0 \
39
+ torchvision==0.16.0 \
40
+ xformers>=0.0.22 \
41
+ triton>=2.1.0 \
42
+ --index-url https://download.pytorch.org/whl/cu121 && \
43
+ pip3 install git+https://github.com/lunarring/latentblending \
44
+ git+https://github.com/chengzeyi/stable-fast.git@main#egg=stable-fast
45
+
46
+ # Optionally store weights in image
47
+ # RUN mkdir -p /root/.cache/torch/hub/checkpoints/ && curl -o /root/.cache/torch/hub/checkpoints//alexnet-owt-7be5be79.pth https://download.pytorch.org/models/alexnet-owt-7be5be79.pth
48
+ # RUN git lfs install && git clone https://huggingface.co/stabilityai/sdxl-turbo /sdxl-turbo
49
+
50
+ # Clone base repo because why not
51
+ RUN git clone https://github.com/lunarring/latentblending.git
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, Lunar Ring
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
__pycache__/utils.cpython-311.pyc ADDED
Binary file (12.6 kB). View file
animation.gif ADDED
configs/v1-inference.yaml DELETED
@@ -1,70 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-04
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "jpg"
11
- cond_stage_key: "txt"
12
- image_size: 64
13
- channels: 4
14
- cond_stage_trainable: false # Note: different from the one we trained before
15
- conditioning_key: crossattn
16
- monitor: val/loss_simple_ema
17
- scale_factor: 0.18215
18
- use_ema: False
19
-
20
- scheduler_config: # 10000 warmup steps
21
- target: ldm.lr_scheduler.LambdaLinearScheduler
22
- params:
23
- warm_up_steps: [ 10000 ]
24
- cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
- f_start: [ 1.e-6 ]
26
- f_max: [ 1. ]
27
- f_min: [ 1. ]
28
-
29
- unet_config:
30
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
- params:
32
- image_size: 32 # unused
33
- in_channels: 4
34
- out_channels: 4
35
- model_channels: 320
36
- attention_resolutions: [ 4, 2, 1 ]
37
- num_res_blocks: 2
38
- channel_mult: [ 1, 2, 4, 4 ]
39
- num_heads: 8
40
- use_spatial_transformer: True
41
- transformer_depth: 1
42
- context_dim: 768
43
- use_checkpoint: True
44
- legacy: False
45
-
46
- first_stage_config:
47
- target: ldm.models.autoencoder.AutoencoderKL
48
- params:
49
- embed_dim: 4
50
- monitor: val/rec_loss
51
- ddconfig:
52
- double_z: true
53
- z_channels: 4
54
- resolution: 256
55
- in_channels: 3
56
- out_ch: 3
57
- ch: 128
58
- ch_mult:
59
- - 1
60
- - 2
61
- - 4
62
- - 4
63
- num_res_blocks: 2
64
- attn_resolutions: []
65
- dropout: 0.0
66
- lossconfig:
67
- target: torch.nn.Identity
68
-
69
- cond_stage_config:
70
- target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/v2-inference-v.yaml DELETED
@@ -1,68 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- parameterization: "v"
6
- linear_start: 0.00085
7
- linear_end: 0.0120
8
- num_timesteps_cond: 1
9
- log_every_t: 200
10
- timesteps: 1000
11
- first_stage_key: "jpg"
12
- cond_stage_key: "txt"
13
- image_size: 64
14
- channels: 4
15
- cond_stage_trainable: false
16
- conditioning_key: crossattn
17
- monitor: val/loss_simple_ema
18
- scale_factor: 0.18215
19
- use_ema: False # we set this to false because this is an inference only config
20
-
21
- unet_config:
22
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
- params:
24
- use_checkpoint: True
25
- use_fp16: True
26
- image_size: 32 # unused
27
- in_channels: 4
28
- out_channels: 4
29
- model_channels: 320
30
- attention_resolutions: [ 4, 2, 1 ]
31
- num_res_blocks: 2
32
- channel_mult: [ 1, 2, 4, 4 ]
33
- num_head_channels: 64 # need to fix for flash-attn
34
- use_spatial_transformer: True
35
- use_linear_in_transformer: True
36
- transformer_depth: 1
37
- context_dim: 1024
38
- legacy: False
39
-
40
- first_stage_config:
41
- target: ldm.models.autoencoder.AutoencoderKL
42
- params:
43
- embed_dim: 4
44
- monitor: val/rec_loss
45
- ddconfig:
46
- #attn_type: "vanilla-xformers"
47
- double_z: true
48
- z_channels: 4
49
- resolution: 256
50
- in_channels: 3
51
- out_ch: 3
52
- ch: 128
53
- ch_mult:
54
- - 1
55
- - 2
56
- - 4
57
- - 4
58
- num_res_blocks: 2
59
- attn_resolutions: []
60
- dropout: 0.0
61
- lossconfig:
62
- target: torch.nn.Identity
63
-
64
- cond_stage_config:
65
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
66
- params:
67
- freeze: True
68
- layer: "penultimate"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/v2-inference.yaml DELETED
@@ -1,67 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-4
3
- target: ldm.models.diffusion.ddpm.LatentDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "jpg"
11
- cond_stage_key: "txt"
12
- image_size: 64
13
- channels: 4
14
- cond_stage_trainable: false
15
- conditioning_key: crossattn
16
- monitor: val/loss_simple_ema
17
- scale_factor: 0.18215
18
- use_ema: False # we set this to false because this is an inference only config
19
-
20
- unet_config:
21
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22
- params:
23
- use_checkpoint: True
24
- use_fp16: True
25
- image_size: 32 # unused
26
- in_channels: 4
27
- out_channels: 4
28
- model_channels: 320
29
- attention_resolutions: [ 4, 2, 1 ]
30
- num_res_blocks: 2
31
- channel_mult: [ 1, 2, 4, 4 ]
32
- num_head_channels: 64 # need to fix for flash-attn
33
- use_spatial_transformer: True
34
- use_linear_in_transformer: True
35
- transformer_depth: 1
36
- context_dim: 1024
37
- legacy: False
38
-
39
- first_stage_config:
40
- target: ldm.models.autoencoder.AutoencoderKL
41
- params:
42
- embed_dim: 4
43
- monitor: val/rec_loss
44
- ddconfig:
45
- #attn_type: "vanilla-xformers"
46
- double_z: true
47
- z_channels: 4
48
- resolution: 256
49
- in_channels: 3
50
- out_ch: 3
51
- ch: 128
52
- ch_mult:
53
- - 1
54
- - 2
55
- - 4
56
- - 4
57
- num_res_blocks: 2
58
- attn_resolutions: []
59
- dropout: 0.0
60
- lossconfig:
61
- target: torch.nn.Identity
62
-
63
- cond_stage_config:
64
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
65
- params:
66
- freeze: True
67
- layer: "penultimate"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/v2-inpainting-inference.yaml DELETED
@@ -1,158 +0,0 @@
1
- model:
2
- base_learning_rate: 5.0e-05
3
- target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "jpg"
11
- cond_stage_key: "txt"
12
- image_size: 64
13
- channels: 4
14
- cond_stage_trainable: false
15
- conditioning_key: hybrid
16
- scale_factor: 0.18215
17
- monitor: val/loss_simple_ema
18
- finetune_keys: null
19
- use_ema: False
20
-
21
- unet_config:
22
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
23
- params:
24
- use_checkpoint: True
25
- image_size: 32 # unused
26
- in_channels: 9
27
- out_channels: 4
28
- model_channels: 320
29
- attention_resolutions: [ 4, 2, 1 ]
30
- num_res_blocks: 2
31
- channel_mult: [ 1, 2, 4, 4 ]
32
- num_head_channels: 64 # need to fix for flash-attn
33
- use_spatial_transformer: True
34
- use_linear_in_transformer: True
35
- transformer_depth: 1
36
- context_dim: 1024
37
- legacy: False
38
-
39
- first_stage_config:
40
- target: ldm.models.autoencoder.AutoencoderKL
41
- params:
42
- embed_dim: 4
43
- monitor: val/rec_loss
44
- ddconfig:
45
- #attn_type: "vanilla-xformers"
46
- double_z: true
47
- z_channels: 4
48
- resolution: 256
49
- in_channels: 3
50
- out_ch: 3
51
- ch: 128
52
- ch_mult:
53
- - 1
54
- - 2
55
- - 4
56
- - 4
57
- num_res_blocks: 2
58
- attn_resolutions: [ ]
59
- dropout: 0.0
60
- lossconfig:
61
- target: torch.nn.Identity
62
-
63
- cond_stage_config:
64
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
65
- params:
66
- freeze: True
67
- layer: "penultimate"
68
-
69
-
70
- data:
71
- target: ldm.data.laion.WebDataModuleFromConfig
72
- params:
73
- tar_base: null # for concat as in LAION-A
74
- p_unsafe_threshold: 0.1
75
- filter_word_list: "data/filters.yaml"
76
- max_pwatermark: 0.45
77
- batch_size: 8
78
- num_workers: 6
79
- multinode: True
80
- min_size: 512
81
- train:
82
- shards:
83
- - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-0/{00000..18699}.tar -"
84
- - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-1/{00000..18699}.tar -"
85
- - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-2/{00000..18699}.tar -"
86
- - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-3/{00000..18699}.tar -"
87
- - "pipe:aws s3 cp s3://stability-aws/laion-a-native/part-4/{00000..18699}.tar -" #{00000-94333}.tar"
88
- shuffle: 10000
89
- image_key: jpg
90
- image_transforms:
91
- - target: torchvision.transforms.Resize
92
- params:
93
- size: 512
94
- interpolation: 3
95
- - target: torchvision.transforms.RandomCrop
96
- params:
97
- size: 512
98
- postprocess:
99
- target: ldm.data.laion.AddMask
100
- params:
101
- mode: "512train-large"
102
- p_drop: 0.25
103
- # NOTE use enough shards to avoid empty validation loops in workers
104
- validation:
105
- shards:
106
- - "pipe:aws s3 cp s3://deep-floyd-s3/datasets/laion_cleaned-part5/{93001..94333}.tar - "
107
- shuffle: 0
108
- image_key: jpg
109
- image_transforms:
110
- - target: torchvision.transforms.Resize
111
- params:
112
- size: 512
113
- interpolation: 3
114
- - target: torchvision.transforms.CenterCrop
115
- params:
116
- size: 512
117
- postprocess:
118
- target: ldm.data.laion.AddMask
119
- params:
120
- mode: "512train-large"
121
- p_drop: 0.25
122
-
123
- lightning:
124
- find_unused_parameters: True
125
- modelcheckpoint:
126
- params:
127
- every_n_train_steps: 5000
128
-
129
- callbacks:
130
- metrics_over_trainsteps_checkpoint:
131
- params:
132
- every_n_train_steps: 10000
133
-
134
- image_logger:
135
- target: main.ImageLogger
136
- params:
137
- enable_autocast: False
138
- disabled: False
139
- batch_frequency: 1000
140
- max_images: 4
141
- increase_log_steps: False
142
- log_first_step: False
143
- log_images_kwargs:
144
- use_ema_scope: False
145
- inpaint: False
146
- plot_progressive_rows: False
147
- plot_diffusion_rows: False
148
- N: 4
149
- unconditional_guidance_scale: 5.0
150
- unconditional_guidance_label: [""]
151
- ddim_steps: 50 # todo check these out for depth2img,
152
- ddim_eta: 0.0 # todo check these out for depth2img,
153
-
154
- trainer:
155
- benchmark: True
156
- val_check_interval: 5000000
157
- num_sanity_val_steps: 0
158
- accumulate_grad_batches: 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/v2-midas-inference.yaml DELETED
@@ -1,74 +0,0 @@
1
- model:
2
- base_learning_rate: 5.0e-07
3
- target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
4
- params:
5
- linear_start: 0.00085
6
- linear_end: 0.0120
7
- num_timesteps_cond: 1
8
- log_every_t: 200
9
- timesteps: 1000
10
- first_stage_key: "jpg"
11
- cond_stage_key: "txt"
12
- image_size: 64
13
- channels: 4
14
- cond_stage_trainable: false
15
- conditioning_key: hybrid
16
- scale_factor: 0.18215
17
- monitor: val/loss_simple_ema
18
- finetune_keys: null
19
- use_ema: False
20
-
21
- depth_stage_config:
22
- target: ldm.modules.midas.api.MiDaSInference
23
- params:
24
- model_type: "dpt_hybrid"
25
-
26
- unet_config:
27
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
28
- params:
29
- use_checkpoint: True
30
- image_size: 32 # unused
31
- in_channels: 5
32
- out_channels: 4
33
- model_channels: 320
34
- attention_resolutions: [ 4, 2, 1 ]
35
- num_res_blocks: 2
36
- channel_mult: [ 1, 2, 4, 4 ]
37
- num_head_channels: 64 # need to fix for flash-attn
38
- use_spatial_transformer: True
39
- use_linear_in_transformer: True
40
- transformer_depth: 1
41
- context_dim: 1024
42
- legacy: False
43
-
44
- first_stage_config:
45
- target: ldm.models.autoencoder.AutoencoderKL
46
- params:
47
- embed_dim: 4
48
- monitor: val/rec_loss
49
- ddconfig:
50
- #attn_type: "vanilla-xformers"
51
- double_z: true
52
- z_channels: 4
53
- resolution: 256
54
- in_channels: 3
55
- out_ch: 3
56
- ch: 128
57
- ch_mult:
58
- - 1
59
- - 2
60
- - 4
61
- - 4
62
- num_res_blocks: 2
63
- attn_resolutions: [ ]
64
- dropout: 0.0
65
- lossconfig:
66
- target: torch.nn.Identity
67
-
68
- cond_stage_config:
69
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
70
- params:
71
- freeze: True
72
- layer: "penultimate"
73
-
74
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs/x4-upscaling.yaml DELETED
@@ -1,76 +0,0 @@
1
- model:
2
- base_learning_rate: 1.0e-04
3
- target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
4
- params:
5
- parameterization: "v"
6
- low_scale_key: "lr"
7
- linear_start: 0.0001
8
- linear_end: 0.02
9
- num_timesteps_cond: 1
10
- log_every_t: 200
11
- timesteps: 1000
12
- first_stage_key: "jpg"
13
- cond_stage_key: "txt"
14
- image_size: 128
15
- channels: 4
16
- cond_stage_trainable: false
17
- conditioning_key: "hybrid-adm"
18
- monitor: val/loss_simple_ema
19
- scale_factor: 0.08333
20
- use_ema: False
21
-
22
- low_scale_config:
23
- target: ldm.modules.diffusionmodules.upscaling.ImageConcatWithNoiseAugmentation
24
- params:
25
- noise_schedule_config: # image space
26
- linear_start: 0.0001
27
- linear_end: 0.02
28
- max_noise_level: 350
29
-
30
- unet_config:
31
- target: ldm.modules.diffusionmodules.openaimodel.UNetModel
32
- params:
33
- use_checkpoint: True
34
- num_classes: 1000 # timesteps for noise conditioning (here constant, just need one)
35
- image_size: 128
36
- in_channels: 7
37
- out_channels: 4
38
- model_channels: 256
39
- attention_resolutions: [ 2,4,8]
40
- num_res_blocks: 2
41
- channel_mult: [ 1, 2, 2, 4]
42
- disable_self_attentions: [True, True, True, False]
43
- disable_middle_self_attn: False
44
- num_heads: 8
45
- use_spatial_transformer: True
46
- transformer_depth: 1
47
- context_dim: 1024
48
- legacy: False
49
- use_linear_in_transformer: True
50
-
51
- first_stage_config:
52
- target: ldm.models.autoencoder.AutoencoderKL
53
- params:
54
- embed_dim: 4
55
- ddconfig:
56
- # attn_type: "vanilla-xformers" this model needs efficient attention to be feasible on HR data, also the decoder seems to break in half precision (UNet is fine though)
57
- double_z: True
58
- z_channels: 4
59
- resolution: 256
60
- in_channels: 3
61
- out_ch: 3
62
- ch: 128
63
- ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
64
- num_res_blocks: 2
65
- attn_resolutions: [ ]
66
- dropout: 0.0
67
-
68
- lossconfig:
69
- target: torch.nn.Identity
70
-
71
- cond_stage_config:
72
- target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
73
- params:
74
- freeze: True
75
- layer: "penultimate"
76
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example1.jpg ADDED
example_multi_trans.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from diffusers import AutoPipelineForText2Image
4
+ from lunar_tools import concatenate_movies
5
+ from latentblending.blending_engine import BlendingEngine
6
+ import numpy as np
7
+ torch.set_grad_enabled(False)
8
+ torch.backends.cudnn.benchmark = False
9
+ warnings.filterwarnings('ignore')
10
+
11
+ # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
12
+ pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
13
+ # pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
14
+
15
+ pipe = AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
16
+ pipe.to('cuda')
17
+ be = BlendingEngine(pipe, do_compile=True)
18
+ be.set_negative_prompt("blurry, pale, low-res, lofi")
19
+ # %% Let's setup the multi transition
20
+ fps = 30
21
+ duration_single_trans = 10
22
+ be.set_dimensions((1024, 1024))
23
+ nmb_prompts = 20
24
+
25
+
26
+ # Specify a list of prompts below
27
+ #%%
28
+
29
+ list_prompts = []
30
+ list_prompts.append("high resolution ultra 8K image with lake and forest")
31
+ list_prompts.append("strange and alien desolate lanscapes 8K")
32
+ list_prompts.append("ultra high res psychedelic skyscraper city landscape 8K unreal engine")
33
+ #%%
34
+ fp_movie = f'surreal_nmb{len(list_prompts)}.mp4'
35
+ # Specify the seeds
36
+ list_seeds = np.random.randint(0, np.iinfo(np.int32).max, len(list_prompts))
37
+
38
+ list_movie_parts = []
39
+ for i in range(len(list_prompts) - 1):
40
+ # For a multi transition we can save some computation time and recycle the latents
41
+ if i == 0:
42
+ be.set_prompt1(list_prompts[i])
43
+ be.set_prompt2(list_prompts[i + 1])
44
+ recycle_img1 = False
45
+ else:
46
+ be.swap_forward()
47
+ be.set_prompt2(list_prompts[i + 1])
48
+ recycle_img1 = True
49
+
50
+ fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
51
+ fixed_seeds = list_seeds[i:i + 2]
52
+ # Run latent blending
53
+ be.run_transition(
54
+ recycle_img1=recycle_img1,
55
+ fixed_seeds=fixed_seeds)
56
+
57
+ # Save movie
58
+ be.write_movie_transition(fp_movie_part, duration_single_trans)
59
+ list_movie_parts.append(fp_movie_part)
60
+
61
+ # Finally, concatente the result
62
+ concatenate_movies(fp_movie, list_movie_parts)
example_multi_trans_json.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from diffusers import AutoPipelineForText2Image
4
+ from latentblending.blending_engine import BlendingEngine
5
+ from lunar_tools import concatenate_movies
6
+ import numpy as np
7
+ torch.set_grad_enabled(False)
8
+ torch.backends.cudnn.benchmark = False
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import json
12
+ # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
13
+ # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
14
+ pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
15
+
16
+ pipe = AutoPipelineForText2Image.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
17
+ pipe.to('cuda')
18
+ be = BlendingEngine(pipe, do_compile=False)
19
+
20
+ fp_movie = f'test.mp4'
21
+ fp_json = "movie_240221_1520.json"
22
+ duration_single_trans = 10
23
+
24
+ # Load the JSON data from the file
25
+ with open(fp_json, 'r') as file:
26
+ data = json.load(file)
27
+
28
+ # Set up width, height, num_inference steps
29
+ width = data[0]["width"]
30
+ height = data[0]["height"]
31
+ num_inference_steps = data[0]["num_inference_steps"]
32
+
33
+ be.set_dimensions((width, height))
34
+ be.set_num_inference_steps(num_inference_steps)
35
+
36
+ # Initialize lists for prompts, negative prompts, and seeds
37
+ list_prompts = []
38
+ list_negative_prompts = []
39
+ list_seeds = []
40
+
41
+ # Extract prompts, negative prompts, and seeds from the data
42
+ for item in data[1:]: # Skip the first item as it contains settings
43
+ list_prompts.append(item["prompt"])
44
+ list_negative_prompts.append(item["negative_prompt"])
45
+ list_seeds.append(item["seed"])
46
+
47
+
48
+ list_movie_parts = []
49
+ for i in range(len(list_prompts) - 1):
50
+ # For a multi transition we can save some computation time and recycle the latents
51
+ if i == 0:
52
+ be.set_prompt1(list_prompts[i])
53
+ be.set_negative_prompt(list_negative_prompts[i])
54
+ be.set_prompt2(list_prompts[i + 1])
55
+ recycle_img1 = False
56
+ else:
57
+ be.swap_forward()
58
+ be.set_negative_prompt(list_negative_prompts[i+1])
59
+ be.set_prompt2(list_prompts[i + 1])
60
+ recycle_img1 = True
61
+
62
+ fp_movie_part = f"tmp_part_{str(i).zfill(3)}.mp4"
63
+ fixed_seeds = list_seeds[i:i + 2]
64
+ # Run latent blending
65
+ be.run_transition(
66
+ recycle_img1=recycle_img1,
67
+ fixed_seeds=fixed_seeds)
68
+
69
+ # Save movie
70
+ be.write_movie_transition(fp_movie_part, duration_single_trans)
71
+ list_movie_parts.append(fp_movie_part)
72
+
73
+ # Finally, concatente the result
74
+ concatenate_movies(fp_movie, list_movie_parts)
75
+ print(f"DONE! MOVIE SAVED IN {fp_movie}")
example_single_trans.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import warnings
3
+ from diffusers import AutoPipelineForText2Image
4
+ from latentblending.blending_engine import BlendingEngine
5
+
6
+ warnings.filterwarnings('ignore')
7
+ torch.set_grad_enabled(False)
8
+ torch.backends.cudnn.benchmark = False
9
+
10
+ # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
11
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
12
+ pipe.to("cuda")
13
+
14
+ be = BlendingEngine(pipe)
15
+ be.set_prompt1("photo of underwater landscape, fish, und the sea, incredible detail, high resolution")
16
+ be.set_prompt2("rendering of an alien planet, strange plants, strange creatures, surreal")
17
+ be.set_negative_prompt("blurry, ugly, pale")
18
+
19
+ # Run latent blending
20
+ be.run_transition()
21
+
22
+ # Save movie
23
+ be.write_movie_transition('movie_example1.mp4', duration_transition=12)
gradio_ui.py DELETED
@@ -1,500 +0,0 @@
1
- # Copyright 2022 Lunar Ring. All rights reserved.
2
- # Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import os
17
- import torch
18
- torch.backends.cudnn.benchmark = False
19
- torch.set_grad_enabled(False)
20
- import numpy as np
21
- import warnings
22
- warnings.filterwarnings('ignore')
23
- import warnings
24
- from tqdm.auto import tqdm
25
- from PIL import Image
26
- from movie_util import MovieSaver, concatenate_movies
27
- from latent_blending import LatentBlending
28
- from stable_diffusion_holder import StableDiffusionHolder
29
- import gradio as gr
30
- from dotenv import find_dotenv, load_dotenv
31
- import shutil
32
- import uuid
33
- from utils import get_time, add_frames_linear_interp
34
- from huggingface_hub import hf_hub_download
35
-
36
-
37
- class BlendingFrontend():
38
- def __init__(
39
- self,
40
- sdh,
41
- share=False):
42
- r"""
43
- Gradio Helper Class to collect UI data and start latent blending.
44
- Args:
45
- sdh:
46
- StableDiffusionHolder
47
- share: bool
48
- Set true to get a shareable gradio link (e.g. for running a remote server)
49
- """
50
- self.share = share
51
-
52
- # UI Defaults
53
- self.num_inference_steps = 30
54
- self.depth_strength = 0.25
55
- self.seed1 = 420
56
- self.seed2 = 420
57
- self.prompt1 = ""
58
- self.prompt2 = ""
59
- self.negative_prompt = ""
60
- self.fps = 30
61
- self.duration_video = 8
62
- self.t_compute_max_allowed = 10
63
-
64
- self.lb = LatentBlending(sdh)
65
- self.lb.sdh.num_inference_steps = self.num_inference_steps
66
- self.init_parameters_from_lb()
67
- self.init_save_dir()
68
-
69
- # Vars
70
- self.list_fp_imgs_current = []
71
- self.recycle_img1 = False
72
- self.recycle_img2 = False
73
- self.list_all_segments = []
74
- self.dp_session = ""
75
- self.user_id = None
76
-
77
- def init_parameters_from_lb(self):
78
- r"""
79
- Automatically init parameters from latentblending instance
80
- """
81
- self.height = self.lb.sdh.height
82
- self.width = self.lb.sdh.width
83
- self.guidance_scale = self.lb.guidance_scale
84
- self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
85
- self.mid_compression_scaler = self.lb.mid_compression_scaler
86
- self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
87
- self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
88
- self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
89
- self.parental_crossfeed_power = self.lb.parental_crossfeed_power
90
- self.parental_crossfeed_range = self.lb.parental_crossfeed_range
91
- self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
92
-
93
- def init_save_dir(self):
94
- r"""
95
- Initializes the directory where stuff is being saved.
96
- You can specify this directory in a ".env" file in your latentblending root, setting
97
- DIR_OUT='/path/to/saving'
98
- """
99
- load_dotenv(find_dotenv(), verbose=False)
100
- self.dp_out = os.getenv("DIR_OUT")
101
- if self.dp_out is None:
102
- self.dp_out = ""
103
- self.dp_imgs = os.path.join(self.dp_out, "imgs")
104
- os.makedirs(self.dp_imgs, exist_ok=True)
105
- self.dp_movies = os.path.join(self.dp_out, "movies")
106
- os.makedirs(self.dp_movies, exist_ok=True)
107
- self.save_empty_image()
108
-
109
- def save_empty_image(self):
110
- r"""
111
- Saves an empty/black dummy image.
112
- """
113
- self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
114
- Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
115
-
116
- def randomize_seed1(self):
117
- r"""
118
- Randomizes the first seed
119
- """
120
- seed = np.random.randint(0, 10000000)
121
- self.seed1 = int(seed)
122
- print(f"randomize_seed1: new seed = {self.seed1}")
123
- return seed
124
-
125
- def randomize_seed2(self):
126
- r"""
127
- Randomizes the second seed
128
- """
129
- seed = np.random.randint(0, 10000000)
130
- self.seed2 = int(seed)
131
- print(f"randomize_seed2: new seed = {self.seed2}")
132
- return seed
133
-
134
- def setup_lb(self, list_ui_vals):
135
- r"""
136
- Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
137
- we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
138
- """
139
- # Collect latent blending variables
140
- self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
141
- self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
142
- self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
143
- self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
144
- self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
145
- self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
146
- self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
147
- self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
148
- self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
149
- self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
150
- self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
151
- self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
152
- self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
153
- self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
154
- self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
155
- self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
156
- self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
157
- self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
158
- self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
159
- self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
160
- self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
161
-
162
- if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
163
- self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
164
- else:
165
- # generate new user id
166
- self.user_id = uuid.uuid4().hex
167
- print(f"made new user_id: {self.user_id} at {get_time('second')}")
168
-
169
- def save_latents(self, fp_latents, list_latents):
170
- r"""
171
- Saves a latent trajectory on disk, in npy format.
172
- """
173
- list_latents_cpu = [l.cpu().numpy() for l in list_latents]
174
- np.save(fp_latents, list_latents_cpu)
175
-
176
- def load_latents(self, fp_latents):
177
- r"""
178
- Loads a latent trajectory from disk, converts to torch tensor.
179
- """
180
- list_latents_cpu = np.load(fp_latents)
181
- list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
182
- return list_latents
183
-
184
- def compute_img1(self, *args):
185
- r"""
186
- Computes the first transition image and returns it for display.
187
- Sets all other transition images and last image to empty (as they are obsolete with this operation)
188
- """
189
- list_ui_vals = args
190
- self.setup_lb(list_ui_vals)
191
- fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
192
- img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
193
- img1.save(fp_img1 + ".jpg")
194
- self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
195
- self.recycle_img1 = True
196
- self.recycle_img2 = False
197
- return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
198
-
199
- def compute_img2(self, *args):
200
- r"""
201
- Computes the last transition image and returns it for display.
202
- Sets all other transition images to empty (as they are obsolete with this operation)
203
- """
204
- if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
205
- return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
206
- list_ui_vals = args
207
- self.setup_lb(list_ui_vals)
208
-
209
- self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
210
- fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
211
- img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
212
- img2.save(fp_img2 + '.jpg')
213
- self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
214
- self.recycle_img2 = True
215
- # fixme save seeds. change filenames?
216
- return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
217
-
218
- def compute_transition(self, *args):
219
- r"""
220
- Computes transition images and movie.
221
- """
222
- list_ui_vals = args
223
- self.setup_lb(list_ui_vals)
224
- print("STARTING TRANSITION...")
225
- fixed_seeds = [self.seed1, self.seed2]
226
- # Inject loaded latents (other user interference)
227
- self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
228
- self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
229
- imgs_transition = self.lb.run_transition(
230
- recycle_img1=self.recycle_img1,
231
- recycle_img2=self.recycle_img2,
232
- num_inference_steps=self.num_inference_steps,
233
- depth_strength=self.depth_strength,
234
- t_compute_max_allowed=self.t_compute_max_allowed,
235
- fixed_seeds=fixed_seeds)
236
- print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
237
-
238
- # Subselect three preview images
239
- idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
240
-
241
- list_imgs_preview = []
242
- for j in idx_img_prev:
243
- list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
244
-
245
- # Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
246
- current_timestamp = get_time('second')
247
- self.list_fp_imgs_current = []
248
- for i in range(len(list_imgs_preview)):
249
- fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
250
- list_imgs_preview[i].save(fp_img)
251
- self.list_fp_imgs_current.append(fp_img)
252
- # Insert cheap frames for the movie
253
- imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
254
-
255
- # Save as movie
256
- self.fp_movie = self.get_fp_video_last()
257
- if os.path.isfile(self.fp_movie):
258
- os.remove(self.fp_movie)
259
- ms = MovieSaver(self.fp_movie, fps=self.fps)
260
- for img in tqdm(imgs_transition_ext):
261
- ms.write_frame(img)
262
- ms.finalize()
263
- print("DONE SAVING MOVIE! SENDING BACK...")
264
-
265
- # Assemble Output, updating the preview images and le movie
266
- list_return = self.list_fp_imgs_current + [self.fp_movie]
267
- return list_return
268
-
269
- def stack_forward(self, prompt2, seed2):
270
- r"""
271
- Allows to generate multi-segment movies. Sets last image -> first image with all
272
- relevant parameters.
273
- """
274
- # Save preview images, prompts and seeds into dictionary for stacking
275
- if len(self.list_all_segments) == 0:
276
- timestamp_session = get_time('second')
277
- self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
278
- os.makedirs(self.dp_session)
279
-
280
- idx_segment = len(self.list_all_segments)
281
- dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
282
-
283
- self.list_all_segments.append(dp_segment)
284
- self.lb.write_imgs_transition(dp_segment)
285
-
286
- fp_movie_last = self.get_fp_video_last()
287
- fp_movie_next = self.get_fp_video_next()
288
-
289
- shutil.copyfile(fp_movie_last, fp_movie_next)
290
-
291
- self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
292
- self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
293
- self.lb.swap_forward()
294
-
295
- shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
296
- fp_multi = self.multi_concat()
297
- list_out = [fp_multi]
298
-
299
- list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
300
- list_out.extend([self.fp_img_empty] * 4)
301
- list_out.append(gr.update(interactive=False, value=prompt2))
302
- list_out.append(gr.update(interactive=False, value=seed2))
303
- list_out.append("")
304
- list_out.append(np.random.randint(0, 10000000))
305
- print(f"stack_forward: fp_multi {fp_multi}")
306
- return list_out
307
-
308
- def multi_concat(self):
309
- r"""
310
- Concatentates all stacked segments into one long movie.
311
- """
312
- list_fp_movies = self.get_fp_video_all()
313
- # Concatenate movies and save
314
- fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
315
- concatenate_movies(fp_final, list_fp_movies)
316
- return fp_final
317
-
318
- def get_fp_video_all(self):
319
- r"""
320
- Collects all stacked movie segments.
321
- """
322
- list_all = os.listdir(self.dp_movies)
323
- str_beg = f"movie_{self.user_id}_"
324
- list_user = [l for l in list_all if str_beg in l]
325
- list_user.sort()
326
- list_user = [os.path.join(self.dp_movies, l) for l in list_user]
327
- return list_user
328
-
329
- def get_fp_video_next(self):
330
- r"""
331
- Gets the filepath of the next movie segment.
332
- """
333
- list_videos = self.get_fp_video_all()
334
- if len(list_videos) == 0:
335
- idx_next = 0
336
- else:
337
- idx_next = len(list_videos)
338
- fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4")
339
- return fp_video_next
340
-
341
- def get_fp_video_last(self):
342
- r"""
343
- Gets the current video that was saved.
344
- """
345
- fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
346
- return fp_video_last
347
-
348
-
349
- if __name__ == "__main__":
350
- fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
351
- # fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
352
- bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
353
- # self = BlendingFrontend(None)
354
-
355
- with gr.Blocks() as demo:
356
- gr.HTML("""<h1>Latent Blending</h1>
357
- <p>Create butter-smooth transitions between prompts, powered by stable diffusion</p>
358
- <p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
359
- <br/>
360
- <a href="https://huggingface.co/spaces/lunarring/latentblending?duplicate=true">
361
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
362
- </p>""")
363
-
364
- with gr.Row():
365
- prompt1 = gr.Textbox(label="prompt 1")
366
- prompt2 = gr.Textbox(label="prompt 2")
367
-
368
- with gr.Row():
369
- duration_compute = gr.Slider(10, 25, bf.t_compute_max_allowed, step=1, label='waiting time', interactive=True)
370
- duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='video duration', interactive=True)
371
- height = gr.Slider(256, 1024, bf.height, step=128, label='height', interactive=True)
372
- width = gr.Slider(256, 1024, bf.width, step=128, label='width', interactive=True)
373
-
374
- with gr.Accordion("Advanced Settings (click to expand)", open=False):
375
-
376
- with gr.Accordion("Diffusion settings", open=True):
377
- with gr.Row():
378
- num_inference_steps = gr.Slider(5, 100, bf.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
379
- guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
380
- negative_prompt = gr.Textbox(label="negative prompt")
381
-
382
- with gr.Accordion("Seed control: adjust seeds for first and last images", open=True):
383
- with gr.Row():
384
- b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
385
- seed1 = gr.Number(bf.seed1, label="seed 1", interactive=True)
386
- seed2 = gr.Number(bf.seed2, label="seed 2", interactive=True)
387
- b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
388
-
389
- with gr.Accordion("Last image crossfeeding.", open=True):
390
- with gr.Row():
391
- branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
392
- branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
393
- branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
394
-
395
- with gr.Accordion("Transition settings", open=True):
396
- with gr.Row():
397
- parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
398
- parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
399
- parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
400
- with gr.Row():
401
- depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
402
- guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
403
-
404
- with gr.Row():
405
- b_compute1 = gr.Button('step1: compute first image', variant='primary')
406
- b_compute2 = gr.Button('step2: compute last image', variant='primary')
407
- b_compute_transition = gr.Button('step3: compute transition', variant='primary')
408
-
409
- with gr.Row():
410
- img1 = gr.Image(label="1/5")
411
- img2 = gr.Image(label="2/5", show_progress=False)
412
- img3 = gr.Image(label="3/5", show_progress=False)
413
- img4 = gr.Image(label="4/5", show_progress=False)
414
- img5 = gr.Image(label="5/5")
415
-
416
- with gr.Row():
417
- vid_single = gr.Video(label="current single trans")
418
- vid_multi = gr.Video(label="concatented multi trans")
419
-
420
- with gr.Row():
421
- b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
422
-
423
- with gr.Row():
424
- gr.Markdown(
425
- """
426
- # Parameters
427
- ## Main
428
- - waiting time: set your waiting time for the transition. high values = better quality
429
- - video duration: seconds per segment
430
- - height/width: in pixels
431
-
432
- ## Diffusion settings
433
- - num_inference_steps: number of diffusion steps
434
- - guidance_scale: latent blending seems to prefer lower values here
435
- - negative prompt: enter negative prompt here, applied for all images
436
-
437
- ## Last image crossfeeding
438
- - branch1_crossfeed_power: Controls the level of cross-feeding between the first and last image branch. For preserving structures.
439
- - branch1_crossfeed_range: Sets the duration of active crossfeed during development. High values enforce strong structural similarity.
440
- - branch1_crossfeed_decay: Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
441
-
442
- ## Transition settings
443
- - parental_crossfeed_power: Similar to branch1_crossfeed_power, however applied for the images withinin the transition.
444
- - parental_crossfeed_range: Similar to branch1_crossfeed_range, however applied for the images withinin the transition.
445
- - parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
446
- - depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
447
- - guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
448
- """)
449
-
450
- with gr.Row():
451
- user_id = gr.Textbox(label="user id", interactive=False)
452
-
453
- # Collect all UI elemts in list to easily pass as inputs in gradio
454
- dict_ui_elem = {}
455
- dict_ui_elem["prompt1"] = prompt1
456
- dict_ui_elem["negative_prompt"] = negative_prompt
457
- dict_ui_elem["prompt2"] = prompt2
458
-
459
- dict_ui_elem["duration_compute"] = duration_compute
460
- dict_ui_elem["duration_video"] = duration_video
461
- dict_ui_elem["height"] = height
462
- dict_ui_elem["width"] = width
463
-
464
- dict_ui_elem["depth_strength"] = depth_strength
465
- dict_ui_elem["branch1_crossfeed_power"] = branch1_crossfeed_power
466
- dict_ui_elem["branch1_crossfeed_range"] = branch1_crossfeed_range
467
- dict_ui_elem["branch1_crossfeed_decay"] = branch1_crossfeed_decay
468
-
469
- dict_ui_elem["num_inference_steps"] = num_inference_steps
470
- dict_ui_elem["guidance_scale"] = guidance_scale
471
- dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
472
- dict_ui_elem["seed1"] = seed1
473
- dict_ui_elem["seed2"] = seed2
474
-
475
- dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
476
- dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
477
- dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
478
- dict_ui_elem["user_id"] = user_id
479
-
480
- # Convert to list, as gradio doesn't seem to accept dicts
481
- list_ui_vals = []
482
- list_ui_keys = []
483
- for k in dict_ui_elem.keys():
484
- list_ui_vals.append(dict_ui_elem[k])
485
- list_ui_keys.append(k)
486
- bf.list_ui_keys = list_ui_keys
487
-
488
- b_newseed1.click(bf.randomize_seed1, outputs=seed1)
489
- b_newseed2.click(bf.randomize_seed2, outputs=seed2)
490
- b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
491
- b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
492
- b_compute_transition.click(bf.compute_transition,
493
- inputs=list_ui_vals,
494
- outputs=[img2, img3, img4, vid_single])
495
-
496
- b_stackforward.click(bf.stack_forward,
497
- inputs=[prompt2, seed2],
498
- outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
499
-
500
- demo.launch(share=bf.share, inbrowser=True, inline=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
latentblending/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ from .blending_engine import BlendingEngine
2
+ from .diffusers_holder import DiffusersHolder
3
+ from .utils import interpolate_spherical, add_frames_linear_interp, interpolate_linear, get_spacing, get_time, yml_load, yml_save
latentblending/__pycache__/diffusers_holder.cpython-311.pyc ADDED
Binary file (18.2 kB). View file
latent_blending.py → latentblending/blending_engine.py RENAMED
@@ -1,52 +1,33 @@
1
- # Copyright 2022 Lunar Ring. All rights reserved.
2
- # Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
  import os
17
  import torch
18
- torch.backends.cudnn.benchmark = False
19
- torch.set_grad_enabled(False)
20
  import numpy as np
21
  import warnings
22
- warnings.filterwarnings('ignore')
23
  import time
24
- import warnings
25
  from tqdm.auto import tqdm
26
  from PIL import Image
27
- from movie_util import MovieSaver
28
  from typing import List, Optional
29
- from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
30
  import lpips
31
- from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
 
 
 
 
 
 
32
 
33
 
34
- class LatentBlending():
35
  def __init__(
36
  self,
37
- sdh: None,
38
- guidance_scale: float = 4,
39
  guidance_scale_mid_damper: float = 0.5,
40
  mid_compression_scaler: float = 1.2):
41
  r"""
42
  Initializes the latent blending class.
43
  Args:
44
- guidance_scale: float
45
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
46
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
47
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
48
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
49
- usually at the expense of lower image quality.
50
  guidance_scale_mid_damper: float = 0.5
51
  Reduces the guidance scale towards the middle of the transition.
52
  A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5.
@@ -59,10 +40,11 @@ class LatentBlending():
59
  and guidance_scale_mid_damper <= 1.0, \
60
  f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
61
 
62
- self.sdh = sdh
63
- self.device = self.sdh.device
64
- self.width = self.sdh.width
65
- self.height = self.sdh.height
 
66
  self.guidance_scale_mid_damper = guidance_scale_mid_damper
67
  self.mid_compression_scaler = mid_compression_scaler
68
  self.seed1 = 0
@@ -71,7 +53,6 @@ class LatentBlending():
71
  # Initialize vars
72
  self.prompt1 = ""
73
  self.prompt2 = ""
74
- self.negative_prompt = ""
75
 
76
  self.tree_latents = [None, None]
77
  self.tree_fracts = None
@@ -79,61 +60,97 @@ class LatentBlending():
79
  self.tree_status = None
80
  self.tree_final_imgs = []
81
 
82
- self.list_nmb_branches_prev = []
83
- self.list_injection_idx_prev = []
84
  self.text_embedding1 = None
85
  self.text_embedding2 = None
86
  self.image1_lowres = None
87
  self.image2_lowres = None
88
  self.negative_prompt = None
89
- self.num_inference_steps = self.sdh.num_inference_steps
90
- self.noise_level_upscaling = 20
91
- self.list_injection_idx = None
92
- self.list_nmb_branches = None
93
-
94
- # Mixing parameters
95
- self.branch1_crossfeed_power = 0.1
96
- self.branch1_crossfeed_range = 0.6
97
- self.branch1_crossfeed_decay = 0.8
98
-
99
- self.parental_crossfeed_power = 0.1
100
- self.parental_crossfeed_range = 0.8
101
- self.parental_crossfeed_power_decay = 0.8
102
-
103
- self.set_guidance_scale(guidance_scale)
104
- self.init_mode()
105
  self.multi_transition_img_first = None
106
  self.multi_transition_img_last = None
107
- self.dt_per_diff = 0
108
- self.spatial_mask = None
109
- self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- def init_mode(self):
112
  r"""
113
- Sets the operational mode. Currently supported are standard, inpainting and x4 upscaling.
 
 
 
 
114
  """
115
- if isinstance(self.sdh.model, LatentUpscaleDiffusion):
116
- self.mode = 'upscale'
117
- elif isinstance(self.sdh.model, LatentInpaintDiffusion):
118
- self.sdh.image_source = None
119
- self.sdh.mask_image = None
120
- self.mode = 'inpaint'
121
- else:
122
- self.mode = 'standard'
123
 
124
- def set_guidance_scale(self, guidance_scale):
125
  r"""
126
  sets the guidance scale.
127
  """
 
 
 
 
 
 
128
  self.guidance_scale_base = guidance_scale
129
  self.guidance_scale = guidance_scale
130
- self.sdh.guidance_scale = guidance_scale
131
 
132
  def set_negative_prompt(self, negative_prompt):
133
  r"""Set the negative prompt. Currenty only one negative prompt is supported
134
  """
135
  self.negative_prompt = negative_prompt
136
- self.sdh.set_negative_prompt(negative_prompt)
137
 
138
  def set_guidance_mid_dampening(self, fract_mixing):
139
  r"""
@@ -144,9 +161,9 @@ class LatentBlending():
144
  max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
145
  guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
146
  self.guidance_scale = guidance_scale_effective
147
- self.sdh.guidance_scale = guidance_scale_effective
148
 
149
- def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
150
  r"""
151
  Sets the crossfeed parameters for the first branch to the last branch.
152
  Args:
@@ -161,7 +178,7 @@ class LatentBlending():
161
  self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
162
  self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
163
 
164
- def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
165
  r"""
166
  Sets the crossfeed parameters for all transition images (within the first and last branch).
167
  Args:
@@ -172,9 +189,22 @@ class LatentBlending():
172
  crossfeed_decay: float [0,1]
173
  Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
174
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  self.parental_crossfeed_power = np.clip(crossfeed_power, 0, 1)
176
  self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
177
- self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
178
 
179
  def set_prompt1(self, prompt: str):
180
  r"""
@@ -213,15 +243,59 @@ class LatentBlending():
213
  image: Image
214
  """
215
  self.image2_lowres = image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  def run_transition(
218
  self,
219
  recycle_img1: Optional[bool] = False,
220
  recycle_img2: Optional[bool] = False,
221
- num_inference_steps: Optional[int] = 30,
222
- depth_strength: Optional[float] = 0.3,
223
- t_compute_max_allowed: Optional[float] = None,
224
- nmb_max_branches: Optional[int] = None,
225
  fixed_seeds: Optional[List[int]] = None):
226
  r"""
227
  Function for computing transitions.
@@ -233,17 +307,7 @@ class LatentBlending():
233
  Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
234
  num_inference_steps:
235
  Number of diffusion steps. Higher values will take more compute time.
236
- depth_strength:
237
- Determines how deep the first injection will happen.
238
- Deeper injections will cause (unwanted) formation of new structures,
239
- more shallow values will go into alpha-blendy land.
240
- t_compute_max_allowed:
241
- Either provide t_compute_max_allowed or nmb_max_branches.
242
- The maximum time allowed for computation. Higher values give better results but take longer.
243
- nmb_max_branches: int
244
- Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
245
- results. Use this if you want to have controllable results independent
246
- of your computer.
247
  fixed_seeds: Optional[List[int)]:
248
  You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
249
  Otherwise random seeds will be taken.
@@ -252,6 +316,7 @@ class LatentBlending():
252
  # Sanity checks first
253
  assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
254
  assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
 
255
 
256
  # Random seeds
257
  if fixed_seeds is not None:
@@ -263,10 +328,7 @@ class LatentBlending():
263
  self.seed1 = fixed_seeds[0]
264
  self.seed2 = fixed_seeds[1]
265
 
266
- # Ensure correct num_inference_steps in holder
267
- self.num_inference_steps = num_inference_steps
268
- self.sdh.num_inference_steps = num_inference_steps
269
-
270
  # Compute / Recycle first image
271
  if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
272
  list_latents1 = self.compute_latents1()
@@ -282,29 +344,28 @@ class LatentBlending():
282
  # Reset the tree, injecting the edge latents1/2 we just generated/recycled
283
  self.tree_latents = [list_latents1, list_latents2]
284
  self.tree_fracts = [0.0, 1.0]
285
- self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
286
  self.tree_idx_injection = [0, 0]
 
287
 
288
- # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
289
- self.spatial_mask = None
290
-
291
- # Set up branching scheme (dependent on provided compute time)
292
- list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
293
 
294
  # Run iteratively, starting with the longest trajectory.
295
  # Always inserting new branches where they are needed most according to image similarity
296
- for s_idx in tqdm(range(len(list_idx_injection))):
297
- nmb_stems = list_nmb_stems[s_idx]
298
- idx_injection = list_idx_injection[s_idx]
299
 
300
  for i in range(nmb_stems):
301
  fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
302
  self.set_guidance_mid_dampening(fract_mixing)
303
  list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
304
  self.insert_into_tree(fract_mixing, idx_injection, list_latents)
305
- # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
306
 
307
  return self.tree_final_imgs
 
 
 
308
 
309
  def compute_latents1(self, return_image=False):
310
  r"""
@@ -322,10 +383,10 @@ class LatentBlending():
322
  latents_start=latents_start,
323
  idx_start=0)
324
  t1 = time.time()
325
- self.dt_per_diff = (t1 - t0) / self.num_inference_steps
326
  self.tree_latents[0] = list_latents1
327
  if return_image:
328
- return self.sdh.latent2image(list_latents1[-1])
329
  else:
330
  return list_latents1
331
 
@@ -357,7 +418,7 @@ class LatentBlending():
357
  self.tree_latents[-1] = list_latents2
358
 
359
  if return_image:
360
- return self.sdh.latent2image(list_latents2[-1])
361
  else:
362
  return list_latents2
363
 
@@ -392,7 +453,7 @@ class LatentBlending():
392
  mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
393
  nmb_mixing = idx_mixing_stop - idx_injection
394
  if nmb_mixing > 0:
395
- mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing)))
396
  mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
397
  latents_start = list_latents_parental_mix[idx_injection - 1]
398
  list_latents = self.run_diffusion(
@@ -421,8 +482,10 @@ class LatentBlending():
421
  results. Use this if you want to have controllable results independent
422
  of your computer.
423
  """
424
- idx_injection_base = int(round(self.num_inference_steps * depth_strength))
425
- list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
 
 
426
  list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
427
  t_compute = 0
428
 
@@ -440,10 +503,11 @@ class LatentBlending():
440
  while not stop_criterion_reached:
441
  list_compute_steps = self.num_inference_steps - list_idx_injection
442
  list_compute_steps *= list_nmb_stems
443
- t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
 
444
  increase_done = False
445
  for s_idx in range(len(list_nmb_stems) - 1):
446
- if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
447
  list_nmb_stems[s_idx] += 1
448
  increase_done = True
449
  break
@@ -474,15 +538,15 @@ class LatentBlending():
474
  the index in terms of diffusion steps, where the next insertion will start.
475
  """
476
  # get_lpips_similarity
477
- similarities = []
478
- for i in range(len(self.tree_final_imgs) - 1):
479
- similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
480
  b_closest1 = np.argmax(similarities)
481
  b_closest2 = b_closest1 + 1
482
  fract_closest1 = self.tree_fracts[b_closest1]
483
  fract_closest2 = self.tree_fracts[b_closest2]
 
484
 
485
- # Ensure that the parents are indeed older!
486
  b_parent1 = b_closest1
487
  while True:
488
  if self.tree_idx_injection[b_parent1] < idx_injection:
@@ -495,7 +559,6 @@ class LatentBlending():
495
  break
496
  else:
497
  b_parent2 += 1
498
- fract_mixing = (fract_closest1 + fract_closest2) / 2
499
  return fract_mixing, b_parent1, b_parent2
500
 
501
  def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
@@ -509,40 +572,21 @@ class LatentBlending():
509
  list_latents: list
510
  list of the latents to be inserted
511
  """
 
 
512
  b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
513
- self.tree_latents.insert(b_parent1 + 1, list_latents)
514
- self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1]))
515
- self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
516
- self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
517
-
518
- def get_spatial_mask_template(self):
519
- r"""
520
- Experimental helper function to get a spatial mask template.
521
- """
522
- shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
523
- C, H, W = shape_latents
524
- return np.ones((H, W))
525
-
526
- def set_spatial_mask(self, img_mask):
527
- r"""
528
- Experimental helper function to set a spatial mask.
529
- The mask forces latents to be overwritten.
530
- Args:
531
- img_mask:
532
- mask image [0,1]. You can get a template using get_spatial_mask_template
533
- """
534
- shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
535
- C, H, W = shape_latents
536
- img_mask = np.asarray(img_mask)
537
- assert len(img_mask.shape) == 2, "Currently, only 2D images are supported as mask"
538
- img_mask = np.clip(img_mask, 0, 1)
539
- assert img_mask.shape[0] == H, f"Your mask needs to be of dimension {H} x {W}"
540
- assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
541
- spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
542
- spatial_mask = torch.unsqueeze(spatial_mask, 0)
543
- spatial_mask = spatial_mask.repeat((C, 1, 1))
544
- spatial_mask = torch.unsqueeze(spatial_mask, 0)
545
- self.spatial_mask = spatial_mask
546
 
547
  def get_noise(self, seed):
548
  r"""
@@ -550,16 +594,7 @@ class LatentBlending():
550
  Args:
551
  seed: int
552
  """
553
- generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
554
- if self.mode == 'standard':
555
- shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
556
- C, H, W = shape_latents
557
- elif self.mode == 'upscale':
558
- w = self.image1_lowres.size[0]
559
- h = self.image1_lowres.size[1]
560
- shape_latents = [self.sdh.model.channels, h, w]
561
- C, H, W = shape_latents
562
- return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
563
 
564
  @torch.no_grad()
565
  def run_diffusion(
@@ -590,132 +625,32 @@ class LatentBlending():
590
  """
591
 
592
  # Ensure correct num_inference_steps in Holder
593
- self.sdh.num_inference_steps = self.num_inference_steps
594
  assert type(list_conditionings) is list, "list_conditionings need to be a list"
595
 
596
- if self.mode == 'standard':
597
- text_embeddings = list_conditionings[0]
598
- return self.sdh.run_diffusion_standard(
599
- text_embeddings=text_embeddings,
600
- latents_start=latents_start,
601
- idx_start=idx_start,
602
- list_latents_mixing=list_latents_mixing,
603
- mixing_coeffs=mixing_coeffs,
604
- spatial_mask=self.spatial_mask,
605
- return_image=return_image)
606
-
607
- elif self.mode == 'upscale':
608
- cond = list_conditionings[0]
609
- uc_full = list_conditionings[1]
610
- return self.sdh.run_diffusion_upscaling(
611
- cond,
612
- uc_full,
613
- latents_start=latents_start,
614
- idx_start=idx_start,
615
- list_latents_mixing=list_latents_mixing,
616
- mixing_coeffs=mixing_coeffs,
617
- return_image=return_image)
618
 
619
- def run_upscaling(
620
- self,
621
- dp_img: str,
622
- depth_strength: float = 0.65,
623
- num_inference_steps: int = 100,
624
- nmb_max_branches_highres: int = 5,
625
- nmb_max_branches_lowres: int = 6,
626
- duration_single_segment=3,
627
- fps=24,
628
- fixed_seeds: Optional[List[int]] = None):
629
- r"""
630
- Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
631
 
632
- Args:
633
- dp_img: str
634
- Path to the low-res transition path (as saved in write_imgs_transition)
635
- depth_strength:
636
- Determines how deep the first injection will happen.
637
- Deeper injections will cause (unwanted) formation of new structures,
638
- more shallow values will go into alpha-blendy land.
639
- num_inference_steps:
640
- Number of diffusion steps. Higher values will take more compute time.
641
- nmb_max_branches_highres: int
642
- Number of final branches of the upscaling transition pass. Note this is the number
643
- of branches between each pair of low-res images.
644
- nmb_max_branches_lowres: int
645
- Number of input low-res images, subsampling all transition images written in the low-res pass.
646
- Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
647
- duration_single_segment: float
648
- The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
649
- fps: float
650
- frames per second of movie
651
- fixed_seeds: Optional[List[int)]:
652
- You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
653
- Otherwise random seeds will be taken.
654
- """
655
- fp_yml = os.path.join(dp_img, "lowres.yaml")
656
- fp_movie = os.path.join(dp_img, "movie_highres.mp4")
657
- ms = MovieSaver(fp_movie, fps=fps)
658
- assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
659
- dict_stuff = yml_load(fp_yml)
660
-
661
- # load lowres images
662
- nmb_images_lowres = dict_stuff['nmb_images']
663
- prompt1 = dict_stuff['prompt1']
664
- prompt2 = dict_stuff['prompt2']
665
- idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
666
- imgs_lowres = []
667
- for i in idx_img_lowres:
668
- fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
669
- assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
670
- imgs_lowres.append(Image.open(fp_img_lowres))
671
-
672
- # set up upscaling
673
- text_embeddingA = self.sdh.get_text_embedding(prompt1)
674
- text_embeddingB = self.sdh.get_text_embedding(prompt2)
675
- list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
676
- for i in range(nmb_max_branches_lowres - 1):
677
- print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
678
- self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
679
- self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
680
- if i == 0:
681
- recycle_img1 = False
682
- else:
683
- self.swap_forward()
684
- recycle_img1 = True
685
-
686
- self.set_image1(imgs_lowres[i])
687
- self.set_image2(imgs_lowres[i + 1])
688
-
689
- list_imgs = self.run_transition(
690
- recycle_img1=recycle_img1,
691
- recycle_img2=False,
692
- num_inference_steps=num_inference_steps,
693
- depth_strength=depth_strength,
694
- nmb_max_branches=nmb_max_branches_highres)
695
- list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
696
-
697
- # Save movie frame
698
- for img in list_imgs_interp:
699
- ms.write_frame(img)
700
- ms.finalize()
701
 
702
  @torch.no_grad()
703
  def get_mixed_conditioning(self, fract_mixing):
704
- if self.mode == 'standard':
705
- text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
706
- list_conditionings = [text_embeddings_mix]
707
- elif self.mode == 'inpaint':
708
- text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
709
- list_conditionings = [text_embeddings_mix]
710
- elif self.mode == 'upscale':
711
- text_embeddings_mix = interpolate_linear(self.text_embedding1, self.text_embedding2, fract_mixing)
712
- cond, uc_full = self.sdh.get_cond_upscaling(self.image1_lowres, text_embeddings_mix, self.noise_level_upscaling)
713
- condB, uc_fullB = self.sdh.get_cond_upscaling(self.image2_lowres, text_embeddings_mix, self.noise_level_upscaling)
714
- cond['c_concat'][0] = interpolate_spherical(cond['c_concat'][0], condB['c_concat'][0], fract_mixing)
715
- uc_full['c_concat'][0] = interpolate_spherical(uc_full['c_concat'][0], uc_fullB['c_concat'][0], fract_mixing)
716
- list_conditionings = [cond, uc_full]
717
- else:
718
- raise ValueError(f"mix_conditioning: unknown mode {self.mode}")
719
  return list_conditionings
720
 
721
  @torch.no_grad()
@@ -729,7 +664,7 @@ class LatentBlending():
729
  prompt: str
730
  ABC trending on artstation painted by Old Greg.
731
  """
732
- return self.sdh.get_text_embedding(prompt)
733
 
734
  def write_imgs_transition(self, dp_img):
735
  r"""
@@ -745,7 +680,6 @@ class LatentBlending():
745
  img_leaf = Image.fromarray(img)
746
  img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
747
  fp_yml = os.path.join(dp_img, "lowres.yaml")
748
- self.save_statedict(fp_yml)
749
 
750
  def write_movie_transition(self, fp_movie, duration_transition, fps=30):
751
  r"""
@@ -761,22 +695,16 @@ class LatentBlending():
761
  """
762
 
763
  # Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
764
- imgs_transition_ext = add_frames_linear_interp(self.tree_final_imgs, duration_transition, fps)
765
 
766
  # Save as MP4
767
  if os.path.isfile(fp_movie):
768
  os.remove(fp_movie)
769
- ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.sdh.height, self.sdh.width])
770
  for img in tqdm(imgs_transition_ext):
771
  ms.write_frame(img)
772
  ms.finalize()
773
 
774
- def save_statedict(self, fp_yml):
775
- # Dump everything relevant into yaml
776
- imgs_transition = self.tree_final_imgs
777
- state_dict = self.get_state_dict()
778
- state_dict['nmb_images'] = len(imgs_transition)
779
- yml_save(fp_yml, state_dict)
780
 
781
  def get_state_dict(self):
782
  state_dict = {}
@@ -784,7 +712,7 @@ class LatentBlending():
784
  'num_inference_steps', 'depth_strength', 'guidance_scale',
785
  'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt',
786
  'branch1_crossfeed_power', 'branch1_crossfeed_range', 'branch1_crossfeed_decay'
787
- 'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_power_decay']
788
  for v in grab_vars:
789
  if hasattr(self, v):
790
  if v == 'seed1' or v == 'seed2':
@@ -799,35 +727,6 @@ class LatentBlending():
799
  pass
800
  return state_dict
801
 
802
- def randomize_seed(self):
803
- r"""
804
- Set a random seed for a fresh start.
805
- """
806
- seed = np.random.randint(999999999)
807
- self.set_seed(seed)
808
-
809
- def set_seed(self, seed: int):
810
- r"""
811
- Set a the seed for a fresh start.
812
- """
813
- self.seed = seed
814
- self.sdh.seed = seed
815
-
816
- def set_width(self, width):
817
- r"""
818
- Set the width of the resulting image.
819
- """
820
- assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64"
821
- self.width = width
822
- self.sdh.width = width
823
-
824
- def set_height(self, height):
825
- r"""
826
- Set the height of the resulting image.
827
- """
828
- assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64"
829
- self.height = height
830
- self.sdh.height = height
831
 
832
  def swap_forward(self):
833
  r"""
@@ -848,16 +747,22 @@ class LatentBlending():
848
  Used to determine the optimal point of insertion to create smooth transitions.
849
  High values indicate low similarity.
850
  """
851
- tensorA = torch.from_numpy(imgA).float().cuda(self.device)
852
  tensorA = 2 * tensorA / 255.0 - 1
853
  tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
854
- tensorB = torch.from_numpy(imgB).float().cuda(self.device)
855
  tensorB = 2 * tensorB / 255.0 - 1
856
  tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
857
  lploss = self.lpips(tensorA, tensorB)
858
  lploss = float(lploss[0][0][0][0])
859
  return lploss
860
 
 
 
 
 
 
 
861
  # Auxiliary functions
862
  def get_closest_idx(
863
  self,
@@ -882,3 +787,51 @@ class LatentBlending():
882
  b_parent1 = tmp
883
 
884
  return b_parent1, b_parent2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
 
 
3
  import numpy as np
4
  import warnings
 
5
  import time
 
6
  from tqdm.auto import tqdm
7
  from PIL import Image
 
8
  from typing import List, Optional
 
9
  import lpips
10
+ import platform
11
+ from latentblending.diffusers_holder import DiffusersHolder
12
+ from latentblending.utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp
13
+ from lunar_tools import MovieSaver, fill_up_frames_linear_interpolation
14
+ warnings.filterwarnings('ignore')
15
+ torch.backends.cudnn.benchmark = False
16
+ torch.set_grad_enabled(False)
17
 
18
 
19
+ class BlendingEngine():
20
  def __init__(
21
  self,
22
+ pipe: None,
23
+ do_compile: bool = False,
24
  guidance_scale_mid_damper: float = 0.5,
25
  mid_compression_scaler: float = 1.2):
26
  r"""
27
  Initializes the latent blending class.
28
  Args:
29
+ pipe: diffusers pipeline (SDXL)
30
+ do_compile: compile pipeline for faster inference using stable fast
 
 
 
 
31
  guidance_scale_mid_damper: float = 0.5
32
  Reduces the guidance scale towards the middle of the transition.
33
  A value of 0.5 would decrease the guidance_scale towards the middle linearly by 0.5.
40
  and guidance_scale_mid_damper <= 1.0, \
41
  f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
42
 
43
+
44
+ self.dh = DiffusersHolder(pipe)
45
+ self.device = self.dh.device
46
+ self.set_dimensions()
47
+
48
  self.guidance_scale_mid_damper = guidance_scale_mid_damper
49
  self.mid_compression_scaler = mid_compression_scaler
50
  self.seed1 = 0
53
  # Initialize vars
54
  self.prompt1 = ""
55
  self.prompt2 = ""
 
56
 
57
  self.tree_latents = [None, None]
58
  self.tree_fracts = None
60
  self.tree_status = None
61
  self.tree_final_imgs = []
62
 
 
 
63
  self.text_embedding1 = None
64
  self.text_embedding2 = None
65
  self.image1_lowres = None
66
  self.image2_lowres = None
67
  self.negative_prompt = None
68
+
69
+ self.set_guidance_scale()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  self.multi_transition_img_first = None
71
  self.multi_transition_img_last = None
72
+ self.dt_unet_step = 0
73
+ if platform.system() == "Darwin":
74
+ self.lpips = lpips.LPIPS(net='alex')
75
+ else:
76
+ self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
77
+
78
+ self.set_prompt1("")
79
+ self.set_prompt2("")
80
+
81
+ self.set_branch1_crossfeed()
82
+ self.set_parental_crossfeed()
83
+
84
+ self.set_num_inference_steps()
85
+ self.benchmark_speed()
86
+ self.set_branching()
87
+
88
+ if do_compile:
89
+ print("starting compilation")
90
+ from sfast.compilers.diffusion_pipeline_compiler import (compile, CompilationConfig)
91
+ self.dh.pipe.enable_xformers_memory_efficient_attention()
92
+ config = CompilationConfig.Default()
93
+ config.enable_xformers = True
94
+ config.enable_triton = True
95
+ config.enable_cuda_graph = True
96
+ self.dh.pipe = compile(self.dh.pipe, config)
97
+
98
+
99
+
100
+ def benchmark_speed(self):
101
+ """
102
+ Measures the time per diffusion step and for the vae decoding
103
+ """
104
+ print("starting speed benchmark...")
105
+ text_embeddings = self.dh.get_text_embedding("test")
106
+ latents_start = self.dh.get_noise(np.random.randint(111111))
107
+ # warmup
108
+ list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
109
+ # bench unet
110
+ t0 = time.time()
111
+ list_latents = self.dh.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False, idx_start=self.num_inference_steps-1)
112
+ self.dt_unet_step = time.time() - t0
113
+
114
+ # bench vae
115
+ t0 = time.time()
116
+ img = self.dh.latent2image(list_latents[-1])
117
+ self.dt_vae = time.time() - t0
118
+ print(f"time per unet iteration: {self.dt_unet_step} time for vae: {self.dt_vae}")
119
 
120
+ def set_dimensions(self, size_output=None):
121
  r"""
122
+ sets the size of the output video.
123
+ Args:
124
+ size_output: tuple
125
+ width x height
126
+ Note: the size will get automatically adjusted to be divisable by 32.
127
  """
128
+ if size_output is None:
129
+ if self.dh.is_sdxl_turbo:
130
+ size_output = (512, 512)
131
+ else:
132
+ size_output = (1024, 1024)
133
+ self.dh.set_dimensions(size_output)
 
 
134
 
135
+ def set_guidance_scale(self, guidance_scale=None):
136
  r"""
137
  sets the guidance scale.
138
  """
139
+ if guidance_scale is None:
140
+ if self.dh.is_sdxl_turbo:
141
+ guidance_scale = 0.0
142
+ else:
143
+ guidance_scale = 4.0
144
+
145
  self.guidance_scale_base = guidance_scale
146
  self.guidance_scale = guidance_scale
147
+ self.dh.guidance_scale = guidance_scale
148
 
149
  def set_negative_prompt(self, negative_prompt):
150
  r"""Set the negative prompt. Currenty only one negative prompt is supported
151
  """
152
  self.negative_prompt = negative_prompt
153
+ self.dh.set_negative_prompt(negative_prompt)
154
 
155
  def set_guidance_mid_dampening(self, fract_mixing):
156
  r"""
161
  max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
162
  guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
163
  self.guidance_scale = guidance_scale_effective
164
+ self.dh.guidance_scale = guidance_scale_effective
165
 
166
+ def set_branch1_crossfeed(self, crossfeed_power=0, crossfeed_range=0, crossfeed_decay=0):
167
  r"""
168
  Sets the crossfeed parameters for the first branch to the last branch.
169
  Args:
178
  self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
179
  self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
180
 
181
+ def set_parental_crossfeed(self, crossfeed_power=None, crossfeed_range=None, crossfeed_decay=None):
182
  r"""
183
  Sets the crossfeed parameters for all transition images (within the first and last branch).
184
  Args:
189
  crossfeed_decay: float [0,1]
190
  Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
191
  """
192
+
193
+ if self.dh.is_sdxl_turbo:
194
+ if crossfeed_power is None:
195
+ crossfeed_power = 1.0
196
+ if crossfeed_range is None:
197
+ crossfeed_range = 1.0
198
+ if crossfeed_decay is None:
199
+ crossfeed_decay = 1.0
200
+ else:
201
+ crossfeed_power = 0.3
202
+ crossfeed_range = 0.6
203
+ crossfeed_decay = 0.9
204
+
205
  self.parental_crossfeed_power = np.clip(crossfeed_power, 0, 1)
206
  self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
207
+ self.parental_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
208
 
209
  def set_prompt1(self, prompt: str):
210
  r"""
243
  image: Image
244
  """
245
  self.image2_lowres = image
246
+
247
+ def set_num_inference_steps(self, num_inference_steps=None):
248
+ if self.dh.is_sdxl_turbo:
249
+ if num_inference_steps is None:
250
+ num_inference_steps = 4
251
+ else:
252
+ if num_inference_steps is None:
253
+ num_inference_steps = 30
254
+
255
+ self.num_inference_steps = num_inference_steps
256
+ self.dh.set_num_inference_steps(num_inference_steps)
257
+
258
+ def set_branching(self, depth_strength=None, t_compute_max_allowed=None, nmb_max_branches=None):
259
+ """
260
+ Sets the branching structure of the blending tree. Default arguments depend on pipe!
261
+ depth_strength:
262
+ Determines how deep the first injection will happen.
263
+ Deeper injections will cause (unwanted) formation of new structures,
264
+ more shallow values will go into alpha-blendy land.
265
+ t_compute_max_allowed:
266
+ Either provide t_compute_max_allowed or nmb_max_branches.
267
+ The maximum time allowed for computation. Higher values give better results but take longer.
268
+ nmb_max_branches: int
269
+ Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
270
+ results. Use this if you want to have controllable results independent
271
+ of your computer.
272
+ """
273
+ if self.dh.is_sdxl_turbo:
274
+ assert t_compute_max_allowed is None, "time-based branching not supported for SDXL Turbo"
275
+ if depth_strength is not None:
276
+ idx_inject = int(round(self.num_inference_steps*depth_strength))
277
+ else:
278
+ idx_inject = 2
279
+ if nmb_max_branches is None:
280
+ nmb_max_branches = 10
281
+
282
+ self.list_idx_injection = [idx_inject]
283
+ self.list_nmb_stems = [nmb_max_branches]
284
+
285
+ else:
286
+ if depth_strength is None:
287
+ depth_strength = 0.5
288
+ if t_compute_max_allowed is None and nmb_max_branches is None:
289
+ t_compute_max_allowed = 20
290
+ elif t_compute_max_allowed is not None and nmb_max_branches is not None:
291
+ raise ValueErorr("Either specify t_compute_max_allowed or nmb_max_branches")
292
+
293
+ self.list_idx_injection, self.list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
294
 
295
  def run_transition(
296
  self,
297
  recycle_img1: Optional[bool] = False,
298
  recycle_img2: Optional[bool] = False,
 
 
 
 
299
  fixed_seeds: Optional[List[int]] = None):
300
  r"""
301
  Function for computing transitions.
307
  Don't recompute the latents for the second keyframe (purely prompt2). Saves compute.
308
  num_inference_steps:
309
  Number of diffusion steps. Higher values will take more compute time.
310
+
 
 
 
 
 
 
 
 
 
 
311
  fixed_seeds: Optional[List[int)]:
312
  You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
313
  Otherwise random seeds will be taken.
316
  # Sanity checks first
317
  assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
318
  assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
319
+
320
 
321
  # Random seeds
322
  if fixed_seeds is not None:
328
  self.seed1 = fixed_seeds[0]
329
  self.seed2 = fixed_seeds[1]
330
 
331
+
 
 
 
332
  # Compute / Recycle first image
333
  if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
334
  list_latents1 = self.compute_latents1()
344
  # Reset the tree, injecting the edge latents1/2 we just generated/recycled
345
  self.tree_latents = [list_latents1, list_latents2]
346
  self.tree_fracts = [0.0, 1.0]
347
+ self.tree_final_imgs = [self.dh.latent2image((self.tree_latents[0][-1])), self.dh.latent2image((self.tree_latents[-1][-1]))]
348
  self.tree_idx_injection = [0, 0]
349
+ self.tree_similarities = [self.get_tree_similarities]
350
 
 
 
 
 
 
351
 
352
  # Run iteratively, starting with the longest trajectory.
353
  # Always inserting new branches where they are needed most according to image similarity
354
+ for s_idx in tqdm(range(len(self.list_idx_injection))):
355
+ nmb_stems = self.list_nmb_stems[s_idx]
356
+ idx_injection = self.list_idx_injection[s_idx]
357
 
358
  for i in range(nmb_stems):
359
  fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
360
  self.set_guidance_mid_dampening(fract_mixing)
361
  list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
362
  self.insert_into_tree(fract_mixing, idx_injection, list_latents)
363
+ # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection} bp1 {b_parent1} bp2 {b_parent2}")
364
 
365
  return self.tree_final_imgs
366
+
367
+
368
+
369
 
370
  def compute_latents1(self, return_image=False):
371
  r"""
383
  latents_start=latents_start,
384
  idx_start=0)
385
  t1 = time.time()
386
+ self.dt_unet_step = (t1 - t0) / self.num_inference_steps
387
  self.tree_latents[0] = list_latents1
388
  if return_image:
389
+ return self.dh.latent2image(list_latents1[-1])
390
  else:
391
  return list_latents1
392
 
418
  self.tree_latents[-1] = list_latents2
419
 
420
  if return_image:
421
+ return self.dh.latent2image(list_latents2[-1])
422
  else:
423
  return list_latents2
424
 
453
  mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
454
  nmb_mixing = idx_mixing_stop - idx_injection
455
  if nmb_mixing > 0:
456
+ mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_decay, nmb_mixing)))
457
  mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
458
  latents_start = list_latents_parental_mix[idx_injection - 1]
459
  list_latents = self.run_diffusion(
482
  results. Use this if you want to have controllable results independent
483
  of your computer.
484
  """
485
+ idx_injection_base = int(np.floor(self.num_inference_steps * depth_strength))
486
+
487
+ steps = int(np.ceil(self.num_inference_steps/10))
488
+ list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps, steps)
489
  list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
490
  t_compute = 0
491
 
503
  while not stop_criterion_reached:
504
  list_compute_steps = self.num_inference_steps - list_idx_injection
505
  list_compute_steps *= list_nmb_stems
506
+ t_compute = np.sum(list_compute_steps) * self.dt_unet_step + self.dt_vae * np.sum(list_nmb_stems)
507
+ t_compute += 2 * (self.num_inference_steps * self.dt_unet_step + self.dt_vae) # outer branches
508
  increase_done = False
509
  for s_idx in range(len(list_nmb_stems) - 1):
510
+ if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 1:
511
  list_nmb_stems[s_idx] += 1
512
  increase_done = True
513
  break
538
  the index in terms of diffusion steps, where the next insertion will start.
539
  """
540
  # get_lpips_similarity
541
+ similarities = self.tree_similarities
542
+ # similarities = self.get_tree_similarities()
 
543
  b_closest1 = np.argmax(similarities)
544
  b_closest2 = b_closest1 + 1
545
  fract_closest1 = self.tree_fracts[b_closest1]
546
  fract_closest2 = self.tree_fracts[b_closest2]
547
+ fract_mixing = (fract_closest1 + fract_closest2) / 2
548
 
549
+ # Ensure that the parents are indeed older
550
  b_parent1 = b_closest1
551
  while True:
552
  if self.tree_idx_injection[b_parent1] < idx_injection:
559
  break
560
  else:
561
  b_parent2 += 1
 
562
  return fract_mixing, b_parent1, b_parent2
563
 
564
  def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
572
  list_latents: list
573
  list of the latents to be inserted
574
  """
575
+ img_insert = self.dh.latent2image(list_latents[-1])
576
+
577
  b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
578
+ left_sim = self.get_lpips_similarity(img_insert, self.tree_final_imgs[b_parent1])
579
+ right_sim = self.get_lpips_similarity(img_insert, self.tree_final_imgs[b_parent2])
580
+ idx_insert = b_parent1 + 1
581
+ self.tree_latents.insert(idx_insert, list_latents)
582
+ self.tree_final_imgs.insert(idx_insert, img_insert)
583
+ self.tree_fracts.insert(idx_insert, fract_mixing)
584
+ self.tree_idx_injection.insert(idx_insert, idx_injection)
585
+
586
+ # update similarities
587
+ self.tree_similarities[b_parent1] = left_sim
588
+ self.tree_similarities.insert(idx_insert, right_sim)
589
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
  def get_noise(self, seed):
592
  r"""
594
  Args:
595
  seed: int
596
  """
597
+ return self.dh.get_noise(seed)
 
 
 
 
 
 
 
 
 
598
 
599
  @torch.no_grad()
600
  def run_diffusion(
625
  """
626
 
627
  # Ensure correct num_inference_steps in Holder
628
+ self.dh.set_num_inference_steps(self.num_inference_steps)
629
  assert type(list_conditionings) is list, "list_conditionings need to be a list"
630
 
631
+ text_embeddings = list_conditionings[0]
632
+ return self.dh.run_diffusion_sd_xl(
633
+ text_embeddings=text_embeddings,
634
+ latents_start=latents_start,
635
+ idx_start=idx_start,
636
+ list_latents_mixing=list_latents_mixing,
637
+ mixing_coeffs=mixing_coeffs,
638
+ return_image=return_image)
639
+
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
 
643
  @torch.no_grad()
644
  def get_mixed_conditioning(self, fract_mixing):
645
+ text_embeddings_mix = []
646
+ for i in range(len(self.text_embedding1)):
647
+ if self.text_embedding1[i] is None:
648
+ mix = None
649
+ else:
650
+ mix = interpolate_linear(self.text_embedding1[i], self.text_embedding2[i], fract_mixing)
651
+ text_embeddings_mix.append(mix)
652
+ list_conditionings = [text_embeddings_mix]
653
+
 
 
 
 
 
 
654
  return list_conditionings
655
 
656
  @torch.no_grad()
664
  prompt: str
665
  ABC trending on artstation painted by Old Greg.
666
  """
667
+ return self.dh.get_text_embedding(prompt)
668
 
669
  def write_imgs_transition(self, dp_img):
670
  r"""
680
  img_leaf = Image.fromarray(img)
681
  img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
682
  fp_yml = os.path.join(dp_img, "lowres.yaml")
 
683
 
684
  def write_movie_transition(self, fp_movie, duration_transition, fps=30):
685
  r"""
695
  """
696
 
697
  # Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
698
+ imgs_transition_ext = fill_up_frames_linear_interpolation(self.tree_final_imgs, duration_transition, fps)
699
 
700
  # Save as MP4
701
  if os.path.isfile(fp_movie):
702
  os.remove(fp_movie)
703
+ ms = MovieSaver(fp_movie, fps=fps, shape_hw=[self.dh.height_img, self.dh.width_img])
704
  for img in tqdm(imgs_transition_ext):
705
  ms.write_frame(img)
706
  ms.finalize()
707
 
 
 
 
 
 
 
708
 
709
  def get_state_dict(self):
710
  state_dict = {}
712
  'num_inference_steps', 'depth_strength', 'guidance_scale',
713
  'guidance_scale_mid_damper', 'mid_compression_scaler', 'negative_prompt',
714
  'branch1_crossfeed_power', 'branch1_crossfeed_range', 'branch1_crossfeed_decay'
715
+ 'parental_crossfeed_power', 'parental_crossfeed_range', 'parental_crossfeed_decay']
716
  for v in grab_vars:
717
  if hasattr(self, v):
718
  if v == 'seed1' or v == 'seed2':
727
  pass
728
  return state_dict
729
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
 
731
  def swap_forward(self):
732
  r"""
747
  Used to determine the optimal point of insertion to create smooth transitions.
748
  High values indicate low similarity.
749
  """
750
+ tensorA = torch.from_numpy(np.asarray(imgA)).float().cuda(self.device)
751
  tensorA = 2 * tensorA / 255.0 - 1
752
  tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
753
+ tensorB = torch.from_numpy(np.asarray(imgB)).float().cuda(self.device)
754
  tensorB = 2 * tensorB / 255.0 - 1
755
  tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
756
  lploss = self.lpips(tensorA, tensorB)
757
  lploss = float(lploss[0][0][0][0])
758
  return lploss
759
 
760
+ def get_tree_similarities(self):
761
+ similarities = []
762
+ for i in range(len(self.tree_final_imgs) - 1):
763
+ similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
764
+ return similarities
765
+
766
  # Auxiliary functions
767
  def get_closest_idx(
768
  self,
787
  b_parent1 = tmp
788
 
789
  return b_parent1, b_parent2
790
+
791
+ #%%
792
+ if __name__ == "__main__":
793
+
794
+ # %% First let us spawn a stable diffusion holder. Uncomment your version of choice.
795
+ from diffusers_holder import DiffusersHolder
796
+ from diffusers import DiffusionPipeline
797
+ from diffusers import AutoencoderTiny
798
+ # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
799
+ pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
800
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path)
801
+
802
+
803
+ # pipe.to("mps")
804
+ pipe.to("cuda")
805
+
806
+ # pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
807
+ # pipe.vae = pipe.vae.cuda()
808
+
809
+ dh = DiffusersHolder(pipe)
810
+
811
+ xxx
812
+ # %% Next let's set up all parameters
813
+ prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
814
+ prompt2 = "rendering of an alien planet, strange plants, strange creatures, surreal"
815
+ negative_prompt = "blurry, ugly, pale" # Optional
816
+
817
+ duration_transition = 12 # In seconds
818
+
819
+ # Spawn latent blending
820
+ be = BlendingEngine(dh)
821
+ be.set_prompt1(prompt1)
822
+ be.set_prompt2(prompt2)
823
+ be.set_negative_prompt(negative_prompt)
824
+
825
+ # Run latent blending
826
+ t0 = time.time()
827
+ be.run_transition(fixed_seeds=[420, 421])
828
+ dt = time.time() - t0
829
+ print(f"dt = {dt}")
830
+
831
+ # Save movie
832
+ fp_movie = f'test.mp4'
833
+ be.write_movie_transition(fp_movie, duration_transition)
834
+
835
+
836
+
837
+
latentblending/diffusers_holder.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import warnings
4
+
5
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
+ from latentblending.utils import interpolate_spherical
7
+ from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel
8
+ from diffusers.models.attention_processor import (
9
+ AttnProcessor2_0,
10
+ LoRAAttnProcessor2_0,
11
+ LoRAXFormersAttnProcessor,
12
+ XFormersAttnProcessor,
13
+ )
14
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import retrieve_timesteps
15
+ warnings.filterwarnings('ignore')
16
+ torch.backends.cudnn.benchmark = False
17
+ torch.set_grad_enabled(False)
18
+
19
+
20
+ class DiffusersHolder():
21
+ def __init__(self, pipe):
22
+ # Base settings
23
+ self.negative_prompt = ""
24
+ self.guidance_scale = 5.0
25
+ self.num_inference_steps = 30
26
+
27
+ # Check if valid pipe
28
+ self.pipe = pipe
29
+ self.device = str(pipe._execution_device)
30
+ self.init_types()
31
+
32
+ self.width_latent = self.pipe.unet.config.sample_size
33
+ self.height_latent = self.pipe.unet.config.sample_size
34
+ self.width_img = self.width_latent * self.pipe.vae_scale_factor
35
+ self.height_img = self.height_latent * self.pipe.vae_scale_factor
36
+
37
+
38
+ def init_types(self):
39
+ assert hasattr(self.pipe, "__class__"), "No valid diffusers pipeline found."
40
+ assert hasattr(self.pipe.__class__, "__name__"), "No valid diffusers pipeline found."
41
+ if self.pipe.__class__.__name__ == 'StableDiffusionXLPipeline':
42
+ self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
43
+ prompt_embeds, _, _, _ = self.pipe.encode_prompt("test")
44
+ else:
45
+ prompt_embeds = self.pipe._encode_prompt("test", self.device, 1, True)
46
+ self.dtype = prompt_embeds.dtype
47
+
48
+ self.is_sdxl_turbo = 'turbo' in self.pipe._name_or_path
49
+
50
+
51
+ def set_num_inference_steps(self, num_inference_steps):
52
+ self.num_inference_steps = num_inference_steps
53
+ self.pipe.scheduler.set_timesteps(self.num_inference_steps, device=self.device)
54
+
55
+ def set_dimensions(self, size_output):
56
+ s = self.pipe.vae_scale_factor
57
+ if size_output is None:
58
+ width = self.pipe.unet.config.sample_size
59
+ height = self.pipe.unet.config.sample_size
60
+ else:
61
+ width, height = size_output
62
+ self.width_img = int(round(width / s) * s)
63
+ self.width_latent = int(self.width_img / s)
64
+ self.height_img = int(round(height / s) * s)
65
+ self.height_latent = int(self.height_img / s)
66
+ print(f"set_dimensions to width={width} and height={height}")
67
+
68
+ def set_negative_prompt(self, negative_prompt):
69
+ r"""Set the negative prompt. Currenty only one negative prompt is supported
70
+ """
71
+ if isinstance(negative_prompt, str):
72
+ self.negative_prompt = [negative_prompt]
73
+ else:
74
+ self.negative_prompt = negative_prompt
75
+
76
+ if len(self.negative_prompt) > 1:
77
+ self.negative_prompt = [self.negative_prompt[0]]
78
+
79
+ def get_text_embedding(self, prompt):
80
+ do_classifier_free_guidance = self.guidance_scale > 1 and self.pipe.unet.config.time_cond_proj_dim is None
81
+ text_embeddings = self.pipe.encode_prompt(
82
+ prompt=prompt,
83
+ prompt_2=prompt,
84
+ device=self.pipe._execution_device,
85
+ num_images_per_prompt=1,
86
+ do_classifier_free_guidance=do_classifier_free_guidance,
87
+ negative_prompt=self.negative_prompt,
88
+ negative_prompt_2=self.negative_prompt,
89
+ prompt_embeds=None,
90
+ negative_prompt_embeds=None,
91
+ pooled_prompt_embeds=None,
92
+ negative_pooled_prompt_embeds=None,
93
+ lora_scale=None,
94
+ clip_skip=None,#self.pipe._clip_skip,
95
+ )
96
+ return text_embeddings
97
+
98
+ def get_noise(self, seed=420):
99
+
100
+ latents = self.pipe.prepare_latents(
101
+ 1,
102
+ self.pipe.unet.config.in_channels,
103
+ self.height_img,
104
+ self.width_img,
105
+ torch.float16,
106
+ self.pipe._execution_device,
107
+ torch.Generator(device=self.device).manual_seed(int(seed)),
108
+ None,
109
+ )
110
+
111
+ return latents
112
+
113
+
114
+ @torch.no_grad()
115
+ def latent2image(
116
+ self,
117
+ latents: torch.FloatTensor,
118
+ output_type="pil"):
119
+ r"""
120
+ Returns an image provided a latent representation from diffusion.
121
+ Args:
122
+ latents: torch.FloatTensor
123
+ Result of the diffusion process.
124
+ output_type: "pil" or "np"
125
+ """
126
+ assert output_type in ["pil", "np"]
127
+
128
+ # make sure the VAE is in float32 mode, as it overflows in float16
129
+ needs_upcasting = self.pipe.vae.dtype == torch.float16 and self.pipe.vae.config.force_upcast
130
+
131
+ if needs_upcasting:
132
+ self.pipe.upcast_vae()
133
+ latents = latents.to(next(iter(self.pipe.vae.post_quant_conv.parameters())).dtype)
134
+
135
+ image = self.pipe.vae.decode(latents / self.pipe.vae.config.scaling_factor, return_dict=False)[0]
136
+
137
+ # cast back to fp16 if needed
138
+ if needs_upcasting:
139
+ self.pipe.vae.to(dtype=torch.float16)
140
+
141
+ image = self.pipe.image_processor.postprocess(image, output_type=output_type)[0]
142
+
143
+ return image
144
+
145
+
146
+ def prepare_mixing(self, mixing_coeffs, list_latents_mixing):
147
+ if type(mixing_coeffs) == float:
148
+ list_mixing_coeffs = (1 + self.num_inference_steps) * [mixing_coeffs]
149
+ elif type(mixing_coeffs) == list:
150
+ assert len(mixing_coeffs) == self.num_inference_steps, f"len(mixing_coeffs) {len(mixing_coeffs)} != self.num_inference_steps {self.num_inference_steps}"
151
+ list_mixing_coeffs = mixing_coeffs
152
+ else:
153
+ raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
154
+ if np.sum(list_mixing_coeffs) > 0:
155
+ assert len(list_latents_mixing) == self.num_inference_steps, f"len(list_latents_mixing) {len(list_latents_mixing)} != self.num_inference_steps {self.num_inference_steps}"
156
+ return list_mixing_coeffs
157
+
158
+ @torch.no_grad()
159
+ def run_diffusion(
160
+ self,
161
+ text_embeddings: torch.FloatTensor,
162
+ latents_start: torch.FloatTensor,
163
+ idx_start: int = 0,
164
+ list_latents_mixing=None,
165
+ mixing_coeffs=0.0,
166
+ return_image: Optional[bool] = False):
167
+
168
+ return self.run_diffusion_sd_xl(text_embeddings, latents_start, idx_start, list_latents_mixing, mixing_coeffs, return_image)
169
+
170
+
171
+
172
+ @torch.no_grad()
173
+ def run_diffusion_sd_xl(
174
+ self,
175
+ text_embeddings: tuple,
176
+ latents_start: torch.FloatTensor,
177
+ idx_start: int = 0,
178
+ list_latents_mixing=None,
179
+ mixing_coeffs=0.0,
180
+ return_image: Optional[bool] = False,
181
+ ):
182
+
183
+
184
+ prompt_2 = None
185
+ height = None
186
+ width = None
187
+ timesteps = None
188
+ denoising_end = None
189
+ negative_prompt_2 = None
190
+ num_images_per_prompt = 1
191
+ eta = 0.0
192
+ generator = None
193
+ latents = None
194
+ prompt_embeds = None
195
+ negative_prompt_embeds = None
196
+ pooled_prompt_embeds = None
197
+ negative_pooled_prompt_embeds = None
198
+ ip_adapter_image = None
199
+ output_type = "pil"
200
+ return_dict = True
201
+ cross_attention_kwargs = None
202
+ guidance_rescale = 0.0
203
+ original_size = None
204
+ crops_coords_top_left = (0, 0)
205
+ target_size = None
206
+ negative_original_size = None
207
+ negative_crops_coords_top_left = (0, 0)
208
+ negative_target_size = None
209
+ clip_skip = None
210
+ callback = None
211
+ callback_on_step_end = None
212
+ callback_on_step_end_tensor_inputs = ["latents"]
213
+ # kwargs are additional keyword arguments and don't need a default value set here.
214
+
215
+ # 0. Default height and width to unet
216
+ height = height or self.pipe.default_sample_size * self.pipe.vae_scale_factor
217
+ width = width or self.pipe.default_sample_size * self.pipe.vae_scale_factor
218
+
219
+ original_size = original_size or (height, width)
220
+ target_size = target_size or (height, width)
221
+
222
+ # 1. Check inputs. skipped.
223
+
224
+ self.pipe._guidance_scale = self.guidance_scale
225
+ self.pipe._guidance_rescale = guidance_rescale
226
+ self.pipe._clip_skip = clip_skip
227
+ self.pipe._cross_attention_kwargs = cross_attention_kwargs
228
+ self.pipe._denoising_end = denoising_end
229
+ self.pipe._interrupt = False
230
+
231
+ # 2. Define call parameters
232
+ list_mixing_coeffs = self.prepare_mixing(mixing_coeffs, list_latents_mixing)
233
+ batch_size = 1
234
+
235
+ device = self.pipe._execution_device
236
+
237
+ # 3. Encode input prompt
238
+ lora_scale = None
239
+ (
240
+ prompt_embeds,
241
+ negative_prompt_embeds,
242
+ pooled_prompt_embeds,
243
+ negative_pooled_prompt_embeds,
244
+ ) = text_embeddings
245
+
246
+ # 4. Prepare timesteps
247
+ timesteps, num_inference_steps = retrieve_timesteps(self.pipe.scheduler, self.num_inference_steps, device, timesteps)
248
+
249
+ # 5. Prepare latent variables
250
+ num_channels_latents = self.pipe.unet.config.in_channels
251
+ latents = latents_start.clone()
252
+ list_latents_out = []
253
+
254
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
255
+ extra_step_kwargs = self.pipe.prepare_extra_step_kwargs(generator, eta)
256
+
257
+ # 7. Prepare added time ids & embeddings
258
+ add_text_embeds = pooled_prompt_embeds
259
+ if self.pipe.text_encoder_2 is None:
260
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
261
+ else:
262
+ text_encoder_projection_dim = self.pipe.text_encoder_2.config.projection_dim
263
+
264
+ add_time_ids = self.pipe._get_add_time_ids(
265
+ original_size,
266
+ crops_coords_top_left,
267
+ target_size,
268
+ dtype=prompt_embeds.dtype,
269
+ text_encoder_projection_dim=text_encoder_projection_dim,
270
+ )
271
+ if negative_original_size is not None and negative_target_size is not None:
272
+ negative_add_time_ids = self.pipe._get_add_time_ids(
273
+ negative_original_size,
274
+ negative_crops_coords_top_left,
275
+ negative_target_size,
276
+ dtype=prompt_embeds.dtype,
277
+ text_encoder_projection_dim=text_encoder_projection_dim,
278
+ )
279
+ else:
280
+ negative_add_time_ids = add_time_ids
281
+
282
+ if self.pipe.do_classifier_free_guidance:
283
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
284
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
285
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
286
+
287
+ prompt_embeds = prompt_embeds.to(device)
288
+ add_text_embeds = add_text_embeds.to(device)
289
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
290
+
291
+ if ip_adapter_image is not None:
292
+ output_hidden_state = False if isinstance(self.pipe.unet.encoder_hid_proj, ImageProjection) else True
293
+ image_embeds, negative_image_embeds = self.pipe.encode_image(
294
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
295
+ )
296
+ if self.pipe.do_classifier_free_guidance:
297
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
298
+ image_embeds = image_embeds.to(device)
299
+
300
+ # 8. Denoising loop
301
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.pipe.scheduler.order, 0)
302
+
303
+ # 9. Optionally get Guidance Scale Embedding
304
+ timestep_cond = None
305
+ if self.pipe.unet.config.time_cond_proj_dim is not None:
306
+ guidance_scale_tensor = torch.tensor(self.pipe.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
307
+ timestep_cond = self.pipe.get_guidance_scale_embedding(
308
+ guidance_scale_tensor, embedding_dim=self.pipe.unet.config.time_cond_proj_dim
309
+ ).to(device=device, dtype=latents.dtype)
310
+
311
+ self.pipe._num_timesteps = len(timesteps)
312
+ for i, t in enumerate(timesteps):
313
+ # Set the right starting latents
314
+ # Write latents out and skip
315
+ if i < idx_start:
316
+ list_latents_out.append(None)
317
+ continue
318
+ elif i == idx_start:
319
+ latents = latents_start.clone()
320
+
321
+ # Mix latents for crossfeeding
322
+ if i > 0 and list_mixing_coeffs[i] > 0:
323
+ latents_mixtarget = list_latents_mixing[i - 1].clone()
324
+ latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
325
+
326
+
327
+ # expand the latents if we are doing classifier free guidance
328
+ latent_model_input = torch.cat([latents] * 2) if self.pipe.do_classifier_free_guidance else latents
329
+
330
+ latent_model_input = self.pipe.scheduler.scale_model_input(latent_model_input, t)
331
+
332
+ # predict the noise residual
333
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
334
+ if ip_adapter_image is not None:
335
+ added_cond_kwargs["image_embeds"] = image_embeds
336
+ noise_pred = self.pipe.unet(
337
+ latent_model_input,
338
+ t,
339
+ encoder_hidden_states=prompt_embeds,
340
+ timestep_cond=timestep_cond,
341
+ cross_attention_kwargs=self.pipe.cross_attention_kwargs,
342
+ added_cond_kwargs=added_cond_kwargs,
343
+ return_dict=False,
344
+ )[0]
345
+
346
+ # perform guidance
347
+ if self.pipe.do_classifier_free_guidance:
348
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
349
+ noise_pred = noise_pred_uncond + self.pipe.guidance_scale * (noise_pred_text - noise_pred_uncond)
350
+
351
+ if self.pipe.do_classifier_free_guidance and self.pipe.guidance_rescale > 0.0:
352
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
353
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.pipe.guidance_rescale)
354
+
355
+ # compute the previous noisy sample x_t -> x_t-1
356
+ latents = self.pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
357
+
358
+ # Append latents
359
+ list_latents_out.append(latents.clone())
360
+
361
+
362
+
363
+ if return_image:
364
+ return self.latent2image(latents)
365
+ else:
366
+ return list_latents_out
367
+
368
+
369
+
370
+ #%%
371
+ if __name__ == "__main__":
372
+ from PIL import Image
373
+ from diffusers import AutoencoderTiny
374
+ # pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
375
+ pretrained_model_name_or_path = "stabilityai/sdxl-turbo"
376
+ pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16, variant="fp16")
377
+ pipe.to("cuda")
378
+ #%
379
+ # pipe.vae = AutoencoderTiny.from_pretrained('madebyollin/taesdxl', torch_device='cuda', torch_dtype=torch.float16)
380
+ # pipe.vae = pipe.vae.cuda()
381
+ #%% resanity
382
+ import time
383
+ self = DiffusersHolder(pipe)
384
+ prompt1 = "photo of underwater landscape, fish, und the sea, incredible detail, high resolution"
385
+ negative_prompt = "blurry, ugly, pale"
386
+ num_inference_steps = 4
387
+ guidance_scale = 0
388
+
389
+ self.set_num_inference_steps(num_inference_steps)
390
+ self.guidance_scale = guidance_scale
391
+
392
+ prefix='turbo'
393
+ for i in range(10):
394
+ self.set_negative_prompt(negative_prompt)
395
+
396
+ text_embeddings = self.get_text_embedding(prompt1)
397
+ latents_start = self.get_noise(np.random.randint(111111))
398
+
399
+ t0 = time.time()
400
+
401
+ # img_refx = self.pipe(prompt=prompt1, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale)[0]
402
+
403
+ img_refx = self.run_diffusion_sd_xl(text_embeddings=text_embeddings, latents_start=latents_start, return_image=False)
404
+
405
+ dt_ref = time.time() - t0
406
+ img_refx.save(f"x_{prefix}_{i}.jpg")
407
+
408
+
409
+
410
+
411
+
412
+ # xxx
413
+
414
+ # self.set_negative_prompt(negative_prompt)
415
+ # self.set_num_inference_steps(num_inference_steps)
416
+ # text_embeddings1 = self.get_text_embedding(prompt1)
417
+ # prompt_embeds1, negative_prompt_embeds1, pooled_prompt_embeds1, negative_pooled_prompt_embeds1 = text_embeddings1
418
+ # latents_start = self.get_noise(420)
419
+ # t0 = time.time()
420
+ # img_dh = self.run_diffusion_sd_xl_resanity(text_embeddings1, latents_start, idx_start=0, return_image=True)
421
+ # dt_dh = time.time() - t0
422
+
423
+
424
+
425
+
426
+ # xxxx
427
+ # #%%
428
+
429
+ # self = DiffusersHolder(pipe)
430
+ # num_inference_steps = 4
431
+ # self.set_num_inference_steps(num_inference_steps)
432
+ # latents_start = self.get_noise(420)
433
+ # guidance_scale = 0
434
+ # self.guidance_scale = 0
435
+
436
+ # #% get embeddings1
437
+ # prompt1 = "Photo of a colorful landscape with a blue sky with clouds"
438
+ # text_embeddings1 = self.get_text_embedding(prompt1)
439
+ # prompt_embeds1, negative_prompt_embeds1, pooled_prompt_embeds1, negative_pooled_prompt_embeds1 = text_embeddings1
440
+
441
+ # #% get embeddings2
442
+ # prompt2 = "Photo of a tree"
443
+ # text_embeddings2 = self.get_text_embedding(prompt2)
444
+ # prompt_embeds2, negative_prompt_embeds2, pooled_prompt_embeds2, negative_pooled_prompt_embeds2 = text_embeddings2
445
+
446
+ # latents1 = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=False)
447
+
448
+ # img1 = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=True)
449
+ # img1B = self.run_diffusion_sd_xl(text_embeddings1, latents_start, idx_start=0, return_image=True)
450
+
451
+
452
+
453
+ # # latents2 = self.run_diffusion_sd_xl(text_embeddings2, latents_start, idx_start=0, return_image=False)
454
+
455
+
456
+ # # # check if brings same image if restarted
457
+ # # img1_return = self.run_diffusion_sd_xl(text_embeddings1, latents1[idx_mix-1], idx_start=idx_start, return_image=True)
458
+
459
+ # # mix latents
460
+ # #%%
461
+ # idx_mix = 2
462
+ # fract=0.8
463
+ # latents_start_mixed = interpolate_spherical(latents1[idx_mix-1], latents2[idx_mix-1], fract)
464
+ # prompt_embeds = interpolate_spherical(prompt_embeds1, prompt_embeds2, fract)
465
+ # pooled_prompt_embeds = interpolate_spherical(pooled_prompt_embeds1, pooled_prompt_embeds2, fract)
466
+ # negative_prompt_embeds = negative_prompt_embeds1
467
+ # negative_pooled_prompt_embeds = negative_pooled_prompt_embeds1
468
+ # text_embeddings_mix = [prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds]
469
+
470
+ # self.run_diffusion_sd_xl(text_embeddings_mix, latents_start_mixed, idx_start=idx_start, return_image=True)
471
+
472
+
473
+
474
+
latentblending/gradio_ui.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ torch.backends.cudnn.benchmark = False
4
+ torch.set_grad_enabled(False)
5
+ import numpy as np
6
+ import warnings
7
+ warnings.filterwarnings('ignore')
8
+ from tqdm.auto import tqdm
9
+ from PIL import Image
10
+ import gradio as gr
11
+ import shutil
12
+ import uuid
13
+ from diffusers import AutoPipelineForText2Image
14
+ from latentblending.blending_engine import BlendingEngine
15
+ import datetime
16
+
17
+ warnings.filterwarnings('ignore')
18
+ torch.set_grad_enabled(False)
19
+ torch.backends.cudnn.benchmark = False
20
+ import json
21
+
22
+
23
+
24
+ class BlendingFrontend():
25
+ def __init__(
26
+ self,
27
+ be,
28
+ share=False):
29
+ r"""
30
+ Gradio Helper Class to collect UI data and start latent blending.
31
+ Args:
32
+ be:
33
+ Blendingengine
34
+ share: bool
35
+ Set true to get a shareable gradio link (e.g. for running a remote server)
36
+ """
37
+ self.be = be
38
+ self.share = share
39
+
40
+ # UI Defaults
41
+ self.seed1 = 420
42
+ self.seed2 = 420
43
+ self.prompt1 = ""
44
+ self.prompt2 = ""
45
+ self.negative_prompt = ""
46
+
47
+ # Vars
48
+ self.prompt = None
49
+ self.negative_prompt = None
50
+ self.list_seeds = []
51
+ self.idx_movie = 0
52
+ self.data = []
53
+
54
+ def take_image0(self):
55
+ return self.take_image(0)
56
+
57
+ def take_image1(self):
58
+ return self.take_image(1)
59
+
60
+ def take_image2(self):
61
+ return self.take_image(2)
62
+
63
+ def take_image3(self):
64
+ return self.take_image(3)
65
+
66
+
67
+ def take_image(self, id_img):
68
+ if self.prompt is None:
69
+ print("Cannot take because no prompt was set!")
70
+ return [None, None, None, None, ""]
71
+ if self.idx_movie == 0:
72
+ current_time = datetime.datetime.now()
73
+ self.fp_out = "movie_" + current_time.strftime("%y%m%d_%H%M") + ".json"
74
+ self.data.append({"settings": "sdxl", "width": bf.be.dh.width_img, "height": self.be.dh.height_img, "num_inference_steps": self.be.dh.num_inference_steps})
75
+
76
+ seed = self.list_seeds[id_img]
77
+
78
+ self.data.append({"iteration": self.idx_movie, "seed": seed, "prompt": self.prompt, "negative_prompt": self.negative_prompt})
79
+
80
+ # Write the data list to a JSON file
81
+ with open(self.fp_out, 'w') as f:
82
+ json.dump(self.data, f, indent=4)
83
+
84
+ self.idx_movie += 1
85
+ self.prompt = None
86
+ return [None, None, None, None, ""]
87
+
88
+
89
+ def compute_imgs(self, prompt, negative_prompt):
90
+ self.prompt = prompt
91
+ self.negative_prompt = negative_prompt
92
+ self.be.set_prompt1(prompt)
93
+ self.be.set_prompt2(prompt)
94
+ self.be.set_negative_prompt(negative_prompt)
95
+ self.list_seeds = []
96
+ self.list_images = []
97
+ for i in range(4):
98
+ seed = np.random.randint(0, 1000000000)
99
+ self.be.seed1 = seed
100
+ self.list_seeds.append(seed)
101
+ img = self.be.compute_latents1(return_image=True)
102
+ self.list_images.append(img)
103
+ return self.list_images
104
+
105
+
106
+
107
+
108
+ if __name__ == "__main__":
109
+
110
+ width = 786
111
+ height = 1024
112
+ num_inference_steps = 4
113
+
114
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
115
+ # pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16")
116
+ pipe.to("cuda")
117
+
118
+ be = BlendingEngine(pipe)
119
+ be.set_dimensions((width, height))
120
+ be.set_num_inference_steps(num_inference_steps)
121
+
122
+ bf = BlendingFrontend(be)
123
+
124
+ with gr.Blocks() as demo:
125
+
126
+ with gr.Row():
127
+ prompt = gr.Textbox(label="prompt")
128
+ negative_prompt = gr.Textbox(label="negative prompt")
129
+
130
+ with gr.Row():
131
+ b_compute = gr.Button('compute new images', variant='primary')
132
+
133
+ with gr.Row():
134
+ with gr.Column():
135
+ img0 = gr.Image(label="seed1")
136
+ b_take0 = gr.Button('take', variant='primary')
137
+ with gr.Column():
138
+ img1 = gr.Image(label="seed2")
139
+ b_take1 = gr.Button('take', variant='primary')
140
+ with gr.Column():
141
+ img2 = gr.Image(label="seed3")
142
+ b_take2 = gr.Button('take', variant='primary')
143
+ with gr.Column():
144
+ img3 = gr.Image(label="seed4")
145
+ b_take3 = gr.Button('take', variant='primary')
146
+
147
+ b_compute.click(bf.compute_imgs, inputs=[prompt, negative_prompt], outputs=[img0, img1, img2, img3])
148
+ b_take0.click(bf.take_image0, outputs=[img0, img1, img2, img3, prompt])
149
+ b_take1.click(bf.take_image1, outputs=[img0, img1, img2, img3, prompt])
150
+ b_take2.click(bf.take_image2, outputs=[img0, img1, img2, img3, prompt])
151
+ b_take3.click(bf.take_image3, outputs=[img0, img1, img2, img3, prompt])
152
+
153
+ demo.launch(share=bf.share, inbrowser=True, inline=False, server_name="10.40.49.100")
utils.py → latentblending/utils.py RENAMED
@@ -24,7 +24,7 @@ import datetime
24
  from typing import List, Union
25
  torch.set_grad_enabled(False)
26
  import yaml
27
-
28
 
29
  @torch.no_grad()
30
  def interpolate_spherical(p0, p1, fract_mixing: float):
@@ -142,6 +142,8 @@ def add_frames_linear_interp(
142
  if nmb_frames_missing < 1:
143
  return list_imgs
144
 
 
 
145
  list_imgs_float = [img.astype(np.float32) for img in list_imgs]
146
  # Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
147
  mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
24
  from typing import List, Union
25
  torch.set_grad_enabled(False)
26
  import yaml
27
+ import PIL
28
 
29
  @torch.no_grad()
30
  def interpolate_spherical(p0, p1, fract_mixing: float):
142
  if nmb_frames_missing < 1:
143
  return list_imgs
144
 
145
+ if type(list_imgs[0]) == PIL.Image.Image:
146
+ list_imgs = [np.asarray(l) for l in list_imgs]
147
  list_imgs_float = [img.astype(np.float32) for img in list_imgs]
148
  # Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
149
  mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
ldm/__pycache__/util.cpython-310.pyc DELETED
Binary file (6.18 kB)
ldm/__pycache__/util.cpython-38.pyc DELETED
Binary file (6.15 kB)
ldm/__pycache__/util.cpython-39.pyc DELETED
Binary file (6.16 kB)
ldm/data/__init__.py DELETED
File without changes
ldm/data/util.py DELETED
@@ -1,24 +0,0 @@
1
- import torch
2
-
3
- from ldm.modules.midas.api import load_midas_transform
4
-
5
-
6
- class AddMiDaS(object):
7
- def __init__(self, model_type):
8
- super().__init__()
9
- self.transform = load_midas_transform(model_type)
10
-
11
- def pt2np(self, x):
12
- x = ((x + 1.0) * .5).detach().cpu().numpy()
13
- return x
14
-
15
- def np2pt(self, x):
16
- x = torch.from_numpy(x) * 2 - 1.
17
- return x
18
-
19
- def __call__(self, sample):
20
- # sample['jpg'] is tensor hwc in [-1, 1] at this point
21
- x = self.pt2np(sample['jpg'])
22
- x = self.transform({"image": x})["image"]
23
- sample['midas_in'] = x
24
- return sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/ldm DELETED
@@ -1 +0,0 @@
1
- ldm
 
ldm/models/__pycache__/autoencoder.cpython-310.pyc DELETED
Binary file (7.72 kB)
ldm/models/__pycache__/autoencoder.cpython-38.pyc DELETED
Binary file (7.61 kB)
ldm/models/__pycache__/autoencoder.cpython-39.pyc DELETED
Binary file (7.68 kB)
ldm/models/autoencoder.py DELETED
@@ -1,219 +0,0 @@
1
- import torch
2
- import pytorch_lightning as pl
3
- import torch.nn.functional as F
4
- from contextlib import contextmanager
5
-
6
- from ldm.modules.diffusionmodules.model import Encoder, Decoder
7
- from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
8
-
9
- from ldm.util import instantiate_from_config
10
- from ldm.modules.ema import LitEma
11
-
12
-
13
- class AutoencoderKL(pl.LightningModule):
14
- def __init__(self,
15
- ddconfig,
16
- lossconfig,
17
- embed_dim,
18
- ckpt_path=None,
19
- ignore_keys=[],
20
- image_key="image",
21
- colorize_nlabels=None,
22
- monitor=None,
23
- ema_decay=None,
24
- learn_logvar=False
25
- ):
26
- super().__init__()
27
- self.learn_logvar = learn_logvar
28
- self.image_key = image_key
29
- self.encoder = Encoder(**ddconfig)
30
- self.decoder = Decoder(**ddconfig)
31
- self.loss = instantiate_from_config(lossconfig)
32
- assert ddconfig["double_z"]
33
- self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
34
- self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
35
- self.embed_dim = embed_dim
36
- if colorize_nlabels is not None:
37
- assert type(colorize_nlabels)==int
38
- self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
39
- if monitor is not None:
40
- self.monitor = monitor
41
-
42
- self.use_ema = ema_decay is not None
43
- if self.use_ema:
44
- self.ema_decay = ema_decay
45
- assert 0. < ema_decay < 1.
46
- self.model_ema = LitEma(self, decay=ema_decay)
47
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
48
-
49
- if ckpt_path is not None:
50
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
51
-
52
- def init_from_ckpt(self, path, ignore_keys=list()):
53
- sd = torch.load(path, map_location="cpu")["state_dict"]
54
- keys = list(sd.keys())
55
- for k in keys:
56
- for ik in ignore_keys:
57
- if k.startswith(ik):
58
- print("Deleting key {} from state_dict.".format(k))
59
- del sd[k]
60
- self.load_state_dict(sd, strict=False)
61
- print(f"Restored from {path}")
62
-
63
- @contextmanager
64
- def ema_scope(self, context=None):
65
- if self.use_ema:
66
- self.model_ema.store(self.parameters())
67
- self.model_ema.copy_to(self)
68
- if context is not None:
69
- print(f"{context}: Switched to EMA weights")
70
- try:
71
- yield None
72
- finally:
73
- if self.use_ema:
74
- self.model_ema.restore(self.parameters())
75
- if context is not None:
76
- print(f"{context}: Restored training weights")
77
-
78
- def on_train_batch_end(self, *args, **kwargs):
79
- if self.use_ema:
80
- self.model_ema(self)
81
-
82
- def encode(self, x):
83
- h = self.encoder(x)
84
- moments = self.quant_conv(h)
85
- posterior = DiagonalGaussianDistribution(moments)
86
- return posterior
87
-
88
- def decode(self, z):
89
- z = self.post_quant_conv(z)
90
- dec = self.decoder(z)
91
- return dec
92
-
93
- def forward(self, input, sample_posterior=True):
94
- posterior = self.encode(input)
95
- if sample_posterior:
96
- z = posterior.sample()
97
- else:
98
- z = posterior.mode()
99
- dec = self.decode(z)
100
- return dec, posterior
101
-
102
- def get_input(self, batch, k):
103
- x = batch[k]
104
- if len(x.shape) == 3:
105
- x = x[..., None]
106
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
107
- return x
108
-
109
- def training_step(self, batch, batch_idx, optimizer_idx):
110
- inputs = self.get_input(batch, self.image_key)
111
- reconstructions, posterior = self(inputs)
112
-
113
- if optimizer_idx == 0:
114
- # train encoder+decoder+logvar
115
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
116
- last_layer=self.get_last_layer(), split="train")
117
- self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
118
- self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
119
- return aeloss
120
-
121
- if optimizer_idx == 1:
122
- # train the discriminator
123
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
124
- last_layer=self.get_last_layer(), split="train")
125
-
126
- self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
127
- self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
128
- return discloss
129
-
130
- def validation_step(self, batch, batch_idx):
131
- log_dict = self._validation_step(batch, batch_idx)
132
- with self.ema_scope():
133
- log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
134
- return log_dict
135
-
136
- def _validation_step(self, batch, batch_idx, postfix=""):
137
- inputs = self.get_input(batch, self.image_key)
138
- reconstructions, posterior = self(inputs)
139
- aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
140
- last_layer=self.get_last_layer(), split="val"+postfix)
141
-
142
- discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
143
- last_layer=self.get_last_layer(), split="val"+postfix)
144
-
145
- self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
146
- self.log_dict(log_dict_ae)
147
- self.log_dict(log_dict_disc)
148
- return self.log_dict
149
-
150
- def configure_optimizers(self):
151
- lr = self.learning_rate
152
- ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
153
- self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
154
- if self.learn_logvar:
155
- print(f"{self.__class__.__name__}: Learning logvar")
156
- ae_params_list.append(self.loss.logvar)
157
- opt_ae = torch.optim.Adam(ae_params_list,
158
- lr=lr, betas=(0.5, 0.9))
159
- opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
160
- lr=lr, betas=(0.5, 0.9))
161
- return [opt_ae, opt_disc], []
162
-
163
- def get_last_layer(self):
164
- return self.decoder.conv_out.weight
165
-
166
- @torch.no_grad()
167
- def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
168
- log = dict()
169
- x = self.get_input(batch, self.image_key)
170
- x = x.to(self.device)
171
- if not only_inputs:
172
- xrec, posterior = self(x)
173
- if x.shape[1] > 3:
174
- # colorize with random projection
175
- assert xrec.shape[1] > 3
176
- x = self.to_rgb(x)
177
- xrec = self.to_rgb(xrec)
178
- log["samples"] = self.decode(torch.randn_like(posterior.sample()))
179
- log["reconstructions"] = xrec
180
- if log_ema or self.use_ema:
181
- with self.ema_scope():
182
- xrec_ema, posterior_ema = self(x)
183
- if x.shape[1] > 3:
184
- # colorize with random projection
185
- assert xrec_ema.shape[1] > 3
186
- xrec_ema = self.to_rgb(xrec_ema)
187
- log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
188
- log["reconstructions_ema"] = xrec_ema
189
- log["inputs"] = x
190
- return log
191
-
192
- def to_rgb(self, x):
193
- assert self.image_key == "segmentation"
194
- if not hasattr(self, "colorize"):
195
- self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
196
- x = F.conv2d(x, weight=self.colorize)
197
- x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
198
- return x
199
-
200
-
201
- class IdentityFirstStage(torch.nn.Module):
202
- def __init__(self, *args, vq_interface=False, **kwargs):
203
- self.vq_interface = vq_interface
204
- super().__init__()
205
-
206
- def encode(self, x, *args, **kwargs):
207
- return x
208
-
209
- def decode(self, x, *args, **kwargs):
210
- return x
211
-
212
- def quantize(self, x, *args, **kwargs):
213
- if self.vq_interface:
214
- return x, None, [None, None, None]
215
- return x
216
-
217
- def forward(self, x, *args, **kwargs):
218
- return x
219
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/__init__.py DELETED
File without changes
ldm/models/diffusion/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (155 Bytes)
ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc DELETED
Binary file (153 Bytes)
ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (153 Bytes)
ldm/models/diffusion/__pycache__/ddim.cpython-310.pyc DELETED
Binary file (9.33 kB)
ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc DELETED
Binary file (9.27 kB)
ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc DELETED
Binary file (9.19 kB)
ldm/models/diffusion/__pycache__/ddpm.cpython-310.pyc DELETED
Binary file (52.8 kB)
ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc DELETED
Binary file (53 kB)
ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc DELETED
Binary file (53 kB)
ldm/models/diffusion/__pycache__/plms.cpython-39.pyc DELETED
Binary file (7.46 kB)
ldm/models/diffusion/__pycache__/sampling_util.cpython-39.pyc DELETED
Binary file (1.07 kB)
ldm/models/diffusion/ddim.py DELETED
@@ -1,336 +0,0 @@
1
- """SAMPLING ONLY."""
2
-
3
- import torch
4
- import numpy as np
5
- from tqdm import tqdm
6
-
7
- from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
8
-
9
-
10
- class DDIMSampler(object):
11
- def __init__(self, model, schedule="linear", **kwargs):
12
- super().__init__()
13
- self.model = model
14
- self.ddpm_num_timesteps = model.num_timesteps
15
- self.schedule = schedule
16
-
17
- def register_buffer(self, name, attr):
18
- if type(attr) == torch.Tensor:
19
- if attr.device != torch.device("cuda"):
20
- attr = attr.to(torch.device("cuda"))
21
- setattr(self, name, attr)
22
-
23
- def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
24
- self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
25
- num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
26
- alphas_cumprod = self.model.alphas_cumprod
27
- assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
28
- to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
29
-
30
- self.register_buffer('betas', to_torch(self.model.betas))
31
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
32
- self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
33
-
34
- # calculations for diffusion q(x_t | x_{t-1}) and others
35
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
36
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
37
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
38
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
39
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
40
-
41
- # ddim sampling parameters
42
- ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
43
- ddim_timesteps=self.ddim_timesteps,
44
- eta=ddim_eta,verbose=verbose)
45
- self.register_buffer('ddim_sigmas', ddim_sigmas)
46
- self.register_buffer('ddim_alphas', ddim_alphas)
47
- self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
48
- self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
49
- sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
50
- (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
51
- 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
52
- self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
53
-
54
- @torch.no_grad()
55
- def sample(self,
56
- S,
57
- batch_size,
58
- shape,
59
- conditioning=None,
60
- callback=None,
61
- normals_sequence=None,
62
- img_callback=None,
63
- quantize_x0=False,
64
- eta=0.,
65
- mask=None,
66
- x0=None,
67
- temperature=1.,
68
- noise_dropout=0.,
69
- score_corrector=None,
70
- corrector_kwargs=None,
71
- verbose=True,
72
- x_T=None,
73
- log_every_t=100,
74
- unconditional_guidance_scale=1.,
75
- unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
76
- dynamic_threshold=None,
77
- ucg_schedule=None,
78
- **kwargs
79
- ):
80
- if conditioning is not None:
81
- if isinstance(conditioning, dict):
82
- ctmp = conditioning[list(conditioning.keys())[0]]
83
- while isinstance(ctmp, list): ctmp = ctmp[0]
84
- cbs = ctmp.shape[0]
85
- if cbs != batch_size:
86
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
87
-
88
- elif isinstance(conditioning, list):
89
- for ctmp in conditioning:
90
- if ctmp.shape[0] != batch_size:
91
- print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
92
-
93
- else:
94
- if conditioning.shape[0] != batch_size:
95
- print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
96
-
97
- self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
98
- # sampling
99
- C, H, W = shape
100
- size = (batch_size, C, H, W)
101
- print(f'Data shape for DDIM sampling is {size}, eta {eta}')
102
-
103
- samples, intermediates = self.ddim_sampling(conditioning, size,
104
- callback=callback,
105
- img_callback=img_callback,
106
- quantize_denoised=quantize_x0,
107
- mask=mask, x0=x0,
108
- ddim_use_original_steps=False,
109
- noise_dropout=noise_dropout,
110
- temperature=temperature,
111
- score_corrector=score_corrector,
112
- corrector_kwargs=corrector_kwargs,
113
- x_T=x_T,
114
- log_every_t=log_every_t,
115
- unconditional_guidance_scale=unconditional_guidance_scale,
116
- unconditional_conditioning=unconditional_conditioning,
117
- dynamic_threshold=dynamic_threshold,
118
- ucg_schedule=ucg_schedule
119
- )
120
- return samples, intermediates
121
-
122
- @torch.no_grad()
123
- def ddim_sampling(self, cond, shape,
124
- x_T=None, ddim_use_original_steps=False,
125
- callback=None, timesteps=None, quantize_denoised=False,
126
- mask=None, x0=None, img_callback=None, log_every_t=100,
127
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
128
- unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
129
- ucg_schedule=None):
130
- device = self.model.betas.device
131
- b = shape[0]
132
- if x_T is None:
133
- img = torch.randn(shape, device=device)
134
- else:
135
- img = x_T
136
-
137
- if timesteps is None:
138
- timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
139
- elif timesteps is not None and not ddim_use_original_steps:
140
- subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
141
- timesteps = self.ddim_timesteps[:subset_end]
142
-
143
- intermediates = {'x_inter': [img], 'pred_x0': [img]}
144
- time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
145
- total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
146
- print(f"Running DDIM Sampling with {total_steps} timesteps")
147
-
148
- iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
149
-
150
- for i, step in enumerate(iterator):
151
- index = total_steps - i - 1
152
- ts = torch.full((b,), step, device=device, dtype=torch.long)
153
-
154
- if mask is not None:
155
- assert x0 is not None
156
- img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
157
- img = img_orig * mask + (1. - mask) * img
158
-
159
- if ucg_schedule is not None:
160
- assert len(ucg_schedule) == len(time_range)
161
- unconditional_guidance_scale = ucg_schedule[i]
162
-
163
- outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
164
- quantize_denoised=quantize_denoised, temperature=temperature,
165
- noise_dropout=noise_dropout, score_corrector=score_corrector,
166
- corrector_kwargs=corrector_kwargs,
167
- unconditional_guidance_scale=unconditional_guidance_scale,
168
- unconditional_conditioning=unconditional_conditioning,
169
- dynamic_threshold=dynamic_threshold)
170
- img, pred_x0 = outs
171
- if callback: callback(i)
172
- if img_callback: img_callback(pred_x0, i)
173
-
174
- if index % log_every_t == 0 or index == total_steps - 1:
175
- intermediates['x_inter'].append(img)
176
- intermediates['pred_x0'].append(pred_x0)
177
-
178
- return img, intermediates
179
-
180
- @torch.no_grad()
181
- def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
182
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
183
- unconditional_guidance_scale=1., unconditional_conditioning=None,
184
- dynamic_threshold=None):
185
- b, *_, device = *x.shape, x.device
186
-
187
- if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
188
- model_output = self.model.apply_model(x, t, c)
189
- else:
190
- x_in = torch.cat([x] * 2)
191
- t_in = torch.cat([t] * 2)
192
- if isinstance(c, dict):
193
- assert isinstance(unconditional_conditioning, dict)
194
- c_in = dict()
195
- for k in c:
196
- if isinstance(c[k], list):
197
- c_in[k] = [torch.cat([
198
- unconditional_conditioning[k][i],
199
- c[k][i]]) for i in range(len(c[k]))]
200
- else:
201
- c_in[k] = torch.cat([
202
- unconditional_conditioning[k],
203
- c[k]])
204
- elif isinstance(c, list):
205
- c_in = list()
206
- assert isinstance(unconditional_conditioning, list)
207
- for i in range(len(c)):
208
- c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
209
- else:
210
- c_in = torch.cat([unconditional_conditioning, c])
211
- model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
212
- model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
213
-
214
- if self.model.parameterization == "v":
215
- e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
216
- else:
217
- e_t = model_output
218
-
219
- if score_corrector is not None:
220
- assert self.model.parameterization == "eps", 'not implemented'
221
- e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
222
-
223
- alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
224
- alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
225
- sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
226
- sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
227
- # select parameters corresponding to the currently considered timestep
228
- a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
229
- a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
230
- sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
231
- sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
232
-
233
- # current prediction for x_0
234
- if self.model.parameterization != "v":
235
- pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
236
- else:
237
- pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
238
-
239
- if quantize_denoised:
240
- pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
241
-
242
- if dynamic_threshold is not None:
243
- raise NotImplementedError()
244
-
245
- # direction pointing to x_t
246
- dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
247
- noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
248
- if noise_dropout > 0.:
249
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
250
- x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
251
- return x_prev, pred_x0
252
-
253
- @torch.no_grad()
254
- def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
255
- unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
256
- num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
257
-
258
- assert t_enc <= num_reference_steps
259
- num_steps = t_enc
260
-
261
- if use_original_steps:
262
- alphas_next = self.alphas_cumprod[:num_steps]
263
- alphas = self.alphas_cumprod_prev[:num_steps]
264
- else:
265
- alphas_next = self.ddim_alphas[:num_steps]
266
- alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
267
-
268
- x_next = x0
269
- intermediates = []
270
- inter_steps = []
271
- for i in tqdm(range(num_steps), desc='Encoding Image'):
272
- t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
273
- if unconditional_guidance_scale == 1.:
274
- noise_pred = self.model.apply_model(x_next, t, c)
275
- else:
276
- assert unconditional_conditioning is not None
277
- e_t_uncond, noise_pred = torch.chunk(
278
- self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
279
- torch.cat((unconditional_conditioning, c))), 2)
280
- noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
281
-
282
- xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
283
- weighted_noise_pred = alphas_next[i].sqrt() * (
284
- (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
285
- x_next = xt_weighted + weighted_noise_pred
286
- if return_intermediates and i % (
287
- num_steps // return_intermediates) == 0 and i < num_steps - 1:
288
- intermediates.append(x_next)
289
- inter_steps.append(i)
290
- elif return_intermediates and i >= num_steps - 2:
291
- intermediates.append(x_next)
292
- inter_steps.append(i)
293
- if callback: callback(i)
294
-
295
- out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
296
- if return_intermediates:
297
- out.update({'intermediates': intermediates})
298
- return x_next, out
299
-
300
- @torch.no_grad()
301
- def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
302
- # fast, but does not allow for exact reconstruction
303
- # t serves as an index to gather the correct alphas
304
- if use_original_steps:
305
- sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
306
- sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
307
- else:
308
- sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
309
- sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
310
-
311
- if noise is None:
312
- noise = torch.randn_like(x0)
313
- return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
314
- extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
315
-
316
- @torch.no_grad()
317
- def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
318
- use_original_steps=False, callback=None):
319
-
320
- timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
321
- timesteps = timesteps[:t_start]
322
-
323
- time_range = np.flip(timesteps)
324
- total_steps = timesteps.shape[0]
325
- print(f"Running DDIM Sampling with {total_steps} timesteps")
326
-
327
- iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
328
- x_dec = x_latent
329
- for i, step in enumerate(iterator):
330
- index = total_steps - i - 1
331
- ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
332
- x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
333
- unconditional_guidance_scale=unconditional_guidance_scale,
334
- unconditional_conditioning=unconditional_conditioning)
335
- if callback: callback(i)
336
- return x_dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/ddpm.py DELETED
@@ -1,1795 +0,0 @@
1
- """
2
- wild mixture of
3
- https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
4
- https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
5
- https://github.com/CompVis/taming-transformers
6
- -- merci
7
- """
8
-
9
- import torch
10
- import torch.nn as nn
11
- import numpy as np
12
- import pytorch_lightning as pl
13
- from torch.optim.lr_scheduler import LambdaLR
14
- from einops import rearrange, repeat
15
- from contextlib import contextmanager, nullcontext
16
- from functools import partial
17
- import itertools
18
- from tqdm import tqdm
19
- from torchvision.utils import make_grid
20
- from pytorch_lightning.utilities.distributed import rank_zero_only
21
- from omegaconf import ListConfig
22
-
23
- from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
24
- from ldm.modules.ema import LitEma
25
- from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
26
- from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
27
- from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
28
- from ldm.models.diffusion.ddim import DDIMSampler
29
-
30
-
31
- __conditioning_keys__ = {'concat': 'c_concat',
32
- 'crossattn': 'c_crossattn',
33
- 'adm': 'y'}
34
-
35
-
36
- def disabled_train(self, mode=True):
37
- """Overwrite model.train with this function to make sure train/eval mode
38
- does not change anymore."""
39
- return self
40
-
41
-
42
- def uniform_on_device(r1, r2, shape, device):
43
- return (r1 - r2) * torch.rand(*shape, device=device) + r2
44
-
45
-
46
- class DDPM(pl.LightningModule):
47
- # classic DDPM with Gaussian diffusion, in image space
48
- def __init__(self,
49
- unet_config,
50
- timesteps=1000,
51
- beta_schedule="linear",
52
- loss_type="l2",
53
- ckpt_path=None,
54
- ignore_keys=[],
55
- load_only_unet=False,
56
- monitor="val/loss",
57
- use_ema=True,
58
- first_stage_key="image",
59
- image_size=256,
60
- channels=3,
61
- log_every_t=100,
62
- clip_denoised=True,
63
- linear_start=1e-4,
64
- linear_end=2e-2,
65
- cosine_s=8e-3,
66
- given_betas=None,
67
- original_elbo_weight=0.,
68
- v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
69
- l_simple_weight=1.,
70
- conditioning_key=None,
71
- parameterization="eps", # all assuming fixed variance schedules
72
- scheduler_config=None,
73
- use_positional_encodings=False,
74
- learn_logvar=False,
75
- logvar_init=0.,
76
- make_it_fit=False,
77
- ucg_training=None,
78
- reset_ema=False,
79
- reset_num_ema_updates=False,
80
- ):
81
- super().__init__()
82
- assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
83
- self.parameterization = parameterization
84
- print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
85
- self.cond_stage_model = None
86
- self.clip_denoised = clip_denoised
87
- self.log_every_t = log_every_t
88
- self.first_stage_key = first_stage_key
89
- self.image_size = image_size # try conv?
90
- self.channels = channels
91
- self.use_positional_encodings = use_positional_encodings
92
- self.model = DiffusionWrapper(unet_config, conditioning_key)
93
- count_params(self.model, verbose=True)
94
- self.use_ema = use_ema
95
- if self.use_ema:
96
- self.model_ema = LitEma(self.model)
97
- print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
98
-
99
- self.use_scheduler = scheduler_config is not None
100
- if self.use_scheduler:
101
- self.scheduler_config = scheduler_config
102
-
103
- self.v_posterior = v_posterior
104
- self.original_elbo_weight = original_elbo_weight
105
- self.l_simple_weight = l_simple_weight
106
-
107
- if monitor is not None:
108
- self.monitor = monitor
109
- self.make_it_fit = make_it_fit
110
- if reset_ema: assert exists(ckpt_path)
111
- if ckpt_path is not None:
112
- self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
113
- if reset_ema:
114
- assert self.use_ema
115
- print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
116
- self.model_ema = LitEma(self.model)
117
- if reset_num_ema_updates:
118
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
119
- assert self.use_ema
120
- self.model_ema.reset_num_updates()
121
-
122
- self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
123
- linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
124
-
125
- self.loss_type = loss_type
126
-
127
- self.learn_logvar = learn_logvar
128
- self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
129
- if self.learn_logvar:
130
- self.logvar = nn.Parameter(self.logvar, requires_grad=True)
131
-
132
- self.ucg_training = ucg_training or dict()
133
- if self.ucg_training:
134
- self.ucg_prng = np.random.RandomState()
135
-
136
- def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
137
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
138
- if exists(given_betas):
139
- betas = given_betas
140
- else:
141
- betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
142
- cosine_s=cosine_s)
143
- alphas = 1. - betas
144
- alphas_cumprod = np.cumprod(alphas, axis=0)
145
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
146
-
147
- timesteps, = betas.shape
148
- self.num_timesteps = int(timesteps)
149
- self.linear_start = linear_start
150
- self.linear_end = linear_end
151
- assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
152
-
153
- to_torch = partial(torch.tensor, dtype=torch.float32)
154
-
155
- self.register_buffer('betas', to_torch(betas))
156
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
157
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
158
-
159
- # calculations for diffusion q(x_t | x_{t-1}) and others
160
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
161
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
162
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
163
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
164
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
165
-
166
- # calculations for posterior q(x_{t-1} | x_t, x_0)
167
- posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
168
- 1. - alphas_cumprod) + self.v_posterior * betas
169
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
170
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
171
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
172
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
173
- self.register_buffer('posterior_mean_coef1', to_torch(
174
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
175
- self.register_buffer('posterior_mean_coef2', to_torch(
176
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
177
-
178
- if self.parameterization == "eps":
179
- lvlb_weights = self.betas ** 2 / (
180
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
181
- elif self.parameterization == "x0":
182
- lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
183
- elif self.parameterization == "v":
184
- lvlb_weights = torch.ones_like(self.betas ** 2 / (
185
- 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
186
- else:
187
- raise NotImplementedError("mu not supported")
188
- lvlb_weights[0] = lvlb_weights[1]
189
- self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
190
- assert not torch.isnan(self.lvlb_weights).all()
191
-
192
- @contextmanager
193
- def ema_scope(self, context=None):
194
- if self.use_ema:
195
- self.model_ema.store(self.model.parameters())
196
- self.model_ema.copy_to(self.model)
197
- if context is not None:
198
- print(f"{context}: Switched to EMA weights")
199
- try:
200
- yield None
201
- finally:
202
- if self.use_ema:
203
- self.model_ema.restore(self.model.parameters())
204
- if context is not None:
205
- print(f"{context}: Restored training weights")
206
-
207
- @torch.no_grad()
208
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
209
- sd = torch.load(path, map_location="cpu")
210
- if "state_dict" in list(sd.keys()):
211
- sd = sd["state_dict"]
212
- keys = list(sd.keys())
213
- for k in keys:
214
- for ik in ignore_keys:
215
- if k.startswith(ik):
216
- print("Deleting key {} from state_dict.".format(k))
217
- del sd[k]
218
- if self.make_it_fit:
219
- n_params = len([name for name, _ in
220
- itertools.chain(self.named_parameters(),
221
- self.named_buffers())])
222
- for name, param in tqdm(
223
- itertools.chain(self.named_parameters(),
224
- self.named_buffers()),
225
- desc="Fitting old weights to new weights",
226
- total=n_params
227
- ):
228
- if not name in sd:
229
- continue
230
- old_shape = sd[name].shape
231
- new_shape = param.shape
232
- assert len(old_shape) == len(new_shape)
233
- if len(new_shape) > 2:
234
- # we only modify first two axes
235
- assert new_shape[2:] == old_shape[2:]
236
- # assumes first axis corresponds to output dim
237
- if not new_shape == old_shape:
238
- new_param = param.clone()
239
- old_param = sd[name]
240
- if len(new_shape) == 1:
241
- for i in range(new_param.shape[0]):
242
- new_param[i] = old_param[i % old_shape[0]]
243
- elif len(new_shape) >= 2:
244
- for i in range(new_param.shape[0]):
245
- for j in range(new_param.shape[1]):
246
- new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
247
-
248
- n_used_old = torch.ones(old_shape[1])
249
- for j in range(new_param.shape[1]):
250
- n_used_old[j % old_shape[1]] += 1
251
- n_used_new = torch.zeros(new_shape[1])
252
- for j in range(new_param.shape[1]):
253
- n_used_new[j] = n_used_old[j % old_shape[1]]
254
-
255
- n_used_new = n_used_new[None, :]
256
- while len(n_used_new.shape) < len(new_shape):
257
- n_used_new = n_used_new.unsqueeze(-1)
258
- new_param /= n_used_new
259
-
260
- sd[name] = new_param
261
-
262
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
263
- sd, strict=False)
264
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
265
- if len(missing) > 0:
266
- print(f"Missing Keys:\n {missing}")
267
- if len(unexpected) > 0:
268
- print(f"\nUnexpected Keys:\n {unexpected}")
269
-
270
- def q_mean_variance(self, x_start, t):
271
- """
272
- Get the distribution q(x_t | x_0).
273
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
274
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
275
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
276
- """
277
- mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
278
- variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
279
- log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
280
- return mean, variance, log_variance
281
-
282
- def predict_start_from_noise(self, x_t, t, noise):
283
- return (
284
- extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
285
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
286
- )
287
-
288
- def predict_start_from_z_and_v(self, x_t, t, v):
289
- # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
290
- # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
291
- return (
292
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
293
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
294
- )
295
-
296
- def predict_eps_from_z_and_v(self, x_t, t, v):
297
- return (
298
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
299
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
300
- )
301
-
302
- def q_posterior(self, x_start, x_t, t):
303
- posterior_mean = (
304
- extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
305
- extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
306
- )
307
- posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
308
- posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
309
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
310
-
311
- def p_mean_variance(self, x, t, clip_denoised: bool):
312
- model_out = self.model(x, t)
313
- if self.parameterization == "eps":
314
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
315
- elif self.parameterization == "x0":
316
- x_recon = model_out
317
- if clip_denoised:
318
- x_recon.clamp_(-1., 1.)
319
-
320
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
321
- return model_mean, posterior_variance, posterior_log_variance
322
-
323
- @torch.no_grad()
324
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
325
- b, *_, device = *x.shape, x.device
326
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
327
- noise = noise_like(x.shape, device, repeat_noise)
328
- # no noise when t == 0
329
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
330
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
331
-
332
- @torch.no_grad()
333
- def p_sample_loop(self, shape, return_intermediates=False):
334
- device = self.betas.device
335
- b = shape[0]
336
- img = torch.randn(shape, device=device)
337
- intermediates = [img]
338
- for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
339
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
340
- clip_denoised=self.clip_denoised)
341
- if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
342
- intermediates.append(img)
343
- if return_intermediates:
344
- return img, intermediates
345
- return img
346
-
347
- @torch.no_grad()
348
- def sample(self, batch_size=16, return_intermediates=False):
349
- image_size = self.image_size
350
- channels = self.channels
351
- return self.p_sample_loop((batch_size, channels, image_size, image_size),
352
- return_intermediates=return_intermediates)
353
-
354
- def q_sample(self, x_start, t, noise=None):
355
- noise = default(noise, lambda: torch.randn_like(x_start))
356
- return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
357
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
358
-
359
- def get_v(self, x, noise, t):
360
- return (
361
- extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
362
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
363
- )
364
-
365
- def get_loss(self, pred, target, mean=True):
366
- if self.loss_type == 'l1':
367
- loss = (target - pred).abs()
368
- if mean:
369
- loss = loss.mean()
370
- elif self.loss_type == 'l2':
371
- if mean:
372
- loss = torch.nn.functional.mse_loss(target, pred)
373
- else:
374
- loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
375
- else:
376
- raise NotImplementedError("unknown loss type '{loss_type}'")
377
-
378
- return loss
379
-
380
- def p_losses(self, x_start, t, noise=None):
381
- noise = default(noise, lambda: torch.randn_like(x_start))
382
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
383
- model_out = self.model(x_noisy, t)
384
-
385
- loss_dict = {}
386
- if self.parameterization == "eps":
387
- target = noise
388
- elif self.parameterization == "x0":
389
- target = x_start
390
- elif self.parameterization == "v":
391
- target = self.get_v(x_start, noise, t)
392
- else:
393
- raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
394
-
395
- loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
396
-
397
- log_prefix = 'train' if self.training else 'val'
398
-
399
- loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
400
- loss_simple = loss.mean() * self.l_simple_weight
401
-
402
- loss_vlb = (self.lvlb_weights[t] * loss).mean()
403
- loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
404
-
405
- loss = loss_simple + self.original_elbo_weight * loss_vlb
406
-
407
- loss_dict.update({f'{log_prefix}/loss': loss})
408
-
409
- return loss, loss_dict
410
-
411
- def forward(self, x, *args, **kwargs):
412
- # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
413
- # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
414
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
415
- return self.p_losses(x, t, *args, **kwargs)
416
-
417
- def get_input(self, batch, k):
418
- x = batch[k]
419
- if len(x.shape) == 3:
420
- x = x[..., None]
421
- x = rearrange(x, 'b h w c -> b c h w')
422
- x = x.to(memory_format=torch.contiguous_format).float()
423
- return x
424
-
425
- def shared_step(self, batch):
426
- x = self.get_input(batch, self.first_stage_key)
427
- loss, loss_dict = self(x)
428
- return loss, loss_dict
429
-
430
- def training_step(self, batch, batch_idx):
431
- for k in self.ucg_training:
432
- p = self.ucg_training[k]["p"]
433
- val = self.ucg_training[k]["val"]
434
- if val is None:
435
- val = ""
436
- for i in range(len(batch[k])):
437
- if self.ucg_prng.choice(2, p=[1 - p, p]):
438
- batch[k][i] = val
439
-
440
- loss, loss_dict = self.shared_step(batch)
441
-
442
- self.log_dict(loss_dict, prog_bar=True,
443
- logger=True, on_step=True, on_epoch=True)
444
-
445
- self.log("global_step", self.global_step,
446
- prog_bar=True, logger=True, on_step=True, on_epoch=False)
447
-
448
- if self.use_scheduler:
449
- lr = self.optimizers().param_groups[0]['lr']
450
- self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
451
-
452
- return loss
453
-
454
- @torch.no_grad()
455
- def validation_step(self, batch, batch_idx):
456
- _, loss_dict_no_ema = self.shared_step(batch)
457
- with self.ema_scope():
458
- _, loss_dict_ema = self.shared_step(batch)
459
- loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
460
- self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
461
- self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
462
-
463
- def on_train_batch_end(self, *args, **kwargs):
464
- if self.use_ema:
465
- self.model_ema(self.model)
466
-
467
- def _get_rows_from_list(self, samples):
468
- n_imgs_per_row = len(samples)
469
- denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
470
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
471
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
472
- return denoise_grid
473
-
474
- @torch.no_grad()
475
- def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
476
- log = dict()
477
- x = self.get_input(batch, self.first_stage_key)
478
- N = min(x.shape[0], N)
479
- n_row = min(x.shape[0], n_row)
480
- x = x.to(self.device)[:N]
481
- log["inputs"] = x
482
-
483
- # get diffusion row
484
- diffusion_row = list()
485
- x_start = x[:n_row]
486
-
487
- for t in range(self.num_timesteps):
488
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
489
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
490
- t = t.to(self.device).long()
491
- noise = torch.randn_like(x_start)
492
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
493
- diffusion_row.append(x_noisy)
494
-
495
- log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
496
-
497
- if sample:
498
- # get denoise row
499
- with self.ema_scope("Plotting"):
500
- samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
501
-
502
- log["samples"] = samples
503
- log["denoise_row"] = self._get_rows_from_list(denoise_row)
504
-
505
- if return_keys:
506
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
507
- return log
508
- else:
509
- return {key: log[key] for key in return_keys}
510
- return log
511
-
512
- def configure_optimizers(self):
513
- lr = self.learning_rate
514
- params = list(self.model.parameters())
515
- if self.learn_logvar:
516
- params = params + [self.logvar]
517
- opt = torch.optim.AdamW(params, lr=lr)
518
- return opt
519
-
520
-
521
- class LatentDiffusion(DDPM):
522
- """main class"""
523
-
524
- def __init__(self,
525
- first_stage_config,
526
- cond_stage_config,
527
- num_timesteps_cond=None,
528
- cond_stage_key="image",
529
- cond_stage_trainable=False,
530
- concat_mode=True,
531
- cond_stage_forward=None,
532
- conditioning_key=None,
533
- scale_factor=1.0,
534
- scale_by_std=False,
535
- force_null_conditioning=False,
536
- *args, **kwargs):
537
- self.force_null_conditioning = force_null_conditioning
538
- self.num_timesteps_cond = default(num_timesteps_cond, 1)
539
- self.scale_by_std = scale_by_std
540
- assert self.num_timesteps_cond <= kwargs['timesteps']
541
- # for backwards compatibility after implementation of DiffusionWrapper
542
- if conditioning_key is None:
543
- conditioning_key = 'concat' if concat_mode else 'crossattn'
544
- if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
545
- conditioning_key = None
546
- ckpt_path = kwargs.pop("ckpt_path", None)
547
- reset_ema = kwargs.pop("reset_ema", False)
548
- reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
549
- ignore_keys = kwargs.pop("ignore_keys", [])
550
- super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
551
- self.concat_mode = concat_mode
552
- self.cond_stage_trainable = cond_stage_trainable
553
- self.cond_stage_key = cond_stage_key
554
- try:
555
- self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
556
- except:
557
- self.num_downs = 0
558
- if not scale_by_std:
559
- self.scale_factor = scale_factor
560
- else:
561
- self.register_buffer('scale_factor', torch.tensor(scale_factor))
562
- self.instantiate_first_stage(first_stage_config)
563
- self.instantiate_cond_stage(cond_stage_config)
564
- self.cond_stage_forward = cond_stage_forward
565
- self.clip_denoised = False
566
- self.bbox_tokenizer = None
567
-
568
- self.restarted_from_ckpt = False
569
- if ckpt_path is not None:
570
- self.init_from_ckpt(ckpt_path, ignore_keys)
571
- self.restarted_from_ckpt = True
572
- if reset_ema:
573
- assert self.use_ema
574
- print(
575
- f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
576
- self.model_ema = LitEma(self.model)
577
- if reset_num_ema_updates:
578
- print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
579
- assert self.use_ema
580
- self.model_ema.reset_num_updates()
581
-
582
- def make_cond_schedule(self, ):
583
- self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
584
- ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
585
- self.cond_ids[:self.num_timesteps_cond] = ids
586
-
587
- @rank_zero_only
588
- @torch.no_grad()
589
- def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
590
- # only for very first batch
591
- if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
592
- assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
593
- # set rescale weight to 1./std of encodings
594
- print("### USING STD-RESCALING ###")
595
- x = super().get_input(batch, self.first_stage_key)
596
- x = x.to(self.device)
597
- encoder_posterior = self.encode_first_stage(x)
598
- z = self.get_first_stage_encoding(encoder_posterior).detach()
599
- del self.scale_factor
600
- self.register_buffer('scale_factor', 1. / z.flatten().std())
601
- print(f"setting self.scale_factor to {self.scale_factor}")
602
- print("### USING STD-RESCALING ###")
603
-
604
- def register_schedule(self,
605
- given_betas=None, beta_schedule="linear", timesteps=1000,
606
- linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
607
- super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
608
-
609
- self.shorten_cond_schedule = self.num_timesteps_cond > 1
610
- if self.shorten_cond_schedule:
611
- self.make_cond_schedule()
612
-
613
- def instantiate_first_stage(self, config):
614
- model = instantiate_from_config(config)
615
- self.first_stage_model = model.eval()
616
- self.first_stage_model.train = disabled_train
617
- for param in self.first_stage_model.parameters():
618
- param.requires_grad = False
619
-
620
- def instantiate_cond_stage(self, config):
621
- if not self.cond_stage_trainable:
622
- if config == "__is_first_stage__":
623
- print("Using first stage also as cond stage.")
624
- self.cond_stage_model = self.first_stage_model
625
- elif config == "__is_unconditional__":
626
- print(f"Training {self.__class__.__name__} as an unconditional model.")
627
- self.cond_stage_model = None
628
- # self.be_unconditional = True
629
- else:
630
- model = instantiate_from_config(config)
631
- self.cond_stage_model = model.eval()
632
- self.cond_stage_model.train = disabled_train
633
- for param in self.cond_stage_model.parameters():
634
- param.requires_grad = False
635
- else:
636
- assert config != '__is_first_stage__'
637
- assert config != '__is_unconditional__'
638
- model = instantiate_from_config(config)
639
- self.cond_stage_model = model
640
-
641
- def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
642
- denoise_row = []
643
- for zd in tqdm(samples, desc=desc):
644
- denoise_row.append(self.decode_first_stage(zd.to(self.device),
645
- force_not_quantize=force_no_decoder_quantization))
646
- n_imgs_per_row = len(denoise_row)
647
- denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
648
- denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
649
- denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
650
- denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
651
- return denoise_grid
652
-
653
- def get_first_stage_encoding(self, encoder_posterior):
654
- if isinstance(encoder_posterior, DiagonalGaussianDistribution):
655
- z = encoder_posterior.sample()
656
- elif isinstance(encoder_posterior, torch.Tensor):
657
- z = encoder_posterior
658
- else:
659
- raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
660
- return self.scale_factor * z
661
-
662
- def get_learned_conditioning(self, c):
663
- if self.cond_stage_forward is None:
664
- if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
665
- c = self.cond_stage_model.encode(c)
666
- if isinstance(c, DiagonalGaussianDistribution):
667
- c = c.mode()
668
- else:
669
- c = self.cond_stage_model(c)
670
- else:
671
- assert hasattr(self.cond_stage_model, self.cond_stage_forward)
672
- c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
673
- return c
674
-
675
- def meshgrid(self, h, w):
676
- y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
677
- x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
678
-
679
- arr = torch.cat([y, x], dim=-1)
680
- return arr
681
-
682
- def delta_border(self, h, w):
683
- """
684
- :param h: height
685
- :param w: width
686
- :return: normalized distance to image border,
687
- wtith min distance = 0 at border and max dist = 0.5 at image center
688
- """
689
- lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
690
- arr = self.meshgrid(h, w) / lower_right_corner
691
- dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
692
- dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
693
- edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
694
- return edge_dist
695
-
696
- def get_weighting(self, h, w, Ly, Lx, device):
697
- weighting = self.delta_border(h, w)
698
- weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
699
- self.split_input_params["clip_max_weight"], )
700
- weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
701
-
702
- if self.split_input_params["tie_braker"]:
703
- L_weighting = self.delta_border(Ly, Lx)
704
- L_weighting = torch.clip(L_weighting,
705
- self.split_input_params["clip_min_tie_weight"],
706
- self.split_input_params["clip_max_tie_weight"])
707
-
708
- L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
709
- weighting = weighting * L_weighting
710
- return weighting
711
-
712
- def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
713
- """
714
- :param x: img of size (bs, c, h, w)
715
- :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
716
- """
717
- bs, nc, h, w = x.shape
718
-
719
- # number of crops in image
720
- Ly = (h - kernel_size[0]) // stride[0] + 1
721
- Lx = (w - kernel_size[1]) // stride[1] + 1
722
-
723
- if uf == 1 and df == 1:
724
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
725
- unfold = torch.nn.Unfold(**fold_params)
726
-
727
- fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
728
-
729
- weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
730
- normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
731
- weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
732
-
733
- elif uf > 1 and df == 1:
734
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
735
- unfold = torch.nn.Unfold(**fold_params)
736
-
737
- fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
738
- dilation=1, padding=0,
739
- stride=(stride[0] * uf, stride[1] * uf))
740
- fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
741
-
742
- weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
743
- normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
744
- weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
745
-
746
- elif df > 1 and uf == 1:
747
- fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
748
- unfold = torch.nn.Unfold(**fold_params)
749
-
750
- fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
751
- dilation=1, padding=0,
752
- stride=(stride[0] // df, stride[1] // df))
753
- fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
754
-
755
- weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
756
- normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
757
- weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
758
-
759
- else:
760
- raise NotImplementedError
761
-
762
- return fold, unfold, normalization, weighting
763
-
764
- @torch.no_grad()
765
- def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
766
- cond_key=None, return_original_cond=False, bs=None, return_x=False):
767
- x = super().get_input(batch, k)
768
- if bs is not None:
769
- x = x[:bs]
770
- x = x.to(self.device)
771
- encoder_posterior = self.encode_first_stage(x)
772
- z = self.get_first_stage_encoding(encoder_posterior).detach()
773
-
774
- if self.model.conditioning_key is not None and not self.force_null_conditioning:
775
- if cond_key is None:
776
- cond_key = self.cond_stage_key
777
- if cond_key != self.first_stage_key:
778
- if cond_key in ['caption', 'coordinates_bbox', "txt"]:
779
- xc = batch[cond_key]
780
- elif cond_key in ['class_label', 'cls']:
781
- xc = batch
782
- else:
783
- xc = super().get_input(batch, cond_key).to(self.device)
784
- else:
785
- xc = x
786
- if not self.cond_stage_trainable or force_c_encode:
787
- if isinstance(xc, dict) or isinstance(xc, list):
788
- c = self.get_learned_conditioning(xc)
789
- else:
790
- c = self.get_learned_conditioning(xc.to(self.device))
791
- else:
792
- c = xc
793
- if bs is not None:
794
- c = c[:bs]
795
-
796
- if self.use_positional_encodings:
797
- pos_x, pos_y = self.compute_latent_shifts(batch)
798
- ckey = __conditioning_keys__[self.model.conditioning_key]
799
- c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
800
-
801
- else:
802
- c = None
803
- xc = None
804
- if self.use_positional_encodings:
805
- pos_x, pos_y = self.compute_latent_shifts(batch)
806
- c = {'pos_x': pos_x, 'pos_y': pos_y}
807
- out = [z, c]
808
- if return_first_stage_outputs:
809
- xrec = self.decode_first_stage(z)
810
- out.extend([x, xrec])
811
- if return_x:
812
- out.extend([x])
813
- if return_original_cond:
814
- out.append(xc)
815
- return out
816
-
817
- @torch.no_grad()
818
- def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
819
- if predict_cids:
820
- if z.dim() == 4:
821
- z = torch.argmax(z.exp(), dim=1).long()
822
- z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
823
- z = rearrange(z, 'b h w c -> b c h w').contiguous()
824
-
825
- z = 1. / self.scale_factor * z
826
- return self.first_stage_model.decode(z)
827
-
828
- @torch.no_grad()
829
- def encode_first_stage(self, x):
830
- return self.first_stage_model.encode(x)
831
-
832
- def shared_step(self, batch, **kwargs):
833
- x, c = self.get_input(batch, self.first_stage_key)
834
- loss = self(x, c)
835
- return loss
836
-
837
- def forward(self, x, c, *args, **kwargs):
838
- t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
839
- if self.model.conditioning_key is not None:
840
- assert c is not None
841
- if self.cond_stage_trainable:
842
- c = self.get_learned_conditioning(c)
843
- if self.shorten_cond_schedule: # TODO: drop this option
844
- tc = self.cond_ids[t].to(self.device)
845
- c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
846
- return self.p_losses(x, c, t, *args, **kwargs)
847
-
848
- def apply_model(self, x_noisy, t, cond, return_ids=False):
849
- if isinstance(cond, dict):
850
- # hybrid case, cond is expected to be a dict
851
- pass
852
- else:
853
- if not isinstance(cond, list):
854
- cond = [cond]
855
- key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
856
- cond = {key: cond}
857
-
858
- x_recon = self.model(x_noisy, t, **cond)
859
-
860
- if isinstance(x_recon, tuple) and not return_ids:
861
- return x_recon[0]
862
- else:
863
- return x_recon
864
-
865
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
866
- return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
867
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
868
-
869
- def _prior_bpd(self, x_start):
870
- """
871
- Get the prior KL term for the variational lower-bound, measured in
872
- bits-per-dim.
873
- This term can't be optimized, as it only depends on the encoder.
874
- :param x_start: the [N x C x ...] tensor of inputs.
875
- :return: a batch of [N] KL values (in bits), one per batch element.
876
- """
877
- batch_size = x_start.shape[0]
878
- t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
879
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
880
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
881
- return mean_flat(kl_prior) / np.log(2.0)
882
-
883
- def p_losses(self, x_start, cond, t, noise=None):
884
- noise = default(noise, lambda: torch.randn_like(x_start))
885
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
886
- model_output = self.apply_model(x_noisy, t, cond)
887
-
888
- loss_dict = {}
889
- prefix = 'train' if self.training else 'val'
890
-
891
- if self.parameterization == "x0":
892
- target = x_start
893
- elif self.parameterization == "eps":
894
- target = noise
895
- elif self.parameterization == "v":
896
- target = self.get_v(x_start, noise, t)
897
- else:
898
- raise NotImplementedError()
899
-
900
- loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
901
- loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
902
-
903
- logvar_t = self.logvar[t].to(self.device)
904
- loss = loss_simple / torch.exp(logvar_t) + logvar_t
905
- # loss = loss_simple / torch.exp(self.logvar) + self.logvar
906
- if self.learn_logvar:
907
- loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
908
- loss_dict.update({'logvar': self.logvar.data.mean()})
909
-
910
- loss = self.l_simple_weight * loss.mean()
911
-
912
- loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
913
- loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
914
- loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
915
- loss += (self.original_elbo_weight * loss_vlb)
916
- loss_dict.update({f'{prefix}/loss': loss})
917
-
918
- return loss, loss_dict
919
-
920
- def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
921
- return_x0=False, score_corrector=None, corrector_kwargs=None):
922
- t_in = t
923
- model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
924
-
925
- if score_corrector is not None:
926
- assert self.parameterization == "eps"
927
- model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
928
-
929
- if return_codebook_ids:
930
- model_out, logits = model_out
931
-
932
- if self.parameterization == "eps":
933
- x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
934
- elif self.parameterization == "x0":
935
- x_recon = model_out
936
- else:
937
- raise NotImplementedError()
938
-
939
- if clip_denoised:
940
- x_recon.clamp_(-1., 1.)
941
- if quantize_denoised:
942
- x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
943
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
944
- if return_codebook_ids:
945
- return model_mean, posterior_variance, posterior_log_variance, logits
946
- elif return_x0:
947
- return model_mean, posterior_variance, posterior_log_variance, x_recon
948
- else:
949
- return model_mean, posterior_variance, posterior_log_variance
950
-
951
- @torch.no_grad()
952
- def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
953
- return_codebook_ids=False, quantize_denoised=False, return_x0=False,
954
- temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
955
- b, *_, device = *x.shape, x.device
956
- outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
957
- return_codebook_ids=return_codebook_ids,
958
- quantize_denoised=quantize_denoised,
959
- return_x0=return_x0,
960
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
961
- if return_codebook_ids:
962
- raise DeprecationWarning("Support dropped.")
963
- model_mean, _, model_log_variance, logits = outputs
964
- elif return_x0:
965
- model_mean, _, model_log_variance, x0 = outputs
966
- else:
967
- model_mean, _, model_log_variance = outputs
968
-
969
- noise = noise_like(x.shape, device, repeat_noise) * temperature
970
- if noise_dropout > 0.:
971
- noise = torch.nn.functional.dropout(noise, p=noise_dropout)
972
- # no noise when t == 0
973
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
974
-
975
- if return_codebook_ids:
976
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
977
- if return_x0:
978
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
979
- else:
980
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
981
-
982
- @torch.no_grad()
983
- def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
984
- img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
985
- score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
986
- log_every_t=None):
987
- if not log_every_t:
988
- log_every_t = self.log_every_t
989
- timesteps = self.num_timesteps
990
- if batch_size is not None:
991
- b = batch_size if batch_size is not None else shape[0]
992
- shape = [batch_size] + list(shape)
993
- else:
994
- b = batch_size = shape[0]
995
- if x_T is None:
996
- img = torch.randn(shape, device=self.device)
997
- else:
998
- img = x_T
999
- intermediates = []
1000
- if cond is not None:
1001
- if isinstance(cond, dict):
1002
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1003
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1004
- else:
1005
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1006
-
1007
- if start_T is not None:
1008
- timesteps = min(timesteps, start_T)
1009
- iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
1010
- total=timesteps) if verbose else reversed(
1011
- range(0, timesteps))
1012
- if type(temperature) == float:
1013
- temperature = [temperature] * timesteps
1014
-
1015
- for i in iterator:
1016
- ts = torch.full((b,), i, device=self.device, dtype=torch.long)
1017
- if self.shorten_cond_schedule:
1018
- assert self.model.conditioning_key != 'hybrid'
1019
- tc = self.cond_ids[ts].to(cond.device)
1020
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1021
-
1022
- img, x0_partial = self.p_sample(img, cond, ts,
1023
- clip_denoised=self.clip_denoised,
1024
- quantize_denoised=quantize_denoised, return_x0=True,
1025
- temperature=temperature[i], noise_dropout=noise_dropout,
1026
- score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
1027
- if mask is not None:
1028
- assert x0 is not None
1029
- img_orig = self.q_sample(x0, ts)
1030
- img = img_orig * mask + (1. - mask) * img
1031
-
1032
- if i % log_every_t == 0 or i == timesteps - 1:
1033
- intermediates.append(x0_partial)
1034
- if callback: callback(i)
1035
- if img_callback: img_callback(img, i)
1036
- return img, intermediates
1037
-
1038
- @torch.no_grad()
1039
- def p_sample_loop(self, cond, shape, return_intermediates=False,
1040
- x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
1041
- mask=None, x0=None, img_callback=None, start_T=None,
1042
- log_every_t=None):
1043
-
1044
- if not log_every_t:
1045
- log_every_t = self.log_every_t
1046
- device = self.betas.device
1047
- b = shape[0]
1048
- if x_T is None:
1049
- img = torch.randn(shape, device=device)
1050
- else:
1051
- img = x_T
1052
-
1053
- intermediates = [img]
1054
- if timesteps is None:
1055
- timesteps = self.num_timesteps
1056
-
1057
- if start_T is not None:
1058
- timesteps = min(timesteps, start_T)
1059
- iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
1060
- range(0, timesteps))
1061
-
1062
- if mask is not None:
1063
- assert x0 is not None
1064
- assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
1065
-
1066
- for i in iterator:
1067
- ts = torch.full((b,), i, device=device, dtype=torch.long)
1068
- if self.shorten_cond_schedule:
1069
- assert self.model.conditioning_key != 'hybrid'
1070
- tc = self.cond_ids[ts].to(cond.device)
1071
- cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
1072
-
1073
- img = self.p_sample(img, cond, ts,
1074
- clip_denoised=self.clip_denoised,
1075
- quantize_denoised=quantize_denoised)
1076
- if mask is not None:
1077
- img_orig = self.q_sample(x0, ts)
1078
- img = img_orig * mask + (1. - mask) * img
1079
-
1080
- if i % log_every_t == 0 or i == timesteps - 1:
1081
- intermediates.append(img)
1082
- if callback: callback(i)
1083
- if img_callback: img_callback(img, i)
1084
-
1085
- if return_intermediates:
1086
- return img, intermediates
1087
- return img
1088
-
1089
- @torch.no_grad()
1090
- def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
1091
- verbose=True, timesteps=None, quantize_denoised=False,
1092
- mask=None, x0=None, shape=None, **kwargs):
1093
- if shape is None:
1094
- shape = (batch_size, self.channels, self.image_size, self.image_size)
1095
- if cond is not None:
1096
- if isinstance(cond, dict):
1097
- cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
1098
- list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
1099
- else:
1100
- cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
1101
- return self.p_sample_loop(cond,
1102
- shape,
1103
- return_intermediates=return_intermediates, x_T=x_T,
1104
- verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
1105
- mask=mask, x0=x0)
1106
-
1107
- @torch.no_grad()
1108
- def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
1109
- if ddim:
1110
- ddim_sampler = DDIMSampler(self)
1111
- shape = (self.channels, self.image_size, self.image_size)
1112
- samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
1113
- shape, cond, verbose=False, **kwargs)
1114
-
1115
- else:
1116
- samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
1117
- return_intermediates=True, **kwargs)
1118
-
1119
- return samples, intermediates
1120
-
1121
- @torch.no_grad()
1122
- def get_unconditional_conditioning(self, batch_size, null_label=None):
1123
- if null_label is not None:
1124
- xc = null_label
1125
- if isinstance(xc, ListConfig):
1126
- xc = list(xc)
1127
- if isinstance(xc, dict) or isinstance(xc, list):
1128
- c = self.get_learned_conditioning(xc)
1129
- else:
1130
- if hasattr(xc, "to"):
1131
- xc = xc.to(self.device)
1132
- c = self.get_learned_conditioning(xc)
1133
- else:
1134
- if self.cond_stage_key in ["class_label", "cls"]:
1135
- xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
1136
- return self.get_learned_conditioning(xc)
1137
- else:
1138
- raise NotImplementedError("todo")
1139
- if isinstance(c, list): # in case the encoder gives us a list
1140
- for i in range(len(c)):
1141
- c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
1142
- else:
1143
- c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
1144
- return c
1145
-
1146
- @torch.no_grad()
1147
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
1148
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1149
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1150
- use_ema_scope=True,
1151
- **kwargs):
1152
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1153
- use_ddim = ddim_steps is not None
1154
-
1155
- log = dict()
1156
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
1157
- return_first_stage_outputs=True,
1158
- force_c_encode=True,
1159
- return_original_cond=True,
1160
- bs=N)
1161
- N = min(x.shape[0], N)
1162
- n_row = min(x.shape[0], n_row)
1163
- log["inputs"] = x
1164
- log["reconstruction"] = xrec
1165
- if self.model.conditioning_key is not None:
1166
- if hasattr(self.cond_stage_model, "decode"):
1167
- xc = self.cond_stage_model.decode(c)
1168
- log["conditioning"] = xc
1169
- elif self.cond_stage_key in ["caption", "txt"]:
1170
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1171
- log["conditioning"] = xc
1172
- elif self.cond_stage_key in ['class_label', "cls"]:
1173
- try:
1174
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1175
- log['conditioning'] = xc
1176
- except KeyError:
1177
- # probably no "human_label" in batch
1178
- pass
1179
- elif isimage(xc):
1180
- log["conditioning"] = xc
1181
- if ismap(xc):
1182
- log["original_conditioning"] = self.to_rgb(xc)
1183
-
1184
- if plot_diffusion_rows:
1185
- # get diffusion row
1186
- diffusion_row = list()
1187
- z_start = z[:n_row]
1188
- for t in range(self.num_timesteps):
1189
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1190
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1191
- t = t.to(self.device).long()
1192
- noise = torch.randn_like(z_start)
1193
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1194
- diffusion_row.append(self.decode_first_stage(z_noisy))
1195
-
1196
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1197
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1198
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1199
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1200
- log["diffusion_row"] = diffusion_grid
1201
-
1202
- if sample:
1203
- # get denoise row
1204
- with ema_scope("Sampling"):
1205
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1206
- ddim_steps=ddim_steps, eta=ddim_eta)
1207
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1208
- x_samples = self.decode_first_stage(samples)
1209
- log["samples"] = x_samples
1210
- if plot_denoise_rows:
1211
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1212
- log["denoise_row"] = denoise_grid
1213
-
1214
- if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
1215
- self.first_stage_model, IdentityFirstStage):
1216
- # also display when quantizing x0 while sampling
1217
- with ema_scope("Plotting Quantized Denoised"):
1218
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1219
- ddim_steps=ddim_steps, eta=ddim_eta,
1220
- quantize_denoised=True)
1221
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
1222
- # quantize_denoised=True)
1223
- x_samples = self.decode_first_stage(samples.to(self.device))
1224
- log["samples_x0_quantized"] = x_samples
1225
-
1226
- if unconditional_guidance_scale > 1.0:
1227
- uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1228
- if self.model.conditioning_key == "crossattn-adm":
1229
- uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
1230
- with ema_scope("Sampling with classifier-free guidance"):
1231
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1232
- ddim_steps=ddim_steps, eta=ddim_eta,
1233
- unconditional_guidance_scale=unconditional_guidance_scale,
1234
- unconditional_conditioning=uc,
1235
- )
1236
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1237
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1238
-
1239
- if inpaint:
1240
- # make a simple center square
1241
- b, h, w = z.shape[0], z.shape[2], z.shape[3]
1242
- mask = torch.ones(N, h, w).to(self.device)
1243
- # zeros will be filled in
1244
- mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
1245
- mask = mask[:, None, ...]
1246
- with ema_scope("Plotting Inpaint"):
1247
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1248
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1249
- x_samples = self.decode_first_stage(samples.to(self.device))
1250
- log["samples_inpainting"] = x_samples
1251
- log["mask"] = mask
1252
-
1253
- # outpaint
1254
- mask = 1. - mask
1255
- with ema_scope("Plotting Outpaint"):
1256
- samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
1257
- ddim_steps=ddim_steps, x0=z[:N], mask=mask)
1258
- x_samples = self.decode_first_stage(samples.to(self.device))
1259
- log["samples_outpainting"] = x_samples
1260
-
1261
- if plot_progressive_rows:
1262
- with ema_scope("Plotting Progressives"):
1263
- img, progressives = self.progressive_denoising(c,
1264
- shape=(self.channels, self.image_size, self.image_size),
1265
- batch_size=N)
1266
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1267
- log["progressive_row"] = prog_row
1268
-
1269
- if return_keys:
1270
- if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
1271
- return log
1272
- else:
1273
- return {key: log[key] for key in return_keys}
1274
- return log
1275
-
1276
- def configure_optimizers(self):
1277
- lr = self.learning_rate
1278
- params = list(self.model.parameters())
1279
- if self.cond_stage_trainable:
1280
- print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
1281
- params = params + list(self.cond_stage_model.parameters())
1282
- if self.learn_logvar:
1283
- print('Diffusion model optimizing logvar')
1284
- params.append(self.logvar)
1285
- opt = torch.optim.AdamW(params, lr=lr)
1286
- if self.use_scheduler:
1287
- assert 'target' in self.scheduler_config
1288
- scheduler = instantiate_from_config(self.scheduler_config)
1289
-
1290
- print("Setting up LambdaLR scheduler...")
1291
- scheduler = [
1292
- {
1293
- 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
1294
- 'interval': 'step',
1295
- 'frequency': 1
1296
- }]
1297
- return [opt], scheduler
1298
- return opt
1299
-
1300
- @torch.no_grad()
1301
- def to_rgb(self, x):
1302
- x = x.float()
1303
- if not hasattr(self, "colorize"):
1304
- self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
1305
- x = nn.functional.conv2d(x, weight=self.colorize)
1306
- x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
1307
- return x
1308
-
1309
-
1310
- class DiffusionWrapper(pl.LightningModule):
1311
- def __init__(self, diff_model_config, conditioning_key):
1312
- super().__init__()
1313
- self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
1314
- self.diffusion_model = instantiate_from_config(diff_model_config)
1315
- self.conditioning_key = conditioning_key
1316
- assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
1317
-
1318
- def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
1319
- if self.conditioning_key is None:
1320
- out = self.diffusion_model(x, t)
1321
- elif self.conditioning_key == 'concat':
1322
- xc = torch.cat([x] + c_concat, dim=1)
1323
- out = self.diffusion_model(xc, t)
1324
- elif self.conditioning_key == 'crossattn':
1325
- if not self.sequential_cross_attn:
1326
- cc = torch.cat(c_crossattn, 1)
1327
- else:
1328
- cc = c_crossattn
1329
- out = self.diffusion_model(x, t, context=cc)
1330
- elif self.conditioning_key == 'hybrid':
1331
- xc = torch.cat([x] + c_concat, dim=1)
1332
- cc = torch.cat(c_crossattn, 1)
1333
- out = self.diffusion_model(xc, t, context=cc)
1334
- elif self.conditioning_key == 'hybrid-adm':
1335
- assert c_adm is not None
1336
- xc = torch.cat([x] + c_concat, dim=1)
1337
- cc = torch.cat(c_crossattn, 1)
1338
- out = self.diffusion_model(xc, t, context=cc, y=c_adm)
1339
- elif self.conditioning_key == 'crossattn-adm':
1340
- assert c_adm is not None
1341
- cc = torch.cat(c_crossattn, 1)
1342
- out = self.diffusion_model(x, t, context=cc, y=c_adm)
1343
- elif self.conditioning_key == 'adm':
1344
- cc = c_crossattn[0]
1345
- out = self.diffusion_model(x, t, y=cc)
1346
- else:
1347
- raise NotImplementedError()
1348
-
1349
- return out
1350
-
1351
-
1352
- class LatentUpscaleDiffusion(LatentDiffusion):
1353
- def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
1354
- super().__init__(*args, **kwargs)
1355
- # assumes that neither the cond_stage nor the low_scale_model contain trainable params
1356
- assert not self.cond_stage_trainable
1357
- self.instantiate_low_stage(low_scale_config)
1358
- self.low_scale_key = low_scale_key
1359
- self.noise_level_key = noise_level_key
1360
-
1361
- def instantiate_low_stage(self, config):
1362
- model = instantiate_from_config(config)
1363
- self.low_scale_model = model.eval()
1364
- self.low_scale_model.train = disabled_train
1365
- for param in self.low_scale_model.parameters():
1366
- param.requires_grad = False
1367
-
1368
- @torch.no_grad()
1369
- def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
1370
- if not log_mode:
1371
- z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
1372
- else:
1373
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1374
- force_c_encode=True, return_original_cond=True, bs=bs)
1375
- x_low = batch[self.low_scale_key][:bs]
1376
- x_low = rearrange(x_low, 'b h w c -> b c h w')
1377
- x_low = x_low.to(memory_format=torch.contiguous_format).float()
1378
- zx, noise_level = self.low_scale_model(x_low)
1379
- if self.noise_level_key is not None:
1380
- # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
1381
- raise NotImplementedError('TODO')
1382
-
1383
- all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
1384
- if log_mode:
1385
- # TODO: maybe disable if too expensive
1386
- x_low_rec = self.low_scale_model.decode(zx)
1387
- return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
1388
- return z, all_conds
1389
-
1390
- @torch.no_grad()
1391
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1392
- plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
1393
- unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
1394
- **kwargs):
1395
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1396
- use_ddim = ddim_steps is not None
1397
-
1398
- log = dict()
1399
- z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
1400
- log_mode=True)
1401
- N = min(x.shape[0], N)
1402
- n_row = min(x.shape[0], n_row)
1403
- log["inputs"] = x
1404
- log["reconstruction"] = xrec
1405
- log["x_lr"] = x_low
1406
- log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
1407
- if self.model.conditioning_key is not None:
1408
- if hasattr(self.cond_stage_model, "decode"):
1409
- xc = self.cond_stage_model.decode(c)
1410
- log["conditioning"] = xc
1411
- elif self.cond_stage_key in ["caption", "txt"]:
1412
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1413
- log["conditioning"] = xc
1414
- elif self.cond_stage_key in ['class_label', 'cls']:
1415
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1416
- log['conditioning'] = xc
1417
- elif isimage(xc):
1418
- log["conditioning"] = xc
1419
- if ismap(xc):
1420
- log["original_conditioning"] = self.to_rgb(xc)
1421
-
1422
- if plot_diffusion_rows:
1423
- # get diffusion row
1424
- diffusion_row = list()
1425
- z_start = z[:n_row]
1426
- for t in range(self.num_timesteps):
1427
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1428
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1429
- t = t.to(self.device).long()
1430
- noise = torch.randn_like(z_start)
1431
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1432
- diffusion_row.append(self.decode_first_stage(z_noisy))
1433
-
1434
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1435
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1436
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1437
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1438
- log["diffusion_row"] = diffusion_grid
1439
-
1440
- if sample:
1441
- # get denoise row
1442
- with ema_scope("Sampling"):
1443
- samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1444
- ddim_steps=ddim_steps, eta=ddim_eta)
1445
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1446
- x_samples = self.decode_first_stage(samples)
1447
- log["samples"] = x_samples
1448
- if plot_denoise_rows:
1449
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1450
- log["denoise_row"] = denoise_grid
1451
-
1452
- if unconditional_guidance_scale > 1.0:
1453
- uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1454
- # TODO explore better "unconditional" choices for the other keys
1455
- # maybe guide away from empty text label and highest noise level and maximally degraded zx?
1456
- uc = dict()
1457
- for k in c:
1458
- if k == "c_crossattn":
1459
- assert isinstance(c[k], list) and len(c[k]) == 1
1460
- uc[k] = [uc_tmp]
1461
- elif k == "c_adm": # todo: only run with text-based guidance?
1462
- assert isinstance(c[k], torch.Tensor)
1463
- #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
1464
- uc[k] = c[k]
1465
- elif isinstance(c[k], list):
1466
- uc[k] = [c[k][i] for i in range(len(c[k]))]
1467
- else:
1468
- uc[k] = c[k]
1469
-
1470
- with ema_scope("Sampling with classifier-free guidance"):
1471
- samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
1472
- ddim_steps=ddim_steps, eta=ddim_eta,
1473
- unconditional_guidance_scale=unconditional_guidance_scale,
1474
- unconditional_conditioning=uc,
1475
- )
1476
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1477
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1478
-
1479
- if plot_progressive_rows:
1480
- with ema_scope("Plotting Progressives"):
1481
- img, progressives = self.progressive_denoising(c,
1482
- shape=(self.channels, self.image_size, self.image_size),
1483
- batch_size=N)
1484
- prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
1485
- log["progressive_row"] = prog_row
1486
-
1487
- return log
1488
-
1489
-
1490
- class LatentFinetuneDiffusion(LatentDiffusion):
1491
- """
1492
- Basis for different finetunas, such as inpainting or depth2image
1493
- To disable finetuning mode, set finetune_keys to None
1494
- """
1495
-
1496
- def __init__(self,
1497
- concat_keys: tuple,
1498
- finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
1499
- "model_ema.diffusion_modelinput_blocks00weight"
1500
- ),
1501
- keep_finetune_dims=4,
1502
- # if model was trained without concat mode before and we would like to keep these channels
1503
- c_concat_log_start=None, # to log reconstruction of c_concat codes
1504
- c_concat_log_end=None,
1505
- *args, **kwargs
1506
- ):
1507
- ckpt_path = kwargs.pop("ckpt_path", None)
1508
- ignore_keys = kwargs.pop("ignore_keys", list())
1509
- super().__init__(*args, **kwargs)
1510
- self.finetune_keys = finetune_keys
1511
- self.concat_keys = concat_keys
1512
- self.keep_dims = keep_finetune_dims
1513
- self.c_concat_log_start = c_concat_log_start
1514
- self.c_concat_log_end = c_concat_log_end
1515
- if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
1516
- if exists(ckpt_path):
1517
- self.init_from_ckpt(ckpt_path, ignore_keys)
1518
-
1519
- def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
1520
- sd = torch.load(path, map_location="cpu")
1521
- if "state_dict" in list(sd.keys()):
1522
- sd = sd["state_dict"]
1523
- keys = list(sd.keys())
1524
- for k in keys:
1525
- for ik in ignore_keys:
1526
- if k.startswith(ik):
1527
- print("Deleting key {} from state_dict.".format(k))
1528
- del sd[k]
1529
-
1530
- # make it explicit, finetune by including extra input channels
1531
- if exists(self.finetune_keys) and k in self.finetune_keys:
1532
- new_entry = None
1533
- for name, param in self.named_parameters():
1534
- if name in self.finetune_keys:
1535
- print(
1536
- f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
1537
- new_entry = torch.zeros_like(param) # zero init
1538
- assert exists(new_entry), 'did not find matching parameter to modify'
1539
- new_entry[:, :self.keep_dims, ...] = sd[k]
1540
- sd[k] = new_entry
1541
-
1542
- missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
1543
- sd, strict=False)
1544
- print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
1545
- if len(missing) > 0:
1546
- print(f"Missing Keys: {missing}")
1547
- if len(unexpected) > 0:
1548
- print(f"Unexpected Keys: {unexpected}")
1549
-
1550
- @torch.no_grad()
1551
- def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
1552
- quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
1553
- plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
1554
- use_ema_scope=True,
1555
- **kwargs):
1556
- ema_scope = self.ema_scope if use_ema_scope else nullcontext
1557
- use_ddim = ddim_steps is not None
1558
-
1559
- log = dict()
1560
- z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
1561
- c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
1562
- N = min(x.shape[0], N)
1563
- n_row = min(x.shape[0], n_row)
1564
- log["inputs"] = x
1565
- log["reconstruction"] = xrec
1566
- if self.model.conditioning_key is not None:
1567
- if hasattr(self.cond_stage_model, "decode"):
1568
- xc = self.cond_stage_model.decode(c)
1569
- log["conditioning"] = xc
1570
- elif self.cond_stage_key in ["caption", "txt"]:
1571
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
1572
- log["conditioning"] = xc
1573
- elif self.cond_stage_key in ['class_label', 'cls']:
1574
- xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
1575
- log['conditioning'] = xc
1576
- elif isimage(xc):
1577
- log["conditioning"] = xc
1578
- if ismap(xc):
1579
- log["original_conditioning"] = self.to_rgb(xc)
1580
-
1581
- if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
1582
- log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
1583
-
1584
- if plot_diffusion_rows:
1585
- # get diffusion row
1586
- diffusion_row = list()
1587
- z_start = z[:n_row]
1588
- for t in range(self.num_timesteps):
1589
- if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
1590
- t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
1591
- t = t.to(self.device).long()
1592
- noise = torch.randn_like(z_start)
1593
- z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
1594
- diffusion_row.append(self.decode_first_stage(z_noisy))
1595
-
1596
- diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
1597
- diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
1598
- diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
1599
- diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
1600
- log["diffusion_row"] = diffusion_grid
1601
-
1602
- if sample:
1603
- # get denoise row
1604
- with ema_scope("Sampling"):
1605
- samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1606
- batch_size=N, ddim=use_ddim,
1607
- ddim_steps=ddim_steps, eta=ddim_eta)
1608
- # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
1609
- x_samples = self.decode_first_stage(samples)
1610
- log["samples"] = x_samples
1611
- if plot_denoise_rows:
1612
- denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
1613
- log["denoise_row"] = denoise_grid
1614
-
1615
- if unconditional_guidance_scale > 1.0:
1616
- uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
1617
- uc_cat = c_cat
1618
- uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
1619
- with ema_scope("Sampling with classifier-free guidance"):
1620
- samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
1621
- batch_size=N, ddim=use_ddim,
1622
- ddim_steps=ddim_steps, eta=ddim_eta,
1623
- unconditional_guidance_scale=unconditional_guidance_scale,
1624
- unconditional_conditioning=uc_full,
1625
- )
1626
- x_samples_cfg = self.decode_first_stage(samples_cfg)
1627
- log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
1628
-
1629
- return log
1630
-
1631
-
1632
- class LatentInpaintDiffusion(LatentFinetuneDiffusion):
1633
- """
1634
- can either run as pure inpainting model (only concat mode) or with mixed conditionings,
1635
- e.g. mask as concat and text via cross-attn.
1636
- To disable finetuning mode, set finetune_keys to None
1637
- """
1638
-
1639
- def __init__(self,
1640
- concat_keys=("mask", "masked_image"),
1641
- masked_image_key="masked_image",
1642
- *args, **kwargs
1643
- ):
1644
- super().__init__(concat_keys, *args, **kwargs)
1645
- self.masked_image_key = masked_image_key
1646
- assert self.masked_image_key in concat_keys
1647
-
1648
- @torch.no_grad()
1649
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1650
- # note: restricted to non-trainable encoders currently
1651
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
1652
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1653
- force_c_encode=True, return_original_cond=True, bs=bs)
1654
-
1655
- assert exists(self.concat_keys)
1656
- c_cat = list()
1657
- for ck in self.concat_keys:
1658
- cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1659
- if bs is not None:
1660
- cc = cc[:bs]
1661
- cc = cc.to(self.device)
1662
- bchw = z.shape
1663
- if ck != self.masked_image_key:
1664
- cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
1665
- else:
1666
- cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
1667
- c_cat.append(cc)
1668
- c_cat = torch.cat(c_cat, dim=1)
1669
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1670
- if return_first_stage_outputs:
1671
- return z, all_conds, x, xrec, xc
1672
- return z, all_conds
1673
-
1674
- @torch.no_grad()
1675
- def log_images(self, *args, **kwargs):
1676
- log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
1677
- log["masked_image"] = rearrange(args[0]["masked_image"],
1678
- 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
1679
- return log
1680
-
1681
-
1682
- class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
1683
- """
1684
- condition on monocular depth estimation
1685
- """
1686
-
1687
- def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
1688
- super().__init__(concat_keys=concat_keys, *args, **kwargs)
1689
- self.depth_model = instantiate_from_config(depth_stage_config)
1690
- self.depth_stage_key = concat_keys[0]
1691
-
1692
- @torch.no_grad()
1693
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1694
- # note: restricted to non-trainable encoders currently
1695
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
1696
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1697
- force_c_encode=True, return_original_cond=True, bs=bs)
1698
-
1699
- assert exists(self.concat_keys)
1700
- assert len(self.concat_keys) == 1
1701
- c_cat = list()
1702
- for ck in self.concat_keys:
1703
- cc = batch[ck]
1704
- if bs is not None:
1705
- cc = cc[:bs]
1706
- cc = cc.to(self.device)
1707
- cc = self.depth_model(cc)
1708
- cc = torch.nn.functional.interpolate(
1709
- cc,
1710
- size=z.shape[2:],
1711
- mode="bicubic",
1712
- align_corners=False,
1713
- )
1714
-
1715
- depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
1716
- keepdim=True)
1717
- cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
1718
- c_cat.append(cc)
1719
- c_cat = torch.cat(c_cat, dim=1)
1720
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1721
- if return_first_stage_outputs:
1722
- return z, all_conds, x, xrec, xc
1723
- return z, all_conds
1724
-
1725
- @torch.no_grad()
1726
- def log_images(self, *args, **kwargs):
1727
- log = super().log_images(*args, **kwargs)
1728
- depth = self.depth_model(args[0][self.depth_stage_key])
1729
- depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
1730
- torch.amax(depth, dim=[1, 2, 3], keepdim=True)
1731
- log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
1732
- return log
1733
-
1734
-
1735
- class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
1736
- """
1737
- condition on low-res image (and optionally on some spatial noise augmentation)
1738
- """
1739
- def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
1740
- low_scale_config=None, low_scale_key=None, *args, **kwargs):
1741
- super().__init__(concat_keys=concat_keys, *args, **kwargs)
1742
- self.reshuffle_patch_size = reshuffle_patch_size
1743
- self.low_scale_model = None
1744
- if low_scale_config is not None:
1745
- print("Initializing a low-scale model")
1746
- assert exists(low_scale_key)
1747
- self.instantiate_low_stage(low_scale_config)
1748
- self.low_scale_key = low_scale_key
1749
-
1750
- def instantiate_low_stage(self, config):
1751
- model = instantiate_from_config(config)
1752
- self.low_scale_model = model.eval()
1753
- self.low_scale_model.train = disabled_train
1754
- for param in self.low_scale_model.parameters():
1755
- param.requires_grad = False
1756
-
1757
- @torch.no_grad()
1758
- def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
1759
- # note: restricted to non-trainable encoders currently
1760
- assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
1761
- z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
1762
- force_c_encode=True, return_original_cond=True, bs=bs)
1763
-
1764
- assert exists(self.concat_keys)
1765
- assert len(self.concat_keys) == 1
1766
- # optionally make spatial noise_level here
1767
- c_cat = list()
1768
- noise_level = None
1769
- for ck in self.concat_keys:
1770
- cc = batch[ck]
1771
- cc = rearrange(cc, 'b h w c -> b c h w')
1772
- if exists(self.reshuffle_patch_size):
1773
- assert isinstance(self.reshuffle_patch_size, int)
1774
- cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
1775
- p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
1776
- if bs is not None:
1777
- cc = cc[:bs]
1778
- cc = cc.to(self.device)
1779
- if exists(self.low_scale_model) and ck == self.low_scale_key:
1780
- cc, noise_level = self.low_scale_model(cc)
1781
- c_cat.append(cc)
1782
- c_cat = torch.cat(c_cat, dim=1)
1783
- if exists(noise_level):
1784
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
1785
- else:
1786
- all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
1787
- if return_first_stage_outputs:
1788
- return z, all_conds, x, xrec, xc
1789
- return z, all_conds
1790
-
1791
- @torch.no_grad()
1792
- def log_images(self, *args, **kwargs):
1793
- log = super().log_images(*args, **kwargs)
1794
- log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
1795
- return log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ldm/models/diffusion/dpm_solver/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .sampler import DPMSolverSampler
 
ldm/models/diffusion/dpm_solver/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (212 Bytes)
ldm/models/diffusion/dpm_solver/__pycache__/dpm_solver.cpython-39.pyc DELETED
Binary file (51.6 kB)
ldm/models/diffusion/dpm_solver/__pycache__/sampler.cpython-39.pyc DELETED
Binary file (2.79 kB)
ldm/models/diffusion/dpm_solver/dpm_solver.py DELETED
@@ -1,1154 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import math
4
- from tqdm import tqdm
5
-
6
-
7
- class NoiseScheduleVP:
8
- def __init__(
9
- self,
10
- schedule='discrete',
11
- betas=None,
12
- alphas_cumprod=None,
13
- continuous_beta_0=0.1,
14
- continuous_beta_1=20.,
15
- ):
16
- """Create a wrapper class for the forward SDE (VP type).
17
- ***
18
- Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
19
- We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
20
- ***
21
- The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
22
- We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
23
- Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
24
- log_alpha_t = self.marginal_log_mean_coeff(t)
25
- sigma_t = self.marginal_std(t)
26
- lambda_t = self.marginal_lambda(t)
27
- Moreover, as lambda(t) is an invertible function, we also support its inverse function:
28
- t = self.inverse_lambda(lambda_t)
29
- ===============================================================
30
- We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
31
- 1. For discrete-time DPMs:
32
- For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
33
- t_i = (i + 1) / N
34
- e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
35
- We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
36
- Args:
37
- betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
38
- alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
39
- Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
40
- **Important**: Please pay special attention for the args for `alphas_cumprod`:
41
- The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
42
- q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
43
- Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
44
- alpha_{t_n} = \sqrt{\hat{alpha_n}},
45
- and
46
- log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
47
- 2. For continuous-time DPMs:
48
- We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
49
- schedule are the default settings in DDPM and improved-DDPM:
50
- Args:
51
- beta_min: A `float` number. The smallest beta for the linear schedule.
52
- beta_max: A `float` number. The largest beta for the linear schedule.
53
- cosine_s: A `float` number. The hyperparameter in the cosine schedule.
54
- cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
55
- T: A `float` number. The ending time of the forward process.
56
- ===============================================================
57
- Args:
58
- schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
59
- 'linear' or 'cosine' for continuous-time DPMs.
60
- Returns:
61
- A wrapper object of the forward SDE (VP type).
62
-
63
- ===============================================================
64
- Example:
65
- # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
66
- >>> ns = NoiseScheduleVP('discrete', betas=betas)
67
- # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
68
- >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
69
- # For continuous-time DPMs (VPSDE), linear schedule:
70
- >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
71
- """
72
-
73
- if schedule not in ['discrete', 'linear', 'cosine']:
74
- raise ValueError(
75
- "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
76
- schedule))
77
-
78
- self.schedule = schedule
79
- if schedule == 'discrete':
80
- if betas is not None:
81
- log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
82
- else:
83
- assert alphas_cumprod is not None
84
- log_alphas = 0.5 * torch.log(alphas_cumprod)
85
- self.total_N = len(log_alphas)
86
- self.T = 1.
87
- self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
88
- self.log_alpha_array = log_alphas.reshape((1, -1,))
89
- else:
90
- self.total_N = 1000
91
- self.beta_0 = continuous_beta_0
92
- self.beta_1 = continuous_beta_1
93
- self.cosine_s = 0.008
94
- self.cosine_beta_max = 999.
95
- self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
96
- 1. + self.cosine_s) / math.pi - self.cosine_s
97
- self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
98
- self.schedule = schedule
99
- if schedule == 'cosine':
100
- # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
101
- # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
102
- self.T = 0.9946
103
- else:
104
- self.T = 1.
105
-
106
- def marginal_log_mean_coeff(self, t):
107
- """
108
- Compute log(alpha_t) of a given continuous-time label t in [0, T].
109
- """
110
- if self.schedule == 'discrete':
111
- return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
112
- self.log_alpha_array.to(t.device)).reshape((-1))
113
- elif self.schedule == 'linear':
114
- return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
115
- elif self.schedule == 'cosine':
116
- log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
117
- log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
118
- return log_alpha_t
119
-
120
- def marginal_alpha(self, t):
121
- """
122
- Compute alpha_t of a given continuous-time label t in [0, T].
123
- """
124
- return torch.exp(self.marginal_log_mean_coeff(t))
125
-
126
- def marginal_std(self, t):
127
- """
128
- Compute sigma_t of a given continuous-time label t in [0, T].
129
- """
130
- return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
131
-
132
- def marginal_lambda(self, t):
133
- """
134
- Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
135
- """
136
- log_mean_coeff = self.marginal_log_mean_coeff(t)
137
- log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
138
- return log_mean_coeff - log_std
139
-
140
- def inverse_lambda(self, lamb):
141
- """
142
- Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
143
- """
144
- if self.schedule == 'linear':
145
- tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
146
- Delta = self.beta_0 ** 2 + tmp
147
- return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
148
- elif self.schedule == 'discrete':
149
- log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
150
- t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
151
- torch.flip(self.t_array.to(lamb.device), [1]))
152
- return t.reshape((-1,))
153
- else:
154
- log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
155
- t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
156
- 1. + self.cosine_s) / math.pi - self.cosine_s
157
- t = t_fn(log_alpha)
158
- return t
159
-
160
-
161
- def model_wrapper(
162
- model,
163
- noise_schedule,
164
- model_type="noise",
165
- model_kwargs={},
166
- guidance_type="uncond",
167
- condition=None,
168
- unconditional_condition=None,
169
- guidance_scale=1.,
170
- classifier_fn=None,
171
- classifier_kwargs={},
172
- ):
173
- """Create a wrapper function for the noise prediction model.
174
- DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
175
- firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
176
- We support four types of the diffusion model by setting `model_type`:
177
- 1. "noise": noise prediction model. (Trained by predicting noise).
178
- 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
179
- 3. "v": velocity prediction model. (Trained by predicting the velocity).
180
- The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
181
- [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
182
- arXiv preprint arXiv:2202.00512 (2022).
183
- [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
184
- arXiv preprint arXiv:2210.02303 (2022).
185
-
186
- 4. "score": marginal score function. (Trained by denoising score matching).
187
- Note that the score function and the noise prediction model follows a simple relationship:
188
- ```
189
- noise(x_t, t) = -sigma_t * score(x_t, t)
190
- ```
191
- We support three types of guided sampling by DPMs by setting `guidance_type`:
192
- 1. "uncond": unconditional sampling by DPMs.
193
- The input `model` has the following format:
194
- ``
195
- model(x, t_input, **model_kwargs) -> noise | x_start | v | score
196
- ``
197
- 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
198
- The input `model` has the following format:
199
- ``
200
- model(x, t_input, **model_kwargs) -> noise | x_start | v | score
201
- ``
202
- The input `classifier_fn` has the following format:
203
- ``
204
- classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
205
- ``
206
- [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
207
- in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
208
- 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
209
- The input `model` has the following format:
210
- ``
211
- model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
212
- ``
213
- And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
214
- [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
215
- arXiv preprint arXiv:2207.12598 (2022).
216
-
217
- The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
218
- or continuous-time labels (i.e. epsilon to T).
219
- We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
220
- ``
221
- def model_fn(x, t_continuous) -> noise:
222
- t_input = get_model_input_time(t_continuous)
223
- return noise_pred(model, x, t_input, **model_kwargs)
224
- ``
225
- where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
226
- ===============================================================
227
- Args:
228
- model: A diffusion model with the corresponding format described above.
229
- noise_schedule: A noise schedule object, such as NoiseScheduleVP.
230
- model_type: A `str`. The parameterization type of the diffusion model.
231
- "noise" or "x_start" or "v" or "score".
232
- model_kwargs: A `dict`. A dict for the other inputs of the model function.
233
- guidance_type: A `str`. The type of the guidance for sampling.
234
- "uncond" or "classifier" or "classifier-free".
235
- condition: A pytorch tensor. The condition for the guided sampling.
236
- Only used for "classifier" or "classifier-free" guidance type.
237
- unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
238
- Only used for "classifier-free" guidance type.
239
- guidance_scale: A `float`. The scale for the guided sampling.
240
- classifier_fn: A classifier function. Only used for the classifier guidance.
241
- classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
242
- Returns:
243
- A noise prediction model that accepts the noised data and the continuous time as the inputs.
244
- """
245
-
246
- def get_model_input_time(t_continuous):
247
- """
248
- Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
249
- For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
250
- For continuous-time DPMs, we just use `t_continuous`.
251
- """
252
- if noise_schedule.schedule == 'discrete':
253
- return (t_continuous - 1. / noise_schedule.total_N) * 1000.
254
- else:
255
- return t_continuous
256
-
257
- def noise_pred_fn(x, t_continuous, cond=None):
258
- if t_continuous.reshape((-1,)).shape[0] == 1:
259
- t_continuous = t_continuous.expand((x.shape[0]))
260
- t_input = get_model_input_time(t_continuous)
261
- if cond is None:
262
- output = model(x, t_input, **model_kwargs)
263
- else:
264
- output = model(x, t_input, cond, **model_kwargs)
265
- if model_type == "noise":
266
- return output
267
- elif model_type == "x_start":
268
- alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
269
- dims = x.dim()
270
- return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
271
- elif model_type == "v":
272
- alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
273
- dims = x.dim()
274
- return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
275
- elif model_type == "score":
276
- sigma_t = noise_schedule.marginal_std(t_continuous)
277
- dims = x.dim()
278
- return -expand_dims(sigma_t, dims) * output
279
-
280
- def cond_grad_fn(x, t_input):
281
- """
282
- Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
283
- """
284
- with torch.enable_grad():
285
- x_in = x.detach().requires_grad_(True)
286
- log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
287
- return torch.autograd.grad(log_prob.sum(), x_in)[0]
288
-
289
- def model_fn(x, t_continuous):
290
- """
291
- The noise predicition model function that is used for DPM-Solver.
292
- """
293
- if t_continuous.reshape((-1,)).shape[0] == 1:
294
- t_continuous = t_continuous.expand((x.shape[0]))
295
- if guidance_type == "uncond":
296
- return noise_pred_fn(x, t_continuous)
297
- elif guidance_type == "classifier":
298
- assert classifier_fn is not None
299
- t_input = get_model_input_time(t_continuous)
300
- cond_grad = cond_grad_fn(x, t_input)
301
- sigma_t = noise_schedule.marginal_std(t_continuous)
302
- noise = noise_pred_fn(x, t_continuous)
303
- return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
304
- elif guidance_type == "classifier-free":
305
- if guidance_scale == 1. or unconditional_condition is None:
306
- return noise_pred_fn(x, t_continuous, cond=condition)
307
- else:
308
- x_in = torch.cat([x] * 2)
309
- t_in = torch.cat([t_continuous] * 2)
310
- c_in = torch.cat([unconditional_condition, condition])
311
- noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
312
- return noise_uncond + guidance_scale * (noise - noise_uncond)
313
-
314
- assert model_type in ["noise", "x_start", "v"]
315
- assert guidance_type in ["uncond", "classifier", "classifier-free"]
316
- return model_fn
317
-
318
-
319
- class DPM_Solver:
320
- def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
321
- """Construct a DPM-Solver.
322
- We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
323
- If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
324
- If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
325
- In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
326
- The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
327
- Args:
328
- model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
329
- ``
330
- def model_fn(x, t_continuous):
331
- return noise
332
- ``
333
- noise_schedule: A noise schedule object, such as NoiseScheduleVP.
334
- predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
335
- thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
336
- max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
337
-
338
- [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
339
- """
340
- self.model = model_fn
341
- self.noise_schedule = noise_schedule
342
- self.predict_x0 = predict_x0
343
- self.thresholding = thresholding
344
- self.max_val = max_val
345
-
346
- def noise_prediction_fn(self, x, t):
347
- """
348
- Return the noise prediction model.
349
- """
350
- return self.model(x, t)
351
-
352
- def data_prediction_fn(self, x, t):
353
- """
354
- Return the data prediction model (with thresholding).
355
- """
356
- noise = self.noise_prediction_fn(x, t)
357
- dims = x.dim()
358
- alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
359
- x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
360
- if self.thresholding:
361
- p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
362
- s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
363
- s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
364
- x0 = torch.clamp(x0, -s, s) / s
365
- return x0
366
-
367
- def model_fn(self, x, t):
368
- """
369
- Convert the model to the noise prediction model or the data prediction model.
370
- """
371
- if self.predict_x0:
372
- return self.data_prediction_fn(x, t)
373
- else:
374
- return self.noise_prediction_fn(x, t)
375
-
376
- def get_time_steps(self, skip_type, t_T, t_0, N, device):
377
- """Compute the intermediate time steps for sampling.
378
- Args:
379
- skip_type: A `str`. The type for the spacing of the time steps. We support three types:
380
- - 'logSNR': uniform logSNR for the time steps.
381
- - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
382
- - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
383
- t_T: A `float`. The starting time of the sampling (default is T).
384
- t_0: A `float`. The ending time of the sampling (default is epsilon).
385
- N: A `int`. The total number of the spacing of the time steps.
386
- device: A torch device.
387
- Returns:
388
- A pytorch tensor of the time steps, with the shape (N + 1,).
389
- """
390
- if skip_type == 'logSNR':
391
- lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
392
- lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
393
- logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
394
- return self.noise_schedule.inverse_lambda(logSNR_steps)
395
- elif skip_type == 'time_uniform':
396
- return torch.linspace(t_T, t_0, N + 1).to(device)
397
- elif skip_type == 'time_quadratic':
398
- t_order = 2
399
- t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
400
- return t
401
- else:
402
- raise ValueError(
403
- "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
404
-
405
- def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
406
- """
407
- Get the order of each step for sampling by the singlestep DPM-Solver.
408
- We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
409
- Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
410
- - If order == 1:
411
- We take `steps` of DPM-Solver-1 (i.e. DDIM).
412
- - If order == 2:
413
- - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
414
- - If steps % 2 == 0, we use K steps of DPM-Solver-2.
415
- - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
416
- - If order == 3:
417
- - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
418
- - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
419
- - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
420
- - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
421
- ============================================
422
- Args:
423
- order: A `int`. The max order for the solver (2 or 3).
424
- steps: A `int`. The total number of function evaluations (NFE).
425
- skip_type: A `str`. The type for the spacing of the time steps. We support three types:
426
- - 'logSNR': uniform logSNR for the time steps.
427
- - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
428
- - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
429
- t_T: A `float`. The starting time of the sampling (default is T).
430
- t_0: A `float`. The ending time of the sampling (default is epsilon).
431
- device: A torch device.
432
- Returns:
433
- orders: A list of the solver order of each step.
434
- """
435
- if order == 3:
436
- K = steps // 3 + 1
437
- if steps % 3 == 0:
438
- orders = [3, ] * (K - 2) + [2, 1]
439
- elif steps % 3 == 1:
440
- orders = [3, ] * (K - 1) + [1]
441
- else:
442
- orders = [3, ] * (K - 1) + [2]
443
- elif order == 2:
444
- if steps % 2 == 0:
445
- K = steps // 2
446
- orders = [2, ] * K
447
- else:
448
- K = steps // 2 + 1
449
- orders = [2, ] * (K - 1) + [1]
450
- elif order == 1:
451
- K = 1
452
- orders = [1, ] * steps
453
- else:
454
- raise ValueError("'order' must be '1' or '2' or '3'.")
455
- if skip_type == 'logSNR':
456
- # To reproduce the results in DPM-Solver paper
457
- timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
- else:
459
- timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
460
- torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
461
- return timesteps_outer, orders
462
-
463
- def denoise_to_zero_fn(self, x, s):
464
- """
465
- Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
466
- """
467
- return self.data_prediction_fn(x, s)
468
-
469
- def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
470
- """
471
- DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
472
- Args:
473
- x: A pytorch tensor. The initial value at time `s`.
474
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
475
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
476
- model_s: A pytorch tensor. The model function evaluated at time `s`.
477
- If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
478
- return_intermediate: A `bool`. If true, also return the model value at time `s`.
479
- Returns:
480
- x_t: A pytorch tensor. The approximated solution at time `t`.
481
- """
482
- ns = self.noise_schedule
483
- dims = x.dim()
484
- lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
485
- h = lambda_t - lambda_s
486
- log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
487
- sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
488
- alpha_t = torch.exp(log_alpha_t)
489
-
490
- if self.predict_x0:
491
- phi_1 = torch.expm1(-h)
492
- if model_s is None:
493
- model_s = self.model_fn(x, s)
494
- x_t = (
495
- expand_dims(sigma_t / sigma_s, dims) * x
496
- - expand_dims(alpha_t * phi_1, dims) * model_s
497
- )
498
- if return_intermediate:
499
- return x_t, {'model_s': model_s}
500
- else:
501
- return x_t
502
- else:
503
- phi_1 = torch.expm1(h)
504
- if model_s is None:
505
- model_s = self.model_fn(x, s)
506
- x_t = (
507
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
508
- - expand_dims(sigma_t * phi_1, dims) * model_s
509
- )
510
- if return_intermediate:
511
- return x_t, {'model_s': model_s}
512
- else:
513
- return x_t
514
-
515
- def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
516
- solver_type='dpm_solver'):
517
- """
518
- Singlestep solver DPM-Solver-2 from time `s` to time `t`.
519
- Args:
520
- x: A pytorch tensor. The initial value at time `s`.
521
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
522
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
523
- r1: A `float`. The hyperparameter of the second-order solver.
524
- model_s: A pytorch tensor. The model function evaluated at time `s`.
525
- If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
526
- return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
527
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
528
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
529
- Returns:
530
- x_t: A pytorch tensor. The approximated solution at time `t`.
531
- """
532
- if solver_type not in ['dpm_solver', 'taylor']:
533
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
534
- if r1 is None:
535
- r1 = 0.5
536
- ns = self.noise_schedule
537
- dims = x.dim()
538
- lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
539
- h = lambda_t - lambda_s
540
- lambda_s1 = lambda_s + r1 * h
541
- s1 = ns.inverse_lambda(lambda_s1)
542
- log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
543
- s1), ns.marginal_log_mean_coeff(t)
544
- sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
545
- alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
546
-
547
- if self.predict_x0:
548
- phi_11 = torch.expm1(-r1 * h)
549
- phi_1 = torch.expm1(-h)
550
-
551
- if model_s is None:
552
- model_s = self.model_fn(x, s)
553
- x_s1 = (
554
- expand_dims(sigma_s1 / sigma_s, dims) * x
555
- - expand_dims(alpha_s1 * phi_11, dims) * model_s
556
- )
557
- model_s1 = self.model_fn(x_s1, s1)
558
- if solver_type == 'dpm_solver':
559
- x_t = (
560
- expand_dims(sigma_t / sigma_s, dims) * x
561
- - expand_dims(alpha_t * phi_1, dims) * model_s
562
- - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
563
- )
564
- elif solver_type == 'taylor':
565
- x_t = (
566
- expand_dims(sigma_t / sigma_s, dims) * x
567
- - expand_dims(alpha_t * phi_1, dims) * model_s
568
- + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
569
- model_s1 - model_s)
570
- )
571
- else:
572
- phi_11 = torch.expm1(r1 * h)
573
- phi_1 = torch.expm1(h)
574
-
575
- if model_s is None:
576
- model_s = self.model_fn(x, s)
577
- x_s1 = (
578
- expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
579
- - expand_dims(sigma_s1 * phi_11, dims) * model_s
580
- )
581
- model_s1 = self.model_fn(x_s1, s1)
582
- if solver_type == 'dpm_solver':
583
- x_t = (
584
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
585
- - expand_dims(sigma_t * phi_1, dims) * model_s
586
- - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
587
- )
588
- elif solver_type == 'taylor':
589
- x_t = (
590
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
591
- - expand_dims(sigma_t * phi_1, dims) * model_s
592
- - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
593
- )
594
- if return_intermediate:
595
- return x_t, {'model_s': model_s, 'model_s1': model_s1}
596
- else:
597
- return x_t
598
-
599
- def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
600
- return_intermediate=False, solver_type='dpm_solver'):
601
- """
602
- Singlestep solver DPM-Solver-3 from time `s` to time `t`.
603
- Args:
604
- x: A pytorch tensor. The initial value at time `s`.
605
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
606
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
607
- r1: A `float`. The hyperparameter of the third-order solver.
608
- r2: A `float`. The hyperparameter of the third-order solver.
609
- model_s: A pytorch tensor. The model function evaluated at time `s`.
610
- If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
611
- model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
612
- If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
613
- return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
614
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
615
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
616
- Returns:
617
- x_t: A pytorch tensor. The approximated solution at time `t`.
618
- """
619
- if solver_type not in ['dpm_solver', 'taylor']:
620
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
621
- if r1 is None:
622
- r1 = 1. / 3.
623
- if r2 is None:
624
- r2 = 2. / 3.
625
- ns = self.noise_schedule
626
- dims = x.dim()
627
- lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
628
- h = lambda_t - lambda_s
629
- lambda_s1 = lambda_s + r1 * h
630
- lambda_s2 = lambda_s + r2 * h
631
- s1 = ns.inverse_lambda(lambda_s1)
632
- s2 = ns.inverse_lambda(lambda_s2)
633
- log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
634
- s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
635
- sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
636
- s2), ns.marginal_std(t)
637
- alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
638
-
639
- if self.predict_x0:
640
- phi_11 = torch.expm1(-r1 * h)
641
- phi_12 = torch.expm1(-r2 * h)
642
- phi_1 = torch.expm1(-h)
643
- phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
644
- phi_2 = phi_1 / h + 1.
645
- phi_3 = phi_2 / h - 0.5
646
-
647
- if model_s is None:
648
- model_s = self.model_fn(x, s)
649
- if model_s1 is None:
650
- x_s1 = (
651
- expand_dims(sigma_s1 / sigma_s, dims) * x
652
- - expand_dims(alpha_s1 * phi_11, dims) * model_s
653
- )
654
- model_s1 = self.model_fn(x_s1, s1)
655
- x_s2 = (
656
- expand_dims(sigma_s2 / sigma_s, dims) * x
657
- - expand_dims(alpha_s2 * phi_12, dims) * model_s
658
- + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
659
- )
660
- model_s2 = self.model_fn(x_s2, s2)
661
- if solver_type == 'dpm_solver':
662
- x_t = (
663
- expand_dims(sigma_t / sigma_s, dims) * x
664
- - expand_dims(alpha_t * phi_1, dims) * model_s
665
- + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
666
- )
667
- elif solver_type == 'taylor':
668
- D1_0 = (1. / r1) * (model_s1 - model_s)
669
- D1_1 = (1. / r2) * (model_s2 - model_s)
670
- D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
671
- D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
672
- x_t = (
673
- expand_dims(sigma_t / sigma_s, dims) * x
674
- - expand_dims(alpha_t * phi_1, dims) * model_s
675
- + expand_dims(alpha_t * phi_2, dims) * D1
676
- - expand_dims(alpha_t * phi_3, dims) * D2
677
- )
678
- else:
679
- phi_11 = torch.expm1(r1 * h)
680
- phi_12 = torch.expm1(r2 * h)
681
- phi_1 = torch.expm1(h)
682
- phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
683
- phi_2 = phi_1 / h - 1.
684
- phi_3 = phi_2 / h - 0.5
685
-
686
- if model_s is None:
687
- model_s = self.model_fn(x, s)
688
- if model_s1 is None:
689
- x_s1 = (
690
- expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
691
- - expand_dims(sigma_s1 * phi_11, dims) * model_s
692
- )
693
- model_s1 = self.model_fn(x_s1, s1)
694
- x_s2 = (
695
- expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
696
- - expand_dims(sigma_s2 * phi_12, dims) * model_s
697
- - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
698
- )
699
- model_s2 = self.model_fn(x_s2, s2)
700
- if solver_type == 'dpm_solver':
701
- x_t = (
702
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
703
- - expand_dims(sigma_t * phi_1, dims) * model_s
704
- - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
705
- )
706
- elif solver_type == 'taylor':
707
- D1_0 = (1. / r1) * (model_s1 - model_s)
708
- D1_1 = (1. / r2) * (model_s2 - model_s)
709
- D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
710
- D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
711
- x_t = (
712
- expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
713
- - expand_dims(sigma_t * phi_1, dims) * model_s
714
- - expand_dims(sigma_t * phi_2, dims) * D1
715
- - expand_dims(sigma_t * phi_3, dims) * D2
716
- )
717
-
718
- if return_intermediate:
719
- return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
720
- else:
721
- return x_t
722
-
723
- def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
724
- """
725
- Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
726
- Args:
727
- x: A pytorch tensor. The initial value at time `s`.
728
- model_prev_list: A list of pytorch tensor. The previous computed model values.
729
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
730
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
731
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
732
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
733
- Returns:
734
- x_t: A pytorch tensor. The approximated solution at time `t`.
735
- """
736
- if solver_type not in ['dpm_solver', 'taylor']:
737
- raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
738
- ns = self.noise_schedule
739
- dims = x.dim()
740
- model_prev_1, model_prev_0 = model_prev_list
741
- t_prev_1, t_prev_0 = t_prev_list
742
- lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
743
- t_prev_0), ns.marginal_lambda(t)
744
- log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
745
- sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
746
- alpha_t = torch.exp(log_alpha_t)
747
-
748
- h_0 = lambda_prev_0 - lambda_prev_1
749
- h = lambda_t - lambda_prev_0
750
- r0 = h_0 / h
751
- D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
752
- if self.predict_x0:
753
- if solver_type == 'dpm_solver':
754
- x_t = (
755
- expand_dims(sigma_t / sigma_prev_0, dims) * x
756
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
757
- - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
758
- )
759
- elif solver_type == 'taylor':
760
- x_t = (
761
- expand_dims(sigma_t / sigma_prev_0, dims) * x
762
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
763
- + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
764
- )
765
- else:
766
- if solver_type == 'dpm_solver':
767
- x_t = (
768
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
769
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
770
- - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
771
- )
772
- elif solver_type == 'taylor':
773
- x_t = (
774
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
775
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
776
- - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
777
- )
778
- return x_t
779
-
780
- def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
781
- """
782
- Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
783
- Args:
784
- x: A pytorch tensor. The initial value at time `s`.
785
- model_prev_list: A list of pytorch tensor. The previous computed model values.
786
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
787
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
788
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
789
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
790
- Returns:
791
- x_t: A pytorch tensor. The approximated solution at time `t`.
792
- """
793
- ns = self.noise_schedule
794
- dims = x.dim()
795
- model_prev_2, model_prev_1, model_prev_0 = model_prev_list
796
- t_prev_2, t_prev_1, t_prev_0 = t_prev_list
797
- lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
798
- t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
799
- log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
800
- sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
801
- alpha_t = torch.exp(log_alpha_t)
802
-
803
- h_1 = lambda_prev_1 - lambda_prev_2
804
- h_0 = lambda_prev_0 - lambda_prev_1
805
- h = lambda_t - lambda_prev_0
806
- r0, r1 = h_0 / h, h_1 / h
807
- D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
808
- D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
809
- D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
810
- D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
811
- if self.predict_x0:
812
- x_t = (
813
- expand_dims(sigma_t / sigma_prev_0, dims) * x
814
- - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
815
- + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
816
- - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
817
- )
818
- else:
819
- x_t = (
820
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
821
- - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
822
- - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
823
- - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
824
- )
825
- return x_t
826
-
827
- def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
828
- r2=None):
829
- """
830
- Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
831
- Args:
832
- x: A pytorch tensor. The initial value at time `s`.
833
- s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
834
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
835
- order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
836
- return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
837
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
838
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
839
- r1: A `float`. The hyperparameter of the second-order or third-order solver.
840
- r2: A `float`. The hyperparameter of the third-order solver.
841
- Returns:
842
- x_t: A pytorch tensor. The approximated solution at time `t`.
843
- """
844
- if order == 1:
845
- return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
846
- elif order == 2:
847
- return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
848
- solver_type=solver_type, r1=r1)
849
- elif order == 3:
850
- return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
851
- solver_type=solver_type, r1=r1, r2=r2)
852
- else:
853
- raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
854
-
855
- def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
856
- """
857
- Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
858
- Args:
859
- x: A pytorch tensor. The initial value at time `s`.
860
- model_prev_list: A list of pytorch tensor. The previous computed model values.
861
- t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
862
- t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
863
- order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
864
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
865
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
866
- Returns:
867
- x_t: A pytorch tensor. The approximated solution at time `t`.
868
- """
869
- if order == 1:
870
- return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
871
- elif order == 2:
872
- return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
873
- elif order == 3:
874
- return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
875
- else:
876
- raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
877
-
878
- def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
879
- solver_type='dpm_solver'):
880
- """
881
- The adaptive step size solver based on singlestep DPM-Solver.
882
- Args:
883
- x: A pytorch tensor. The initial value at time `t_T`.
884
- order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
885
- t_T: A `float`. The starting time of the sampling (default is T).
886
- t_0: A `float`. The ending time of the sampling (default is epsilon).
887
- h_init: A `float`. The initial step size (for logSNR).
888
- atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
889
- rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
890
- theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
891
- t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
892
- current time and `t_0` is less than `t_err`. The default setting is 1e-5.
893
- solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
894
- The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
895
- Returns:
896
- x_0: A pytorch tensor. The approximated solution at time `t_0`.
897
- [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
898
- """
899
- ns = self.noise_schedule
900
- s = t_T * torch.ones((x.shape[0],)).to(x)
901
- lambda_s = ns.marginal_lambda(s)
902
- lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
903
- h = h_init * torch.ones_like(s).to(x)
904
- x_prev = x
905
- nfe = 0
906
- if order == 2:
907
- r1 = 0.5
908
- lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
909
- higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
910
- solver_type=solver_type,
911
- **kwargs)
912
- elif order == 3:
913
- r1, r2 = 1. / 3., 2. / 3.
914
- lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
915
- return_intermediate=True,
916
- solver_type=solver_type)
917
- higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
918
- solver_type=solver_type,
919
- **kwargs)
920
- else:
921
- raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
922
- while torch.abs((s - t_0)).mean() > t_err:
923
- t = ns.inverse_lambda(lambda_s + h)
924
- x_lower, lower_noise_kwargs = lower_update(x, s, t)
925
- x_higher = higher_update(x, s, t, **lower_noise_kwargs)
926
- delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
927
- norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
928
- E = norm_fn((x_higher - x_lower) / delta).max()
929
- if torch.all(E <= 1.):
930
- x = x_higher
931
- s = t
932
- x_prev = x_lower
933
- lambda_s = ns.marginal_lambda(s)
934
- h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
935
- nfe += order
936
- print('adaptive solver nfe', nfe)
937
- return x
938
-
939
- def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
940
- method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
941
- atol=0.0078, rtol=0.05,
942
- ):
943
- """
944
- Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
945
- =====================================================
946
- We support the following algorithms for both noise prediction model and data prediction model:
947
- - 'singlestep':
948
- Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
949
- We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
950
- The total number of function evaluations (NFE) == `steps`.
951
- Given a fixed NFE == `steps`, the sampling procedure is:
952
- - If `order` == 1:
953
- - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
954
- - If `order` == 2:
955
- - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
956
- - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
957
- - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
958
- - If `order` == 3:
959
- - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
960
- - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
961
- - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
962
- - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
963
- - 'multistep':
964
- Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
965
- We initialize the first `order` values by lower order multistep solvers.
966
- Given a fixed NFE == `steps`, the sampling procedure is:
967
- Denote K = steps.
968
- - If `order` == 1:
969
- - We use K steps of DPM-Solver-1 (i.e. DDIM).
970
- - If `order` == 2:
971
- - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
972
- - If `order` == 3:
973
- - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
974
- - 'singlestep_fixed':
975
- Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
976
- We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
977
- - 'adaptive':
978
- Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
979
- We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
980
- You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
981
- (NFE) and the sample quality.
982
- - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
983
- - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
984
- =====================================================
985
- Some advices for choosing the algorithm:
986
- - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
987
- Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
988
- e.g.
989
- >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
990
- >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
991
- skip_type='time_uniform', method='singlestep')
992
- - For **guided sampling with large guidance scale** by DPMs:
993
- Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
994
- e.g.
995
- >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
996
- >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
997
- skip_type='time_uniform', method='multistep')
998
- We support three types of `skip_type`:
999
- - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1000
- - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1001
- - 'time_quadratic': quadratic time for the time steps.
1002
- =====================================================
1003
- Args:
1004
- x: A pytorch tensor. The initial value at time `t_start`
1005
- e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1006
- steps: A `int`. The total number of function evaluations (NFE).
1007
- t_start: A `float`. The starting time of the sampling.
1008
- If `T` is None, we use self.noise_schedule.T (default is 1.0).
1009
- t_end: A `float`. The ending time of the sampling.
1010
- If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1011
- e.g. if total_N == 1000, we have `t_end` == 1e-3.
1012
- For discrete-time DPMs:
1013
- - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1014
- For continuous-time DPMs:
1015
- - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1016
- order: A `int`. The order of DPM-Solver.
1017
- skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1018
- method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1019
- denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1020
- Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1021
- This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1022
- score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1023
- for diffusion models sampling by diffusion SDEs for low-resolutional images
1024
- (such as CIFAR-10). However, we observed that such trick does not matter for
1025
- high-resolutional images. As it needs an additional NFE, we do not recommend
1026
- it for high-resolutional images.
1027
- lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1028
- Only valid for `method=multistep` and `steps < 15`. We empirically find that
1029
- this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1030
- (especially for steps <= 10). So we recommend to set it to be `True`.
1031
- solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
1032
- atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1033
- rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1034
- Returns:
1035
- x_end: A pytorch tensor. The approximated solution at time `t_end`.
1036
- """
1037
- t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1038
- t_T = self.noise_schedule.T if t_start is None else t_start
1039
- device = x.device
1040
- if method == 'adaptive':
1041
- with torch.no_grad():
1042
- x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
1043
- solver_type=solver_type)
1044
- elif method == 'multistep':
1045
- assert steps >= order
1046
- timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1047
- assert timesteps.shape[0] - 1 == steps
1048
- with torch.no_grad():
1049
- vec_t = timesteps[0].expand((x.shape[0]))
1050
- model_prev_list = [self.model_fn(x, vec_t)]
1051
- t_prev_list = [vec_t]
1052
- # Init the first `order` values by lower order multistep DPM-Solver.
1053
- for init_order in tqdm(range(1, order), desc="DPM init order"):
1054
- vec_t = timesteps[init_order].expand(x.shape[0])
1055
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
1056
- solver_type=solver_type)
1057
- model_prev_list.append(self.model_fn(x, vec_t))
1058
- t_prev_list.append(vec_t)
1059
- # Compute the remaining values by `order`-th order multistep DPM-Solver.
1060
- for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
1061
- vec_t = timesteps[step].expand(x.shape[0])
1062
- if lower_order_final and steps < 15:
1063
- step_order = min(order, steps + 1 - step)
1064
- else:
1065
- step_order = order
1066
- x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
1067
- solver_type=solver_type)
1068
- for i in range(order - 1):
1069
- t_prev_list[i] = t_prev_list[i + 1]
1070
- model_prev_list[i] = model_prev_list[i + 1]
1071
- t_prev_list[-1] = vec_t
1072
- # We do not need to evaluate the final model value.
1073
- if step < steps:
1074
- model_prev_list[-1] = self.model_fn(x, vec_t)
1075
- elif method in ['singlestep', 'singlestep_fixed']:
1076
- if method == 'singlestep':
1077
- timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
1078
- skip_type=skip_type,
1079
- t_T=t_T, t_0=t_0,
1080
- device=device)
1081
- elif method == 'singlestep_fixed':
1082
- K = steps // order
1083
- orders = [order, ] * K
1084
- timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1085
- for i, order in enumerate(orders):
1086
- t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
1087
- timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
1088
- N=order, device=device)
1089
- lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1090
- vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
1091
- h = lambda_inner[-1] - lambda_inner[0]
1092
- r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1093
- r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1094
- x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
1095
- if denoise_to_zero:
1096
- x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
1097
- return x
1098
-
1099
-
1100
- #############################################################
1101
- # other utility functions
1102
- #############################################################
1103
-
1104
- def interpolate_fn(x, xp, yp):
1105
- """
1106
- A piecewise linear function y = f(x), using xp and yp as keypoints.
1107
- We implement f(x) in a differentiable way (i.e. applicable for autograd).
1108
- The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1109
- Args:
1110
- x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1111
- xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1112
- yp: PyTorch tensor with shape [C, K].
1113
- Returns:
1114
- The function values f(x), with shape [N, C].
1115
- """
1116
- N, K = x.shape[0], xp.shape[1]
1117
- all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1118
- sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1119
- x_idx = torch.argmin(x_indices, dim=2)
1120
- cand_start_idx = x_idx - 1
1121
- start_idx = torch.where(
1122
- torch.eq(x_idx, 0),
1123
- torch.tensor(1, device=x.device),
1124
- torch.where(
1125
- torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1126
- ),
1127
- )
1128
- end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1129
- start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1130
- end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1131
- start_idx2 = torch.where(
1132
- torch.eq(x_idx, 0),
1133
- torch.tensor(0, device=x.device),
1134
- torch.where(
1135
- torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1136
- ),
1137
- )
1138
- y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1139
- start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1140
- end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1141
- cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1142
- return cand
1143
-
1144
-
1145
- def expand_dims(v, dims):
1146
- """
1147
- Expand the tensor `v` to the dim `dims`.
1148
- Args:
1149
- `v`: a PyTorch tensor with shape [N].
1150
- `dim`: a `int`.
1151
- Returns:
1152
- a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1153
- """
1154
- return v[(...,) + (None,) * (dims - 1)]