JunhaoZhuang commited on
Commit
0a5c8a9
1 Parent(s): d5eb601
Files changed (36) hide show
  1. .DS_Store +0 -0
  2. app.py +1 -1
  3. assets/example_0/input.jpg +0 -0
  4. assets/example_0/ref1.jpg +0 -0
  5. assets/example_1/input.jpg +0 -0
  6. assets/example_1/ref1.jpg +0 -0
  7. assets/example_1/ref2.jpg +0 -0
  8. assets/example_1/ref3.jpg +0 -0
  9. assets/example_2/input.png +0 -0
  10. assets/example_2/ref1.png +0 -0
  11. assets/example_2/ref2.png +0 -0
  12. assets/example_2/ref3.png +0 -0
  13. assets/example_3/input.png +0 -0
  14. assets/example_3/ref1.png +0 -0
  15. assets/example_3/ref2.png +0 -0
  16. assets/example_3/ref3.png +0 -0
  17. assets/example_4/input.jpg +0 -0
  18. assets/example_4/ref1.jpg +0 -0
  19. assets/example_4/ref2.jpg +0 -0
  20. assets/example_4/ref3.jpg +0 -0
  21. assets/example_5/input.png +0 -0
  22. assets/example_5/ref1.png +0 -0
  23. assets/example_5/ref2.png +0 -0
  24. assets/example_5/ref3.png +0 -0
  25. assets/mask.png +0 -0
  26. diffusers/src/diffusers/models/autoencoders_/__init__.py +0 -8
  27. diffusers/src/diffusers/models/autoencoders_/autoencoder_asym_kl.py +0 -184
  28. diffusers/src/diffusers/models/autoencoders_/autoencoder_kl.py +0 -570
  29. diffusers/src/diffusers/models/autoencoders_/autoencoder_kl_cogvideox.py +0 -1374
  30. diffusers/src/diffusers/models/autoencoders_/autoencoder_kl_temporal_decoder.py +0 -401
  31. diffusers/src/diffusers/models/autoencoders_/autoencoder_oobleck.py +0 -464
  32. diffusers/src/diffusers/models/autoencoders_/autoencoder_tiny.py +0 -348
  33. diffusers/src/diffusers/models/autoencoders_/consistency_decoder_vae.py +0 -460
  34. diffusers/src/diffusers/models/autoencoders_/vae.py +0 -1005
  35. diffusers/src/diffusers/models/autoencoders_/vq_model.py +0 -182
  36. diffusers/src/diffusers/pipelines/colorflow/pipeline_colorflow_sd.py +1 -7
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -504,4 +504,4 @@ with gr.Blocks() as demo:
504
  label="Examples",
505
  examples_per_page=6,
506
  )
507
- demo.launch(server_name="0.0.0.0", server_port=22348)
 
504
  label="Examples",
505
  examples_per_page=6,
506
  )
507
+ demo.launch()
assets/example_0/input.jpg DELETED
Binary file (130 kB)
 
assets/example_0/ref1.jpg DELETED
Binary file (130 kB)
 
assets/example_1/input.jpg DELETED
Binary file (139 kB)
 
assets/example_1/ref1.jpg DELETED
Binary file (169 kB)
 
assets/example_1/ref2.jpg DELETED
Binary file (169 kB)
 
assets/example_1/ref3.jpg DELETED
Binary file (172 kB)
 
assets/example_2/input.png DELETED
Binary file (671 kB)
 
assets/example_2/ref1.png DELETED
Binary file (729 kB)
 
assets/example_2/ref2.png DELETED
Binary file (684 kB)
 
assets/example_2/ref3.png DELETED
Binary file (629 kB)
 
assets/example_3/input.png DELETED
Binary file (661 kB)
 
assets/example_3/ref1.png DELETED
Binary file (812 kB)
 
assets/example_3/ref2.png DELETED
Binary file (538 kB)
 
assets/example_3/ref3.png DELETED
Binary file (717 kB)
 
assets/example_4/input.jpg DELETED
Binary file (165 kB)
 
assets/example_4/ref1.jpg DELETED
Binary file (175 kB)
 
assets/example_4/ref2.jpg DELETED
Binary file (163 kB)
 
assets/example_4/ref3.jpg DELETED
Binary file (179 kB)
 
assets/example_5/input.png DELETED
Binary file (687 kB)
 
assets/example_5/ref1.png DELETED
Binary file (907 kB)
 
assets/example_5/ref2.png DELETED
Binary file (805 kB)
 
assets/example_5/ref3.png DELETED
Binary file (487 kB)
 
assets/mask.png DELETED
Binary file (2.31 kB)
 
