Atin Sakkeer Hussain commited on
Commit
795ce43
1 Parent(s): eae7a25
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/M2UGen-Demo.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="GOOGLE" />
10
+ <option name="myDocStringFormat" value="Google" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="18">
8
+ <item index="0" class="java.lang.String" itemvalue="pandas" />
9
+ <item index="1" class="java.lang.String" itemvalue="tqdm" />
10
+ <item index="2" class="java.lang.String" itemvalue="absl-py" />
11
+ <item index="3" class="java.lang.String" itemvalue="dgl" />
12
+ <item index="4" class="java.lang.String" itemvalue="torch" />
13
+ <item index="5" class="java.lang.String" itemvalue="numpy" />
14
+ <item index="6" class="java.lang.String" itemvalue="Cython" />
15
+ <item index="7" class="java.lang.String" itemvalue="torchlibrosa" />
16
+ <item index="8" class="java.lang.String" itemvalue="gdown" />
17
+ <item index="9" class="java.lang.String" itemvalue="wget" />
18
+ <item index="10" class="java.lang.String" itemvalue="accelerate" />
19
+ <item index="11" class="java.lang.String" itemvalue="transformers" />
20
+ <item index="12" class="java.lang.String" itemvalue="gradio" />
21
+ <item index="13" class="java.lang.String" itemvalue="tensorboard" />
22
+ <item index="14" class="java.lang.String" itemvalue="diffusers" />
23
+ <item index="15" class="java.lang.String" itemvalue="opencv-python" />
24
+ <item index="16" class="java.lang.String" itemvalue="huggingface_hub" />
25
+ <item index="17" class="java.lang.String" itemvalue="Pillow" />
26
+ </list>
27
+ </value>
28
+ </option>
29
+ </inspection_tool>
30
+ <inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
31
+ <option name="ignoredErrors">
32
+ <list>
33
+ <option value="W605" />
34
+ </list>
35
+ </option>
36
+ </inspection_tool>
37
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
38
+ <option name="ignoredErrors">
39
+ <list>
40
+ <option value="N806" />
41
+ <option value="N802" />
42
+ <option value="N803" />
43
+ </list>
44
+ </option>
45
+ </inspection_tool>
46
+ <inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
47
+ <option name="ignoredIdentifiers">
48
+ <list>
49
+ <option value="tokenizers.BertWordPieceTokenizer" />
50
+ <option value="cv2.aruco" />
51
+ <option value="llama" />
52
+ </list>
53
+ </option>
54
+ </inspection_tool>
55
+ </profile>
56
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (AudioCaption)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/M2UGen-Demo.iml" filepath="$PROJECT_DIR$/.idea/M2UGen-Demo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
llama/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .llama import ModelArgs, Transformer
2
+ from .tokenizer import Tokenizer
3
+ from .m2ugen import *
4
+ from .utils import format_prompt
llama/audioldm2/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline_audioldm2 import AudioLDM2Pipeline
llama/audioldm2/modeling_audioldm2.py ADDED
@@ -0,0 +1,1513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ AttentionProcessor,
29
+ AttnAddedKVProcessor,
30
+ AttnProcessor,
31
+ )
32
+ from diffusers.models.embeddings import (
33
+ TimestepEmbedding,
34
+ Timesteps,
35
+ )
36
+ from diffusers.models.modeling_utils import ModelMixin
37
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
38
+ from diffusers.models.transformer_2d import Transformer2DModel
39
+ from diffusers.models.unet_2d_blocks import DownBlock2D, UpBlock2D
40
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
41
+ from diffusers.utils import BaseOutput, is_torch_version, logging
42
+
43
+
44
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
+
46
+
47
+ def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token):
48
+ batch_size = hidden_states.shape[0]
49
+
50
+ if attention_mask is not None:
51
+ # Add two more steps to attn mask
52
+ new_attn_mask_step = attention_mask.new_ones((batch_size, 1))
53
+ attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1)
54
+
55
+ # Add the SOS / EOS tokens at the start / end of the sequence respectively
56
+ sos_token = sos_token.expand(batch_size, 1, -1)
57
+ eos_token = eos_token.expand(batch_size, 1, -1)
58
+ hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1)
59
+ return hidden_states, attention_mask
60
+
61
+
62
+ @dataclass
63
+ class AudioLDM2ProjectionModelOutput(BaseOutput):
64
+ """
65
+ Args:
66
+ Class for AudioLDM2 projection layer's outputs.
67
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
68
+ Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text
69
+ encoders and subsequently concatenating them together.
70
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
71
+ Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks
72
+ for the two text encoders together. Mask values selected in `[0, 1]`:
73
+
74
+ - 1 for tokens that are **not masked**,
75
+ - 0 for tokens that are **masked**.
76
+ """
77
+
78
+ hidden_states: torch.FloatTensor
79
+ attention_mask: Optional[torch.LongTensor] = None
80
+
81
+
82
+ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
83
+ """
84
+ A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned
85
+ embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with
86
+ `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first.
87
+
88
+ Args:
89
+ text_encoder_dim (`int`):
90
+ Dimensionality of the text embeddings from the first text encoder (CLAP).
91
+ text_encoder_1_dim (`int`):
92
+ Dimensionality of the text embeddings from the second text encoder (T5 or VITS).
93
+ langauge_model_dim (`int`):
94
+ Dimensionality of the text embeddings from the language model (GPT2).
95
+ """
96
+
97
+ @register_to_config
98
+ def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
99
+ super().__init__()
100
+ # additional projection layers for each text encoder
101
+ self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
102
+ self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim)
103
+
104
+ # learnable SOS / EOS token embeddings for each text encoder
105
+ self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim))
106
+ self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim))
107
+
108
+ self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
109
+ self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
110
+
111
+ def forward(
112
+ self,
113
+ hidden_states: Optional[torch.FloatTensor] = None,
114
+ hidden_states_1: Optional[torch.FloatTensor] = None,
115
+ attention_mask: Optional[torch.LongTensor] = None,
116
+ attention_mask_1: Optional[torch.LongTensor] = None,
117
+ ):
118
+ hidden_states = self.projection(hidden_states)
119
+ hidden_states, attention_mask = add_special_tokens(
120
+ hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
121
+ )
122
+
123
+ hidden_states_1 = self.projection_1(hidden_states_1)
124
+ hidden_states_1, attention_mask_1 = add_special_tokens(
125
+ hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
126
+ )
127
+
128
+ # concatenate clap and t5 text encoding
129
+ hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1)
130
+
131
+ # concatenate attention masks
132
+ if attention_mask is None and attention_mask_1 is not None:
133
+ attention_mask = attention_mask_1.new_ones((hidden_states[:2]))
134
+ elif attention_mask is not None and attention_mask_1 is None:
135
+ attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2]))
136
+
137
+ if attention_mask is not None and attention_mask_1 is not None:
138
+ attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1)
139
+ else:
140
+ attention_mask = None
141
+
142
+ return AudioLDM2ProjectionModelOutput(
143
+ hidden_states=hidden_states,
144
+ attention_mask=attention_mask,
145
+ )
146
+
147
+
148
+ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
149
+ r"""
150
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
151
+ shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional
152
+ self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up
153
+ to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`.
154
+
155
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
156
+ for all models (such as downloading or saving).
157
+
158
+ Parameters:
159
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
160
+ Height and width of input/output sample.
161
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
162
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
163
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
164
+ Whether to flip the sin to cos in the time embedding.
165
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
166
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
167
+ The tuple of downsample blocks to use.
168
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
169
+ Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2.
170
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
171
+ The tuple of upsample blocks to use.
172
+ only_cross_attention (`bool` or `Tuple[bool]`, *optional*, default to `False`):
173
+ Whether to include self-attention in the basic transformer blocks, see
174
+ [`~models.attention.BasicTransformerBlock`].
175
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
176
+ The tuple of output channels for each block.
177
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
178
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
179
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
180
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
181
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
182
+ If `None`, normalization and activation layers is skipped in post-processing.
183
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
184
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
185
+ The dimension of the cross attention features.
186
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
187
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
188
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
189
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
190
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
191
+ num_attention_heads (`int`, *optional*):
192
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
193
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
194
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
195
+ class_embed_type (`str`, *optional*, defaults to `None`):
196
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
197
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
198
+ num_class_embeds (`int`, *optional*, defaults to `None`):
199
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
200
+ class conditioning with `class_embed_type` equal to `None`.
201
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
202
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
203
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
204
+ An optional override for the dimension of the projected time embedding.
205
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
206
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
207
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
208
+ timestep_post_act (`str`, *optional*, defaults to `None`):
209
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
210
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
211
+ The dimension of `cond_proj` layer in the timestep embedding.
212
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
213
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
214
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
215
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
216
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
217
+ embeddings with the class embeddings.
218
+ """
219
+
220
+ _supports_gradient_checkpointing = True
221
+
222
+ @register_to_config
223
+ def __init__(
224
+ self,
225
+ sample_size: Optional[int] = None,
226
+ in_channels: int = 4,
227
+ out_channels: int = 4,
228
+ flip_sin_to_cos: bool = True,
229
+ freq_shift: int = 0,
230
+ down_block_types: Tuple[str] = (
231
+ "CrossAttnDownBlock2D",
232
+ "CrossAttnDownBlock2D",
233
+ "CrossAttnDownBlock2D",
234
+ "DownBlock2D",
235
+ ),
236
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
237
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
238
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
239
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
240
+ layers_per_block: Union[int, Tuple[int]] = 2,
241
+ downsample_padding: int = 1,
242
+ mid_block_scale_factor: float = 1,
243
+ act_fn: str = "silu",
244
+ norm_num_groups: Optional[int] = 32,
245
+ norm_eps: float = 1e-5,
246
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
247
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
248
+ attention_head_dim: Union[int, Tuple[int]] = 8,
249
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
250
+ use_linear_projection: bool = False,
251
+ class_embed_type: Optional[str] = None,
252
+ num_class_embeds: Optional[int] = None,
253
+ upcast_attention: bool = False,
254
+ resnet_time_scale_shift: str = "default",
255
+ time_embedding_type: str = "positional",
256
+ time_embedding_dim: Optional[int] = None,
257
+ time_embedding_act_fn: Optional[str] = None,
258
+ timestep_post_act: Optional[str] = None,
259
+ time_cond_proj_dim: Optional[int] = None,
260
+ conv_in_kernel: int = 3,
261
+ conv_out_kernel: int = 3,
262
+ projection_class_embeddings_input_dim: Optional[int] = None,
263
+ class_embeddings_concat: bool = False,
264
+ ):
265
+ super().__init__()
266
+
267
+ self.sample_size = sample_size
268
+
269
+ if num_attention_heads is not None:
270
+ raise ValueError(
271
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
272
+ )
273
+
274
+ # If `num_attention_heads` is not defined (which is the case for most models)
275
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
276
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
277
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
278
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
279
+ # which is why we correct for the naming here.
280
+ num_attention_heads = num_attention_heads or attention_head_dim
281
+
282
+ # Check inputs
283
+ if len(down_block_types) != len(up_block_types):
284
+ raise ValueError(
285
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
286
+ )
287
+
288
+ if len(block_out_channels) != len(down_block_types):
289
+ raise ValueError(
290
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
291
+ )
292
+
293
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
294
+ raise ValueError(
295
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
296
+ )
297
+
298
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
299
+ raise ValueError(
300
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
301
+ )
302
+
303
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
304
+ raise ValueError(
305
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
306
+ )
307
+
308
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
309
+ raise ValueError(
310
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
311
+ )
312
+
313
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
314
+ raise ValueError(
315
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
316
+ )
317
+
318
+ # input
319
+ conv_in_padding = (conv_in_kernel - 1) // 2
320
+ self.conv_in = nn.Conv2d(
321
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
322
+ )
323
+
324
+ # time
325
+ if time_embedding_type == "positional":
326
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
327
+
328
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
329
+ timestep_input_dim = block_out_channels[0]
330
+ else:
331
+ raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.")
332
+
333
+ self.time_embedding = TimestepEmbedding(
334
+ timestep_input_dim,
335
+ time_embed_dim,
336
+ act_fn=act_fn,
337
+ post_act_fn=timestep_post_act,
338
+ cond_proj_dim=time_cond_proj_dim,
339
+ )
340
+
341
+ # class embedding
342
+ if class_embed_type is None and num_class_embeds is not None:
343
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
344
+ elif class_embed_type == "timestep":
345
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
346
+ elif class_embed_type == "identity":
347
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
348
+ elif class_embed_type == "projection":
349
+ if projection_class_embeddings_input_dim is None:
350
+ raise ValueError(
351
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
352
+ )
353
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
354
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
355
+ # 2. it projects from an arbitrary input dimension.
356
+ #
357
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
358
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
359
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
360
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
361
+ elif class_embed_type == "simple_projection":
362
+ if projection_class_embeddings_input_dim is None:
363
+ raise ValueError(
364
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
365
+ )
366
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
367
+ else:
368
+ self.class_embedding = None
369
+
370
+ if time_embedding_act_fn is None:
371
+ self.time_embed_act = None
372
+ else:
373
+ self.time_embed_act = get_activation(time_embedding_act_fn)
374
+
375
+ self.down_blocks = nn.ModuleList([])
376
+ self.up_blocks = nn.ModuleList([])
377
+
378
+ if isinstance(only_cross_attention, bool):
379
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
380
+
381
+ if isinstance(num_attention_heads, int):
382
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
383
+
384
+ if isinstance(cross_attention_dim, int):
385
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
386
+
387
+ if isinstance(layers_per_block, int):
388
+ layers_per_block = [layers_per_block] * len(down_block_types)
389
+
390
+ if isinstance(transformer_layers_per_block, int):
391
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
392
+
393
+ if class_embeddings_concat:
394
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
395
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
396
+ # regular time embeddings
397
+ blocks_time_embed_dim = time_embed_dim * 2
398
+ else:
399
+ blocks_time_embed_dim = time_embed_dim
400
+
401
+ # down
402
+ output_channel = block_out_channels[0]
403
+ for i, down_block_type in enumerate(down_block_types):
404
+ input_channel = output_channel
405
+ output_channel = block_out_channels[i]
406
+ is_final_block = i == len(block_out_channels) - 1
407
+
408
+ down_block = get_down_block(
409
+ down_block_type,
410
+ num_layers=layers_per_block[i],
411
+ transformer_layers_per_block=transformer_layers_per_block[i],
412
+ in_channels=input_channel,
413
+ out_channels=output_channel,
414
+ temb_channels=blocks_time_embed_dim,
415
+ add_downsample=not is_final_block,
416
+ resnet_eps=norm_eps,
417
+ resnet_act_fn=act_fn,
418
+ resnet_groups=norm_num_groups,
419
+ cross_attention_dim=cross_attention_dim[i],
420
+ num_attention_heads=num_attention_heads[i],
421
+ downsample_padding=downsample_padding,
422
+ use_linear_projection=use_linear_projection,
423
+ only_cross_attention=only_cross_attention[i],
424
+ upcast_attention=upcast_attention,
425
+ resnet_time_scale_shift=resnet_time_scale_shift,
426
+ )
427
+ self.down_blocks.append(down_block)
428
+
429
+ # mid
430
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
431
+ self.mid_block = UNetMidBlock2DCrossAttn(
432
+ transformer_layers_per_block=transformer_layers_per_block[-1],
433
+ in_channels=block_out_channels[-1],
434
+ temb_channels=blocks_time_embed_dim,
435
+ resnet_eps=norm_eps,
436
+ resnet_act_fn=act_fn,
437
+ output_scale_factor=mid_block_scale_factor,
438
+ resnet_time_scale_shift=resnet_time_scale_shift,
439
+ cross_attention_dim=cross_attention_dim[-1],
440
+ num_attention_heads=num_attention_heads[-1],
441
+ resnet_groups=norm_num_groups,
442
+ use_linear_projection=use_linear_projection,
443
+ upcast_attention=upcast_attention,
444
+ )
445
+ else:
446
+ raise ValueError(
447
+ f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2."
448
+ )
449
+
450
+ # count how many layers upsample the images
451
+ self.num_upsamplers = 0
452
+
453
+ # up
454
+ reversed_block_out_channels = list(reversed(block_out_channels))
455
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
456
+ reversed_layers_per_block = list(reversed(layers_per_block))
457
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
458
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
459
+ only_cross_attention = list(reversed(only_cross_attention))
460
+
461
+ output_channel = reversed_block_out_channels[0]
462
+ for i, up_block_type in enumerate(up_block_types):
463
+ is_final_block = i == len(block_out_channels) - 1
464
+
465
+ prev_output_channel = output_channel
466
+ output_channel = reversed_block_out_channels[i]
467
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
468
+
469
+ # add upsample block for all BUT final layer
470
+ if not is_final_block:
471
+ add_upsample = True
472
+ self.num_upsamplers += 1
473
+ else:
474
+ add_upsample = False
475
+
476
+ up_block = get_up_block(
477
+ up_block_type,
478
+ num_layers=reversed_layers_per_block[i] + 1,
479
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
480
+ in_channels=input_channel,
481
+ out_channels=output_channel,
482
+ prev_output_channel=prev_output_channel,
483
+ temb_channels=blocks_time_embed_dim,
484
+ add_upsample=add_upsample,
485
+ resnet_eps=norm_eps,
486
+ resnet_act_fn=act_fn,
487
+ resnet_groups=norm_num_groups,
488
+ cross_attention_dim=reversed_cross_attention_dim[i],
489
+ num_attention_heads=reversed_num_attention_heads[i],
490
+ use_linear_projection=use_linear_projection,
491
+ only_cross_attention=only_cross_attention[i],
492
+ upcast_attention=upcast_attention,
493
+ resnet_time_scale_shift=resnet_time_scale_shift,
494
+ )
495
+ self.up_blocks.append(up_block)
496
+ prev_output_channel = output_channel
497
+
498
+ # out
499
+ if norm_num_groups is not None:
500
+ self.conv_norm_out = nn.GroupNorm(
501
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
502
+ )
503
+
504
+ self.conv_act = get_activation(act_fn)
505
+
506
+ else:
507
+ self.conv_norm_out = None
508
+ self.conv_act = None
509
+
510
+ conv_out_padding = (conv_out_kernel - 1) // 2
511
+ self.conv_out = nn.Conv2d(
512
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
513
+ )
514
+
515
+ @property
516
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
517
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
518
+ r"""
519
+ Returns:
520
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
521
+ indexed by its weight name.
522
+ """
523
+ # set recursively
524
+ processors = {}
525
+
526
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
527
+ if hasattr(module, "get_processor"):
528
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
529
+
530
+ for sub_name, child in module.named_children():
531
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
532
+
533
+ return processors
534
+
535
+ for name, module in self.named_children():
536
+ fn_recursive_add_processors(name, module, processors)
537
+
538
+ return processors
539
+
540
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
541
+ def set_attn_processor(
542
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
543
+ ):
544
+ r"""
545
+ Sets the attention processor to use to compute attention.
546
+
547
+ Parameters:
548
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
549
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
550
+ for **all** `Attention` layers.
551
+
552
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
553
+ processor. This is strongly recommended when setting trainable attention processors.
554
+
555
+ """
556
+ count = len(self.attn_processors.keys())
557
+
558
+ if isinstance(processor, dict) and len(processor) != count:
559
+ raise ValueError(
560
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
561
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
562
+ )
563
+
564
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
565
+ if hasattr(module, "set_processor"):
566
+ if not isinstance(processor, dict):
567
+ module.set_processor(processor, _remove_lora=_remove_lora)
568
+ else:
569
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
570
+
571
+ for sub_name, child in module.named_children():
572
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
573
+
574
+ for name, module in self.named_children():
575
+ fn_recursive_attn_processor(name, module, processor)
576
+
577
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
578
+ def set_default_attn_processor(self):
579
+ """
580
+ Disables custom attention processors and sets the default attention implementation.
581
+ """
582
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
583
+ processor = AttnAddedKVProcessor()
584
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
585
+ processor = AttnProcessor()
586
+ else:
587
+ raise ValueError(
588
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
589
+ )
590
+
591
+ self.set_attn_processor(processor, _remove_lora=True)
592
+
593
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
594
+ def set_attention_slice(self, slice_size):
595
+ r"""
596
+ Enable sliced attention computation.
597
+
598
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
599
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
600
+
601
+ Args:
602
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
603
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
604
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
605
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
606
+ must be a multiple of `slice_size`.
607
+ """
608
+ sliceable_head_dims = []
609
+
610
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
611
+ if hasattr(module, "set_attention_slice"):
612
+ sliceable_head_dims.append(module.sliceable_head_dim)
613
+
614
+ for child in module.children():
615
+ fn_recursive_retrieve_sliceable_dims(child)
616
+
617
+ # retrieve number of attention layers
618
+ for module in self.children():
619
+ fn_recursive_retrieve_sliceable_dims(module)
620
+
621
+ num_sliceable_layers = len(sliceable_head_dims)
622
+
623
+ if slice_size == "auto":
624
+ # half the attention head size is usually a good trade-off between
625
+ # speed and memory
626
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
627
+ elif slice_size == "max":
628
+ # make smallest slice possible
629
+ slice_size = num_sliceable_layers * [1]
630
+
631
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
632
+
633
+ if len(slice_size) != len(sliceable_head_dims):
634
+ raise ValueError(
635
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
636
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
637
+ )
638
+
639
+ for i in range(len(slice_size)):
640
+ size = slice_size[i]
641
+ dim = sliceable_head_dims[i]
642
+ if size is not None and size > dim:
643
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
644
+
645
+ # Recursively walk through all the children.
646
+ # Any children which exposes the set_attention_slice method
647
+ # gets the message
648
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
649
+ if hasattr(module, "set_attention_slice"):
650
+ module.set_attention_slice(slice_size.pop())
651
+
652
+ for child in module.children():
653
+ fn_recursive_set_attention_slice(child, slice_size)
654
+
655
+ reversed_slice_size = list(reversed(slice_size))
656
+ for module in self.children():
657
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
658
+
659
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
660
+ def _set_gradient_checkpointing(self, module, value=False):
661
+ if hasattr(module, "gradient_checkpointing"):
662
+ module.gradient_checkpointing = value
663
+
664
+ def forward(
665
+ self,
666
+ sample: torch.FloatTensor,
667
+ timestep: Union[torch.Tensor, float, int],
668
+ encoder_hidden_states: torch.Tensor,
669
+ class_labels: Optional[torch.Tensor] = None,
670
+ timestep_cond: Optional[torch.Tensor] = None,
671
+ attention_mask: Optional[torch.Tensor] = None,
672
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
673
+ encoder_attention_mask: Optional[torch.Tensor] = None,
674
+ return_dict: bool = True,
675
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
676
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
677
+ ) -> Union[UNet2DConditionOutput, Tuple]:
678
+ r"""
679
+ The [`AudioLDM2UNet2DConditionModel`] forward method.
680
+
681
+ Args:
682
+ sample (`torch.FloatTensor`):
683
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
684
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
685
+ encoder_hidden_states (`torch.FloatTensor`):
686
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
687
+ encoder_attention_mask (`torch.Tensor`):
688
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
689
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
690
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
691
+ return_dict (`bool`, *optional*, defaults to `True`):
692
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
693
+ tuple.
694
+ cross_attention_kwargs (`dict`, *optional*):
695
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
696
+ encoder_hidden_states_1 (`torch.FloatTensor`, *optional*):
697
+ A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be
698
+ used to condition the model on a different set of embeddings to `encoder_hidden_states`.
699
+ encoder_attention_mask_1 (`torch.Tensor`, *optional*):
700
+ A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`.
701
+ If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
702
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
703
+
704
+ Returns:
705
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
706
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
707
+ a `tuple` is returned where the first element is the sample tensor.
708
+ """
709
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
710
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
711
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
712
+ # on the fly if necessary.
713
+ default_overall_up_factor = 2**self.num_upsamplers
714
+
715
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
716
+ forward_upsample_size = False
717
+ upsample_size = None
718
+
719
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
720
+ logger.info("Forward upsample size to force interpolation output size.")
721
+ forward_upsample_size = True
722
+
723
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
724
+ # expects mask of shape:
725
+ # [batch, key_tokens]
726
+ # adds singleton query_tokens dimension:
727
+ # [batch, 1, key_tokens]
728
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
729
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
730
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
731
+ if attention_mask is not None:
732
+ # assume that mask is expressed as:
733
+ # (1 = keep, 0 = discard)
734
+ # convert mask into a bias that can be added to attention scores:
735
+ # (keep = +0, discard = -10000.0)
736
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
737
+ attention_mask = attention_mask.unsqueeze(1)
738
+
739
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
740
+ if encoder_attention_mask is not None:
741
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
742
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
743
+
744
+ if encoder_attention_mask_1 is not None:
745
+ encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
746
+ encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
747
+
748
+ # 1. time
749
+ timesteps = timestep
750
+ if not torch.is_tensor(timesteps):
751
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
752
+ # This would be a good case for the `match` statement (Python 3.10+)
753
+ is_mps = sample.device.type == "mps"
754
+ if isinstance(timestep, float):
755
+ dtype = torch.float32 if is_mps else torch.float64
756
+ else:
757
+ dtype = torch.int32 if is_mps else torch.int64
758
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
759
+ elif len(timesteps.shape) == 0:
760
+ timesteps = timesteps[None].to(sample.device)
761
+
762
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
763
+ timesteps = timesteps.expand(sample.shape[0])
764
+
765
+ t_emb = self.time_proj(timesteps)
766
+
767
+ # `Timesteps` does not contain any weights and will always return f32 tensors
768
+ # but time_embedding might actually be running in fp16. so we need to cast here.
769
+ # there might be better ways to encapsulate this.
770
+ t_emb = t_emb.to(dtype=sample.dtype)
771
+
772
+ emb = self.time_embedding(t_emb, timestep_cond)
773
+ aug_emb = None
774
+
775
+ if self.class_embedding is not None:
776
+ if class_labels is None:
777
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
778
+
779
+ if self.config.class_embed_type == "timestep":
780
+ class_labels = self.time_proj(class_labels)
781
+
782
+ # `Timesteps` does not contain any weights and will always return f32 tensors
783
+ # there might be better ways to encapsulate this.
784
+ class_labels = class_labels.to(dtype=sample.dtype)
785
+
786
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
787
+
788
+ if self.config.class_embeddings_concat:
789
+ emb = torch.cat([emb, class_emb], dim=-1)
790
+ else:
791
+ emb = emb + class_emb
792
+
793
+ emb = emb + aug_emb if aug_emb is not None else emb
794
+
795
+ if self.time_embed_act is not None:
796
+ emb = self.time_embed_act(emb)
797
+
798
+ # 2. pre-process
799
+ sample = self.conv_in(sample)
800
+
801
+ # 3. down
802
+ down_block_res_samples = (sample,)
803
+ for downsample_block in self.down_blocks:
804
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
805
+ sample, res_samples = downsample_block(
806
+ hidden_states=sample,
807
+ temb=emb,
808
+ encoder_hidden_states=encoder_hidden_states,
809
+ attention_mask=attention_mask,
810
+ cross_attention_kwargs=cross_attention_kwargs,
811
+ encoder_attention_mask=encoder_attention_mask,
812
+ encoder_hidden_states_1=encoder_hidden_states_1,
813
+ encoder_attention_mask_1=encoder_attention_mask_1,
814
+ )
815
+ else:
816
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
817
+
818
+ down_block_res_samples += res_samples
819
+
820
+ # 4. mid
821
+ if self.mid_block is not None:
822
+ sample = self.mid_block(
823
+ sample,
824
+ emb,
825
+ encoder_hidden_states=encoder_hidden_states,
826
+ attention_mask=attention_mask,
827
+ cross_attention_kwargs=cross_attention_kwargs,
828
+ encoder_attention_mask=encoder_attention_mask,
829
+ encoder_hidden_states_1=encoder_hidden_states_1,
830
+ encoder_attention_mask_1=encoder_attention_mask_1,
831
+ )
832
+
833
+ # 5. up
834
+ for i, upsample_block in enumerate(self.up_blocks):
835
+ is_final_block = i == len(self.up_blocks) - 1
836
+
837
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
838
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
839
+
840
+ # if we have not reached the final block and need to forward the
841
+ # upsample size, we do it here
842
+ if not is_final_block and forward_upsample_size:
843
+ upsample_size = down_block_res_samples[-1].shape[2:]
844
+
845
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
846
+ sample = upsample_block(
847
+ hidden_states=sample,
848
+ temb=emb,
849
+ res_hidden_states_tuple=res_samples,
850
+ encoder_hidden_states=encoder_hidden_states,
851
+ cross_attention_kwargs=cross_attention_kwargs,
852
+ upsample_size=upsample_size,
853
+ attention_mask=attention_mask,
854
+ encoder_attention_mask=encoder_attention_mask,
855
+ encoder_hidden_states_1=encoder_hidden_states_1,
856
+ encoder_attention_mask_1=encoder_attention_mask_1,
857
+ )
858
+ else:
859
+ sample = upsample_block(
860
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
861
+ )
862
+
863
+ # 6. post-process
864
+ if self.conv_norm_out:
865
+ sample = self.conv_norm_out(sample)
866
+ sample = self.conv_act(sample)
867
+ sample = self.conv_out(sample)
868
+
869
+ if not return_dict:
870
+ return (sample,)
871
+
872
+ return UNet2DConditionOutput(sample=sample)
873
+
874
+
875
+ def get_down_block(
876
+ down_block_type,
877
+ num_layers,
878
+ in_channels,
879
+ out_channels,
880
+ temb_channels,
881
+ add_downsample,
882
+ resnet_eps,
883
+ resnet_act_fn,
884
+ transformer_layers_per_block=1,
885
+ num_attention_heads=None,
886
+ resnet_groups=None,
887
+ cross_attention_dim=None,
888
+ downsample_padding=None,
889
+ use_linear_projection=False,
890
+ only_cross_attention=False,
891
+ upcast_attention=False,
892
+ resnet_time_scale_shift="default",
893
+ ):
894
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
895
+ if down_block_type == "DownBlock2D":
896
+ return DownBlock2D(
897
+ num_layers=num_layers,
898
+ in_channels=in_channels,
899
+ out_channels=out_channels,
900
+ temb_channels=temb_channels,
901
+ add_downsample=add_downsample,
902
+ resnet_eps=resnet_eps,
903
+ resnet_act_fn=resnet_act_fn,
904
+ resnet_groups=resnet_groups,
905
+ downsample_padding=downsample_padding,
906
+ resnet_time_scale_shift=resnet_time_scale_shift,
907
+ )
908
+ elif down_block_type == "CrossAttnDownBlock2D":
909
+ if cross_attention_dim is None:
910
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
911
+ return CrossAttnDownBlock2D(
912
+ num_layers=num_layers,
913
+ transformer_layers_per_block=transformer_layers_per_block,
914
+ in_channels=in_channels,
915
+ out_channels=out_channels,
916
+ temb_channels=temb_channels,
917
+ add_downsample=add_downsample,
918
+ resnet_eps=resnet_eps,
919
+ resnet_act_fn=resnet_act_fn,
920
+ resnet_groups=resnet_groups,
921
+ downsample_padding=downsample_padding,
922
+ cross_attention_dim=cross_attention_dim,
923
+ num_attention_heads=num_attention_heads,
924
+ use_linear_projection=use_linear_projection,
925
+ only_cross_attention=only_cross_attention,
926
+ upcast_attention=upcast_attention,
927
+ resnet_time_scale_shift=resnet_time_scale_shift,
928
+ )
929
+ raise ValueError(f"{down_block_type} does not exist.")
930
+
931
+
932
+ def get_up_block(
933
+ up_block_type,
934
+ num_layers,
935
+ in_channels,
936
+ out_channels,
937
+ prev_output_channel,
938