File size: 9,904 Bytes
910e2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import numbers
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from diffusers.utils import is_torch_version


if is_torch_version(">=", "2.1.0"):
    LayerNorm = nn.LayerNorm
else:
    # Has optional bias parameter compared to torch layer norm
    # TODO: replace with torch layernorm once min required torch version >= 2.1
    class LayerNorm(nn.Module):
        def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True):
            super().__init__()

            self.eps = eps

            if isinstance(dim, numbers.Integral):
                dim = (dim,)

            self.dim = torch.Size(dim)

            if elementwise_affine:
                self.weight = nn.Parameter(torch.ones(dim))
                self.bias = nn.Parameter(torch.zeros(dim)) if bias else None
            else:
                self.weight = None
                self.bias = None

        def forward(self, input):
            return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps)


class FP32LayerNorm(nn.LayerNorm):
    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        origin_dtype = inputs.dtype
        return F.layer_norm(
            inputs.float(),
            self.normalized_shape,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        ).to(origin_dtype)


class RMSNorm(nn.Module):
    def __init__(self, dim, eps: float, elementwise_affine: bool = True):
        super().__init__()

        self.eps = eps

        if isinstance(dim, numbers.Integral):
            dim = (dim,)

        self.dim = torch.Size(dim)

        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.weight = None

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        if self.weight is not None:
            # convert into half-precision if necessary
            if self.weight.dtype in [torch.float16, torch.bfloat16]:
                hidden_states = hidden_states.to(self.weight.dtype)
            hidden_states = hidden_states * self.weight
        else:
            hidden_states = hidden_states.to(input_dtype)

        return hidden_states


class AdaLayerNormContinuous(nn.Module):
    def __init__(

        self,

        embedding_dim: int,

        conditioning_embedding_dim: int,

        # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters

        # because the output is immediately scaled and shifted by the projected conditioning embeddings.

        # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.

        # However, this is how it was implemented in the original code, and it's rather likely you should

        # set `elementwise_affine` to False.

        elementwise_affine=True,

        eps=1e-5,

        bias=True,

        norm_type="layer_norm",

    ):
        super().__init__()
        self.silu = nn.SiLU()
        self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
        if norm_type == "layer_norm":
            self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
        elif norm_type == "rms_norm":
            self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
        else:
            raise ValueError(f"unknown norm_type {norm_type}")

    def forward_with_pad(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
        assert hidden_length is not None
        
        emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
        batch_emb = torch.zeros_like(x).repeat(1, 1, 2)

        i_sum = 0
        num_stages = len(hidden_length)
        for i_p, length in enumerate(hidden_length):
            batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
            i_sum += length

        batch_scale, batch_shift = torch.chunk(batch_emb, 2, dim=2)
        x = self.norm(x) * (1 + batch_scale) + batch_shift
        return x

    def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor, hidden_length=None) -> torch.Tensor:
        # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
        if hidden_length is not None:
            return self.forward_with_pad(x, conditioning_embedding, hidden_length)
        emb = self.linear(self.silu(conditioning_embedding).to(x.dtype))
        scale, shift = torch.chunk(emb, 2, dim=1)
        x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
        return x


class AdaLayerNormZero(nn.Module):
    r"""

    Norm layer adaptive layer norm zero (adaLN-Zero).



    Parameters:

        embedding_dim (`int`): The size of each embedding vector.

        num_embeddings (`int`): The size of the embeddings dictionary.

    """

    def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
        super().__init__()
        self.emb = None

        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
        self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)

    def forward_with_pad(

        self,

        x: torch.Tensor,

        timestep: Optional[torch.Tensor] = None,

        class_labels: Optional[torch.LongTensor] = None,

        hidden_dtype: Optional[torch.dtype] = None,

        emb: Optional[torch.Tensor] = None,

        hidden_length: Optional[torch.Tensor] = None,

    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        # hidden_length: [[20, 30], [30, 40], [50, 60]]
        # x: [bs, seq_len, dim]
        if self.emb is not None:
            emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)

        emb = self.linear(self.silu(emb))
        batch_emb = torch.zeros_like(x).repeat(1, 1, 6)
    
        i_sum = 0
        num_stages = len(hidden_length)
        for i_p, length in enumerate(hidden_length):
            batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
            i_sum += length

        batch_shift_msa, batch_scale_msa, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp = batch_emb.chunk(6, dim=2)
        x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa
        return x, batch_gate_msa, batch_shift_mlp, batch_scale_mlp, batch_gate_mlp

    def forward(

        self,

        x: torch.Tensor,

        timestep: Optional[torch.Tensor] = None,

        class_labels: Optional[torch.LongTensor] = None,

        hidden_dtype: Optional[torch.dtype] = None,

        emb: Optional[torch.Tensor] = None,

        hidden_length: Optional[torch.Tensor] = None,

    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        if hidden_length is not None:
            return self.forward_with_pad(x, timestep, class_labels, hidden_dtype, emb, hidden_length)
        if self.emb is not None:
            emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
        emb = self.linear(self.silu(emb))
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        return x, gate_msa, shift_mlp, scale_mlp, gate_mlp


class AdaLayerNormZeroSingle(nn.Module):
    r"""

    Norm layer adaptive layer norm zero (adaLN-Zero).



    Parameters:

        embedding_dim (`int`): The size of each embedding vector.

        num_embeddings (`int`): The size of the embeddings dictionary.

    """

    def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
        super().__init__()

        self.silu = nn.SiLU()
        self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
        if norm_type == "layer_norm":
            self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
        else:
            raise ValueError(
                f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
            )

    def forward_with_pad(

        self,

        x: torch.Tensor,

        emb: Optional[torch.Tensor] = None,

        hidden_length: Optional[torch.Tensor] = None,

    ):
        emb = self.linear(self.silu(emb))
        batch_emb = torch.zeros_like(x).repeat(1, 1, 3)

        i_sum = 0
        num_stages = len(hidden_length)
        for i_p, length in enumerate(hidden_length):
            batch_emb[:, i_sum:i_sum+length] = emb[i_p::num_stages][:,None]
            i_sum += length

        batch_shift_msa, batch_scale_msa, batch_gate_msa = batch_emb.chunk(3, dim=2)

        x = self.norm(x) * (1 + batch_scale_msa) + batch_shift_msa

        return x, batch_gate_msa

    def forward(

        self,

        x: torch.Tensor,

        emb: Optional[torch.Tensor] = None,

        hidden_length: Optional[torch.Tensor] = None,

    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        if hidden_length is not None:
            return self.forward_with_pad(x, emb, hidden_length)
        emb = self.linear(self.silu(emb))
        shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
        x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
        return x, gate_msa