File size: 11,066 Bytes
4bf9661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
from tqdm import tqdm
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D


class DownsampleCausal3D(nn.Module):

    def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
        super().__init__()
        self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)

    def forward(self, hidden_states):
        hidden_states = self.conv(hidden_states)
        return hidden_states


class DownEncoderBlockCausal3D(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        dropout=0.0,
        num_layers=1,
        eps=1e-6,
        num_groups=32,
        add_downsample=True,
        downsample_stride=2,
    ):

        super().__init__()
        resnets = []
        for i in range(num_layers):
            cur_in_channel = in_channels if i == 0 else out_channels
            resnets.append(
                ResnetBlockCausal3D(
                    in_channels=cur_in_channel,
                    out_channels=out_channels,
                    groups=num_groups,
                    dropout=dropout,
                    eps=eps,
                ))
        self.resnets = nn.ModuleList(resnets)

        self.downsamplers = None
        if add_downsample:
            self.downsamplers = nn.ModuleList([DownsampleCausal3D(
                out_channels,
                out_channels,
                stride=downsample_stride,
            )])

    def forward(self, hidden_states):
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states)

        if self.downsamplers is not None:
            for downsampler in self.downsamplers:
                hidden_states = downsampler(hidden_states)

        return hidden_states


class EncoderCausal3D(nn.Module):

    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 16,
        eps=1e-6,
        dropout=0.0,
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        num_groups=32,
        time_compression_ratio: int = 4,
        spatial_compression_ratio: int = 8,
        gradient_checkpointing=False,
    ):
        super().__init__()
        self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
        self.down_blocks = nn.ModuleList([])

        # down
        output_channel = block_out_channels[0]
        for i in range(len(block_out_channels)):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1
            num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
            num_time_downsample_layers = int(np.log2(time_compression_ratio))

            add_spatial_downsample = bool(i < num_spatial_downsample_layers)
            add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)

            downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
            downsample_stride_T = (2,) if add_time_downsample else (1,)
            downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
            down_block = DownEncoderBlockCausal3D(
                in_channels=input_channel,
                out_channels=output_channel,
                dropout=dropout,
                num_layers=layers_per_block,
                eps=eps,
                num_groups=num_groups,
                add_downsample=bool(add_spatial_downsample or add_time_downsample),
                downsample_stride=downsample_stride,
            )
            self.down_blocks.append(down_block)

        # mid
        self.mid_block = UNetMidBlockCausal3D(
            in_channels=block_out_channels[-1],
            dropout=dropout,
            eps=eps,
            num_groups=num_groups,
            attention_head_dim=block_out_channels[-1],
        )
        # out
        self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
        self.conv_act = nn.SiLU()
        self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)

        self.gradient_checkpointing = gradient_checkpointing

    def forward(self, hidden_states):
        hidden_states = self.conv_in(hidden_states)
        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module):

                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            # down
            for down_block in self.down_blocks:
                torch.utils.checkpoint.checkpoint(
                    create_custom_forward(down_block),
                    hidden_states,
                    use_reentrant=False,
                )
            # middle
            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(self.mid_block),
                hidden_states,
                use_reentrant=False,
            )
        else:
            # down
            for down_block in self.down_blocks:
                hidden_states = down_block(hidden_states)
            # middle
            hidden_states = self.mid_block(hidden_states)
        # post-process
        hidden_states = self.conv_norm_out(hidden_states)
        hidden_states = self.conv_act(hidden_states)
        hidden_states = self.conv_out(hidden_states)

        return hidden_states


class HunyuanVideoVAEEncoder(nn.Module):

    def __init__(
        self,
        in_channels=3,
        out_channels=16,
        eps=1e-6,
        dropout=0.0,
        block_out_channels=[128, 256, 512, 512],
        layers_per_block=2,
        num_groups=32,
        time_compression_ratio=4,
        spatial_compression_ratio=8,
        gradient_checkpointing=False,
    ):
        super().__init__()
        self.encoder = EncoderCausal3D(
            in_channels=in_channels,
            out_channels=out_channels,
            eps=eps,
            dropout=dropout,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            num_groups=num_groups,
            time_compression_ratio=time_compression_ratio,
            spatial_compression_ratio=spatial_compression_ratio,
            gradient_checkpointing=gradient_checkpointing,
        )
        self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
        self.scaling_factor = 0.476986


    def forward(self, images):
        latents = self.encoder(images)
        latents = self.quant_conv(latents)
        latents = latents[:, :16]
        latents = latents * self.scaling_factor
        return latents
    

    def build_1d_mask(self, length, left_bound, right_bound, border_width):
        x = torch.ones((length,))
        if not left_bound:
            x[:border_width] = (torch.arange(border_width) + 1) / border_width
        if not right_bound:
            x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
        return x
    

    def build_mask(self, data, is_bound, border_width):
        _, _, T, H, W = data.shape
        t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
        h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
        w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])

        t = repeat(t, "T -> T H W", T=T, H=H, W=W)
        h = repeat(h, "H -> T H W", T=T, H=H, W=W)
        w = repeat(w, "W -> T H W", T=T, H=H, W=W)

        mask = torch.stack([t, h, w]).min(dim=0).values
        mask = rearrange(mask, "T H W -> 1 1 T H W")
        return mask
    

    def tile_forward(self, hidden_states, tile_size, tile_stride):
        B, C, T, H, W = hidden_states.shape
        size_t, size_h, size_w = tile_size
        stride_t, stride_h, stride_w = tile_stride

        # Split tasks
        tasks = []
        for t in range(0, T, stride_t):
            if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
            for h in range(0, H, stride_h):
                if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
                for w in range(0, W, stride_w):
                    if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
                    t_, h_, w_ = t + size_t, h + size_h, w + size_w
                    tasks.append((t, t_, h, h_, w, w_))

        # Run
        torch_dtype = self.quant_conv.weight.dtype
        data_device = hidden_states.device
        computation_device = self.quant_conv.weight.device

        weight = torch.zeros((1, 1,  (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
        values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)

        for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
            hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
            hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
            if t > 0:
                hidden_states_batch = hidden_states_batch[:, :, 1:]

            mask = self.build_mask(
                hidden_states_batch,
                is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
                border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
            ).to(dtype=torch_dtype, device=data_device)

            target_t = 0 if t==0 else t // 4 + 1
            target_h = h // 8
            target_w = w // 8
            values[
                :,
                :,
                target_t: target_t + hidden_states_batch.shape[2],
                target_h: target_h + hidden_states_batch.shape[3],
                target_w: target_w + hidden_states_batch.shape[4],
            ] += hidden_states_batch * mask
            weight[
                :,
                :,
                target_t: target_t + hidden_states_batch.shape[2],
                target_h: target_h + hidden_states_batch.shape[3],
                target_w: target_w + hidden_states_batch.shape[4],
            ] += mask
        return values / weight


    def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
        latents = latents.to(self.quant_conv.weight.dtype)
        return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)


    @staticmethod
    def state_dict_converter():
        return HunyuanVideoVAEEncoderStateDictConverter()


class HunyuanVideoVAEEncoderStateDictConverter:

    def __init__(self):
        pass

    def from_diffusers(self, state_dict):
        state_dict_ = {}
        for name in state_dict:
            if name.startswith('encoder.') or name.startswith('quant_conv.'):
                state_dict_[name] = state_dict[name]
        return state_dict_