JTriggerFish commited on
Commit
940c4c8
·
1 Parent(s): ec37736

Revert "Upload folder using huggingface_hub"

Browse files

This reverts commit ec377361f0dcaf8b153e6940aa88d84bd7ab7515.

fcdm_diffae/__init__.py CHANGED
@@ -26,8 +26,8 @@ from .encoder import EncoderPosterior
26
  from .model import FCDMDiffAE
27
 
28
  __all__ = [
29
- "EncoderPosterior",
30
  "FCDMDiffAE",
31
  "FCDMDiffAEConfig",
32
  "FCDMDiffAEInferenceConfig",
 
33
  ]
 
26
  from .model import FCDMDiffAE
27
 
28
  __all__ = [
 
29
  "FCDMDiffAE",
30
  "FCDMDiffAEConfig",
31
  "FCDMDiffAEInferenceConfig",
32
+ "EncoderPosterior",
33
  ]
fcdm_diffae/__pycache__/__init__.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/__init__.cpython-312.pyc and b/fcdm_diffae/__pycache__/__init__.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/config.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/config.cpython-312.pyc and b/fcdm_diffae/__pycache__/config.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/decoder.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/decoder.cpython-312.pyc and b/fcdm_diffae/__pycache__/decoder.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/encoder.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/encoder.cpython-312.pyc and b/fcdm_diffae/__pycache__/encoder.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/model.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/model.cpython-312.pyc and b/fcdm_diffae/__pycache__/model.cpython-312.pyc differ
 
fcdm_diffae/__pycache__/samplers.cpython-312.pyc CHANGED
Binary files a/fcdm_diffae/__pycache__/samplers.cpython-312.pyc and b/fcdm_diffae/__pycache__/samplers.cpython-312.pyc differ
 
fcdm_diffae/config.py CHANGED
@@ -26,30 +26,12 @@ class FCDMDiffAEConfig:
26
  bottleneck_posterior_kind: str = "diagonal_gaussian"
27
  # Post-bottleneck normalization: "channel_wise" or "disabled"
28
  bottleneck_norm_mode: str = "disabled"
29
- # Bottleneck patchification: "off" or "patch_2x2"
30
- # When "patch_2x2", encoder latents are 2x2 patchified after the bottleneck
31
- # (channels * 4, spatial / 2), and decode unpatchifies before the decoder.
32
- bottleneck_patchify_mode: str = "off"
33
  # VP diffusion schedule endpoints
34
  logsnr_min: float = -10.0
35
  logsnr_max: float = 10.0
36
  # Pixel-space noise std for VP diffusion initialization
37
  pixel_noise_std: float = 0.558
38
 
39
- @property
40
- def latent_channels(self) -> int:
41
- """Channel width of the exported latent space."""
42
- if self.bottleneck_patchify_mode == "patch_2x2":
43
- return self.bottleneck_dim * 4
44
- return self.bottleneck_dim
45
-
46
- @property
47
- def effective_patch_size(self) -> int:
48
- """Effective spatial stride from image to latent grid."""
49
- if self.bottleneck_patchify_mode == "patch_2x2":
50
- return self.patch_size * 2
51
- return self.patch_size
52
-
53
  def save(self, path: str | Path) -> None:
54
  """Save config as JSON."""
55
  p = Path(path)
 
26
  bottleneck_posterior_kind: str = "diagonal_gaussian"
27
  # Post-bottleneck normalization: "channel_wise" or "disabled"
28
  bottleneck_norm_mode: str = "disabled"
 
 
 
 
29
  # VP diffusion schedule endpoints
30
  logsnr_min: float = -10.0
31
  logsnr_max: float = 10.0
32
  # Pixel-space noise std for VP diffusion initialization
33
  pixel_noise_std: float = 0.558
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def save(self, path: str | Path) -> None:
36
  """Save config as JSON."""
37
  p = Path(path)
fcdm_diffae/model.py CHANGED
@@ -71,14 +71,14 @@ class FCDMDiffAE(nn.Module):
71
  super().__init__()
72
  self.config = config
73
 
74
- # Latent running stats for whitening/dewhitening (at exported latent channels)
75
  self.register_buffer(
76
  "latent_norm_running_mean",
77
- torch.zeros((config.latent_channels,), dtype=torch.float32),
78
  )
79
  self.register_buffer(
80
  "latent_norm_running_var",
81
- torch.ones((config.latent_channels,), dtype=torch.float32),
82
  )
83
 
