JTriggerFish commited on
Commit ·
940c4c8
1
Parent(s): ec37736
Revert "Upload folder using huggingface_hub"
Browse filesThis reverts commit ec377361f0dcaf8b153e6940aa88d84bd7ab7515.
- fcdm_diffae/__init__.py +1 -1
- fcdm_diffae/__pycache__/__init__.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/config.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/decoder.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/encoder.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/model.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/samplers.cpython-312.pyc +0 -0
- fcdm_diffae/config.py +0 -18
- fcdm_diffae/model.py +11 -32
fcdm_diffae/__init__.py
CHANGED
|
@@ -26,8 +26,8 @@ from .encoder import EncoderPosterior
|
|
| 26 |
from .model import FCDMDiffAE
|
| 27 |
|
| 28 |
__all__ = [
|
| 29 |
-
"EncoderPosterior",
|
| 30 |
"FCDMDiffAE",
|
| 31 |
"FCDMDiffAEConfig",
|
| 32 |
"FCDMDiffAEInferenceConfig",
|
|
|
|
| 33 |
]
|
|
|
|
| 26 |
from .model import FCDMDiffAE
|
| 27 |
|
| 28 |
__all__ = [
|
|
|
|
| 29 |
"FCDMDiffAE",
|
| 30 |
"FCDMDiffAEConfig",
|
| 31 |
"FCDMDiffAEInferenceConfig",
|
| 32 |
+
"EncoderPosterior",
|
| 33 |
]
|
fcdm_diffae/__pycache__/__init__.cpython-312.pyc
CHANGED
|
Binary files a/fcdm_diffae/__pycache__/__init__.cpython-312.pyc and b/fcdm_diffae/__pycache__/__init__.cpython-312.pyc differ
|
|
|
fcdm_diffae/__pycache__/config.cpython-312.pyc
CHANGED
|
Binary files a/fcdm_diffae/__pycache__/config.cpython-312.pyc and b/fcdm_diffae/__pycache__/config.cpython-312.pyc differ
|
|
|
fcdm_diffae/__pycache__/decoder.cpython-312.pyc
CHANGED
|
Binary files a/fcdm_diffae/__pycache__/decoder.cpython-312.pyc and b/fcdm_diffae/__pycache__/decoder.cpython-312.pyc differ
|
|
|
fcdm_diffae/__pycache__/encoder.cpython-312.pyc
CHANGED
|
Binary files a/fcdm_diffae/__pycache__/encoder.cpython-312.pyc and b/fcdm_diffae/__pycache__/encoder.cpython-312.pyc differ
|
|
|
fcdm_diffae/__pycache__/model.cpython-312.pyc
CHANGED
|
Binary files a/fcdm_diffae/__pycache__/model.cpython-312.pyc and b/fcdm_diffae/__pycache__/model.cpython-312.pyc differ
|
|
|
fcdm_diffae/__pycache__/samplers.cpython-312.pyc
CHANGED
|
Binary files a/fcdm_diffae/__pycache__/samplers.cpython-312.pyc and b/fcdm_diffae/__pycache__/samplers.cpython-312.pyc differ
|
|
|
fcdm_diffae/config.py
CHANGED
|
@@ -26,30 +26,12 @@ class FCDMDiffAEConfig:
|
|
| 26 |
bottleneck_posterior_kind: str = "diagonal_gaussian"
|
| 27 |
# Post-bottleneck normalization: "channel_wise" or "disabled"
|
| 28 |
bottleneck_norm_mode: str = "disabled"
|
| 29 |
-
# Bottleneck patchification: "off" or "patch_2x2"
|
| 30 |
-
# When "patch_2x2", encoder latents are 2x2 patchified after the bottleneck
|
| 31 |
-
# (channels * 4, spatial / 2), and decode unpatchifies before the decoder.
|
| 32 |
-
bottleneck_patchify_mode: str = "off"
|
| 33 |
# VP diffusion schedule endpoints
|
| 34 |
logsnr_min: float = -10.0
|
| 35 |
logsnr_max: float = 10.0
|
| 36 |
# Pixel-space noise std for VP diffusion initialization
|
| 37 |
pixel_noise_std: float = 0.558
|
| 38 |
|
| 39 |
-
@property
|
| 40 |
-
def latent_channels(self) -> int:
|
| 41 |
-
"""Channel width of the exported latent space."""
|
| 42 |
-
if self.bottleneck_patchify_mode == "patch_2x2":
|
| 43 |
-
return self.bottleneck_dim * 4
|
| 44 |
-
return self.bottleneck_dim
|
| 45 |
-
|
| 46 |
-
@property
|
| 47 |
-
def effective_patch_size(self) -> int:
|
| 48 |
-
"""Effective spatial stride from image to latent grid."""
|
| 49 |
-
if self.bottleneck_patchify_mode == "patch_2x2":
|
| 50 |
-
return self.patch_size * 2
|
| 51 |
-
return self.patch_size
|
| 52 |
-
|
| 53 |
def save(self, path: str | Path) -> None:
|
| 54 |
"""Save config as JSON."""
|
| 55 |
p = Path(path)
|
|
|
|
| 26 |
bottleneck_posterior_kind: str = "diagonal_gaussian"
|
| 27 |
# Post-bottleneck normalization: "channel_wise" or "disabled"
|
| 28 |
bottleneck_norm_mode: str = "disabled"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
# VP diffusion schedule endpoints
|
| 30 |
logsnr_min: float = -10.0
|
| 31 |
logsnr_max: float = 10.0
|
| 32 |
# Pixel-space noise std for VP diffusion initialization
|
| 33 |
pixel_noise_std: float = 0.558
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
def save(self, path: str | Path) -> None:
|
| 36 |
"""Save config as JSON."""
|
| 37 |
p = Path(path)
|
fcdm_diffae/model.py
CHANGED
|
@@ -71,14 +71,14 @@ class FCDMDiffAE(nn.Module):
|
|
| 71 |
super().__init__()
|
| 72 |
self.config = config
|
| 73 |
|
| 74 |
-
# Latent running stats for whitening/dewhitening
|
| 75 |
self.register_buffer(
|
| 76 |
"latent_norm_running_mean",
|
| 77 |
-
torch.zeros((config.
|
| 78 |
)
|
| 79 |
self.register_buffer(
|
| 80 |
"latent_norm_running_var",
|
| 81 |
-
torch.ones((config.
|
| 82 |
)
|
| 83 |
|
| 84 |
self.encoder = Encoder(
|
|
@@ -205,20 +205,6 @@ class FCDMDiffAE(nn.Module):
|
|
| 205 |
mean, std = self._latent_norm_stats()
|
| 206 |
return z * std.to(device=z.device) + mean.to(device=z.device)
|
| 207 |
|
| 208 |
-
def _patchify(self, z: Tensor) -> Tensor:
|
| 209 |
-
"""2x2 patchify: [B, C, H, W] -> [B, 4C, H/2, W/2]."""
|
| 210 |
-
b, c, h, w = z.shape
|
| 211 |
-
z = z.reshape(b, c, h // 2, 2, w // 2, 2)
|
| 212 |
-
z = z.permute(0, 1, 3, 5, 2, 4)
|
| 213 |
-
return z.reshape(b, c * 4, h // 2, w // 2)
|
| 214 |
-
|
| 215 |
-
def _unpatchify(self, z: Tensor) -> Tensor:
|
| 216 |
-
"""2x2 unpatchify: [B, 4C, H/2, W/2] -> [B, C, H, W]."""
|
| 217 |
-
b, c, h, w = z.shape
|
| 218 |
-
z = z.reshape(b, c // 4, 2, 2, h, w)
|
| 219 |
-
z = z.permute(0, 1, 4, 2, 5, 3)
|
| 220 |
-
return z.reshape(b, c // 4, h * 2, w * 2)
|
| 221 |
-
|
| 222 |
def encode(self, images: Tensor) -> Tensor:
|
| 223 |
"""Encode images to whitened latents (posterior mode).
|
| 224 |
|
|
@@ -226,19 +212,16 @@ class FCDMDiffAE(nn.Module):
|
|
| 226 |
use by downstream latent-space diffusion models.
|
| 227 |
|
| 228 |
Args:
|
| 229 |
-
images: [B, 3, H, W] in [-1, 1], H and W divisible by
|
| 230 |
-
effective_patch_size.
|
| 231 |
|
| 232 |
Returns:
|
| 233 |
-
Whitened latents [B,
|
| 234 |
"""
|
| 235 |
try:
|
| 236 |
model_dtype = next(self.parameters()).dtype
|
| 237 |
except StopIteration:
|
| 238 |
model_dtype = torch.float32
|
| 239 |
z = self.encoder(images.to(dtype=model_dtype))
|
| 240 |
-
if self.config.bottleneck_patchify_mode == "patch_2x2":
|
| 241 |
-
z = self._patchify(z)
|
| 242 |
return self.whiten(z).to(dtype=model_dtype)
|
| 243 |
|
| 244 |
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
|
|
@@ -267,13 +250,12 @@ class FCDMDiffAE(nn.Module):
|
|
| 267 |
) -> Tensor:
|
| 268 |
"""Decode whitened latents to images via VP diffusion.
|
| 269 |
|
| 270 |
-
Latents are dewhitened
|
| 271 |
-
before being passed to the decoder.
|
| 272 |
|
| 273 |
Args:
|
| 274 |
-
latents: [B,
|
| 275 |
-
height: Output image height (divisible by
|
| 276 |
-
width: Output image width (divisible by
|
| 277 |
inference_config: Optional inference parameters.
|
| 278 |
|
| 279 |
Returns:
|
|
@@ -289,11 +271,8 @@ class FCDMDiffAE(nn.Module):
|
|
| 289 |
except StopIteration:
|
| 290 |
model_dtype = torch.float32
|
| 291 |
|
| 292 |
-
# Dewhiten
|
| 293 |
-
latents = self.dewhiten(latents)
|
| 294 |
-
if config.bottleneck_patchify_mode == "patch_2x2":
|
| 295 |
-
latents = self._unpatchify(latents)
|
| 296 |
-
latents = latents.to(dtype=model_dtype)
|
| 297 |
|
| 298 |
if height % config.patch_size != 0 or width % config.patch_size != 0:
|
| 299 |
raise ValueError(
|
|
|
|
| 71 |
super().__init__()
|
| 72 |
self.config = config
|
| 73 |
|
| 74 |
+
# Latent running stats for whitening/dewhitening
|
| 75 |
self.register_buffer(
|
| 76 |
"latent_norm_running_mean",
|
| 77 |
+
torch.zeros((config.bottleneck_dim,), dtype=torch.float32),
|
| 78 |
)
|
| 79 |
self.register_buffer(
|
| 80 |
"latent_norm_running_var",
|
| 81 |
+
torch.ones((config.bottleneck_dim,), dtype=torch.float32),
|
| 82 |
)
|
| 83 |
|
| 84 |
self.encoder = Encoder(
|
|
|
|
| 205 |
mean, std = self._latent_norm_stats()
|
| 206 |
return z * std.to(device=z.device) + mean.to(device=z.device)
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
def encode(self, images: Tensor) -> Tensor:
|
| 209 |
"""Encode images to whitened latents (posterior mode).
|
| 210 |
|
|
|
|
| 212 |
use by downstream latent-space diffusion models.
|
| 213 |
|
| 214 |
Args:
|
| 215 |
+
images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
|
|
|
|
| 216 |
|
| 217 |
Returns:
|
| 218 |
+
Whitened latents [B, bottleneck_dim, H/patch, W/patch].
|
| 219 |
"""
|
| 220 |
try:
|
| 221 |
model_dtype = next(self.parameters()).dtype
|
| 222 |
except StopIteration:
|
| 223 |
model_dtype = torch.float32
|
| 224 |
z = self.encoder(images.to(dtype=model_dtype))
|
|
|
|
|
|
|
| 225 |
return self.whiten(z).to(dtype=model_dtype)
|
| 226 |
|
| 227 |
def encode_posterior(self, images: Tensor) -> EncoderPosterior:
|
|
|
|
| 250 |
) -> Tensor:
|
| 251 |
"""Decode whitened latents to images via VP diffusion.
|
| 252 |
|
| 253 |
+
Latents are dewhitened internally before being passed to the decoder.
|
|
|
|
| 254 |
|
| 255 |
Args:
|
| 256 |
+
latents: [B, bottleneck_dim, h, w] whitened encoder latents.
|
| 257 |
+
height: Output image height (divisible by patch_size).
|
| 258 |
+
width: Output image width (divisible by patch_size).
|
| 259 |
inference_config: Optional inference parameters.
|
| 260 |
|
| 261 |
Returns:
|
|
|
|
| 271 |
except StopIteration:
|
| 272 |
model_dtype = torch.float32
|
| 273 |
|
| 274 |
+
# Dewhiten back to raw encoder scale for the decoder
|
| 275 |
+
latents = self.dewhiten(latents).to(dtype=model_dtype)
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
if height % config.patch_size != 0 or width % config.patch_size != 0:
|
| 278 |
raise ValueError(
|