diffusers/src/diffusers/models/autoencoders_/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- from .autoencoder_asym_kl import AsymmetricAutoencoderKL
2
- from .autoencoder_kl import AutoencoderKL
3
- from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
4
- from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
5
- from .autoencoder_oobleck import AutoencoderOobleck
6
- from .autoencoder_tiny import AutoencoderTiny
7
- from .consistency_decoder_vae import ConsistencyDecoderVAE
8
- from .vq_model import VQModel
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/autoencoder_asym_kl.py DELETED
@@ -1,184 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Optional, Tuple, Union
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
- from ...configuration_utils import ConfigMixin, register_to_config
20
- from ...utils.accelerate_utils import apply_forward_hook
21
- from ..modeling_outputs import AutoencoderKLOutput
22
- from ..modeling_utils import ModelMixin
23
- from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
24
-
25
-
26
- class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
27
- r"""
28
- Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
29
- for encoding images into latents and decoding latent representations into images.
30
-
31
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
32
- for all models (such as downloading or saving).
33
-
34
- Parameters:
35
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
36
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
37
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
38
- Tuple of downsample block types.
39
- down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
40
- Tuple of down block output channels.
41
- layers_per_down_block (`int`, *optional*, defaults to `1`):
42
- Number layers for down block.
43
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
44
- Tuple of upsample block types.
45
- up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
46
- Tuple of up block output channels.
47
- layers_per_up_block (`int`, *optional*, defaults to `1`):
48
- Number layers for up block.
49
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
50
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
51
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
52
- norm_num_groups (`int`, *optional*, defaults to `32`):
53
- Number of groups to use for the first normalization layer in ResNet blocks.
54
- scaling_factor (`float`, *optional*, defaults to 0.18215):
55
- The component-wise standard deviation of the trained latent space computed using the first batch of the
56
- training set. This is used to scale the latent space to have unit variance when training the diffusion
57
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
58
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
59
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
60
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
61
- """
62
-
63
- @register_to_config
64
- def __init__(
65
- self,
66
- in_channels: int = 3,
67
- out_channels: int = 3,
68
- down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
69
- down_block_out_channels: Tuple[int, ...] = (64,),
70
- layers_per_down_block: int = 1,
71
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
72
- up_block_out_channels: Tuple[int, ...] = (64,),
73
- layers_per_up_block: int = 1,
74
- act_fn: str = "silu",
75
- latent_channels: int = 4,
76
- norm_num_groups: int = 32,
77
- sample_size: int = 32,
78
- scaling_factor: float = 0.18215,
79
- ) -> None:
80
- super().__init__()
81
-
82
- # pass init params to Encoder
83
- self.encoder = Encoder(
84
- in_channels=in_channels,
85
- out_channels=latent_channels,
86
- down_block_types=down_block_types,
87
- block_out_channels=down_block_out_channels,
88
- layers_per_block=layers_per_down_block,
89
- act_fn=act_fn,
90
- norm_num_groups=norm_num_groups,
91
- double_z=True,
92
- )
93
-
94
- # pass init params to Decoder
95
- self.decoder = MaskConditionDecoder(
96
- in_channels=latent_channels,
97
- out_channels=out_channels,
98
- up_block_types=up_block_types,
99
- block_out_channels=up_block_out_channels,
100
- layers_per_block=layers_per_up_block,
101
- act_fn=act_fn,
102
- norm_num_groups=norm_num_groups,
103
- )
104
-
105
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
106
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
107
-
108
- self.use_slicing = False
109
- self.use_tiling = False
110
-
111
- self.register_to_config(block_out_channels=up_block_out_channels)
112
- self.register_to_config(force_upcast=False)
113
-
114
- @apply_forward_hook
115
- def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor]]:
116
- h = self.encoder(x)
117
- moments = self.quant_conv(h)
118
- posterior = DiagonalGaussianDistribution(moments)
119
-
120
- if not return_dict:
121
- return (posterior,)
122
-
123
- return AutoencoderKLOutput(latent_dist=posterior)
124
-
125
- def _decode(
126
- self,
127
- z: torch.Tensor,
128
- image: Optional[torch.Tensor] = None,
129
- mask: Optional[torch.Tensor] = None,
130
- return_dict: bool = True,
131
- ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
132
- z = self.post_quant_conv(z)
133
- dec = self.decoder(z, image, mask)
134
-
135
- if not return_dict:
136
- return (dec,)
137
-
138
- return DecoderOutput(sample=dec)
139
-
140
- @apply_forward_hook
141
- def decode(
142
- self,
143
- z: torch.Tensor,
144
- generator: Optional[torch.Generator] = None,
145
- image: Optional[torch.Tensor] = None,
146
- mask: Optional[torch.Tensor] = None,
147
- return_dict: bool = True,
148
- ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
149
- decoded = self._decode(z, image, mask).sample
150
-
151
- if not return_dict:
152
- return (decoded,)
153
-
154
- return DecoderOutput(sample=decoded)
155
-
156
- def forward(
157
- self,
158
- sample: torch.Tensor,
159
- mask: Optional[torch.Tensor] = None,
160
- sample_posterior: bool = False,
161
- return_dict: bool = True,
162
- generator: Optional[torch.Generator] = None,
163
- ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
164
- r"""
165
- Args:
166
- sample (`torch.Tensor`): Input sample.
167
- mask (`torch.Tensor`, *optional*, defaults to `None`): Optional inpainting mask.
168
- sample_posterior (`bool`, *optional*, defaults to `False`):
169
- Whether to sample from the posterior.
170
- return_dict (`bool`, *optional*, defaults to `True`):
171
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
172
- """
173
- x = sample
174
- posterior = self.encode(x).latent_dist
175
- if sample_posterior:
176
- z = posterior.sample(generator=generator)
177
- else:
178
- z = posterior.mode()
179
- dec = self.decode(z, generator, sample, mask).sample
180
-
181
- if not return_dict:
182
- return (dec,)
183
-
184
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/autoencoder_kl.py DELETED
@@ -1,570 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Dict, Optional, Tuple, Union
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
- from ...configuration_utils import ConfigMixin, register_to_config
20
- from ...loaders.single_file_model import FromOriginalModelMixin
21
- from ...utils import deprecate
22
- from ...utils.accelerate_utils import apply_forward_hook
23
- from ..attention_processor import (
24
- ADDED_KV_ATTENTION_PROCESSORS,
25
- CROSS_ATTENTION_PROCESSORS,
26
- Attention,
27
- AttentionProcessor,
28
- AttnAddedKVProcessor,
29
- AttnProcessor,
30
- FusedAttnProcessor2_0,
31
- )
32
- from ..modeling_outputs import AutoencoderKLOutput
33
- from ..modeling_utils import ModelMixin
34
- from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
35
-
36
-
37
- class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
38
- r"""
39
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
40
-
41
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
42
- for all models (such as downloading or saving).
43
-
44
- Parameters:
45
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
46
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
47
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
48
- Tuple of downsample block types.
49
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
50
- Tuple of upsample block types.
51
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
52
- Tuple of block output channels.
53
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
54
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
55
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
56
- scaling_factor (`float`, *optional*, defaults to 0.18215):
57
- The component-wise standard deviation of the trained latent space computed using the first batch of the
58
- training set. This is used to scale the latent space to have unit variance when training the diffusion
59
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
60
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
61
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
62
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
63
- force_upcast (`bool`, *optional*, default to `True`):
64
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
65
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
66
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
67
- mid_block_add_attention (`bool`, *optional*, default to `True`):
68
- If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
69
- mid_block will only have resnet blocks
70
- """
71
-
72
- _supports_gradient_checkpointing = True
73
- _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
74
-
75
- @register_to_config
76
- def __init__(
77
- self,
78
- in_channels: int = 3,
79
- out_channels: int = 3,
80
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
81
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
82
- block_out_channels: Tuple[int] = (64,),
83
- layers_per_block: int = 1,
84
- act_fn: str = "silu",
85
- latent_channels: int = 4,
86
- norm_num_groups: int = 32,
87
- sample_size: int = 32,
88
- scaling_factor: float = 0.18215,
89
- shift_factor: Optional[float] = None,
90
- latents_mean: Optional[Tuple[float]] = None,
91
- latents_std: Optional[Tuple[float]] = None,
92
- force_upcast: float = True,
93
- use_quant_conv: bool = True,
94
- use_post_quant_conv: bool = True,
95
- mid_block_add_attention: bool = True,
96
- ):
97
- super().__init__()
98
-
99
- # pass init params to Encoder
100
- self.encoder = Encoder(
101
- in_channels=in_channels,
102
- out_channels=latent_channels,
103
- down_block_types=down_block_types,
104
- block_out_channels=block_out_channels,
105
- layers_per_block=layers_per_block,
106
- act_fn=act_fn,
107
- norm_num_groups=norm_num_groups,
108
- double_z=True,
109
- mid_block_add_attention=mid_block_add_attention,
110
- )
111
-
112
- # pass init params to Decoder
113
- self.decoder = Decoder(
114
- in_channels=latent_channels,
115
- out_channels=out_channels,
116
- up_block_types=up_block_types,
117
- block_out_channels=block_out_channels,
118
- layers_per_block=layers_per_block,
119
- norm_num_groups=norm_num_groups,
120
- act_fn=act_fn,
121
- mid_block_add_attention=mid_block_add_attention,
122
- )
123
-
124
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
125
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
126
-
127
- self.use_slicing = False
128
- self.use_tiling = False
129
-
130
- # only relevant if vae tiling is enabled
131
- self.tile_sample_min_size = self.config.sample_size
132
- sample_size = (
133
- self.config.sample_size[0]
134
- if isinstance(self.config.sample_size, (list, tuple))
135
- else self.config.sample_size
136
- )
137
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
138
- self.tile_overlap_factor = 0.25
139
-
140
- def _set_gradient_checkpointing(self, module, value=False):
141
- if isinstance(module, (Encoder, Decoder)):
142
- module.gradient_checkpointing = value
143
-
144
- def enable_tiling(self, use_tiling: bool = True):
145
- r"""
146
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
147
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
148
- processing larger images.
149
- """
150
- self.use_tiling = use_tiling
151
-
152
- def disable_tiling(self):
153
- r"""
154
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
155
- decoding in one step.
156
- """
157
- self.enable_tiling(False)
158
-
159
- def enable_slicing(self):
160
- r"""
161
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
162
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
163
- """
164
- self.use_slicing = True
165
-
166
- def disable_slicing(self):
167
- r"""
168
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
169
- decoding in one step.
170
- """
171
- self.use_slicing = False
172
-
173
- @property
174
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
175
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
176
- r"""
177
- Returns:
178
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
179
- indexed by its weight name.
180
- """
181
- # set recursively
182
- processors = {}
183
-
184
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
185
- if hasattr(module, "get_processor"):
186
- processors[f"{name}.processor"] = module.get_processor()
187
-
188
- for sub_name, child in module.named_children():
189
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
190
-
191
- return processors
192
-
193
- for name, module in self.named_children():
194
- fn_recursive_add_processors(name, module, processors)
195
-
196
- return processors
197
-
198
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
199
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
200
- r"""
201
- Sets the attention processor to use to compute attention.
202
-
203
- Parameters:
204
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
205
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
206
- for **all** `Attention` layers.
207
-
208
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
209
- processor. This is strongly recommended when setting trainable attention processors.
210
-
211
- """
212
- count = len(self.attn_processors.keys())
213
-
214
- if isinstance(processor, dict) and len(processor) != count:
215
- raise ValueError(
216
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
217
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
218
- )
219
-
220
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
221
- if hasattr(module, "set_processor"):
222
- if not isinstance(processor, dict):
223
- module.set_processor(processor)
224
- else:
225
- module.set_processor(processor.pop(f"{name}.processor"))
226
-
227
- for sub_name, child in module.named_children():
228
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
229
-
230
- for name, module in self.named_children():
231
- fn_recursive_attn_processor(name, module, processor)
232
-
233
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
234
- def set_default_attn_processor(self):
235
- """
236
- Disables custom attention processors and sets the default attention implementation.
237
- """
238
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
239
- processor = AttnAddedKVProcessor()
240
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
241
- processor = AttnProcessor()
242
- else:
243
- raise ValueError(
244
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
245
- )
246
-
247
- self.set_attn_processor(processor)
248
-
249
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
250
- batch_size, num_channels, height, width = x.shape
251
-
252
- if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
253
- return self._tiled_encode(x)
254
-
255
- enc = self.encoder(x)
256
- if self.quant_conv is not None:
257
- enc = self.quant_conv(enc)
258
-
259
- return enc
260
-
261
- @apply_forward_hook
262
- def encode(
263
- self, x: torch.Tensor, return_dict: bool = True
264
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
265
- """
266
- Encode a batch of images into latents.
267
-
268
- Args:
269
- x (`torch.Tensor`): Input batch of images.
270
- return_dict (`bool`, *optional*, defaults to `True`):
271
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
272
-
273
- Returns:
274
- The latent representations of the encoded images. If `return_dict` is True, a
275
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
276
- """
277
- if self.use_slicing and x.shape[0] > 1:
278
- encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
279
- h = torch.cat(encoded_slices)
280
- else:
281
- h = self._encode(x)
282
-
283
- posterior = DiagonalGaussianDistribution(h)
284
-
285
- if not return_dict:
286
- return (posterior,)
287
-
288
- return AutoencoderKLOutput(latent_dist=posterior)
289
-
290
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
291
- if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
292
- return self.tiled_decode(z, return_dict=return_dict)
293
-
294
- if self.post_quant_conv is not None:
295
- z = self.post_quant_conv(z)
296
-
297
- dec = self.decoder(z)
298
-
299
- if not return_dict:
300
- return (dec,)
301
-
302
- return DecoderOutput(sample=dec)
303
-
304
- @apply_forward_hook
305
- def decode(
306
- self, z: torch.FloatTensor, return_dict: bool = True, generator=None
307
- ) -> Union[DecoderOutput, torch.FloatTensor]:
308
- """
309
- Decode a batch of images.
310
-
311
- Args:
312
- z (`torch.Tensor`): Input batch of latent vectors.
313
- return_dict (`bool`, *optional*, defaults to `True`):
314
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
315
-
316
- Returns:
317
- [`~models.vae.DecoderOutput`] or `tuple`:
318
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
319
- returned.
320
-
321
- """
322
- if self.use_slicing and z.shape[0] > 1:
323
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
324
- decoded = torch.cat(decoded_slices)
325
- else:
326
- decoded = self._decode(z).sample
327
-
328
- if not return_dict:
329
- return (decoded,)
330
-
331
- return DecoderOutput(sample=decoded)
332
-
333
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
334
- blend_extent = min(a.shape[2], b.shape[2], blend_extent)
335
- for y in range(blend_extent):
336
- b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
337
- return b
338
-
339
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
340
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
341
- for x in range(blend_extent):
342
- b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
343
- return b
344
-
345
- def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
346
- r"""Encode a batch of images using a tiled encoder.
347
-
348
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
349
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
350
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
351
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
352
- output, but they should be much less noticeable.
353
-
354
- Args:
355
- x (`torch.Tensor`): Input batch of images.
356
-
357
- Returns:
358
- `torch.Tensor`:
359
- The latent representation of the encoded videos.
360
- """
361
-
362
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
363
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
364
- row_limit = self.tile_latent_min_size - blend_extent
365
-
366
- # Split the image into 512x512 tiles and encode them separately.
367
- rows = []
368
- for i in range(0, x.shape[2], overlap_size):
369
- row = []
370
- for j in range(0, x.shape[3], overlap_size):
371
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
372
- tile = self.encoder(tile)
373
- if self.config.use_quant_conv:
374
- tile = self.quant_conv(tile)
375
- row.append(tile)
376
- rows.append(row)
377
- result_rows = []
378
- for i, row in enumerate(rows):
379
- result_row = []
380
- for j, tile in enumerate(row):
381
- # blend the above tile and the left tile
382
- # to the current tile and add the current tile to the result row
383
- if i > 0:
384
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
385
- if j > 0:
386
- tile = self.blend_h(row[j - 1], tile, blend_extent)
387
- result_row.append(tile[:, :, :row_limit, :row_limit])
388
- result_rows.append(torch.cat(result_row, dim=3))
389
-
390
- enc = torch.cat(result_rows, dim=2)
391
- return enc
392
-
393
- def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
394
- r"""Encode a batch of images using a tiled encoder.
395
-
396
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
397
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
398
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
399
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
400
- output, but they should be much less noticeable.
401
-
402
- Args:
403
- x (`torch.Tensor`): Input batch of images.
404
- return_dict (`bool`, *optional*, defaults to `True`):
405
- Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
406
-
407
- Returns:
408
- [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
409
- If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
410
- `tuple` is returned.
411
- """
412
- deprecation_message = (
413
- "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
414
- "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
415
- "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
416
- )
417
- deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
418
-
419
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
420
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
421
- row_limit = self.tile_latent_min_size - blend_extent
422
-
423
- # Split the image into 512x512 tiles and encode them separately.
424
- rows = []
425
- for i in range(0, x.shape[2], overlap_size):
426
- row = []
427
- for j in range(0, x.shape[3], overlap_size):
428
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
429
- tile = self.encoder(tile)
430
- if self.config.use_quant_conv:
431
- tile = self.quant_conv(tile)
432
- row.append(tile)
433
- rows.append(row)
434
- result_rows = []
435
- for i, row in enumerate(rows):
436
- result_row = []
437
- for j, tile in enumerate(row):
438
- # blend the above tile and the left tile
439
- # to the current tile and add the current tile to the result row
440
- if i > 0:
441
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
442
- if j > 0:
443
- tile = self.blend_h(row[j - 1], tile, blend_extent)
444
- result_row.append(tile[:, :, :row_limit, :row_limit])
445
- result_rows.append(torch.cat(result_row, dim=3))
446
-
447
- moments = torch.cat(result_rows, dim=2)
448
- posterior = DiagonalGaussianDistribution(moments)
449
-
450
- if not return_dict:
451
- return (posterior,)
452
-
453
- return AutoencoderKLOutput(latent_dist=posterior)
454
-
455
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
456
- r"""
457
- Decode a batch of images using a tiled decoder.
458
-
459
- Args:
460
- z (`torch.Tensor`): Input batch of latent vectors.
461
- return_dict (`bool`, *optional*, defaults to `True`):
462
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
463
-
464
- Returns:
465
- [`~models.vae.DecoderOutput`] or `tuple`:
466
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
467
- returned.
468
- """
469
- overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
470
- blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
471
- row_limit = self.tile_sample_min_size - blend_extent
472
-
473
- # Split z into overlapping 64x64 tiles and decode them separately.
474
- # The tiles have an overlap to avoid seams between tiles.
475
- rows = []
476
- for i in range(0, z.shape[2], overlap_size):
477
- row = []
478
- for j in range(0, z.shape[3], overlap_size):
479
- tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
480
- if self.config.use_post_quant_conv:
481
- tile = self.post_quant_conv(tile)
482
- decoded = self.decoder(tile)
483
- row.append(decoded)
484
- rows.append(row)
485
- result_rows = []
486
- for i, row in enumerate(rows):
487
- result_row = []
488
- for j, tile in enumerate(row):
489
- # blend the above tile and the left tile
490
- # to the current tile and add the current tile to the result row
491
- if i > 0:
492
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
493
- if j > 0:
494
- tile = self.blend_h(row[j - 1], tile, blend_extent)
495
- result_row.append(tile[:, :, :row_limit, :row_limit])
496
- result_rows.append(torch.cat(result_row, dim=3))
497
-
498
- dec = torch.cat(result_rows, dim=2)
499
- if not return_dict:
500
- return (dec,)
501
-
502
- return DecoderOutput(sample=dec)
503
-
504
- def forward(
505
- self,
506
- sample: torch.Tensor,
507
- sample_posterior: bool = False,
508
- return_dict: bool = True,
509
- generator: Optional[torch.Generator] = None,
510
- ) -> Union[DecoderOutput, torch.Tensor]:
511
- r"""
512
- Args:
513
- sample (`torch.Tensor`): Input sample.
514
- sample_posterior (`bool`, *optional*, defaults to `False`):
515
- Whether to sample from the posterior.
516
- return_dict (`bool`, *optional*, defaults to `True`):
517
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
518
- """
519
- x = sample
520
- posterior = self.encode(x).latent_dist
521
- if sample_posterior:
522
- z = posterior.sample(generator=generator)
523
- else:
524
- z = posterior.mode()
525
- dec = self.decode(z).sample
526
-
527
- if not return_dict:
528
- return (dec,)
529
-
530
- return DecoderOutput(sample=dec)
531
-
532
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
533
- def fuse_qkv_projections(self):
534
- """
535
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
536
- are fused. For cross-attention modules, key and value projection matrices are fused.
537
-
538
- <Tip warning={true}>
539
-
540
- This API is 🧪 experimental.
541
-
542
- </Tip>
543
- """
544
- self.original_attn_processors = None
545
-
546
- for _, attn_processor in self.attn_processors.items():
547
- if "Added" in str(attn_processor.__class__.__name__):
548
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
549
-
550
- self.original_attn_processors = self.attn_processors
551
-
552
- for module in self.modules():
553
- if isinstance(module, Attention):
554
- module.fuse_projections(fuse=True)
555
-
556
- self.set_attn_processor(FusedAttnProcessor2_0())
557
-
558
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
559
- def unfuse_qkv_projections(self):
560
- """Disables the fused QKV projection if enabled.
561
-
562
- <Tip warning={true}>
563
-
564
- This API is 🧪 experimental.
565
-
566
- </Tip>
567
-
568
- """
569
- if self.original_attn_processors is not None:
570
- self.set_attn_processor(self.original_attn_processors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/autoencoder_kl_cogvideox.py DELETED
@@ -1,1374 +0,0 @@
1
- # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Optional, Tuple, Union
17
-
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
- import torch.nn.functional as F
22
-
23
- from ...configuration_utils import ConfigMixin, register_to_config
24
- from ...loaders.single_file_model import FromOriginalModelMixin
25
- from ...utils import logging
26
- from ...utils.accelerate_utils import apply_forward_hook
27
- from ..activations import get_activation
28
- from ..downsampling import CogVideoXDownsample3D
29
- from ..modeling_outputs import AutoencoderKLOutput
30
- from ..modeling_utils import ModelMixin
31
- from ..upsampling import CogVideoXUpsample3D
32
- from .vae import DecoderOutput, DiagonalGaussianDistribution
33
-
34
-
35
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
-
37
-
38
- class CogVideoXSafeConv3d(nn.Conv3d):
39
- r"""
40
- A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
41
- """
42
-
43
- def forward(self, input: torch.Tensor) -> torch.Tensor:
44
- memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
45
-
46
- # Set to 2GB, suitable for CuDNN
47
- if memory_count > 2:
48
- kernel_size = self.kernel_size[0]
49
- part_num = int(memory_count / 2) + 1
50
- input_chunks = torch.chunk(input, part_num, dim=2)
51
-
52
- if kernel_size > 1:
53
- input_chunks = [input_chunks[0]] + [
54
- torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
55
- for i in range(1, len(input_chunks))
56
- ]
57
-
58
- output_chunks = []
59
- for input_chunk in input_chunks:
60
- output_chunks.append(super().forward(input_chunk))
61
- output = torch.cat(output_chunks, dim=2)
62
- return output
63
- else:
64
- return super().forward(input)
65
-
66
-
67
- class CogVideoXCausalConv3d(nn.Module):
68
- r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
69
-
70
- Args:
71
- in_channels (`int`): Number of channels in the input tensor.
72
- out_channels (`int`): Number of output channels produced by the convolution.
73
- kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
74
- stride (`int`, defaults to `1`): Stride of the convolution.
75
- dilation (`int`, defaults to `1`): Dilation rate of the convolution.
76
- pad_mode (`str`, defaults to `"constant"`): Padding mode.
77
- """
78
-
79
- def __init__(
80
- self,
81
- in_channels: int,
82
- out_channels: int,
83
- kernel_size: Union[int, Tuple[int, int, int]],
84
- stride: int = 1,
85
- dilation: int = 1,
86
- pad_mode: str = "constant",
87
- ):
88
- super().__init__()
89
-
90
- if isinstance(kernel_size, int):
91
- kernel_size = (kernel_size,) * 3
92
-
93
- time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
94
-
95
- self.pad_mode = pad_mode
96
- time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
97
- height_pad = height_kernel_size // 2
98
- width_pad = width_kernel_size // 2
99
-
100
- self.height_pad = height_pad
101
- self.width_pad = width_pad
102
- self.time_pad = time_pad
103
- self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
104
-
105
- self.temporal_dim = 2
106
- self.time_kernel_size = time_kernel_size
107
-
108
- stride = (stride, 1, 1)
109
- dilation = (dilation, 1, 1)
110
- self.conv = CogVideoXSafeConv3d(
111
- in_channels=in_channels,
112
- out_channels=out_channels,
113
- kernel_size=kernel_size,
114
- stride=stride,
115
- dilation=dilation,
116
- )
117
-
118
- self.conv_cache = None
119
-
120
- def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
121
- kernel_size = self.time_kernel_size
122
- if kernel_size > 1:
123
- cached_inputs = (
124
- [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
125
- )
126
- inputs = torch.cat(cached_inputs + [inputs], dim=2)
127
- return inputs
128
-
129
- def _clear_fake_context_parallel_cache(self):
130
- del self.conv_cache
131
- self.conv_cache = None
132
-
133
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
134
- inputs = self.fake_context_parallel_forward(inputs)
135
-
136
- self._clear_fake_context_parallel_cache()
137
- # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
138
- # hundred megabytes and so let's not do it for now
139
- self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
140
-
141
- padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
142
- inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
143
-
144
- output = self.conv(inputs)
145
- return output
146
-
147
-
148
- class CogVideoXSpatialNorm3D(nn.Module):
149
- r"""
150
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
151
- to 3D-video like data.
152
-
153
- CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
154
-
155
- Args:
156
- f_channels (`int`):
157
- The number of channels for input to group normalization layer, and output of the spatial norm layer.
158
- zq_channels (`int`):
159
- The number of channels for the quantized vector as described in the paper.
160
- groups (`int`):
161
- Number of groups to separate the channels into for group normalization.
162
- """
163
-
164
- def __init__(
165
- self,
166
- f_channels: int,
167
- zq_channels: int,
168
- groups: int = 32,
169
- ):
170
- super().__init__()
171
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
172
- self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
173
- self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
174
-
175
- def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
176
- if f.shape[2] > 1 and f.shape[2] % 2 == 1:
177
- f_first, f_rest = f[:, :, :1], f[:, :, 1:]
178
- f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
179
- z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
180
- z_first = F.interpolate(z_first, size=f_first_size)
181
- z_rest = F.interpolate(z_rest, size=f_rest_size)
182
- zq = torch.cat([z_first, z_rest], dim=2)
183
- else:
184
- zq = F.interpolate(zq, size=f.shape[-3:])
185
-
186
- norm_f = self.norm_layer(f)
187
- new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
188
- return new_f
189
-
190
-
191
- class CogVideoXResnetBlock3D(nn.Module):
192
- r"""
193
- A 3D ResNet block used in the CogVideoX model.
194
-
195
- Args:
196
- in_channels (`int`):
197
- Number of input channels.
198
- out_channels (`int`, *optional*):
199
- Number of output channels. If None, defaults to `in_channels`.
200
- dropout (`float`, defaults to `0.0`):
201
- Dropout rate.
202
- temb_channels (`int`, defaults to `512`):
203
- Number of time embedding channels.
204
- groups (`int`, defaults to `32`):
205
- Number of groups to separate the channels into for group normalization.
206
- eps (`float`, defaults to `1e-6`):
207
- Epsilon value for normalization layers.
208
- non_linearity (`str`, defaults to `"swish"`):
209
- Activation function to use.
210
- conv_shortcut (bool, defaults to `False`):
211
- Whether or not to use a convolution shortcut.
212
- spatial_norm_dim (`int`, *optional*):
213
- The dimension to use for spatial norm if it is to be used instead of group norm.
214
- pad_mode (str, defaults to `"first"`):
215
- Padding mode.
216
- """
217
-
218
- def __init__(
219
- self,
220
- in_channels: int,
221
- out_channels: Optional[int] = None,
222
- dropout: float = 0.0,
223
- temb_channels: int = 512,
224
- groups: int = 32,
225
- eps: float = 1e-6,
226
- non_linearity: str = "swish",
227
- conv_shortcut: bool = False,
228
- spatial_norm_dim: Optional[int] = None,
229
- pad_mode: str = "first",
230
- ):
231
- super().__init__()
232
-
233
- out_channels = out_channels or in_channels
234
-
235
- self.in_channels = in_channels
236
- self.out_channels = out_channels
237
- self.nonlinearity = get_activation(non_linearity)
238
- self.use_conv_shortcut = conv_shortcut
239
-
240
- if spatial_norm_dim is None:
241
- self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
242
- self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
243
- else:
244
- self.norm1 = CogVideoXSpatialNorm3D(
245
- f_channels=in_channels,
246
- zq_channels=spatial_norm_dim,
247
- groups=groups,
248
- )
249
- self.norm2 = CogVideoXSpatialNorm3D(
250
- f_channels=out_channels,
251
- zq_channels=spatial_norm_dim,
252
- groups=groups,
253
- )
254
-
255
- self.conv1 = CogVideoXCausalConv3d(
256
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
257
- )
258
-
259
- if temb_channels > 0:
260
- self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
261
-
262
- self.dropout = nn.Dropout(dropout)
263
- self.conv2 = CogVideoXCausalConv3d(
264
- in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
265
- )
266
-
267
- if self.in_channels != self.out_channels:
268
- if self.use_conv_shortcut:
269
- self.conv_shortcut = CogVideoXCausalConv3d(
270
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
271
- )
272
- else:
273
- self.conv_shortcut = CogVideoXSafeConv3d(
274
- in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
275
- )
276
-
277
- def forward(
278
- self,
279
- inputs: torch.Tensor,
280
- temb: Optional[torch.Tensor] = None,
281
- zq: Optional[torch.Tensor] = None,
282
- ) -> torch.Tensor:
283
- hidden_states = inputs
284
-
285
- if zq is not None:
286
- hidden_states = self.norm1(hidden_states, zq)
287
- else:
288
- hidden_states = self.norm1(hidden_states)
289
-
290
- hidden_states = self.nonlinearity(hidden_states)
291
- hidden_states = self.conv1(hidden_states)
292
-
293
- if temb is not None:
294
- hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
295
-
296
- if zq is not None:
297
- hidden_states = self.norm2(hidden_states, zq)
298
- else:
299
- hidden_states = self.norm2(hidden_states)
300
-
301
- hidden_states = self.nonlinearity(hidden_states)
302
- hidden_states = self.dropout(hidden_states)
303
- hidden_states = self.conv2(hidden_states)
304
-
305
- if self.in_channels != self.out_channels:
306
- inputs = self.conv_shortcut(inputs)
307
-
308
- hidden_states = hidden_states + inputs
309
- return hidden_states
310
-
311
-
312
- class CogVideoXDownBlock3D(nn.Module):
313
- r"""
314
- A downsampling block used in the CogVideoX model.
315
-
316
- Args:
317
- in_channels (`int`):
318
- Number of input channels.
319
- out_channels (`int`, *optional*):
320
- Number of output channels. If None, defaults to `in_channels`.
321
- temb_channels (`int`, defaults to `512`):
322
- Number of time embedding channels.
323
- num_layers (`int`, defaults to `1`):
324
- Number of resnet layers.
325
- dropout (`float`, defaults to `0.0`):
326
- Dropout rate.
327
- resnet_eps (`float`, defaults to `1e-6`):
328
- Epsilon value for normalization layers.
329
- resnet_act_fn (`str`, defaults to `"swish"`):
330
- Activation function to use.
331
- resnet_groups (`int`, defaults to `32`):
332
- Number of groups to separate the channels into for group normalization.
333
- add_downsample (`bool`, defaults to `True`):
334
- Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
335
- compress_time (`bool`, defaults to `False`):
336
- Whether or not to downsample across temporal dimension.
337
- pad_mode (str, defaults to `"first"`):
338
- Padding mode.
339
- """
340
-
341
- _supports_gradient_checkpointing = True
342
-
343
- def __init__(
344
- self,
345
- in_channels: int,
346
- out_channels: int,
347
- temb_channels: int,
348
- dropout: float = 0.0,
349
- num_layers: int = 1,
350
- resnet_eps: float = 1e-6,
351
- resnet_act_fn: str = "swish",
352
- resnet_groups: int = 32,
353
- add_downsample: bool = True,
354
- downsample_padding: int = 0,
355
- compress_time: bool = False,
356
- pad_mode: str = "first",
357
- ):
358
- super().__init__()
359
-
360
- resnets = []
361
- for i in range(num_layers):
362
- in_channel = in_channels if i == 0 else out_channels
363
- resnets.append(
364
- CogVideoXResnetBlock3D(
365
- in_channels=in_channel,
366
- out_channels=out_channels,
367
- dropout=dropout,
368
- temb_channels=temb_channels,
369
- groups=resnet_groups,
370
- eps=resnet_eps,
371
- non_linearity=resnet_act_fn,
372
- pad_mode=pad_mode,
373
- )
374
- )
375
-
376
- self.resnets = nn.ModuleList(resnets)
377
- self.downsamplers = None
378
-
379
- if add_downsample:
380
- self.downsamplers = nn.ModuleList(
381
- [
382
- CogVideoXDownsample3D(
383
- out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
384
- )
385
- ]
386
- )
387
-
388
- self.gradient_checkpointing = False
389
-
390
- def forward(
391
- self,
392
- hidden_states: torch.Tensor,
393
- temb: Optional[torch.Tensor] = None,
394
- zq: Optional[torch.Tensor] = None,
395
- ) -> torch.Tensor:
396
- for resnet in self.resnets:
397
- if self.training and self.gradient_checkpointing:
398
-
399
- def create_custom_forward(module):
400
- def create_forward(*inputs):
401
- return module(*inputs)
402
-
403
- return create_forward
404
-
405
- hidden_states = torch.utils.checkpoint.checkpoint(
406
- create_custom_forward(resnet), hidden_states, temb, zq
407
- )
408
- else:
409
- hidden_states = resnet(hidden_states, temb, zq)
410
-
411
- if self.downsamplers is not None:
412
- for downsampler in self.downsamplers:
413
- hidden_states = downsampler(hidden_states)
414
-
415
- return hidden_states
416
-
417
-
418
- class CogVideoXMidBlock3D(nn.Module):
419
- r"""
420
- A middle block used in the CogVideoX model.
421
-
422
- Args:
423
- in_channels (`int`):
424
- Number of input channels.
425
- temb_channels (`int`, defaults to `512`):
426
- Number of time embedding channels.
427
- dropout (`float`, defaults to `0.0`):
428
- Dropout rate.
429
- num_layers (`int`, defaults to `1`):
430
- Number of resnet layers.
431
- resnet_eps (`float`, defaults to `1e-6`):
432
- Epsilon value for normalization layers.
433
- resnet_act_fn (`str`, defaults to `"swish"`):
434
- Activation function to use.
435
- resnet_groups (`int`, defaults to `32`):
436
- Number of groups to separate the channels into for group normalization.
437
- spatial_norm_dim (`int`, *optional*):
438
- The dimension to use for spatial norm if it is to be used instead of group norm.
439
- pad_mode (str, defaults to `"first"`):
440
- Padding mode.
441
- """
442
-
443
- _supports_gradient_checkpointing = True
444
-
445
- def __init__(
446
- self,
447
- in_channels: int,
448
- temb_channels: int,
449
- dropout: float = 0.0,
450
- num_layers: int = 1,
451
- resnet_eps: float = 1e-6,
452
- resnet_act_fn: str = "swish",
453
- resnet_groups: int = 32,
454
- spatial_norm_dim: Optional[int] = None,
455
- pad_mode: str = "first",
456
- ):
457
- super().__init__()
458
-
459
- resnets = []
460
- for _ in range(num_layers):
461
- resnets.append(
462
- CogVideoXResnetBlock3D(
463
- in_channels=in_channels,
464
- out_channels=in_channels,
465
- dropout=dropout,
466
- temb_channels=temb_channels,
467
- groups=resnet_groups,
468
- eps=resnet_eps,
469
- spatial_norm_dim=spatial_norm_dim,
470
- non_linearity=resnet_act_fn,
471
- pad_mode=pad_mode,
472
- )
473
- )
474
- self.resnets = nn.ModuleList(resnets)
475
-
476
- self.gradient_checkpointing = False
477
-
478
- def forward(
479
- self,
480
- hidden_states: torch.Tensor,
481
- temb: Optional[torch.Tensor] = None,
482
- zq: Optional[torch.Tensor] = None,
483
- ) -> torch.Tensor:
484
- for resnet in self.resnets:
485
- if self.training and self.gradient_checkpointing:
486
-
487
- def create_custom_forward(module):
488
- def create_forward(*inputs):
489
- return module(*inputs)
490
-
491
- return create_forward
492
-
493
- hidden_states = torch.utils.checkpoint.checkpoint(
494
- create_custom_forward(resnet), hidden_states, temb, zq
495
- )
496
- else:
497
- hidden_states = resnet(hidden_states, temb, zq)
498
-
499
- return hidden_states
500
-
501
-
502
- class CogVideoXUpBlock3D(nn.Module):
503
- r"""
504
- An upsampling block used in the CogVideoX model.
505
-
506
- Args:
507
- in_channels (`int`):
508
- Number of input channels.
509
- out_channels (`int`, *optional*):
510
- Number of output channels. If None, defaults to `in_channels`.
511
- temb_channels (`int`, defaults to `512`):
512
- Number of time embedding channels.
513
- dropout (`float`, defaults to `0.0`):
514
- Dropout rate.
515
- num_layers (`int`, defaults to `1`):
516
- Number of resnet layers.
517
- resnet_eps (`float`, defaults to `1e-6`):
518
- Epsilon value for normalization layers.
519
- resnet_act_fn (`str`, defaults to `"swish"`):
520
- Activation function to use.
521
- resnet_groups (`int`, defaults to `32`):
522
- Number of groups to separate the channels into for group normalization.
523
- spatial_norm_dim (`int`, defaults to `16`):
524
- The dimension to use for spatial norm if it is to be used instead of group norm.
525
- add_upsample (`bool`, defaults to `True`):
526
- Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
527
- compress_time (`bool`, defaults to `False`):
528
- Whether or not to downsample across temporal dimension.
529
- pad_mode (str, defaults to `"first"`):
530
- Padding mode.
531
- """
532
-
533
- def __init__(
534
- self,
535
- in_channels: int,
536
- out_channels: int,
537
- temb_channels: int,
538
- dropout: float = 0.0,
539
- num_layers: int = 1,
540
- resnet_eps: float = 1e-6,
541
- resnet_act_fn: str = "swish",
542
- resnet_groups: int = 32,
543
- spatial_norm_dim: int = 16,
544
- add_upsample: bool = True,
545
- upsample_padding: int = 1,
546
- compress_time: bool = False,
547
- pad_mode: str = "first",
548
- ):
549
- super().__init__()
550
-
551
- resnets = []
552
- for i in range(num_layers):
553
- in_channel = in_channels if i == 0 else out_channels
554
- resnets.append(
555
- CogVideoXResnetBlock3D(
556
- in_channels=in_channel,
557
- out_channels=out_channels,
558
- dropout=dropout,
559
- temb_channels=temb_channels,
560
- groups=resnet_groups,
561
- eps=resnet_eps,
562
- non_linearity=resnet_act_fn,
563
- spatial_norm_dim=spatial_norm_dim,
564
- pad_mode=pad_mode,
565
- )
566
- )
567
-
568
- self.resnets = nn.ModuleList(resnets)
569
- self.upsamplers = None
570
-
571
- if add_upsample:
572
- self.upsamplers = nn.ModuleList(
573
- [
574
- CogVideoXUpsample3D(
575
- out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
576
- )
577
- ]
578
- )
579
-
580
- self.gradient_checkpointing = False
581
-
582
- def forward(
583
- self,
584
- hidden_states: torch.Tensor,
585
- temb: Optional[torch.Tensor] = None,
586
- zq: Optional[torch.Tensor] = None,
587
- ) -> torch.Tensor:
588
- r"""Forward method of the `CogVideoXUpBlock3D` class."""
589
- for resnet in self.resnets:
590
- if self.training and self.gradient_checkpointing:
591
-
592
- def create_custom_forward(module):
593
- def create_forward(*inputs):
594
- return module(*inputs)
595
-
596
- return create_forward
597
-
598
- hidden_states = torch.utils.checkpoint.checkpoint(
599
- create_custom_forward(resnet), hidden_states, temb, zq
600
- )
601
- else:
602
- hidden_states = resnet(hidden_states, temb, zq)
603
-
604
- if self.upsamplers is not None:
605
- for upsampler in self.upsamplers:
606
- hidden_states = upsampler(hidden_states)
607
-
608
- return hidden_states
609
-
610
-
611
- class CogVideoXEncoder3D(nn.Module):
612
- r"""
613
- The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
614
-
615
- Args:
616
- in_channels (`int`, *optional*, defaults to 3):
617
- The number of input channels.
618
- out_channels (`int`, *optional*, defaults to 3):
619
- The number of output channels.
620
- down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
621
- The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
622
- options.
623
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
624
- The number of output channels for each block.
625
- act_fn (`str`, *optional*, defaults to `"silu"`):
626
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
627
- layers_per_block (`int`, *optional*, defaults to 2):
628
- The number of layers per block.
629
- norm_num_groups (`int`, *optional*, defaults to 32):
630
- The number of groups for normalization.
631
- """
632
-
633
- _supports_gradient_checkpointing = True
634
-
635
- def __init__(
636
- self,
637
- in_channels: int = 3,
638
- out_channels: int = 16,
639
- down_block_types: Tuple[str, ...] = (
640
- "CogVideoXDownBlock3D",
641
- "CogVideoXDownBlock3D",
642
- "CogVideoXDownBlock3D",
643
- "CogVideoXDownBlock3D",
644
- ),
645
- block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
646
- layers_per_block: int = 3,
647
- act_fn: str = "silu",
648
- norm_eps: float = 1e-6,
649
- norm_num_groups: int = 32,
650
- dropout: float = 0.0,
651
- pad_mode: str = "first",
652
- temporal_compression_ratio: float = 4,
653
- ):
654
- super().__init__()
655
-
656
- # log2 of temporal_compress_times
657
- temporal_compress_level = int(np.log2(temporal_compression_ratio))
658
-
659
- self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
660
- self.down_blocks = nn.ModuleList([])
661
-
662
- # down blocks
663
- output_channel = block_out_channels[0]
664
- for i, down_block_type in enumerate(down_block_types):
665
- input_channel = output_channel
666
- output_channel = block_out_channels[i]
667
- is_final_block = i == len(block_out_channels) - 1
668
- compress_time = i < temporal_compress_level
669
-
670
- if down_block_type == "CogVideoXDownBlock3D":
671
- down_block = CogVideoXDownBlock3D(
672
- in_channels=input_channel,
673
- out_channels=output_channel,
674
- temb_channels=0,
675
- dropout=dropout,
676
- num_layers=layers_per_block,
677
- resnet_eps=norm_eps,
678
- resnet_act_fn=act_fn,
679
- resnet_groups=norm_num_groups,
680
- add_downsample=not is_final_block,
681
- compress_time=compress_time,
682
- )
683
- else:
684
- raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
685
-
686
- self.down_blocks.append(down_block)
687
-
688
- # mid block
689
- self.mid_block = CogVideoXMidBlock3D(
690
- in_channels=block_out_channels[-1],
691
- temb_channels=0,
692
- dropout=dropout,
693
- num_layers=2,
694
- resnet_eps=norm_eps,
695
- resnet_act_fn=act_fn,
696
- resnet_groups=norm_num_groups,
697
- pad_mode=pad_mode,
698
- )
699
-
700
- self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
701
- self.conv_act = nn.SiLU()
702
- self.conv_out = CogVideoXCausalConv3d(
703
- block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
704
- )
705
-
706
- self.gradient_checkpointing = False
707
-
708
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
709
- r"""The forward method of the `CogVideoXEncoder3D` class."""
710
- hidden_states = self.conv_in(sample)
711
-
712
- if self.training and self.gradient_checkpointing:
713
-
714
- def create_custom_forward(module):
715
- def custom_forward(*inputs):
716
- return module(*inputs)
717
-
718
- return custom_forward
719
-
720
- # 1. Down
721
- for down_block in self.down_blocks:
722
- hidden_states = torch.utils.checkpoint.checkpoint(
723
- create_custom_forward(down_block), hidden_states, temb, None
724
- )
725
-
726
- # 2. Mid
727
- hidden_states = torch.utils.checkpoint.checkpoint(
728
- create_custom_forward(self.mid_block), hidden_states, temb, None
729
- )
730
- else:
731
- # 1. Down
732
- for down_block in self.down_blocks:
733
- hidden_states = down_block(hidden_states, temb, None)
734
-
735
- # 2. Mid
736
- hidden_states = self.mid_block(hidden_states, temb, None)
737
-
738
- # 3. Post-process
739
- hidden_states = self.norm_out(hidden_states)
740
- hidden_states = self.conv_act(hidden_states)
741
- hidden_states = self.conv_out(hidden_states)
742
- return hidden_states
743
-
744
-
745
- class CogVideoXDecoder3D(nn.Module):
746
- r"""
747
- The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
748
- sample.
749
-
750
- Args:
751
- in_channels (`int`, *optional*, defaults to 3):
752
- The number of input channels.
753
- out_channels (`int`, *optional*, defaults to 3):
754
- The number of output channels.
755
- up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
756
- The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
757
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
758
- The number of output channels for each block.
759
- act_fn (`str`, *optional*, defaults to `"silu"`):
760
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
761
- layers_per_block (`int`, *optional*, defaults to 2):
762
- The number of layers per block.
763
- norm_num_groups (`int`, *optional*, defaults to 32):
764
- The number of groups for normalization.
765
- """
766
-
767
- _supports_gradient_checkpointing = True
768
-
769
- def __init__(
770
- self,
771
- in_channels: int = 16,
772
- out_channels: int = 3,
773
- up_block_types: Tuple[str, ...] = (
774
- "CogVideoXUpBlock3D",
775
- "CogVideoXUpBlock3D",
776
- "CogVideoXUpBlock3D",
777
- "CogVideoXUpBlock3D",
778
- ),
779
- block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
780
- layers_per_block: int = 3,
781
- act_fn: str = "silu",
782
- norm_eps: float = 1e-6,
783
- norm_num_groups: int = 32,
784
- dropout: float = 0.0,
785
- pad_mode: str = "first",
786
- temporal_compression_ratio: float = 4,
787
- ):
788
- super().__init__()
789
-
790
- reversed_block_out_channels = list(reversed(block_out_channels))
791
-
792
- self.conv_in = CogVideoXCausalConv3d(
793
- in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
794
- )
795
-
796
- # mid block
797
- self.mid_block = CogVideoXMidBlock3D(
798
- in_channels=reversed_block_out_channels[0],
799
- temb_channels=0,
800
- num_layers=2,
801
- resnet_eps=norm_eps,
802
- resnet_act_fn=act_fn,
803
- resnet_groups=norm_num_groups,
804
- spatial_norm_dim=in_channels,
805
- pad_mode=pad_mode,
806
- )
807
-
808
- # up blocks
809
- self.up_blocks = nn.ModuleList([])
810
-
811
- output_channel = reversed_block_out_channels[0]
812
- temporal_compress_level = int(np.log2(temporal_compression_ratio))
813
-
814
- for i, up_block_type in enumerate(up_block_types):
815
- prev_output_channel = output_channel
816
- output_channel = reversed_block_out_channels[i]
817
- is_final_block = i == len(block_out_channels) - 1
818
- compress_time = i < temporal_compress_level
819
-
820
- if up_block_type == "CogVideoXUpBlock3D":
821
- up_block = CogVideoXUpBlock3D(
822
- in_channels=prev_output_channel,
823
- out_channels=output_channel,
824
- temb_channels=0,
825
- dropout=dropout,
826
- num_layers=layers_per_block + 1,
827
- resnet_eps=norm_eps,
828
- resnet_act_fn=act_fn,
829
- resnet_groups=norm_num_groups,
830
- spatial_norm_dim=in_channels,
831
- add_upsample=not is_final_block,
832
- compress_time=compress_time,
833
- pad_mode=pad_mode,
834
- )
835
- prev_output_channel = output_channel
836
- else:
837
- raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
838
-
839
- self.up_blocks.append(up_block)
840
-
841
- self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
842
- self.conv_act = nn.SiLU()
843
- self.conv_out = CogVideoXCausalConv3d(
844
- reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
845
- )
846
-
847
- self.gradient_checkpointing = False
848
-
849
- def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
850
- r"""The forward method of the `CogVideoXDecoder3D` class."""
851
- hidden_states = self.conv_in(sample)
852
-
853
- if self.training and self.gradient_checkpointing:
854
-
855
- def create_custom_forward(module):
856
- def custom_forward(*inputs):
857
- return module(*inputs)
858
-
859
- return custom_forward
860
-
861
- # 1. Mid
862
- hidden_states = torch.utils.checkpoint.checkpoint(
863
- create_custom_forward(self.mid_block), hidden_states, temb, sample
864
- )
865
-
866
- # 2. Up
867
- for up_block in self.up_blocks:
868
- hidden_states = torch.utils.checkpoint.checkpoint(
869
- create_custom_forward(up_block), hidden_states, temb, sample
870
- )
871
- else:
872
- # 1. Mid
873
- hidden_states = self.mid_block(hidden_states, temb, sample)
874
-
875
- # 2. Up
876
- for up_block in self.up_blocks:
877
- hidden_states = up_block(hidden_states, temb, sample)
878
-
879
- # 3. Post-process
880
- hidden_states = self.norm_out(hidden_states, sample)
881
- hidden_states = self.conv_act(hidden_states)
882
- hidden_states = self.conv_out(hidden_states)
883
- return hidden_states
884
-
885
-
886
- class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
887
- r"""
888
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
889
- [CogVideoX](https://github.com/THUDM/CogVideo).
890
-
891
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
892
- for all models (such as downloading or saving).
893
-
894
- Parameters:
895
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
896
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
897
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
898
- Tuple of downsample block types.
899
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
900
- Tuple of upsample block types.
901
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
902
- Tuple of block output channels.
903
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
904
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
905
- scaling_factor (`float`, *optional*, defaults to `1.15258426`):
906
- The component-wise standard deviation of the trained latent space computed using the first batch of the
907
- training set. This is used to scale the latent space to have unit variance when training the diffusion
908
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
909
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
910
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
911
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
912
- force_upcast (`bool`, *optional*, default to `True`):
913
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
914
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
915
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
916
- """
917
-
918
- _supports_gradient_checkpointing = True
919
- _no_split_modules = ["CogVideoXResnetBlock3D"]
920
-
921
- @register_to_config
922
- def __init__(
923
- self,
924
- in_channels: int = 3,
925
- out_channels: int = 3,
926
- down_block_types: Tuple[str] = (
927
- "CogVideoXDownBlock3D",
928
- "CogVideoXDownBlock3D",
929
- "CogVideoXDownBlock3D",
930
- "CogVideoXDownBlock3D",
931
- ),
932
- up_block_types: Tuple[str] = (
933
- "CogVideoXUpBlock3D",
934
- "CogVideoXUpBlock3D",
935
- "CogVideoXUpBlock3D",
936
- "CogVideoXUpBlock3D",
937
- ),
938
- block_out_channels: Tuple[int] = (128, 256, 256, 512),
939
- latent_channels: int = 16,
940
- layers_per_block: int = 3,
941
- act_fn: str = "silu",
942
- norm_eps: float = 1e-6,
943
- norm_num_groups: int = 32,
944
- temporal_compression_ratio: float = 4,
945
- sample_height: int = 480,
946
- sample_width: int = 720,
947
- scaling_factor: float = 1.15258426,
948
- shift_factor: Optional[float] = None,
949
- latents_mean: Optional[Tuple[float]] = None,
950
- latents_std: Optional[Tuple[float]] = None,
951
- force_upcast: float = True,
952
- use_quant_conv: bool = False,
953
- use_post_quant_conv: bool = False,
954
- ):
955
- super().__init__()
956
-
957
- self.encoder = CogVideoXEncoder3D(
958
- in_channels=in_channels,
959
- out_channels=latent_channels,
960
- down_block_types=down_block_types,
961
- block_out_channels=block_out_channels,
962
- layers_per_block=layers_per_block,
963
- act_fn=act_fn,
964
- norm_eps=norm_eps,
965
- norm_num_groups=norm_num_groups,
966
- temporal_compression_ratio=temporal_compression_ratio,
967
- )
968
- self.decoder = CogVideoXDecoder3D(
969
- in_channels=latent_channels,
970
- out_channels=out_channels,
971
- up_block_types=up_block_types,
972
- block_out_channels=block_out_channels,
973
- layers_per_block=layers_per_block,
974
- act_fn=act_fn,
975
- norm_eps=norm_eps,
976
- norm_num_groups=norm_num_groups,
977
- temporal_compression_ratio=temporal_compression_ratio,
978
- )
979
- self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
980
- self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
981
-
982
- self.use_slicing = False
983
- self.use_tiling = False
984
-
985
- # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
986
- # recommended because the temporal parts of the VAE, here, are tricky to understand.
987
- # If you decode X latent frames together, the number of output frames is:
988
- # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
989
- #
990
- # Example with num_latent_frames_batch_size = 2:
991
- # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
992
- # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
993
- # => 6 * 8 = 48 frames
994
- # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
995
- # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
996
- # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
997
- # => 1 * 9 + 5 * 8 = 49 frames
998
- # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
999
- # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1000
- # number of temporal frames.
1001
- self.num_latent_frames_batch_size = 2
1002
- self.num_sample_frames_batch_size = 8
1003
-
1004
- # We make the minimum height and width of sample for tiling half that of the generally supported
1005
- self.tile_sample_min_height = sample_height // 2
1006
- self.tile_sample_min_width = sample_width // 2
1007
- self.tile_latent_min_height = int(
1008
- self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1009
- )
1010
- self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1011
-
1012
- # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1013
- # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1014
- # and so the tiling implementation has only been tested on those specific resolutions.
1015
- self.tile_overlap_factor_height = 1 / 6
1016
- self.tile_overlap_factor_width = 1 / 5
1017
-
1018
- def _set_gradient_checkpointing(self, module, value=False):
1019
- if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1020
- module.gradient_checkpointing = value
1021
-
1022
- def _clear_fake_context_parallel_cache(self):
1023
- for name, module in self.named_modules():
1024
- if isinstance(module, CogVideoXCausalConv3d):
1025
- logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
1026
- module._clear_fake_context_parallel_cache()
1027
-
1028
- def enable_tiling(
1029
- self,
1030
- tile_sample_min_height: Optional[int] = None,
1031
- tile_sample_min_width: Optional[int] = None,
1032
- tile_overlap_factor_height: Optional[float] = None,
1033
- tile_overlap_factor_width: Optional[float] = None,
1034
- ) -> None:
1035
- r"""
1036
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1037
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1038
- processing larger images.
1039
-
1040
- Args:
1041
- tile_sample_min_height (`int`, *optional*):
1042
- The minimum height required for a sample to be separated into tiles across the height dimension.
1043
- tile_sample_min_width (`int`, *optional*):
1044
- The minimum width required for a sample to be separated into tiles across the width dimension.
1045
- tile_overlap_factor_height (`int`, *optional*):
1046
- The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1047
- no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1048
- value might cause more tiles to be processed leading to slow down of the decoding process.
1049
- tile_overlap_factor_width (`int`, *optional*):
1050
- The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1051
- are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1052
- value might cause more tiles to be processed leading to slow down of the decoding process.
1053
- """
1054
- self.use_tiling = True
1055
- self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1056
- self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1057
- self.tile_latent_min_height = int(
1058
- self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1059
- )
1060
- self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1061
- self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1062
- self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1063
-
1064
- def disable_tiling(self) -> None:
1065
- r"""
1066
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1067
- decoding in one step.
1068
- """
1069
- self.use_tiling = False
1070
-
1071
- def enable_slicing(self) -> None:
1072
- r"""
1073
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1074
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1075
- """
1076
- self.use_slicing = True
1077
-
1078
- def disable_slicing(self) -> None:
1079
- r"""
1080
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1081
- decoding in one step.
1082
- """
1083
- self.use_slicing = False
1084
-
1085
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
1086
- batch_size, num_channels, num_frames, height, width = x.shape
1087
-
1088
- if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1089
- return self.tiled_encode(x)
1090
-
1091
- frame_batch_size = self.num_sample_frames_batch_size
1092
- # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1093
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1094
- enc = []
1095
- for i in range(num_batches):
1096
- remaining_frames = num_frames % frame_batch_size
1097
- start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1098
- end_frame = frame_batch_size * (i + 1) + remaining_frames
1099
- x_intermediate = x[:, :, start_frame:end_frame]
1100
- x_intermediate = self.encoder(x_intermediate)
1101
- if self.quant_conv is not None:
1102
- x_intermediate = self.quant_conv(x_intermediate)
1103
- enc.append(x_intermediate)
1104
-
1105
- self._clear_fake_context_parallel_cache()
1106
- enc = torch.cat(enc, dim=2)
1107
-
1108
- return enc
1109
-
1110
- @apply_forward_hook
1111
- def encode(
1112
- self, x: torch.Tensor, return_dict: bool = True
1113
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1114
- """
1115
- Encode a batch of images into latents.
1116
-
1117
- Args:
1118
- x (`torch.Tensor`): Input batch of images.
1119
- return_dict (`bool`, *optional*, defaults to `True`):
1120
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1121
-
1122
- Returns:
1123
- The latent representations of the encoded videos. If `return_dict` is True, a
1124
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1125
- """
1126
- if self.use_slicing and x.shape[0] > 1:
1127
- encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1128
- h = torch.cat(encoded_slices)
1129
- else:
1130
- h = self._encode(x)
1131
-
1132
- posterior = DiagonalGaussianDistribution(h)
1133
-
1134
- if not return_dict:
1135
- return (posterior,)
1136
- return AutoencoderKLOutput(latent_dist=posterior)
1137
-
1138
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1139
- batch_size, num_channels, num_frames, height, width = z.shape
1140
-
1141
- if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1142
- return self.tiled_decode(z, return_dict=return_dict)
1143
-
1144
- frame_batch_size = self.num_latent_frames_batch_size
1145
- num_batches = num_frames // frame_batch_size
1146
- dec = []
1147
- for i in range(num_batches):
1148
- remaining_frames = num_frames % frame_batch_size
1149
- start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1150
- end_frame = frame_batch_size * (i + 1) + remaining_frames
1151
- z_intermediate = z[:, :, start_frame:end_frame]
1152
- if self.post_quant_conv is not None:
1153
- z_intermediate = self.post_quant_conv(z_intermediate)
1154
- z_intermediate = self.decoder(z_intermediate)
1155
- dec.append(z_intermediate)
1156
-
1157
- self._clear_fake_context_parallel_cache()
1158
- dec = torch.cat(dec, dim=2)
1159
-
1160
- if not return_dict:
1161
- return (dec,)
1162
-
1163
- return DecoderOutput(sample=dec)
1164
-
1165
- @apply_forward_hook
1166
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1167
- """
1168
- Decode a batch of images.
1169
-
1170
- Args:
1171
- z (`torch.Tensor`): Input batch of latent vectors.
1172
- return_dict (`bool`, *optional*, defaults to `True`):
1173
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1174
-
1175
- Returns:
1176
- [`~models.vae.DecoderOutput`] or `tuple`:
1177
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1178
- returned.
1179
- """
1180
- if self.use_slicing and z.shape[0] > 1:
1181
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1182
- decoded = torch.cat(decoded_slices)
1183
- else:
1184
- decoded = self._decode(z).sample
1185
-
1186
- if not return_dict:
1187
- return (decoded,)
1188
- return DecoderOutput(sample=decoded)
1189
-
1190
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1191
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1192
- for y in range(blend_extent):
1193
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1194
- y / blend_extent
1195
- )
1196
- return b
1197
-
1198
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1199
- blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1200
- for x in range(blend_extent):
1201
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1202
- x / blend_extent
1203
- )
1204
- return b
1205
-
1206
- def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1207
- r"""Encode a batch of images using a tiled encoder.
1208
-
1209
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1210
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1211
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1212
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1213
- output, but they should be much less noticeable.
1214
-
1215
- Args:
1216
- x (`torch.Tensor`): Input batch of videos.
1217
-
1218
- Returns:
1219
- `torch.Tensor`:
1220
- The latent representation of the encoded videos.
1221
- """
1222
- # For a rough memory estimate, take a look at the `tiled_decode` method.
1223
- batch_size, num_channels, num_frames, height, width = x.shape
1224
-
1225
- overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1226
- overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1227
- blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1228
- blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1229
- row_limit_height = self.tile_latent_min_height - blend_extent_height
1230
- row_limit_width = self.tile_latent_min_width - blend_extent_width
1231
- frame_batch_size = self.num_sample_frames_batch_size
1232
-
1233
- # Split x into overlapping tiles and encode them separately.
1234
- # The tiles have an overlap to avoid seams between tiles.
1235
- rows = []
1236
- for i in range(0, height, overlap_height):
1237
- row = []
1238
- for j in range(0, width, overlap_width):
1239
- # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1240
- num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
1241
- time = []
1242
- for k in range(num_batches):
1243
- remaining_frames = num_frames % frame_batch_size
1244
- start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1245
- end_frame = frame_batch_size * (k + 1) + remaining_frames
1246
- tile = x[
1247
- :,
1248
- :,
1249
- start_frame:end_frame,
1250
- i : i + self.tile_sample_min_height,
1251
- j : j + self.tile_sample_min_width,
1252
- ]
1253
- tile = self.encoder(tile)
1254
- if self.quant_conv is not None:
1255
- tile = self.quant_conv(tile)
1256
- time.append(tile)
1257
- self._clear_fake_context_parallel_cache()
1258
- row.append(torch.cat(time, dim=2))
1259
- rows.append(row)
1260
-
1261
- result_rows = []
1262
- for i, row in enumerate(rows):
1263
- result_row = []
1264
- for j, tile in enumerate(row):
1265
- # blend the above tile and the left tile
1266
- # to the current tile and add the current tile to the result row
1267
- if i > 0:
1268
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1269
- if j > 0:
1270
- tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1271
- result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1272
- result_rows.append(torch.cat(result_row, dim=4))
1273
-
1274
- enc = torch.cat(result_rows, dim=3)
1275
- return enc
1276
-
1277
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1278
- r"""
1279
- Decode a batch of images using a tiled decoder.
1280
-
1281
- Args:
1282
- z (`torch.Tensor`): Input batch of latent vectors.
1283
- return_dict (`bool`, *optional*, defaults to `True`):
1284
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1285
-
1286
- Returns:
1287
- [`~models.vae.DecoderOutput`] or `tuple`:
1288
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1289
- returned.
1290
- """
1291
- # Rough memory assessment:
1292
- # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1293
- # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1294
- # - Assume fp16 (2 bytes per value).
1295
- # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1296
- #
1297
- # Memory assessment when using tiling:
1298
- # - Assume everything as above but now HxW is 240x360 by tiling in half
1299
- # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1300
-
1301
- batch_size, num_channels, num_frames, height, width = z.shape
1302
-
1303
- overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1304
- overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1305
- blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1306
- blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1307
- row_limit_height = self.tile_sample_min_height - blend_extent_height
1308
- row_limit_width = self.tile_sample_min_width - blend_extent_width
1309
- frame_batch_size = self.num_latent_frames_batch_size
1310
-
1311
- # Split z into overlapping tiles and decode them separately.
1312
- # The tiles have an overlap to avoid seams between tiles.
1313
- rows = []
1314
- for i in range(0, height, overlap_height):
1315
- row = []
1316
- for j in range(0, width, overlap_width):
1317
- num_batches = num_frames // frame_batch_size
1318
- time = []
1319
- for k in range(num_batches):
1320
- remaining_frames = num_frames % frame_batch_size
1321
- start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1322
- end_frame = frame_batch_size * (k + 1) + remaining_frames
1323
- tile = z[
1324
- :,
1325
- :,
1326
- start_frame:end_frame,
1327
- i : i + self.tile_latent_min_height,
1328
- j : j + self.tile_latent_min_width,
1329
- ]
1330
- if self.post_quant_conv is not None:
1331
- tile = self.post_quant_conv(tile)
1332
- tile = self.decoder(tile)
1333
- time.append(tile)
1334
- self._clear_fake_context_parallel_cache()
1335
- row.append(torch.cat(time, dim=2))
1336
- rows.append(row)
1337
-
1338
- result_rows = []
1339
- for i, row in enumerate(rows):
1340
- result_row = []
1341
- for j, tile in enumerate(row):
1342
- # blend the above tile and the left tile
1343
- # to the current tile and add the current tile to the result row
1344
- if i > 0:
1345
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1346
- if j > 0:
1347
- tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1348
- result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1349
- result_rows.append(torch.cat(result_row, dim=4))
1350
-
1351
- dec = torch.cat(result_rows, dim=3)
1352
-
1353
- if not return_dict:
1354
- return (dec,)
1355
-
1356
- return DecoderOutput(sample=dec)
1357
-
1358
- def forward(
1359
- self,
1360
- sample: torch.Tensor,
1361
- sample_posterior: bool = False,
1362
- return_dict: bool = True,
1363
- generator: Optional[torch.Generator] = None,
1364
- ) -> Union[torch.Tensor, torch.Tensor]:
1365
- x = sample
1366
- posterior = self.encode(x).latent_dist
1367
- if sample_posterior:
1368
- z = posterior.sample(generator=generator)
1369
- else:
1370
- z = posterior.mode()
1371
- dec = self.decode(z)
1372
- if not return_dict:
1373
- return (dec,)
1374
- return dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/autoencoder_kl_temporal_decoder.py DELETED
@@ -1,401 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Dict, Optional, Tuple, Union
15
-
16
- import torch
17
- import torch.nn as nn
18
-
19
- from ...configuration_utils import ConfigMixin, register_to_config
20
- from ...utils import is_torch_version
21
- from ...utils.accelerate_utils import apply_forward_hook
22
- from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
23
- from ..modeling_outputs import AutoencoderKLOutput
24
- from ..modeling_utils import ModelMixin
25
- from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
26
- from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
27
-
28
-
29
- class TemporalDecoder(nn.Module):
30
- def __init__(
31
- self,
32
- in_channels: int = 4,
33
- out_channels: int = 3,
34
- block_out_channels: Tuple[int] = (128, 256, 512, 512),
35
- layers_per_block: int = 2,
36
- ):
37
- super().__init__()
38
- self.layers_per_block = layers_per_block
39
-
40
- self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
41
- self.mid_block = MidBlockTemporalDecoder(
42
- num_layers=self.layers_per_block,
43
- in_channels=block_out_channels[-1],
44
- out_channels=block_out_channels[-1],
45
- attention_head_dim=block_out_channels[-1],
46
- )
47
-
48
- # up
49
- self.up_blocks = nn.ModuleList([])
50
- reversed_block_out_channels = list(reversed(block_out_channels))
51
- output_channel = reversed_block_out_channels[0]
52
- for i in range(len(block_out_channels)):
53
- prev_output_channel = output_channel
54
- output_channel = reversed_block_out_channels[i]
55
-
56
- is_final_block = i == len(block_out_channels) - 1
57
- up_block = UpBlockTemporalDecoder(
58
- num_layers=self.layers_per_block + 1,
59
- in_channels=prev_output_channel,
60
- out_channels=output_channel,
61
- add_upsample=not is_final_block,
62
- )
63
- self.up_blocks.append(up_block)
64
- prev_output_channel = output_channel
65
-
66
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-6)
67
-
68
- self.conv_act = nn.SiLU()
69
- self.conv_out = torch.nn.Conv2d(
70
- in_channels=block_out_channels[0],
71
- out_channels=out_channels,
72
- kernel_size=3,
73
- padding=1,
74
- )
75
-
76
- conv_out_kernel_size = (3, 1, 1)
77
- padding = [int(k // 2) for k in conv_out_kernel_size]
78
- self.time_conv_out = torch.nn.Conv3d(
79
- in_channels=out_channels,
80
- out_channels=out_channels,
81
- kernel_size=conv_out_kernel_size,
82
- padding=padding,
83
- )
84
-
85
- self.gradient_checkpointing = False
86
-
87
- def forward(
88
- self,
89
- sample: torch.Tensor,
90
- image_only_indicator: torch.Tensor,
91
- num_frames: int = 1,
92
- ) -> torch.Tensor:
93
- r"""The forward method of the `Decoder` class."""
94
-
95
- sample = self.conv_in(sample)
96
-
97
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
98
- if self.training and self.gradient_checkpointing:
99
-
100
- def create_custom_forward(module):
101
- def custom_forward(*inputs):
102
- return module(*inputs)
103
-
104
- return custom_forward
105
-
106
- if is_torch_version(">=", "1.11.0"):
107
- # middle
108
- sample = torch.utils.checkpoint.checkpoint(
109
- create_custom_forward(self.mid_block),
110
- sample,
111
- image_only_indicator,
112
- use_reentrant=False,
113
- )
114
- sample = sample.to(upscale_dtype)
115
-
116
- # up
117
- for up_block in self.up_blocks:
118
- sample = torch.utils.checkpoint.checkpoint(
119
- create_custom_forward(up_block),
120
- sample,
121
- image_only_indicator,
122
- use_reentrant=False,
123
- )
124
- else:
125
- # middle
126
- sample = torch.utils.checkpoint.checkpoint(
127
- create_custom_forward(self.mid_block),
128
- sample,
129
- image_only_indicator,
130
- )
131
- sample = sample.to(upscale_dtype)
132
-
133
- # up
134
- for up_block in self.up_blocks:
135
- sample = torch.utils.checkpoint.checkpoint(
136
- create_custom_forward(up_block),
137
- sample,
138
- image_only_indicator,
139
- )
140
- else:
141
- # middle
142
- sample = self.mid_block(sample, image_only_indicator=image_only_indicator)
143
- sample = sample.to(upscale_dtype)
144
-
145
- # up
146
- for up_block in self.up_blocks:
147
- sample = up_block(sample, image_only_indicator=image_only_indicator)
148
-
149
- # post-process
150
- sample = self.conv_norm_out(sample)
151
- sample = self.conv_act(sample)
152
- sample = self.conv_out(sample)
153
-
154
- batch_frames, channels, height, width = sample.shape
155
- batch_size = batch_frames // num_frames
156
- sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
157
- sample = self.time_conv_out(sample)
158
-
159
- sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
160
-
161
- return sample
162
-
163
-
164
- class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
165
- r"""
166
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
167
-
168
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
169
- for all models (such as downloading or saving).
170
-
171
- Parameters:
172
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
173
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
174
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
175
- Tuple of downsample block types.
176
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
177
- Tuple of block output channels.
178
- layers_per_block: (`int`, *optional*, defaults to 1): Number of layers per block.
179
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
180
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
181
- scaling_factor (`float`, *optional*, defaults to 0.18215):
182
- The component-wise standard deviation of the trained latent space computed using the first batch of the
183
- training set. This is used to scale the latent space to have unit variance when training the diffusion
184
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
185
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
186
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
187
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
188
- force_upcast (`bool`, *optional*, default to `True`):
189
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
190
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
191
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
192
- """
193
-
194
- _supports_gradient_checkpointing = True
195
-
196
- @register_to_config
197
- def __init__(
198
- self,
199
- in_channels: int = 3,
200
- out_channels: int = 3,
201
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
202
- block_out_channels: Tuple[int] = (64,),
203
- layers_per_block: int = 1,
204
- latent_channels: int = 4,
205
- sample_size: int = 32,
206
- scaling_factor: float = 0.18215,
207
- force_upcast: float = True,
208
- ):
209
- super().__init__()
210
-
211
- # pass init params to Encoder
212
- self.encoder = Encoder(
213
- in_channels=in_channels,
214
- out_channels=latent_channels,
215
- down_block_types=down_block_types,
216
- block_out_channels=block_out_channels,
217
- layers_per_block=layers_per_block,
218
- double_z=True,
219
- )
220
-
221
- # pass init params to Decoder
222
- self.decoder = TemporalDecoder(
223
- in_channels=latent_channels,
224
- out_channels=out_channels,
225
- block_out_channels=block_out_channels,
226
- layers_per_block=layers_per_block,
227
- )
228
-
229
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
230
-
231
- sample_size = (
232
- self.config.sample_size[0]
233
- if isinstance(self.config.sample_size, (list, tuple))
234
- else self.config.sample_size
235
- )
236
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
237
- self.tile_overlap_factor = 0.25
238
-
239
- def _set_gradient_checkpointing(self, module, value=False):
240
- if isinstance(module, (Encoder, TemporalDecoder)):
241
- module.gradient_checkpointing = value
242
-
243
- @property
244
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
245
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
246
- r"""
247
- Returns:
248
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
249
- indexed by its weight name.
250
- """
251
- # set recursively
252
- processors = {}
253
-
254
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
255
- if hasattr(module, "get_processor"):
256
- processors[f"{name}.processor"] = module.get_processor()
257
-
258
- for sub_name, child in module.named_children():
259
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
260
-
261
- return processors
262
-
263
- for name, module in self.named_children():
264
- fn_recursive_add_processors(name, module, processors)
265
-
266
- return processors
267
-
268
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
269
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
270
- r"""
271
- Sets the attention processor to use to compute attention.
272
-
273
- Parameters:
274
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
275
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
276
- for **all** `Attention` layers.
277
-
278
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
279
- processor. This is strongly recommended when setting trainable attention processors.
280
-
281
- """
282
- count = len(self.attn_processors.keys())
283
-
284
- if isinstance(processor, dict) and len(processor) != count:
285
- raise ValueError(
286
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
287
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
288
- )
289
-
290
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
291
- if hasattr(module, "set_processor"):
292
- if not isinstance(processor, dict):
293
- module.set_processor(processor)
294
- else:
295
- module.set_processor(processor.pop(f"{name}.processor"))
296
-
297
- for sub_name, child in module.named_children():
298
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
299
-
300
- for name, module in self.named_children():
301
- fn_recursive_attn_processor(name, module, processor)
302
-
303
- def set_default_attn_processor(self):
304
- """
305
- Disables custom attention processors and sets the default attention implementation.
306
- """
307
- if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
308
- processor = AttnProcessor()
309
- else:
310
- raise ValueError(
311
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
312
- )
313
-
314
- self.set_attn_processor(processor)
315
-
316
- @apply_forward_hook
317
- def encode(
318
- self, x: torch.Tensor, return_dict: bool = True
319
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
320
- """
321
- Encode a batch of images into latents.
322
-
323
- Args:
324
- x (`torch.Tensor`): Input batch of images.
325
- return_dict (`bool`, *optional*, defaults to `True`):
326
- Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain
327
- tuple.
328
-
329
- Returns:
330
- The latent representations of the encoded images. If `return_dict` is True, a
331
- [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is
332
- returned.
333
- """
334
- h = self.encoder(x)
335
- moments = self.quant_conv(h)
336
- posterior = DiagonalGaussianDistribution(moments)
337
-
338
- if not return_dict:
339
- return (posterior,)
340
-
341
- return AutoencoderKLOutput(latent_dist=posterior)
342
-
343
- @apply_forward_hook
344
- def decode(
345
- self,
346
- z: torch.Tensor,
347
- num_frames: int,
348
- return_dict: bool = True,
349
- ) -> Union[DecoderOutput, torch.Tensor]:
350
- """
351
- Decode a batch of images.
352
-
353
- Args:
354
- z (`torch.Tensor`): Input batch of latent vectors.
355
- return_dict (`bool`, *optional*, defaults to `True`):
356
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
357
-
358
- Returns:
359
- [`~models.vae.DecoderOutput`] or `tuple`:
360
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
361
- returned.
362
-
363
- """
364
- batch_size = z.shape[0] // num_frames
365
- image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
366
- decoded = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
367
-
368
- if not return_dict:
369
- return (decoded,)
370
-
371
- return DecoderOutput(sample=decoded)
372
-
373
- def forward(
374
- self,
375
- sample: torch.Tensor,
376
- sample_posterior: bool = False,
377
- return_dict: bool = True,
378
- generator: Optional[torch.Generator] = None,
379
- num_frames: int = 1,
380
- ) -> Union[DecoderOutput, torch.Tensor]:
381
- r"""
382
- Args:
383
- sample (`torch.Tensor`): Input sample.
384
- sample_posterior (`bool`, *optional*, defaults to `False`):
385
- Whether to sample from the posterior.
386
- return_dict (`bool`, *optional*, defaults to `True`):
387
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
388
- """
389
- x = sample
390
- posterior = self.encode(x).latent_dist
391
- if sample_posterior:
392
- z = posterior.sample(generator=generator)
393
- else:
394
- z = posterior.mode()
395
-
396
- dec = self.decode(z, num_frames=num_frames).sample
397
-
398
- if not return_dict:
399
- return (dec,)
400
-
401
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/autoencoder_oobleck.py DELETED
@@ -1,464 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
- from dataclasses import dataclass
16
- from typing import Optional, Tuple, Union
17
-
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
- from torch.nn.utils import weight_norm
22
-
23
- from ...configuration_utils import ConfigMixin, register_to_config
24
- from ...utils import BaseOutput
25
- from ...utils.accelerate_utils import apply_forward_hook
26
- from ...utils.torch_utils import randn_tensor
27
- from ..modeling_utils import ModelMixin
28
-
29
-
30
- class Snake1d(nn.Module):
31
- """
32
- A 1-dimensional Snake activation function module.
33
- """
34
-
35
- def __init__(self, hidden_dim, logscale=True):
36
- super().__init__()
37
- self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
38
- self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
39
-
40
- self.alpha.requires_grad = True
41
- self.beta.requires_grad = True
42
- self.logscale = logscale
43
-
44
- def forward(self, hidden_states):
45
- shape = hidden_states.shape
46
-
47
- alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
48
- beta = self.beta if not self.logscale else torch.exp(self.beta)
49
-
50
- hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
51
- hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
52
- hidden_states = hidden_states.reshape(shape)
53
- return hidden_states
54
-
55
-
56
- class OobleckResidualUnit(nn.Module):
57
- """
58
- A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
59
- """
60
-
61
- def __init__(self, dimension: int = 16, dilation: int = 1):
62
- super().__init__()
63
- pad = ((7 - 1) * dilation) // 2
64
-
65
- self.snake1 = Snake1d(dimension)
66
- self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
67
- self.snake2 = Snake1d(dimension)
68
- self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))
69
-
70
- def forward(self, hidden_state):
71
- """
72
- Forward pass through the residual unit.
73
-
74
- Args:
75
- hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
76
- Input tensor .
77
-
78
- Returns:
79
- output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`)
80
- Input tensor after passing through the residual unit.
81
- """
82
- output_tensor = hidden_state
83
- output_tensor = self.conv1(self.snake1(output_tensor))
84
- output_tensor = self.conv2(self.snake2(output_tensor))
85
-
86
- padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
87
- if padding > 0:
88
- hidden_state = hidden_state[..., padding:-padding]
89
- output_tensor = hidden_state + output_tensor
90
- return output_tensor
91
-
92
-
93
- class OobleckEncoderBlock(nn.Module):
94
- """Encoder block used in Oobleck encoder."""
95
-
96
- def __init__(self, input_dim, output_dim, stride: int = 1):
97
- super().__init__()
98
-
99
- self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1)
100
- self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3)
101
- self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9)
102
- self.snake1 = Snake1d(input_dim)
103
- self.conv1 = weight_norm(
104
- nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2))
105
- )
106
-
107
- def forward(self, hidden_state):
108
- hidden_state = self.res_unit1(hidden_state)
109
- hidden_state = self.res_unit2(hidden_state)
110
- hidden_state = self.snake1(self.res_unit3(hidden_state))
111
- hidden_state = self.conv1(hidden_state)
112
-
113
- return hidden_state
114
-
115
-
116
- class OobleckDecoderBlock(nn.Module):
117
- """Decoder block used in Oobleck decoder."""
118
-
119
- def __init__(self, input_dim, output_dim, stride: int = 1):
120
- super().__init__()
121
-
122
- self.snake1 = Snake1d(input_dim)
123
- self.conv_t1 = weight_norm(
124
- nn.ConvTranspose1d(
125
- input_dim,
126
- output_dim,
127
- kernel_size=2 * stride,
128
- stride=stride,
129
- padding=math.ceil(stride / 2),
130
- )
131
- )
132
- self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1)
133
- self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3)
134
- self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9)
135
-
136
- def forward(self, hidden_state):
137
- hidden_state = self.snake1(hidden_state)
138
- hidden_state = self.conv_t1(hidden_state)
139
- hidden_state = self.res_unit1(hidden_state)
140
- hidden_state = self.res_unit2(hidden_state)
141
- hidden_state = self.res_unit3(hidden_state)
142
-
143
- return hidden_state
144
-
145
-
146
- class OobleckDiagonalGaussianDistribution(object):
147
- def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
148
- self.parameters = parameters
149
- self.mean, self.scale = parameters.chunk(2, dim=1)
150
- self.std = nn.functional.softplus(self.scale) + 1e-4
151
- self.var = self.std * self.std
152
- self.logvar = torch.log(self.var)
153
- self.deterministic = deterministic
154
-
155
- def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
156
- # make sure sample is on the same device as the parameters and has same dtype
157
- sample = randn_tensor(
158
- self.mean.shape,
159
- generator=generator,
160
- device=self.parameters.device,
161
- dtype=self.parameters.dtype,
162
- )
163
- x = self.mean + self.std * sample
164
- return x
165
-
166
- def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor:
167
- if self.deterministic:
168
- return torch.Tensor([0.0])
169
- else:
170
- if other is None:
171
- return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean()
172
- else:
173
- normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var
174
- var_ratio = self.var / other.var
175
- logvar_diff = self.logvar - other.logvar
176
-
177
- kl = normalized_diff + var_ratio + logvar_diff - 1
178
-
179
- kl = kl.sum(1).mean()
180
- return kl
181
-
182
- def mode(self) -> torch.Tensor:
183
- return self.mean
184
-
185
-
186
- @dataclass
187
- class AutoencoderOobleckOutput(BaseOutput):
188
- """
189
- Output of AutoencoderOobleck encoding method.
190
-
191
- Args:
192
- latent_dist (`OobleckDiagonalGaussianDistribution`):
193
- Encoded outputs of `Encoder` represented as the mean and standard deviation of
194
- `OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents
195
- from the distribution.
196
- """
197
-
198
- latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821
199
-
200
-
201
- @dataclass
202
- class OobleckDecoderOutput(BaseOutput):
203
- r"""
204
- Output of decoding method.
205
-
206
- Args:
207
- sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`):
208
- The decoded output sample from the last layer of the model.
209
- """
210
-
211
- sample: torch.Tensor
212
-
213
-
214
- class OobleckEncoder(nn.Module):
215
- """Oobleck Encoder"""
216
-
217
- def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples):
218
- super().__init__()
219
-
220
- strides = downsampling_ratios
221
- channel_multiples = [1] + channel_multiples
222
-
223
- # Create first convolution
224
- self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3))
225
-
226
- self.block = []
227
- # Create EncoderBlocks that double channels as they downsample by `stride`
228
- for stride_index, stride in enumerate(strides):
229
- self.block += [
230
- OobleckEncoderBlock(
231
- input_dim=encoder_hidden_size * channel_multiples[stride_index],
232
- output_dim=encoder_hidden_size * channel_multiples[stride_index + 1],
233
- stride=stride,
234
- )
235
- ]
236
-
237
- self.block = nn.ModuleList(self.block)
238
- d_model = encoder_hidden_size * channel_multiples[-1]
239
- self.snake1 = Snake1d(d_model)
240
- self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1))
241
-
242
- def forward(self, hidden_state):
243
- hidden_state = self.conv1(hidden_state)
244
-
245
- for module in self.block:
246
- hidden_state = module(hidden_state)
247
-
248
- hidden_state = self.snake1(hidden_state)
249
- hidden_state = self.conv2(hidden_state)
250
-
251
- return hidden_state
252
-
253
-
254
- class OobleckDecoder(nn.Module):
255
- """Oobleck Decoder"""
256
-
257
- def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
258
- super().__init__()
259
-
260
- strides = upsampling_ratios
261
- channel_multiples = [1] + channel_multiples
262
-
263
- # Add first conv layer
264
- self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
265
-
266
- # Add upsampling + MRF blocks
267
- block = []
268
- for stride_index, stride in enumerate(strides):
269
- block += [
270
- OobleckDecoderBlock(
271
- input_dim=channels * channel_multiples[len(strides) - stride_index],
272
- output_dim=channels * channel_multiples[len(strides) - stride_index - 1],
273
- stride=stride,
274
- )
275
- ]
276
-
277
- self.block = nn.ModuleList(block)
278
- output_dim = channels
279
- self.snake1 = Snake1d(output_dim)
280
- self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))
281
-
282
- def forward(self, hidden_state):
283
- hidden_state = self.conv1(hidden_state)
284
-
285
- for layer in self.block:
286
- hidden_state = layer(hidden_state)
287
-
288
- hidden_state = self.snake1(hidden_state)
289
- hidden_state = self.conv2(hidden_state)
290
-
291
- return hidden_state
292
-
293
-
294
- class AutoencoderOobleck(ModelMixin, ConfigMixin):
295
- r"""
296
- An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
297
- introduced in Stable Audio.
298
-
299
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
300
- for all models (such as downloading or saving).
301
-
302
- Parameters:
303
- encoder_hidden_size (`int`, *optional*, defaults to 128):
304
- Intermediate representation dimension for the encoder.
305
- downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`):
306
- Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder.
307
- channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`):
308
- Multiples used to determine the hidden sizes of the hidden layers.
309
- decoder_channels (`int`, *optional*, defaults to 128):
310
- Intermediate representation dimension for the decoder.
311
- decoder_input_channels (`int`, *optional*, defaults to 64):
312
- Input dimension for the decoder. Corresponds to the latent dimension.
313
- audio_channels (`int`, *optional*, defaults to 2):
314
- Number of channels in the audio data. Either 1 for mono or 2 for stereo.
315
- sampling_rate (`int`, *optional*, defaults to 44100):
316
- The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz).
317
- """
318
-
319
- _supports_gradient_checkpointing = False
320
-
321
- @register_to_config
322
- def __init__(
323
- self,
324
- encoder_hidden_size=128,
325
- downsampling_ratios=[2, 4, 4, 8, 8],
326
- channel_multiples=[1, 2, 4, 8, 16],
327
- decoder_channels=128,
328
- decoder_input_channels=64,
329
- audio_channels=2,
330
- sampling_rate=44100,
331
- ):
332
- super().__init__()
333
-
334
- self.encoder_hidden_size = encoder_hidden_size
335
- self.downsampling_ratios = downsampling_ratios
336
- self.decoder_channels = decoder_channels
337
- self.upsampling_ratios = downsampling_ratios[::-1]
338
- self.hop_length = int(np.prod(downsampling_ratios))
339
- self.sampling_rate = sampling_rate
340
-
341
- self.encoder = OobleckEncoder(
342
- encoder_hidden_size=encoder_hidden_size,
343
- audio_channels=audio_channels,
344
- downsampling_ratios=downsampling_ratios,
345
- channel_multiples=channel_multiples,
346
- )
347
-
348
- self.decoder = OobleckDecoder(
349
- channels=decoder_channels,
350
- input_channels=decoder_input_channels,
351
- audio_channels=audio_channels,
352
- upsampling_ratios=self.upsampling_ratios,
353
- channel_multiples=channel_multiples,
354
- )
355
-
356
- self.use_slicing = False
357
-
358
- def enable_slicing(self):
359
- r"""
360
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
361
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
362
- """
363
- self.use_slicing = True
364
-
365
- def disable_slicing(self):
366
- r"""
367
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
368
- decoding in one step.
369
- """
370
- self.use_slicing = False
371
-
372
- @apply_forward_hook
373
- def encode(
374
- self, x: torch.Tensor, return_dict: bool = True
375
- ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]:
376
- """
377
- Encode a batch of images into latents.
378
-
379
- Args:
380
- x (`torch.Tensor`): Input batch of images.
381
- return_dict (`bool`, *optional*, defaults to `True`):
382
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
383
-
384
- Returns:
385
- The latent representations of the encoded images. If `return_dict` is True, a
386
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
387
- """
388
- if self.use_slicing and x.shape[0] > 1:
389
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
390
- h = torch.cat(encoded_slices)
391
- else:
392
- h = self.encoder(x)
393
-
394
- posterior = OobleckDiagonalGaussianDistribution(h)
395
-
396
- if not return_dict:
397
- return (posterior,)
398
-
399
- return AutoencoderOobleckOutput(latent_dist=posterior)
400
-
401
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]:
402
- dec = self.decoder(z)
403
-
404
- if not return_dict:
405
- return (dec,)
406
-
407
- return OobleckDecoderOutput(sample=dec)
408
-
409
- @apply_forward_hook
410
- def decode(
411
- self, z: torch.FloatTensor, return_dict: bool = True, generator=None
412
- ) -> Union[OobleckDecoderOutput, torch.FloatTensor]:
413
- """
414
- Decode a batch of images.
415
-
416
- Args:
417
- z (`torch.Tensor`): Input batch of latent vectors.
418
- return_dict (`bool`, *optional*, defaults to `True`):
419
- Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple.
420
-
421
- Returns:
422
- [`~models.vae.OobleckDecoderOutput`] or `tuple`:
423
- If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple`
424
- is returned.
425
-
426
- """
427
- if self.use_slicing and z.shape[0] > 1:
428
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
429
- decoded = torch.cat(decoded_slices)
430
- else:
431
- decoded = self._decode(z).sample
432
-
433
- if not return_dict:
434
- return (decoded,)
435
-
436
- return OobleckDecoderOutput(sample=decoded)
437
-
438
- def forward(
439
- self,
440
- sample: torch.Tensor,
441
- sample_posterior: bool = False,
442
- return_dict: bool = True,
443
- generator: Optional[torch.Generator] = None,
444
- ) -> Union[OobleckDecoderOutput, torch.Tensor]:
445
- r"""
446
- Args:
447
- sample (`torch.Tensor`): Input sample.
448
- sample_posterior (`bool`, *optional*, defaults to `False`):
449
- Whether to sample from the posterior.
450
- return_dict (`bool`, *optional*, defaults to `True`):
451
- Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple.
452
- """
453
- x = sample
454
- posterior = self.encode(x).latent_dist
455
- if sample_posterior:
456
- z = posterior.sample(generator=generator)
457
- else:
458
- z = posterior.mode()
459
- dec = self.decode(z).sample
460
-
461
- if not return_dict:
462
- return (dec,)
463
-
464
- return OobleckDecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/autoencoder_tiny.py DELETED
@@ -1,348 +0,0 @@
1
- # Copyright 2024 Ollin Boer Bohan and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from dataclasses import dataclass
17
- from typing import Optional, Tuple, Union
18
-
19
- import torch
20
-
21
- from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...utils import BaseOutput
23
- from ...utils.accelerate_utils import apply_forward_hook
24
- from ..modeling_utils import ModelMixin
25
- from .vae import DecoderOutput, DecoderTiny, EncoderTiny
26
-
27
-
28
- @dataclass
29
- class AutoencoderTinyOutput(BaseOutput):
30
- """
31
- Output of AutoencoderTiny encoding method.
32
-
33
- Args:
34
- latents (`torch.Tensor`): Encoded outputs of the `Encoder`.
35
-
36
- """
37
-
38
- latents: torch.Tensor
39
-
40
-
41
- class AutoencoderTiny(ModelMixin, ConfigMixin):
42
- r"""
43
- A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
44
-
45
- [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`.
46
-
47
- This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
48
- all models (such as downloading or saving).
49
-
50
- Parameters:
51
- in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
52
- out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
53
- encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
54
- Tuple of integers representing the number of output channels for each encoder block. The length of the
55
- tuple should be equal to the number of encoder blocks.
56
- decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`):
57
- Tuple of integers representing the number of output channels for each decoder block. The length of the
58
- tuple should be equal to the number of decoder blocks.
59
- act_fn (`str`, *optional*, defaults to `"relu"`):
60
- Activation function to be used throughout the model.
61
- latent_channels (`int`, *optional*, defaults to 4):
62
- Number of channels in the latent representation. The latent space acts as a compressed representation of
63
- the input image.
64
- upsampling_scaling_factor (`int`, *optional*, defaults to 2):
65
- Scaling factor for upsampling in the decoder. It determines the size of the output image during the
66
- upsampling process.
67
- num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`):
68
- Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The
69
- length of the tuple should be equal to the number of stages in the encoder. Each stage has a different
70
- number of encoder blocks.
71
- num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`):
72
- Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The
73
- length of the tuple should be equal to the number of stages in the decoder. Each stage has a different
74
- number of decoder blocks.
75
- latent_magnitude (`float`, *optional*, defaults to 3.0):
76
- Magnitude of the latent representation. This parameter scales the latent representation values to control
77
- the extent of information preservation.
78
- latent_shift (float, *optional*, defaults to 0.5):
79
- Shift applied to the latent representation. This parameter controls the center of the latent space.
80
- scaling_factor (`float`, *optional*, defaults to 1.0):
81
- The component-wise standard deviation of the trained latent space computed using the first batch of the
82
- training set. This is used to scale the latent space to have unit variance when training the diffusion
83
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
84
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
85
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
86
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder,
87
- however, no such scaling factor was used, hence the value of 1.0 as the default.
88
- force_upcast (`bool`, *optional*, default to `False`):
89
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
90
- can be fine-tuned / trained to a lower range without losing too much precision, in which case
91
- `force_upcast` can be set to `False` (see this fp16-friendly
92
- [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
93
- """
94
-
95
- _supports_gradient_checkpointing = True
96
-
97
- @register_to_config
98
- def __init__(
99
- self,
100
- in_channels: int = 3,
101
- out_channels: int = 3,
102
- encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
103
- decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64),
104
- act_fn: str = "relu",
105
- upsample_fn: str = "nearest",
106
- latent_channels: int = 4,
107
- upsampling_scaling_factor: int = 2,
108
- num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3),
109
- num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1),
110
- latent_magnitude: int = 3,
111
- latent_shift: float = 0.5,
112
- force_upcast: bool = False,
113
- scaling_factor: float = 1.0,
114
- shift_factor: float = 0.0,
115
- ):
116
- super().__init__()
117
-
118
- if len(encoder_block_out_channels) != len(num_encoder_blocks):
119
- raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.")
120
- if len(decoder_block_out_channels) != len(num_decoder_blocks):
121
- raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.")
122
-
123
- self.encoder = EncoderTiny(
124
- in_channels=in_channels,
125
- out_channels=latent_channels,
126
- num_blocks=num_encoder_blocks,
127
- block_out_channels=encoder_block_out_channels,
128
- act_fn=act_fn,
129
- )
130
-
131
- self.decoder = DecoderTiny(
132
- in_channels=latent_channels,
133
- out_channels=out_channels,
134
- num_blocks=num_decoder_blocks,
135
- block_out_channels=decoder_block_out_channels,
136
- upsampling_scaling_factor=upsampling_scaling_factor,
137
- act_fn=act_fn,
138
- upsample_fn=upsample_fn,
139
- )
140
-
141
- self.latent_magnitude = latent_magnitude
142
- self.latent_shift = latent_shift
143
- self.scaling_factor = scaling_factor
144
-
145
- self.use_slicing = False
146
- self.use_tiling = False
147
-
148
- # only relevant if vae tiling is enabled
149
- self.spatial_scale_factor = 2**out_channels
150
- self.tile_overlap_factor = 0.125
151
- self.tile_sample_min_size = 512
152
- self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
153
-
154
- self.register_to_config(block_out_channels=decoder_block_out_channels)
155
- self.register_to_config(force_upcast=False)
156
-
157
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
158
- if isinstance(module, (EncoderTiny, DecoderTiny)):
159
- module.gradient_checkpointing = value
160
-
161
- def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
162
- """raw latents -> [0, 1]"""
163
- return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
164
-
165
- def unscale_latents(self, x: torch.Tensor) -> torch.Tensor:
166
- """[0, 1] -> raw latents"""
167
- return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
168
-
169
- def enable_slicing(self) -> None:
170
- r"""
171
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
172
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
173
- """
174
- self.use_slicing = True
175
-
176
- def disable_slicing(self) -> None:
177
- r"""
178
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
179
- decoding in one step.
180
- """
181
- self.use_slicing = False
182
-
183
- def enable_tiling(self, use_tiling: bool = True) -> None:
184
- r"""
185
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
186
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
187
- processing larger images.
188
- """
189
- self.use_tiling = use_tiling
190
-
191
- def disable_tiling(self) -> None:
192
- r"""
193
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
194
- decoding in one step.
195
- """
196
- self.enable_tiling(False)
197
-
198
- def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
199
- r"""Encode a batch of images using a tiled encoder.
200
-
201
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
202
- steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
203
- tiles overlap and are blended together to form a smooth output.
204
-
205
- Args:
206
- x (`torch.Tensor`): Input batch of images.
207
-
208
- Returns:
209
- `torch.Tensor`: Encoded batch of images.
210
- """
211
- # scale of encoder output relative to input
212
- sf = self.spatial_scale_factor
213
- tile_size = self.tile_sample_min_size
214
-
215
- # number of pixels to blend and to traverse between tile
216
- blend_size = int(tile_size * self.tile_overlap_factor)
217
- traverse_size = tile_size - blend_size
218
-
219
- # tiles index (up/left)
220
- ti = range(0, x.shape[-2], traverse_size)
221
- tj = range(0, x.shape[-1], traverse_size)
222
-
223
- # mask for blending
224
- blend_masks = torch.stack(
225
- torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij")
226
- )
227
- blend_masks = blend_masks.clamp(0, 1).to(x.device)
228
-
229
- # output array
230
- out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device)
231
- for i in ti:
232
- for j in tj:
233
- tile_in = x[..., i : i + tile_size, j : j + tile_size]
234
- # tile result
235
- tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf]
236
- tile = self.encoder(tile_in)
237
- h, w = tile.shape[-2], tile.shape[-1]
238
- # blend tile result into output
239
- blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
240
- blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
241
- blend_mask = blend_mask_i * blend_mask_j
242
- tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w]
243
- tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
244
- return out
245
-
246
- def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor:
247
- r"""Encode a batch of images using a tiled encoder.
248
-
249
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
250
- steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the
251
- tiles overlap and are blended together to form a smooth output.
252
-
253
- Args:
254
- x (`torch.Tensor`): Input batch of images.
255
-
256
- Returns:
257
- `torch.Tensor`: Encoded batch of images.
258
- """
259
- # scale of decoder output relative to input
260
- sf = self.spatial_scale_factor
261
- tile_size = self.tile_latent_min_size
262
-
263
- # number of pixels to blend and to traverse between tiles
264
- blend_size = int(tile_size * self.tile_overlap_factor)
265
- traverse_size = tile_size - blend_size
266
-
267
- # tiles index (up/left)
268
- ti = range(0, x.shape[-2], traverse_size)
269
- tj = range(0, x.shape[-1], traverse_size)
270
-
271
- # mask for blending
272
- blend_masks = torch.stack(
273
- torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij")
274
- )
275
- blend_masks = blend_masks.clamp(0, 1).to(x.device)
276
-
277
- # output array
278
- out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device)
279
- for i in ti:
280
- for j in tj:
281
- tile_in = x[..., i : i + tile_size, j : j + tile_size]
282
- # tile result
283
- tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf]
284
- tile = self.decoder(tile_in)
285
- h, w = tile.shape[-2], tile.shape[-1]
286
- # blend tile result into output
287
- blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0]
288
- blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1]
289
- blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w]
290
- tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out)
291
- return out
292
-
293
- @apply_forward_hook
294
- def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]:
295
- if self.use_slicing and x.shape[0] > 1:
296
- output = [
297
- self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1)
298
- ]
299
- output = torch.cat(output)
300
- else:
301
- output = self._tiled_encode(x) if self.use_tiling else self.encoder(x)
302
-
303
- if not return_dict:
304
- return (output,)
305
-
306
- return AutoencoderTinyOutput(latents=output)
307
-
308
- @apply_forward_hook
309
- def decode(
310
- self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
311
- ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
312
- if self.use_slicing and x.shape[0] > 1:
313
- output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
314
- output = torch.cat(output)
315
- else:
316
- output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
317
-
318
- if not return_dict:
319
- return (output,)
320
-
321
- return DecoderOutput(sample=output)
322
-
323
- def forward(
324
- self,
325
- sample: torch.Tensor,
326
- return_dict: bool = True,
327
- ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
328
- r"""
329
- Args:
330
- sample (`torch.Tensor`): Input sample.
331
- return_dict (`bool`, *optional*, defaults to `True`):
332
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
333
- """
334
- enc = self.encode(sample).latents
335
-
336
- # scale latents to be in [0, 1], then quantize latents to a byte tensor,
337
- # as if we were storing the latents in an RGBA uint8 image.
338
- scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
339
-
340
- # unquantize latents back into [0, 1], then unscale latents back to their original range,
341
- # as if we were loading the latents from an RGBA uint8 image.
342
- unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
343
-
344
- dec = self.decode(unscaled_enc)
345
-
346
- if not return_dict:
347
- return (dec,)
348
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/consistency_decoder_vae.py DELETED
@@ -1,460 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Dict, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from torch import nn
20
-
21
- from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...schedulers import ConsistencyDecoderScheduler
23
- from ...utils import BaseOutput
24
- from ...utils.accelerate_utils import apply_forward_hook
25
- from ...utils.torch_utils import randn_tensor
26
- from ..attention_processor import (
27
- ADDED_KV_ATTENTION_PROCESSORS,
28
- CROSS_ATTENTION_PROCESSORS,
29
- AttentionProcessor,
30
- AttnAddedKVProcessor,
31
- AttnProcessor,
32
- )
33
- from ..modeling_utils import ModelMixin
34
- from ..unets.unet_2d import UNet2DModel
35
- from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
36
-
37
-
38
- @dataclass
39
- class ConsistencyDecoderVAEOutput(BaseOutput):
40
- """
41
- Output of encoding method.
42
-
43
- Args:
44
- latent_dist (`DiagonalGaussianDistribution`):
45
- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
46
- `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
47
- """
48
-
49
- latent_dist: "DiagonalGaussianDistribution"
50
-
51
-
52
- class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
53
- r"""
54
- The consistency decoder used with DALL-E 3.
55
-
56
- Examples:
57
- ```py
58
- >>> import torch
59
- >>> from diffusers import StableDiffusionPipeline, ConsistencyDecoderVAE
60
-
61
- >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
62
- >>> pipe = StableDiffusionPipeline.from_pretrained(
63
- ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
64
- ... ).to("cuda")
65
-
66
- >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
67
- >>> image
68
- ```
69
- """
70
-
71
- @register_to_config
72
- def __init__(
73
- self,
74
- scaling_factor: float = 0.18215,
75
- latent_channels: int = 4,
76
- sample_size: int = 32,
77
- encoder_act_fn: str = "silu",
78
- encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
79
- encoder_double_z: bool = True,
80
- encoder_down_block_types: Tuple[str, ...] = (
81
- "DownEncoderBlock2D",
82
- "DownEncoderBlock2D",
83
- "DownEncoderBlock2D",
84
- "DownEncoderBlock2D",
85
- ),
86
- encoder_in_channels: int = 3,
87
- encoder_layers_per_block: int = 2,
88
- encoder_norm_num_groups: int = 32,
89
- encoder_out_channels: int = 4,
90
- decoder_add_attention: bool = False,
91
- decoder_block_out_channels: Tuple[int, ...] = (320, 640, 1024, 1024),
92
- decoder_down_block_types: Tuple[str, ...] = (
93
- "ResnetDownsampleBlock2D",
94
- "ResnetDownsampleBlock2D",
95
- "ResnetDownsampleBlock2D",
96
- "ResnetDownsampleBlock2D",
97
- ),
98
- decoder_downsample_padding: int = 1,
99
- decoder_in_channels: int = 7,
100
- decoder_layers_per_block: int = 3,
101
- decoder_norm_eps: float = 1e-05,
102
- decoder_norm_num_groups: int = 32,
103
- decoder_num_train_timesteps: int = 1024,
104
- decoder_out_channels: int = 6,
105
- decoder_resnet_time_scale_shift: str = "scale_shift",
106
- decoder_time_embedding_type: str = "learned",
107
- decoder_up_block_types: Tuple[str, ...] = (
108
- "ResnetUpsampleBlock2D",
109
- "ResnetUpsampleBlock2D",
110
- "ResnetUpsampleBlock2D",
111
- "ResnetUpsampleBlock2D",
112
- ),
113
- ):
114
- super().__init__()
115
- self.encoder = Encoder(
116
- act_fn=encoder_act_fn,
117
- block_out_channels=encoder_block_out_channels,
118
- double_z=encoder_double_z,
119
- down_block_types=encoder_down_block_types,
120
- in_channels=encoder_in_channels,
121
- layers_per_block=encoder_layers_per_block,
122
- norm_num_groups=encoder_norm_num_groups,
123
- out_channels=encoder_out_channels,
124
- )
125
-
126
- self.decoder_unet = UNet2DModel(
127
- add_attention=decoder_add_attention,
128
- block_out_channels=decoder_block_out_channels,
129
- down_block_types=decoder_down_block_types,
130
- downsample_padding=decoder_downsample_padding,
131
- in_channels=decoder_in_channels,
132
- layers_per_block=decoder_layers_per_block,
133
- norm_eps=decoder_norm_eps,
134
- norm_num_groups=decoder_norm_num_groups,
135
- num_train_timesteps=decoder_num_train_timesteps,
136
- out_channels=decoder_out_channels,
137
- resnet_time_scale_shift=decoder_resnet_time_scale_shift,
138
- time_embedding_type=decoder_time_embedding_type,
139
- up_block_types=decoder_up_block_types,
140
- )
141
- self.decoder_scheduler = ConsistencyDecoderScheduler()
142
- self.register_to_config(block_out_channels=encoder_block_out_channels)
143
- self.register_to_config(force_upcast=False)
144
- self.register_buffer(
145
- "means",
146
- torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
147
- persistent=False,
148
- )
149
- self.register_buffer(
150
- "stds", torch.tensor([0.9654121, 1.0440036, 0.76147926, 0.77022034])[None, :, None, None], persistent=False
151
- )
152
-
153
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
154
-
155
- self.use_slicing = False
156
- self.use_tiling = False
157
-
158
- # only relevant if vae tiling is enabled
159
- self.tile_sample_min_size = self.config.sample_size
160
- sample_size = (
161
- self.config.sample_size[0]
162
- if isinstance(self.config.sample_size, (list, tuple))
163
- else self.config.sample_size
164
- )
165
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
166
- self.tile_overlap_factor = 0.25
167
-
168
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
169
- def enable_tiling(self, use_tiling: bool = True):
170
- r"""
171
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
172
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
173
- processing larger images.
174
- """
175
- self.use_tiling = use_tiling
176
-
177
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
178
- def disable_tiling(self):
179
- r"""
180
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
181
- decoding in one step.
182
- """
183
- self.enable_tiling(False)
184
-
185
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
186
- def enable_slicing(self):
187
- r"""
188
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
189
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
190
- """
191
- self.use_slicing = True
192
-
193
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
194
- def disable_slicing(self):
195
- r"""
196
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
197
- decoding in one step.
198
- """
199
- self.use_slicing = False
200
-
201
- @property
202
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
203
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
204
- r"""
205
- Returns:
206
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
207
- indexed by its weight name.
208
- """
209
- # set recursively
210
- processors = {}
211
-
212
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
213
- if hasattr(module, "get_processor"):
214
- processors[f"{name}.processor"] = module.get_processor()
215
-
216
- for sub_name, child in module.named_children():
217
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
218
-
219
- return processors
220
-
221
- for name, module in self.named_children():
222
- fn_recursive_add_processors(name, module, processors)
223
-
224
- return processors
225
-
226
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
227
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
228
- r"""
229
- Sets the attention processor to use to compute attention.
230
-
231
- Parameters:
232
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
233
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
234
- for **all** `Attention` layers.
235
-
236
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
237
- processor. This is strongly recommended when setting trainable attention processors.
238
-
239
- """
240
- count = len(self.attn_processors.keys())
241
-
242
- if isinstance(processor, dict) and len(processor) != count:
243
- raise ValueError(
244
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
245
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
246
- )
247
-
248
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
249
- if hasattr(module, "set_processor"):
250
- if not isinstance(processor, dict):
251
- module.set_processor(processor)
252
- else:
253
- module.set_processor(processor.pop(f"{name}.processor"))
254
-
255
- for sub_name, child in module.named_children():
256
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
257
-
258
- for name, module in self.named_children():
259
- fn_recursive_attn_processor(name, module, processor)
260
-
261
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
262
- def set_default_attn_processor(self):
263
- """
264
- Disables custom attention processors and sets the default attention implementation.
265
- """
266
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
267
- processor = AttnAddedKVProcessor()
268
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
269
- processor = AttnProcessor()
270
- else:
271
- raise ValueError(
272
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
273
- )
274
-
275
- self.set_attn_processor(processor)
276
-
277
- @apply_forward_hook
278
- def encode(
279
- self, x: torch.Tensor, return_dict: bool = True
280
- ) -> Union[ConsistencyDecoderVAEOutput, Tuple[DiagonalGaussianDistribution]]:
281
- """
282
- Encode a batch of images into latents.
283
-
284
- Args:
285
- x (`torch.Tensor`): Input batch of images.
286
- return_dict (`bool`, *optional*, defaults to `True`):
287
- Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
288
- instead of a plain tuple.
289
-
290
- Returns:
291
- The latent representations of the encoded images. If `return_dict` is True, a
292
- [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a
293
- plain `tuple` is returned.
294
- """
295
- if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
296
- return self.tiled_encode(x, return_dict=return_dict)
297
-
298
- if self.use_slicing and x.shape[0] > 1:
299
- encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
300
- h = torch.cat(encoded_slices)
301
- else:
302
- h = self.encoder(x)
303
-
304
- moments = self.quant_conv(h)
305
- posterior = DiagonalGaussianDistribution(moments)
306
-
307
- if not return_dict:
308
- return (posterior,)
309
-
310
- return ConsistencyDecoderVAEOutput(latent_dist=posterior)
311
-
312
- @apply_forward_hook
313
- def decode(
314
- self,
315
- z: torch.Tensor,
316
- generator: Optional[torch.Generator] = None,
317
- return_dict: bool = True,
318
- num_inference_steps: int = 2,
319
- ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
320
- """
321
- Decodes the input latent vector `z` using the consistency decoder VAE model.
322
-
323
- Args:
324
- z (torch.Tensor): The input latent vector.
325
- generator (Optional[torch.Generator]): The random number generator. Default is None.
326
- return_dict (bool): Whether to return the output as a dictionary. Default is True.
327
- num_inference_steps (int): The number of inference steps. Default is 2.
328
-
329
- Returns:
330
- Union[DecoderOutput, Tuple[torch.Tensor]]: The decoded output.
331
-
332
- """
333
- z = (z * self.config.scaling_factor - self.means) / self.stds
334
-
335
- scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
336
- z = F.interpolate(z, mode="nearest", scale_factor=scale_factor)
337
-
338
- batch_size, _, height, width = z.shape
339
-
340
- self.decoder_scheduler.set_timesteps(num_inference_steps, device=self.device)
341
-
342
- x_t = self.decoder_scheduler.init_noise_sigma * randn_tensor(
343
- (batch_size, 3, height, width), generator=generator, dtype=z.dtype, device=z.device
344
- )
345
-
346
- for t in self.decoder_scheduler.timesteps:
347
- model_input = torch.concat([self.decoder_scheduler.scale_model_input(x_t, t), z], dim=1)
348
- model_output = self.decoder_unet(model_input, t).sample[:, :3, :, :]
349
- prev_sample = self.decoder_scheduler.step(model_output, t, x_t, generator).prev_sample
350
- x_t = prev_sample
351
-
352
- x_0 = x_t
353
-
354
- if not return_dict:
355
- return (x_0,)
356
-
357
- return DecoderOutput(sample=x_0)
358
-
359
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_v
360
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
361
- blend_extent = min(a.shape[2], b.shape[2], blend_extent)
362
- for y in range(blend_extent):
363
- b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
364
- return b
365
-
366
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.blend_h
367
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
368
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
369
- for x in range(blend_extent):
370
- b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
371
- return b
372
-
373
- def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
374
- r"""Encode a batch of images using a tiled encoder.
375
-
376
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
377
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
378
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
379
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
380
- output, but they should be much less noticeable.
381
-
382
- Args:
383
- x (`torch.Tensor`): Input batch of images.
384
- return_dict (`bool`, *optional*, defaults to `True`):
385
- Whether or not to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
386
- instead of a plain tuple.
387
-
388
- Returns:
389
- [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
390
- If return_dict is True, a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
391
- is returned, otherwise a plain `tuple` is returned.
392
- """
393
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
394
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
395
- row_limit = self.tile_latent_min_size - blend_extent
396
-
397
- # Split the image into 512x512 tiles and encode them separately.
398
- rows = []
399
- for i in range(0, x.shape[2], overlap_size):
400
- row = []
401
- for j in range(0, x.shape[3], overlap_size):
402
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
403
- tile = self.encoder(tile)
404
- tile = self.quant_conv(tile)
405
- row.append(tile)
406
- rows.append(row)
407
- result_rows = []
408
- for i, row in enumerate(rows):
409
- result_row = []
410
- for j, tile in enumerate(row):
411
- # blend the above tile and the left tile
412
- # to the current tile and add the current tile to the result row
413
- if i > 0:
414
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
415
- if j > 0:
416
- tile = self.blend_h(row[j - 1], tile, blend_extent)
417
- result_row.append(tile[:, :, :row_limit, :row_limit])
418
- result_rows.append(torch.cat(result_row, dim=3))
419
-
420
- moments = torch.cat(result_rows, dim=2)
421
- posterior = DiagonalGaussianDistribution(moments)
422
-
423
- if not return_dict:
424
- return (posterior,)
425
-
426
- return ConsistencyDecoderVAEOutput(latent_dist=posterior)
427
-
428
- def forward(
429
- self,
430
- sample: torch.Tensor,
431
- sample_posterior: bool = False,
432
- return_dict: bool = True,
433
- generator: Optional[torch.Generator] = None,
434
- ) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
435
- r"""
436
- Args:
437
- sample (`torch.Tensor`): Input sample.
438
- sample_posterior (`bool`, *optional*, defaults to `False`):
439
- Whether to sample from the posterior.
440
- return_dict (`bool`, *optional*, defaults to `True`):
441
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
442
- generator (`torch.Generator`, *optional*, defaults to `None`):
443
- Generator to use for sampling.
444
-
445
- Returns:
446
- [`DecoderOutput`] or `tuple`:
447
- If return_dict is True, a [`DecoderOutput`] is returned, otherwise a plain `tuple` is returned.
448
- """
449
- x = sample
450
- posterior = self.encode(x).latent_dist
451
- if sample_posterior:
452
- z = posterior.sample(generator=generator)
453
- else:
454
- z = posterior.mode()
455
- dec = self.decode(z, generator=generator).sample
456
-
457
- if not return_dict:
458
- return (dec,)
459
-
460
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/vae.py DELETED
@@ -1,1005 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Optional, Tuple
16
-
17
- import numpy as np
18
- import torch
19
- import torch.nn as nn
20
-
21
- from ...utils import BaseOutput, is_torch_version
22
- from ...utils.torch_utils import randn_tensor
23
- from ..activations import get_activation
24
- from ..attention_processor import SpatialNorm
25
- from ..unets.unet_2d_blocks import (
26
- AutoencoderTinyBlock,
27
- UNetMidBlock2D,
28
- get_down_block,
29
- get_up_block,
30
- )
31
-
32
-
33
- @dataclass
34
- class DecoderOutput(BaseOutput):
35
- r"""
36
- Output of decoding method.
37
-
38
- Args:
39
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
40
- The decoded output sample from the last layer of the model.
41
- """
42
-
43
- sample: torch.FloatTensor
44
-
45
-
46
- class Encoder(nn.Module):
47
- r"""
48
- The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
49
-
50
- Args:
51
- in_channels (`int`, *optional*, defaults to 3):
52
- The number of input channels.
53
- out_channels (`int`, *optional*, defaults to 3):
54
- The number of output channels.
55
- down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
56
- The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
57
- options.
58
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
59
- The number of output channels for each block.
60
- layers_per_block (`int`, *optional*, defaults to 2):
61
- The number of layers per block.
62
- norm_num_groups (`int`, *optional*, defaults to 32):
63
- The number of groups for normalization.
64
- act_fn (`str`, *optional*, defaults to `"silu"`):
65
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
66
- double_z (`bool`, *optional*, defaults to `True`):
67
- Whether to double the number of output channels for the last block.
68
- """
69
-
70
- def __init__(
71
- self,
72
- in_channels: int = 3,
73
- out_channels: int = 3,
74
- down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
75
- block_out_channels: Tuple[int, ...] = (64,),
76
- layers_per_block: int = 2,
77
- norm_num_groups: int = 32,
78
- act_fn: str = "silu",
79
- double_z: bool = True,
80
- mid_block_add_attention=True,
81
- ):
82
- super().__init__()
83
- self.layers_per_block = layers_per_block
84
-
85
- self.conv_in = nn.Conv2d(
86
- in_channels,
87
- block_out_channels[0],
88
- kernel_size=3,
89
- stride=1,
90
- padding=1,
91
- )
92
-
93
- self.mid_block = None
94
- self.down_blocks = nn.ModuleList([])
95
-
96
- # down
97
- output_channel = block_out_channels[0]
98
- for i, down_block_type in enumerate(down_block_types):
99
- input_channel = output_channel
100
- output_channel = block_out_channels[i]
101
- is_final_block = i == len(block_out_channels) - 1
102
-
103
- down_block = get_down_block(
104
- down_block_type,
105
- num_layers=self.layers_per_block,
106
- in_channels=input_channel,
107
- out_channels=output_channel,
108
- add_downsample=not is_final_block,
109
- resnet_eps=1e-6,
110
- downsample_padding=0,
111
- resnet_act_fn=act_fn,
112
- resnet_groups=norm_num_groups,
113
- attention_head_dim=output_channel,
114
- temb_channels=None,
115
- )
116
- self.down_blocks.append(down_block)
117
-
118
- # mid
119
- self.mid_block = UNetMidBlock2D(
120
- in_channels=block_out_channels[-1],
121
- resnet_eps=1e-6,
122
- resnet_act_fn=act_fn,
123
- output_scale_factor=1,
124
- resnet_time_scale_shift="default",
125
- attention_head_dim=block_out_channels[-1],
126
- resnet_groups=norm_num_groups,
127
- temb_channels=None,
128
- add_attention=mid_block_add_attention,
129
- )
130
-
131
- # out
132
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
133
- self.conv_act = nn.SiLU()
134
-
135
- conv_out_channels = 2 * out_channels if double_z else out_channels
136
- self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
137
-
138
- self.gradient_checkpointing = False
139
-
140
- def forward(self, sample: torch.FloatTensor, hidden_flag = False) -> torch.FloatTensor:
141
- r"""The forward method of the `Encoder` class."""
142
-
143
- sample = self.conv_in(sample)
144
- hidden_list = []
145
-
146
- if self.training and self.gradient_checkpointing:
147
-
148
- def create_custom_forward(module):
149
- def custom_forward(*inputs):
150
- return module(*inputs)
151
-
152
- return custom_forward
153
-
154
- # down
155
- if is_torch_version(">=", "1.11.0"):
156
- for down_block in self.down_blocks:
157
- sample = torch.utils.checkpoint.checkpoint(
158
- create_custom_forward(down_block), sample, use_reentrant=False
159
- )
160
- # middle
161
- sample = torch.utils.checkpoint.checkpoint(
162
- create_custom_forward(self.mid_block), sample, use_reentrant=False
163
- )
164
- else:
165
- for down_block in self.down_blocks:
166
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
167
- # middle
168
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
169
-
170
- else:
171
- # down
172
- hidden_list.append(sample)
173
- for down_block in self.down_blocks:
174
- sample = down_block(sample)
175
- hidden_list.append(sample)
176
-
177
- # middle
178
- sample = self.mid_block(sample)
179
- # hidden_list.append(sample)
180
-
181
- # post-process
182
- sample = self.conv_norm_out(sample)
183
- sample = self.conv_act(sample)
184
- sample = self.conv_out(sample)
185
-
186
- if hidden_flag:
187
- return sample, hidden_list
188
- else:
189
- return sample
190
-
191
-
192
- class Decoder(nn.Module):
193
- r"""
194
- The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
195
-
196
- Args:
197
- in_channels (`int`, *optional*, defaults to 3):
198
- The number of input channels.
199
- out_channels (`int`, *optional*, defaults to 3):
200
- The number of output channels.
201
- up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
202
- The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
203
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
204
- The number of output channels for each block.
205
- layers_per_block (`int`, *optional*, defaults to 2):
206
- The number of layers per block.
207
- norm_num_groups (`int`, *optional*, defaults to 32):
208
- The number of groups for normalization.
209
- act_fn (`str`, *optional*, defaults to `"silu"`):
210
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
211
- norm_type (`str`, *optional*, defaults to `"group"`):
212
- The normalization type to use. Can be either `"group"` or `"spatial"`.
213
- """
214
-
215
- def __init__(
216
- self,
217
- in_channels: int = 3,
218
- out_channels: int = 3,
219
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
220
- block_out_channels: Tuple[int, ...] = (64,),
221
- layers_per_block: int = 2,
222
- norm_num_groups: int = 32,
223
- act_fn: str = "silu",
224
- norm_type: str = "group", # group, spatial
225
- mid_block_add_attention=True,
226
- ):
227
- super().__init__()
228
- self.layers_per_block = layers_per_block
229
-
230
- self.conv_in = nn.Conv2d(
231
- in_channels,
232
- block_out_channels[-1],
233
- kernel_size=3,
234
- stride=1,
235
- padding=1,
236
- )
237
-
238
- self.mid_block = None
239
- self.up_blocks = nn.ModuleList([])
240
-
241
- temb_channels = in_channels if norm_type == "spatial" else None
242
-
243
- # mid
244
- self.mid_block = UNetMidBlock2D(
245
- in_channels=block_out_channels[-1],
246
- resnet_eps=1e-6,
247
- resnet_act_fn=act_fn,
248
- output_scale_factor=1,
249
- resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
250
- attention_head_dim=block_out_channels[-1],
251
- resnet_groups=norm_num_groups,
252
- temb_channels=temb_channels,
253
- add_attention=mid_block_add_attention,
254
- )
255
-
256
- # up
257
- reversed_block_out_channels = list(reversed(block_out_channels))
258
- output_channel = reversed_block_out_channels[0]
259
- for i, up_block_type in enumerate(up_block_types):
260
- prev_output_channel = output_channel
261
- output_channel = reversed_block_out_channels[i]
262
-
263
- is_final_block = i == len(block_out_channels) - 1
264
-
265
- up_block = get_up_block(
266
- up_block_type,
267
- num_layers=self.layers_per_block + 1,
268
- in_channels=prev_output_channel,
269
- out_channels=output_channel,
270
- prev_output_channel=None,
271
- add_upsample=not is_final_block,
272
- resnet_eps=1e-6,
273
- resnet_act_fn=act_fn,
274
- resnet_groups=norm_num_groups,
275
- attention_head_dim=output_channel,
276
- temb_channels=temb_channels,
277
- resnet_time_scale_shift=norm_type,
278
- )
279
- self.up_blocks.append(up_block)
280
- prev_output_channel = output_channel
281
-
282
- # out
283
- if norm_type == "spatial":
284
- self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
285
- else:
286
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
287
- self.conv_act = nn.SiLU()
288
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
289
-
290
- self.gradient_checkpointing = False
291
-
292
- def forward(
293
- self,
294
- sample: torch.FloatTensor,
295
- latent_embeds: Optional[torch.FloatTensor] = None,
296
- hidden_list: list = None,
297
- ) -> torch.FloatTensor:
298
- r"""The forward method of the `Decoder` class."""
299
-
300
- if hidden_list is not None:
301
- hidden_list.reverse()
302
- hidden_idx = 0
303
- sample = self.conv_in(sample)
304
-
305
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
306
- if self.training and self.gradient_checkpointing:
307
-
308
- def create_custom_forward(module):
309
- def custom_forward(*inputs):
310
- return module(*inputs)
311
-
312
- return custom_forward
313
-
314
- if is_torch_version(">=", "1.11.0"):
315
- # middle
316
- sample = torch.utils.checkpoint.checkpoint(
317
- create_custom_forward(self.mid_block),
318
- sample,
319
- latent_embeds,
320
- use_reentrant=False,
321
- )
322
- sample = sample.to(upscale_dtype)
323
-
324
- # up
325
- for up_block in self.up_blocks:
326
- sample = torch.utils.checkpoint.checkpoint(
327
- create_custom_forward(up_block),
328
- sample,
329
- latent_embeds,
330
- use_reentrant=False,
331
- )
332
- else:
333
- # middle
334
- sample = torch.utils.checkpoint.checkpoint(
335
- create_custom_forward(self.mid_block), sample, latent_embeds
336
- )
337
- sample = sample.to(upscale_dtype)
338
-
339
- # up
340
- for up_block in self.up_blocks:
341
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
342
- else:
343
- # middle
344
- # print(sample.shape)
345
- if hidden_list is not None:
346
- # print(sample.shape, hidden_list[hidden_idx].shape)
347
- sample += hidden_list[hidden_idx]
348
- hidden_idx += 1
349
- sample = self.mid_block(sample, latent_embeds)
350
- sample = sample.to(upscale_dtype)
351
-
352
-
353
- # up
354
- for up_block in self.up_blocks:
355
- # print(sample.shape)
356
- if hidden_list is not None:
357
- # print(sample.shape, hidden_list[hidden_idx].shape)
358
- sample += hidden_list[hidden_idx]
359
- hidden_idx += 1
360
- sample = up_block(sample, latent_embeds)
361
-
362
- # post-process
363
- if latent_embeds is None:
364
- sample = self.conv_norm_out(sample)
365
- else:
366
- sample = self.conv_norm_out(sample, latent_embeds)
367
- sample = self.conv_act(sample)
368
- sample = self.conv_out(sample)
369
-
370
- return sample
371
-
372
-
373
- class UpSample(nn.Module):
374
- r"""
375
- The `UpSample` layer of a variational autoencoder that upsamples its input.
376
-
377
- Args:
378
- in_channels (`int`, *optional*, defaults to 3):
379
- The number of input channels.
380
- out_channels (`int`, *optional*, defaults to 3):
381
- The number of output channels.
382
- """
383
-
384
- def __init__(
385
- self,
386
- in_channels: int,
387
- out_channels: int,
388
- ) -> None:
389
- super().__init__()
390
- self.in_channels = in_channels
391
- self.out_channels = out_channels
392
- self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
393
-
394
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
395
- r"""The forward method of the `UpSample` class."""
396
- x = torch.relu(x)
397
- x = self.deconv(x)
398
- return x
399
-
400
-
401
- class MaskConditionEncoder(nn.Module):
402
- """
403
- used in AsymmetricAutoencoderKL
404
- """
405
-
406
- def __init__(
407
- self,
408
- in_ch: int,
409
- out_ch: int = 192,
410
- res_ch: int = 768,
411
- stride: int = 16,
412
- ) -> None:
413
- super().__init__()
414
-
415
- channels = []
416
- while stride > 1:
417
- stride = stride // 2
418
- in_ch_ = out_ch * 2
419
- if out_ch > res_ch:
420
- out_ch = res_ch
421
- if stride == 1:
422
- in_ch_ = res_ch
423
- channels.append((in_ch_, out_ch))
424
- out_ch *= 2
425
-
426
- out_channels = []
427
- for _in_ch, _out_ch in channels:
428
- out_channels.append(_out_ch)
429
- out_channels.append(channels[-1][0])
430
-
431
- layers = []
432
- in_ch_ = in_ch
433
- for l in range(len(out_channels)):
434
- out_ch_ = out_channels[l]
435
- if l == 0 or l == 1:
436
- layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
437
- else:
438
- layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
439
- in_ch_ = out_ch_
440
-
441
- self.layers = nn.Sequential(*layers)
442
-
443
- def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
444
- r"""The forward method of the `MaskConditionEncoder` class."""
445
- out = {}
446
- for l in range(len(self.layers)):
447
- layer = self.layers[l]
448
- x = layer(x)
449
- out[str(tuple(x.shape))] = x
450
- x = torch.relu(x)
451
- return out
452
-
453
-
454
- class MaskConditionDecoder(nn.Module):
455
- r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
456
- decoder with a conditioner on the mask and masked image.
457
-
458
- Args:
459
- in_channels (`int`, *optional*, defaults to 3):
460
- The number of input channels.
461
- out_channels (`int`, *optional*, defaults to 3):
462
- The number of output channels.
463
- up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
464
- The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
465
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
466
- The number of output channels for each block.
467
- layers_per_block (`int`, *optional*, defaults to 2):
468
- The number of layers per block.
469
- norm_num_groups (`int`, *optional*, defaults to 32):
470
- The number of groups for normalization.
471
- act_fn (`str`, *optional*, defaults to `"silu"`):
472
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
473
- norm_type (`str`, *optional*, defaults to `"group"`):
474
- The normalization type to use. Can be either `"group"` or `"spatial"`.
475
- """
476
-
477
- def __init__(
478
- self,
479
- in_channels: int = 3,
480
- out_channels: int = 3,
481
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
482
- block_out_channels: Tuple[int, ...] = (64,),
483
- layers_per_block: int = 2,
484
- norm_num_groups: int = 32,
485
- act_fn: str = "silu",
486
- norm_type: str = "group", # group, spatial
487
- ):
488
- super().__init__()
489
- self.layers_per_block = layers_per_block
490
-
491
- self.conv_in = nn.Conv2d(
492
- in_channels,
493
- block_out_channels[-1],
494
- kernel_size=3,
495
- stride=1,
496
- padding=1,
497
- )
498
-
499
- self.mid_block = None
500
- self.up_blocks = nn.ModuleList([])
501
-
502
- temb_channels = in_channels if norm_type == "spatial" else None
503
-
504
- # mid
505
- self.mid_block = UNetMidBlock2D(
506
- in_channels=block_out_channels[-1],
507
- resnet_eps=1e-6,
508
- resnet_act_fn=act_fn,
509
- output_scale_factor=1,
510
- resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
511
- attention_head_dim=block_out_channels[-1],
512
- resnet_groups=norm_num_groups,
513
- temb_channels=temb_channels,
514
- )
515
-
516
- # up
517
- reversed_block_out_channels = list(reversed(block_out_channels))
518
- output_channel = reversed_block_out_channels[0]
519
- for i, up_block_type in enumerate(up_block_types):
520
- prev_output_channel = output_channel
521
- output_channel = reversed_block_out_channels[i]
522
-
523
- is_final_block = i == len(block_out_channels) - 1
524
-
525
- up_block = get_up_block(
526
- up_block_type,
527
- num_layers=self.layers_per_block + 1,
528
- in_channels=prev_output_channel,
529
- out_channels=output_channel,
530
- prev_output_channel=None,
531
- add_upsample=not is_final_block,
532
- resnet_eps=1e-6,
533
- resnet_act_fn=act_fn,
534
- resnet_groups=norm_num_groups,
535
- attention_head_dim=output_channel,
536
- temb_channels=temb_channels,
537
- resnet_time_scale_shift=norm_type,
538
- )
539
- self.up_blocks.append(up_block)
540
- prev_output_channel = output_channel
541
-
542
- # condition encoder
543
- self.condition_encoder = MaskConditionEncoder(
544
- in_ch=out_channels,
545
- out_ch=block_out_channels[0],
546
- res_ch=block_out_channels[-1],
547
- )
548
-
549
- # out
550
- if norm_type == "spatial":
551
- self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
552
- else:
553
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
554
- self.conv_act = nn.SiLU()
555
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
556
-
557
- self.gradient_checkpointing = False
558
-
559
- def forward(
560
- self,
561
- z: torch.FloatTensor,
562
- image: Optional[torch.FloatTensor] = None,
563
- mask: Optional[torch.FloatTensor] = None,
564
- latent_embeds: Optional[torch.FloatTensor] = None,
565
- ) -> torch.FloatTensor:
566
- r"""The forward method of the `MaskConditionDecoder` class."""
567
- sample = z
568
- sample = self.conv_in(sample)
569
-
570
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
571
- if self.training and self.gradient_checkpointing:
572
-
573
- def create_custom_forward(module):
574
- def custom_forward(*inputs):
575
- return module(*inputs)
576
-
577
- return custom_forward
578
-
579
- if is_torch_version(">=", "1.11.0"):
580
- # middle
581
- sample = torch.utils.checkpoint.checkpoint(
582
- create_custom_forward(self.mid_block),
583
- sample,
584
- latent_embeds,
585
- use_reentrant=False,
586
- )
587
- sample = sample.to(upscale_dtype)
588
-
589
- # condition encoder
590
- if image is not None and mask is not None:
591
- masked_image = (1 - mask) * image
592
- im_x = torch.utils.checkpoint.checkpoint(
593
- create_custom_forward(self.condition_encoder),
594
- masked_image,
595
- mask,
596
- use_reentrant=False,
597
- )
598
-
599
- # up
600
- for up_block in self.up_blocks:
601
- if image is not None and mask is not None:
602
- sample_ = im_x[str(tuple(sample.shape))]
603
- mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
604
- sample = sample * mask_ + sample_ * (1 - mask_)
605
- sample = torch.utils.checkpoint.checkpoint(
606
- create_custom_forward(up_block),
607
- sample,
608
- latent_embeds,
609
- use_reentrant=False,
610
- )
611
- if image is not None and mask is not None:
612
- sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
613
- else:
614
- # middle
615
- sample = torch.utils.checkpoint.checkpoint(
616
- create_custom_forward(self.mid_block), sample, latent_embeds
617
- )
618
- sample = sample.to(upscale_dtype)
619
-
620
- # condition encoder
621
- if image is not None and mask is not None:
622
- masked_image = (1 - mask) * image
623
- im_x = torch.utils.checkpoint.checkpoint(
624
- create_custom_forward(self.condition_encoder),
625
- masked_image,
626
- mask,
627
- )
628
-
629
- # up
630
- for up_block in self.up_blocks:
631
- if image is not None and mask is not None:
632
- sample_ = im_x[str(tuple(sample.shape))]
633
- mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
634
- sample = sample * mask_ + sample_ * (1 - mask_)
635
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
636
- if image is not None and mask is not None:
637
- sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
638
- else:
639
- # middle
640
- sample = self.mid_block(sample, latent_embeds)
641
- sample = sample.to(upscale_dtype)
642
-
643
- # condition encoder
644
- if image is not None and mask is not None:
645
- masked_image = (1 - mask) * image
646
- im_x = self.condition_encoder(masked_image, mask)
647
-
648
- # up
649
- for up_block in self.up_blocks:
650
- if image is not None and mask is not None:
651
- sample_ = im_x[str(tuple(sample.shape))]
652
- mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
653
- sample = sample * mask_ + sample_ * (1 - mask_)
654
- sample = up_block(sample, latent_embeds)
655
- if image is not None and mask is not None:
656
- sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
657
-
658
- # post-process
659
- if latent_embeds is None:
660
- sample = self.conv_norm_out(sample)
661
- else:
662
- sample = self.conv_norm_out(sample, latent_embeds)
663
- sample = self.conv_act(sample)
664
- sample = self.conv_out(sample)
665
-
666
- return sample
667
-
668
-
669
- class VectorQuantizer(nn.Module):
670
- """
671
- Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
672
- multiplications and allows for post-hoc remapping of indices.
673
- """
674
-
675
- # NOTE: due to a bug the beta term was applied to the wrong term. for
676
- # backwards compatibility we use the buggy version by default, but you can
677
- # specify legacy=False to fix it.
678
- def __init__(
679
- self,
680
- n_e: int,
681
- vq_embed_dim: int,
682
- beta: float,
683
- remap=None,
684
- unknown_index: str = "random",
685
- sane_index_shape: bool = False,
686
- legacy: bool = True,
687
- ):
688
- super().__init__()
689
- self.n_e = n_e
690
- self.vq_embed_dim = vq_embed_dim
691
- self.beta = beta
692
- self.legacy = legacy
693
-
694
- self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
695
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
696
-
697
- self.remap = remap
698
- if self.remap is not None:
699
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
700
- self.used: torch.Tensor
701
- self.re_embed = self.used.shape[0]
702
- self.unknown_index = unknown_index # "random" or "extra" or integer
703
- if self.unknown_index == "extra":
704
- self.unknown_index = self.re_embed
705
- self.re_embed = self.re_embed + 1
706
- print(
707
- f"Remapping {self.n_e} indices to {self.re_embed} indices. "
708
- f"Using {self.unknown_index} for unknown indices."
709
- )
710
- else:
711
- self.re_embed = n_e
712
-
713
- self.sane_index_shape = sane_index_shape
714
-
715
- def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
716
- ishape = inds.shape
717
- assert len(ishape) > 1
718
- inds = inds.reshape(ishape[0], -1)
719
- used = self.used.to(inds)
720
- match = (inds[:, :, None] == used[None, None, ...]).long()
721
- new = match.argmax(-1)
722
- unknown = match.sum(2) < 1
723
- if self.unknown_index == "random":
724
- new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
725
- else:
726
- new[unknown] = self.unknown_index
727
- return new.reshape(ishape)
728
-
729
- def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
730
- ishape = inds.shape
731
- assert len(ishape) > 1
732
- inds = inds.reshape(ishape[0], -1)
733
- used = self.used.to(inds)
734
- if self.re_embed > self.used.shape[0]: # extra token
735
- inds[inds >= self.used.shape[0]] = 0 # simply set to zero
736
- back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
737
- return back.reshape(ishape)
738
-
739
- def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
740
- # reshape z -> (batch, height, width, channel) and flatten
741
- z = z.permute(0, 2, 3, 1).contiguous()
742
- z_flattened = z.view(-1, self.vq_embed_dim)
743
-
744
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
745
- min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
746
-
747
- z_q = self.embedding(min_encoding_indices).view(z.shape)
748
- perplexity = None
749
- min_encodings = None
750
-
751
- # compute loss for embedding
752
- if not self.legacy:
753
- loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
754
- else:
755
- loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
756
-
757
- # preserve gradients
758
- z_q: torch.FloatTensor = z + (z_q - z).detach()
759
-
760
- # reshape back to match original input shape
761
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
762
-
763
- if self.remap is not None:
764
- min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
765
- min_encoding_indices = self.remap_to_used(min_encoding_indices)
766
- min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
767
-
768
- if self.sane_index_shape:
769
- min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
770
-
771
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
772
-
773
- def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
774
- # shape specifying (batch, height, width, channel)
775
- if self.remap is not None:
776
- indices = indices.reshape(shape[0], -1) # add batch axis
777
- indices = self.unmap_to_all(indices)
778
- indices = indices.reshape(-1) # flatten again
779
-
780
- # get quantized latent vectors
781
- z_q: torch.FloatTensor = self.embedding(indices)
782
-
783
- if shape is not None:
784
- z_q = z_q.view(shape)
785
- # reshape back to match original input shape
786
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
787
-
788
- return z_q
789
-
790
-
791
- class DiagonalGaussianDistribution(object):
792
- def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
793
- self.parameters = parameters
794
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
795
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
796
- self.deterministic = deterministic
797
- self.std = torch.exp(0.5 * self.logvar)
798
- self.var = torch.exp(self.logvar)
799
- if self.deterministic:
800
- self.var = self.std = torch.zeros_like(
801
- self.mean, device=self.parameters.device, dtype=self.parameters.dtype
802
- )
803
-
804
- def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
805
- # make sure sample is on the same device as the parameters and has same dtype
806
- sample = randn_tensor(
807
- self.mean.shape,
808
- generator=generator,
809
- device=self.parameters.device,
810
- dtype=self.parameters.dtype,
811
- )
812
- x = self.mean + self.std * sample
813
- return x
814
-
815
- def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
816
- if self.deterministic:
817
- return torch.Tensor([0.0])
818
- else:
819
- if other is None:
820
- return 0.5 * torch.sum(
821
- torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
822
- dim=[1, 2, 3],
823
- )
824
- else:
825
- return 0.5 * torch.sum(
826
- torch.pow(self.mean - other.mean, 2) / other.var
827
- + self.var / other.var
828
- - 1.0
829
- - self.logvar
830
- + other.logvar,
831
- dim=[1, 2, 3],
832
- )
833
-
834
- def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
835
- if self.deterministic:
836
- return torch.Tensor([0.0])
837
- logtwopi = np.log(2.0 * np.pi)
838
- return 0.5 * torch.sum(
839
- logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
840
- dim=dims,
841
- )
842
-
843
- def mode(self) -> torch.Tensor:
844
- return self.mean
845
-
846
-
847
- class EncoderTiny(nn.Module):
848
- r"""
849
- The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
850
-
851
- Args:
852
- in_channels (`int`):
853
- The number of input channels.
854
- out_channels (`int`):
855
- The number of output channels.
856
- num_blocks (`Tuple[int, ...]`):
857
- Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
858
- use.
859
- block_out_channels (`Tuple[int, ...]`):
860
- The number of output channels for each block.
861
- act_fn (`str`):
862
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
863
- """
864
-
865
- def __init__(
866
- self,
867
- in_channels: int,
868
- out_channels: int,
869
- num_blocks: Tuple[int, ...],
870
- block_out_channels: Tuple[int, ...],
871
- act_fn: str,
872
- ):
873
- super().__init__()
874
-
875
- layers = []
876
- for i, num_block in enumerate(num_blocks):
877
- num_channels = block_out_channels[i]
878
-
879
- if i == 0:
880
- layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
881
- else:
882
- layers.append(
883
- nn.Conv2d(
884
- num_channels,
885
- num_channels,
886
- kernel_size=3,
887
- padding=1,
888
- stride=2,
889
- bias=False,
890
- )
891
- )
892
-
893
- for _ in range(num_block):
894
- layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
895
-
896
- layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
897
-
898
- self.layers = nn.Sequential(*layers)
899
- self.gradient_checkpointing = False
900
-
901
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
902
- r"""The forward method of the `EncoderTiny` class."""
903
- if self.training and self.gradient_checkpointing:
904
-
905
- def create_custom_forward(module):
906
- def custom_forward(*inputs):
907
- return module(*inputs)
908
-
909
- return custom_forward
910
-
911
- if is_torch_version(">=", "1.11.0"):
912
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
913
- else:
914
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
915
-
916
- else:
917
- # scale image from [-1, 1] to [0, 1] to match TAESD convention
918
- x = self.layers(x.add(1).div(2))
919
-
920
- return x
921
-
922
-
923
- class DecoderTiny(nn.Module):
924
- r"""
925
- The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
926
-
927
- Args:
928
- in_channels (`int`):
929
- The number of input channels.
930
- out_channels (`int`):
931
- The number of output channels.
932
- num_blocks (`Tuple[int, ...]`):
933
- Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
934
- use.
935
- block_out_channels (`Tuple[int, ...]`):
936
- The number of output channels for each block.
937
- upsampling_scaling_factor (`int`):
938
- The scaling factor to use for upsampling.
939
- act_fn (`str`):
940
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
941
- """
942
-
943
- def __init__(
944
- self,
945
- in_channels: int,
946
- out_channels: int,
947
- num_blocks: Tuple[int, ...],
948
- block_out_channels: Tuple[int, ...],
949
- upsampling_scaling_factor: int,
950
- act_fn: str,
951
- ):
952
- super().__init__()
953
-
954
- layers = [
955
- nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
956
- get_activation(act_fn),
957
- ]
958
-
959
- for i, num_block in enumerate(num_blocks):
960
- is_final_block = i == (len(num_blocks) - 1)
961
- num_channels = block_out_channels[i]
962
-
963
- for _ in range(num_block):
964
- layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
965
-
966
- if not is_final_block:
967
- layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
968
-
969
- conv_out_channel = num_channels if not is_final_block else out_channels
970
- layers.append(
971
- nn.Conv2d(
972
- num_channels,
973
- conv_out_channel,
974
- kernel_size=3,
975
- padding=1,
976
- bias=is_final_block,
977
- )
978
- )
979
-
980
- self.layers = nn.Sequential(*layers)
981
- self.gradient_checkpointing = False
982
-
983
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
984
- r"""The forward method of the `DecoderTiny` class."""
985
- # Clamp.
986
- x = torch.tanh(x / 3) * 3
987
-
988
- if self.training and self.gradient_checkpointing:
989
-
990
- def create_custom_forward(module):
991
- def custom_forward(*inputs):
992
- return module(*inputs)
993
-
994
- return custom_forward
995
-
996
- if is_torch_version(">=", "1.11.0"):
997
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
998
- else:
999
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
1000
-
1001
- else:
1002
- x = self.layers(x)
1003
-
1004
- # scale image from [0, 1] to [-1, 1] to match diffusers convention
1005
- return x.mul(2).sub(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/models/autoencoders_/vq_model.py DELETED
@@ -1,182 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
-
20
- from ...configuration_utils import ConfigMixin, register_to_config
21
- from ...utils import BaseOutput
22
- from ...utils.accelerate_utils import apply_forward_hook
23
- from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
24
- from ..modeling_utils import ModelMixin
25
-
26
-
27
- @dataclass
28
- class VQEncoderOutput(BaseOutput):
29
- """
30
- Output of VQModel encoding method.
31
-
32
- Args:
33
- latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
34
- The encoded output sample from the last layer of the model.
35
- """
36
-
37
- latents: torch.Tensor
38
-
39
-
40
- class VQModel(ModelMixin, ConfigMixin):
41
- r"""
42
- A VQ-VAE model for decoding latent representations.
43
-
44
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
- for all models (such as downloading or saving).
46
-
47
- Parameters:
48
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
51
- Tuple of downsample block types.
52
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
53
- Tuple of upsample block types.
54
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
55
- Tuple of block output channels.
56
- layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
57
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
58
- latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
59
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
60
- num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
61
- norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
62
- vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
63
- scaling_factor (`float`, *optional*, defaults to `0.18215`):
64
- The component-wise standard deviation of the trained latent space computed using the first batch of the
65
- training set. This is used to scale the latent space to have unit variance when training the diffusion
66
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
67
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
68
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
69
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
70
- norm_type (`str`, *optional*, defaults to `"group"`):
71
- Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
72
- """
73
-
74
- @register_to_config
75
- def __init__(
76
- self,
77
- in_channels: int = 3,
78
- out_channels: int = 3,
79
- down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
80
- up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
81
- block_out_channels: Tuple[int, ...] = (64,),
82
- layers_per_block: int = 1,
83
- act_fn: str = "silu",
84
- latent_channels: int = 3,
85
- sample_size: int = 32,
86
- num_vq_embeddings: int = 256,
87
- norm_num_groups: int = 32,
88
- vq_embed_dim: Optional[int] = None,
89
- scaling_factor: float = 0.18215,
90
- norm_type: str = "group", # group, spatial
91
- mid_block_add_attention=True,
92
- lookup_from_codebook=False,
93
- force_upcast=False,
94
- ):
95
- super().__init__()
96
-
97
- # pass init params to Encoder
98
- self.encoder = Encoder(
99
- in_channels=in_channels,
100
- out_channels=latent_channels,
101
- down_block_types=down_block_types,
102
- block_out_channels=block_out_channels,
103
- layers_per_block=layers_per_block,
104
- act_fn=act_fn,
105
- norm_num_groups=norm_num_groups,
106
- double_z=False,
107
- mid_block_add_attention=mid_block_add_attention,
108
- )
109
-
110
- vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
111
-
112
- self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
113
- self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
114
- self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
115
-
116
- # pass init params to Decoder
117
- self.decoder = Decoder(
118
- in_channels=latent_channels,
119
- out_channels=out_channels,
120
- up_block_types=up_block_types,
121
- block_out_channels=block_out_channels,
122
- layers_per_block=layers_per_block,
123
- act_fn=act_fn,
124
- norm_num_groups=norm_num_groups,
125
- norm_type=norm_type,
126
- mid_block_add_attention=mid_block_add_attention,
127
- )
128
-
129
- @apply_forward_hook
130
- def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
131
- h = self.encoder(x)
132
- h = self.quant_conv(h)
133
-
134
- if not return_dict:
135
- return (h,)
136
-
137
- return VQEncoderOutput(latents=h)
138
-
139
- @apply_forward_hook
140
- def decode(
141
- self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
142
- ) -> Union[DecoderOutput, torch.Tensor]:
143
- # also go through quantization layer
144
- if not force_not_quantize:
145
- quant, commit_loss, _ = self.quantize(h)
146
- elif self.config.lookup_from_codebook:
147
- quant = self.quantize.get_codebook_entry(h, shape)
148
- commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
149
- else:
150
- quant = h
151
- commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
152
- quant2 = self.post_quant_conv(quant)
153
- dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
154
-
155
- if not return_dict:
156
- return dec, commit_loss
157
-
158
- return DecoderOutput(sample=dec, commit_loss=commit_loss)
159
-
160
- def forward(
161
- self, sample: torch.Tensor, return_dict: bool = True
162
- ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
163
- r"""
164
- The [`VQModel`] forward method.
165
-
166
- Args:
167
- sample (`torch.Tensor`): Input sample.
168
- return_dict (`bool`, *optional*, defaults to `True`):
169
- Whether or not to return a [`models.autoencoders.vq_model.VQEncoderOutput`] instead of a plain tuple.
170
-
171
- Returns:
172
- [`~models.autoencoders.vq_model.VQEncoderOutput`] or `tuple`:
173
- If return_dict is True, a [`~models.autoencoders.vq_model.VQEncoderOutput`] is returned, otherwise a
174
- plain `tuple` is returned.
175
- """
176
-
177
- h = self.encode(sample).latents
178
- dec = self.decode(h)
179
-
180
- if not return_dict:
181
- return dec.sample, dec.commit_loss
182
- return dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/src/diffusers/pipelines/colorflow/pipeline_colorflow_sd.py CHANGED
@@ -1069,9 +1069,6 @@ class ColorFlowSDPipeline(
1069
 
1070
  image_B.paste(image.crop((left, top, right, bottom)), (left, top))
1071
 
1072
- # image_A.save('/group/40034/zhuangjunhao/BrushNet_RAG/ref.png')
1073
- # image_B.save('/group/40034/zhuangjunhao/BrushNet_RAG/bw.png')
1074
-
1075
 
1076
 
1077
  image = self.prepare_image(
@@ -1186,10 +1183,7 @@ class ColorFlowSDPipeline(
1186
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1187
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1188
  for i, t in enumerate(timesteps):
1189
- # import os
1190
- # print(t)
1191
- # os.makedirs(f'/group/40034/zhuangjunhao/BrushNet_RAG/examples/colorguider/test_pipeline/test_result/all_output/paper_atten/300001/attenmap/{t.item()}/',exist_ok=True)
1192
- # Relevant thread:
1193
  # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1194
  if (is_unet_compiled and is_colorguider_compiled) and is_torch_higher_equal_2_1:
1195
  torch._inductor.cudagraph_mark_step_begin()
 
1069
 
1070
  image_B.paste(image.crop((left, top, right, bottom)), (left, top))
1071
 
 
 
 
1072
 
1073
 
1074
  image = self.prepare_image(
 
1183
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1184
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1185
  for i, t in enumerate(timesteps):
1186
+ # Relevant thread:
 
 
 
1187
  # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1188
  if (is_unet_compiled and is_colorguider_compiled) and is_torch_higher_equal_2_1:
1189
  torch._inductor.cudagraph_mark_step_begin()