84
  self.encoder = Encoder(
@@ -205,20 +205,6 @@ class FCDMDiffAE(nn.Module):
205
  mean, std = self._latent_norm_stats()
206
  return z * std.to(device=z.device) + mean.to(device=z.device)
207
 
208
- def _patchify(self, z: Tensor) -> Tensor:
209
- """2x2 patchify: [B, C, H, W] -> [B, 4C, H/2, W/2]."""
210
- b, c, h, w = z.shape
211
- z = z.reshape(b, c, h // 2, 2, w // 2, 2)
212
- z = z.permute(0, 1, 3, 5, 2, 4)
213
- return z.reshape(b, c * 4, h // 2, w // 2)
214
-
215
- def _unpatchify(self, z: Tensor) -> Tensor:
216
- """2x2 unpatchify: [B, 4C, H/2, W/2] -> [B, C, H, W]."""
217
- b, c, h, w = z.shape
218
- z = z.reshape(b, c // 4, 2, 2, h, w)
219
- z = z.permute(0, 1, 4, 2, 5, 3)
220
- return z.reshape(b, c // 4, h * 2, w * 2)
221
-
222
  def encode(self, images: Tensor) -> Tensor:
223
  """Encode images to whitened latents (posterior mode).
224
 
@@ -226,19 +212,16 @@ class FCDMDiffAE(nn.Module):
226
  use by downstream latent-space diffusion models.
227
 
228
  Args:
229
- images: [B, 3, H, W] in [-1, 1], H and W divisible by
230
- effective_patch_size.
231
 
232
  Returns:
233
- Whitened latents [B, latent_channels, H/effective_patch, W/effective_patch].
234
  """
235
  try:
236
  model_dtype = next(self.parameters()).dtype
237
  except StopIteration:
238
  model_dtype = torch.float32
239
  z = self.encoder(images.to(dtype=model_dtype))
240
- if self.config.bottleneck_patchify_mode == "patch_2x2":
241
- z = self._patchify(z)
242
  return self.whiten(z).to(dtype=model_dtype)
243
 
244
  def encode_posterior(self, images: Tensor) -> EncoderPosterior:
@@ -267,13 +250,12 @@ class FCDMDiffAE(nn.Module):
267
  ) -> Tensor:
268
  """Decode whitened latents to images via VP diffusion.
269
 
270
- Latents are dewhitened and (if applicable) unpatchified internally
271
- before being passed to the decoder.
272
 
273
  Args:
274
- latents: [B, latent_channels, h, w] whitened encoder latents.
275
- height: Output image height (divisible by effective_patch_size).
276
- width: Output image width (divisible by effective_patch_size).
277
  inference_config: Optional inference parameters.
278
 
279
  Returns:
@@ -289,11 +271,8 @@ class FCDMDiffAE(nn.Module):
289
  except StopIteration:
290
  model_dtype = torch.float32
291
 
292
- # Dewhiten and unpatchify back to raw encoder scale for the decoder
293
- latents = self.dewhiten(latents)
294
- if config.bottleneck_patchify_mode == "patch_2x2":
295
- latents = self._unpatchify(latents)
296
- latents = latents.to(dtype=model_dtype)
297
 
298
  if height % config.patch_size != 0 or width % config.patch_size != 0:
299
  raise ValueError(
 
71
  super().__init__()
72
  self.config = config
73
 
74
+ # Latent running stats for whitening/dewhitening
75
  self.register_buffer(
76
  "latent_norm_running_mean",
77
+ torch.zeros((config.bottleneck_dim,), dtype=torch.float32),
78
  )
79
  self.register_buffer(
80
  "latent_norm_running_var",
81
+ torch.ones((config.bottleneck_dim,), dtype=torch.float32),
82
  )
83
 
84
  self.encoder = Encoder(
 
205
  mean, std = self._latent_norm_stats()
206
  return z * std.to(device=z.device) + mean.to(device=z.device)
207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  def encode(self, images: Tensor) -> Tensor:
209
  """Encode images to whitened latents (posterior mode).
210
 
 
212
  use by downstream latent-space diffusion models.
213
 
214
  Args:
215
+ images: [B, 3, H, W] in [-1, 1], H and W divisible by patch_size.
 
216
 
217
  Returns:
218
+ Whitened latents [B, bottleneck_dim, H/patch, W/patch].
219
  """
220
  try:
221
  model_dtype = next(self.parameters()).dtype
222
  except StopIteration:
223
  model_dtype = torch.float32
224
  z = self.encoder(images.to(dtype=model_dtype))
 
 
225
  return self.whiten(z).to(dtype=model_dtype)
226
 
227
  def encode_posterior(self, images: Tensor) -> EncoderPosterior:
 
250
  ) -> Tensor:
251
  """Decode whitened latents to images via VP diffusion.
252
 
253
+ Latents are dewhitened internally before being passed to the decoder.
 
254
 
255
  Args:
256
+ latents: [B, bottleneck_dim, h, w] whitened encoder latents.
257
+ height: Output image height (divisible by patch_size).
258
+ width: Output image width (divisible by patch_size).
259
  inference_config: Optional inference parameters.
260
 
261
  Returns:
 
271
  except StopIteration:
272
  model_dtype = torch.float32
273
 
274
+ # Dewhiten back to raw encoder scale for the decoder
275
+ latents = self.dewhiten(latents).to(dtype=model_dtype)
 
 
 
276
 
277
  if height % config.patch_size != 0 or width % config.patch_size != 0:
278
  raise ValueError(