smatta commited on
Commit
ad632ac
·
verified ·
1 Parent(s): bf657ff

Upload vae/ae_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vae/ae_model.py +141 -0
vae/ae_model.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """VAE model for WorldEngine frame encoding/decoding."""
17
+
18
+ from dataclasses import dataclass
19
+ from typing import List, Tuple
20
+
21
+ import torch
22
+ from torch import Tensor
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+ from .dcae import Encoder, Decoder, bake_weight_norm
27
+
28
+
29
+ @dataclass
30
+ class EncoderDecoderConfig:
31
+ """Config object for Encoder/Decoder initialization."""
32
+
33
+ channels: int
34
+ latent_channels: int
35
+ ch_0: int
36
+ ch_max: int
37
+ encoder_blocks_per_stage: List[int]
38
+ decoder_blocks_per_stage: List[int]
39
+ skip_logvar: bool = False
40
+
41
+
42
+ class WorldEngineVAE(ModelMixin, ConfigMixin):
43
+ """
44
+ VAE for encoding/decoding video frames using DCAE architecture.
45
+
46
+ Encodes RGB uint8 images to latent space and decodes latents back to RGB.
47
+ """
48
+
49
+ _supports_gradient_checkpointing = False
50
+
51
+ @register_to_config
52
+ def __init__(
53
+ self,
54
+ # Common parameters
55
+ sample_size: Tuple[int, int] = (360, 640),
56
+ channels: int = 3,
57
+ latent_channels: int = 16,
58
+ # Encoder parameters
59
+ encoder_ch_0: int = 64,
60
+ encoder_ch_max: int = 256,
61
+ encoder_blocks_per_stage: List[int] = None,
62
+ # Decoder parameters
63
+ decoder_ch_0: int = 128,
64
+ decoder_ch_max: int = 1024,
65
+ decoder_blocks_per_stage: List[int] = None,
66
+ # Shared parameters
67
+ skip_logvar: bool = False,
68
+ # Scaling factors
69
+ scale_factor: float = 1.0,
70
+ shift_factor: float = 0.0,
71
+ ):
72
+ super().__init__()
73
+
74
+ # Default blocks per stage
75
+ if encoder_blocks_per_stage is None:
76
+ encoder_blocks_per_stage = [1, 1, 1, 1]
77
+ if decoder_blocks_per_stage is None:
78
+ decoder_blocks_per_stage = [1, 1, 1, 1]
79
+
80
+ # Create encoder config
81
+ encoder_config = EncoderDecoderConfig(
82
+ channels=channels,
83
+ latent_channels=latent_channels,
84
+ ch_0=encoder_ch_0,
85
+ ch_max=encoder_ch_max,
86
+ encoder_blocks_per_stage=list(encoder_blocks_per_stage),
87
+ decoder_blocks_per_stage=list(decoder_blocks_per_stage),
88
+ skip_logvar=skip_logvar,
89
+ )
90
+
91
+ # Create decoder config
92
+ decoder_config = EncoderDecoderConfig(
93
+ channels=channels,
94
+ latent_channels=latent_channels,
95
+ ch_0=decoder_ch_0,
96
+ ch_max=decoder_ch_max,
97
+ encoder_blocks_per_stage=list(encoder_blocks_per_stage),
98
+ decoder_blocks_per_stage=list(decoder_blocks_per_stage),
99
+ skip_logvar=skip_logvar,
100
+ )
101
+
102
+ self.encoder = Encoder(encoder_config)
103
+ self.decoder = Decoder(decoder_config)
104
+
105
+ def encode(self, img: Tensor):
106
+ """RGB -> RGB+D -> latent"""
107
+ assert img.dim() == 3, "Expected [H, W, C] image tensor"
108
+ img = img.unsqueeze(0).to(device=self.device, dtype=self.dtype)
109
+ rgb = img.permute(0, 3, 1, 2).contiguous().div(255).mul(2).sub(1)
110
+ return self.encoder(rgb)
111
+
112
+ def decode(self, latent: Tensor):
113
+ decoded = self.decoder(latent)
114
+ decoded = (decoded / 2 + 0.5).clamp(0, 1)
115
+ decoded = (decoded * 255).round().to(torch.uint8)
116
+ return decoded.squeeze(0).permute(1, 2, 0)[..., :3]
117
+
118
+ def forward(self, x: Tensor, encode: bool = True) -> Tensor:
119
+ """
120
+ Forward pass - encode or decode based on flag.
121
+
122
+ Args:
123
+ x: Input tensor (image for encode, latent for decode)
124
+ encode: If True, encode; if False, decode
125
+
126
+ Returns:
127
+ Encoded latent or decoded image
128
+ """
129
+ if encode:
130
+ return self.encode(x)
131
+ else:
132
+ return self.decode(x)
133
+
134
+ def bake_weight_norm(self):
135
+ """Remove weight_norm parametrizations, baking normalized weights into regular tensors.
136
+
137
+ Call this after loading weights and before torch.compile to avoid
138
+ CUDA graph capture errors from in-place weight updates.
139
+ """
140
+ bake_weight_norm(self)
141
+ return self