Upload folder using huggingface_hub
Browse files- fcdm_diffae/__pycache__/__init__.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/adaln.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/config.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/decoder.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/encoder.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/fcdm_block.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/model.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/norms.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/samplers.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/straight_through_encoder.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/time_embed.cpython-312.pyc +0 -0
- fcdm_diffae/__pycache__/vp_diffusion.cpython-312.pyc +0 -0
- fcdm_diffae/decoder.py +3 -4
- fcdm_diffae/samplers.py +6 -14
fcdm_diffae/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (1.29 kB). View file
|
|
|
fcdm_diffae/__pycache__/adaln.cpython-312.pyc
ADDED
|
Binary file (3.88 kB). View file
|
|
|
fcdm_diffae/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (3.31 kB). View file
|
|
|
fcdm_diffae/__pycache__/decoder.cpython-312.pyc
ADDED
|
Binary file (7.38 kB). View file
|
|
|
fcdm_diffae/__pycache__/encoder.cpython-312.pyc
ADDED
|
Binary file (6.76 kB). View file
|
|
|
fcdm_diffae/__pycache__/fcdm_block.cpython-312.pyc
ADDED
|
Binary file (5.8 kB). View file
|
|
|
fcdm_diffae/__pycache__/model.cpython-312.pyc
ADDED
|
Binary file (15.9 kB). View file
|
|
|
fcdm_diffae/__pycache__/norms.cpython-312.pyc
ADDED
|
Binary file (2.78 kB). View file
|
|
|
fcdm_diffae/__pycache__/samplers.cpython-312.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
fcdm_diffae/__pycache__/straight_through_encoder.cpython-312.pyc
ADDED
|
Binary file (2.08 kB). View file
|
|
|
fcdm_diffae/__pycache__/time_embed.cpython-312.pyc
ADDED
|
Binary file (4.82 kB). View file
|
|
|
fcdm_diffae/__pycache__/vp_diffusion.cpython-312.pyc
ADDED
|
Binary file (7.51 kB). View file
|
|
|
fcdm_diffae/decoder.py
CHANGED
|
@@ -23,10 +23,9 @@ class Decoder(nn.Module):
|
|
| 23 |
-> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
|
| 24 |
-> Conv1x1 -> PixelShuffle
|
| 25 |
|
| 26 |
-
|
| 27 |
-
-
|
| 28 |
-
|
| 29 |
-
``latent_mask_feature`` before fusion.
|
| 30 |
"""
|
| 31 |
|
| 32 |
def __init__(
|
|
|
|
| 23 |
-> Start blocks (2) -> Middle blocks (4) -> Skip fuse -> End blocks (2)
|
| 24 |
-> Conv1x1 -> PixelShuffle
|
| 25 |
|
| 26 |
+
Path-Drop Guidance (PDG) at inference:
|
| 27 |
+
- Replace middle block output with ``path_drop_mask_feature`` to create
|
| 28 |
+
an unconditional prediction, then extrapolate.
|
|
|
|
| 29 |
"""
|
| 30 |
|
| 31 |
def __init__(
|
fcdm_diffae/samplers.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
"""DDIM and DPM++2M samplers for VP diffusion with
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
@@ -15,7 +15,7 @@ from .vp_diffusion import (
|
|
| 15 |
|
| 16 |
|
| 17 |
class DecoderForwardFn(Protocol):
|
| 18 |
-
"""Callable that predicts x0 from (x_t, t, latents) with
|
| 19 |
|
| 20 |
def __call__(
|
| 21 |
self,
|
|
@@ -72,7 +72,7 @@ def _predict_with_pdg(
|
|
| 72 |
state: Current noised state [B, C, H, W].
|
| 73 |
t_vec: Timestep vector [B].
|
| 74 |
latents: Encoder latents.
|
| 75 |
-
pdg_mode: "disabled"
|
| 76 |
pdg_strength: CFG-like strength for PDG.
|
| 77 |
|
| 78 |
Returns:
|
|
@@ -86,14 +86,6 @@ def _predict_with_pdg(
|
|
| 86 |
torch.float32
|
| 87 |
)
|
| 88 |
return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
|
| 89 |
-
elif pdg_mode == "token_mask":
|
| 90 |
-
x0_uncond = forward_fn(state, t_vec, latents, mask_latent_tokens=True).to(
|
| 91 |
-
torch.float32
|
| 92 |
-
)
|
| 93 |
-
x0_cond = forward_fn(state, t_vec, latents, mask_latent_tokens=False).to(
|
| 94 |
-
torch.float32
|
| 95 |
-
)
|
| 96 |
-
return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
|
| 97 |
else:
|
| 98 |
return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
|
| 99 |
torch.float32
|
|
@@ -114,7 +106,7 @@ def run_ddim(
|
|
| 114 |
pdg_strength: float = 1.5,
|
| 115 |
device: torch.device | None = None,
|
| 116 |
) -> Tensor:
|
| 117 |
-
"""Run DDIM sampling loop with
|
| 118 |
|
| 119 |
Args:
|
| 120 |
forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
|
|
@@ -123,7 +115,7 @@ def run_ddim(
|
|
| 123 |
latents: Encoder latents [B, bottleneck_dim, h, w].
|
| 124 |
logsnr_min, logsnr_max: VP schedule endpoints.
|
| 125 |
log_change_high, log_change_low: Shifted-cosine schedule parameters.
|
| 126 |
-
pdg_mode: "disabled"
|
| 127 |
pdg_strength: CFG-like strength for PDG.
|
| 128 |
device: Target device.
|
| 129 |
|
|
@@ -190,7 +182,7 @@ def run_dpmpp_2m(
|
|
| 190 |
pdg_strength: float = 1.5,
|
| 191 |
device: torch.device | None = None,
|
| 192 |
) -> Tensor:
|
| 193 |
-
"""Run DPM++2M sampling loop with
|
| 194 |
|
| 195 |
Multi-step solver using exponential integrator formulation in half-lambda space.
|
| 196 |
"""
|
|
|
|
| 1 |
+
"""DDIM and DPM++2M samplers for VP diffusion with path-drop PDG support."""
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class DecoderForwardFn(Protocol):
|
| 18 |
+
"""Callable that predicts x0 from (x_t, t, latents) with path-drop PDG flag."""
|
| 19 |
|
| 20 |
def __call__(
|
| 21 |
self,
|
|
|
|
| 72 |
state: Current noised state [B, C, H, W].
|
| 73 |
t_vec: Timestep vector [B].
|
| 74 |
latents: Encoder latents.
|
| 75 |
+
pdg_mode: "disabled" or "path_drop".
|
| 76 |
pdg_strength: CFG-like strength for PDG.
|
| 77 |
|
| 78 |
Returns:
|
|
|
|
| 86 |
torch.float32
|
| 87 |
)
|
| 88 |
return x0_uncond + pdg_strength * (x0_cond - x0_uncond)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
else:
|
| 90 |
return forward_fn(state, t_vec, latents, drop_middle_blocks=False).to(
|
| 91 |
torch.float32
|
|
|
|
| 106 |
pdg_strength: float = 1.5,
|
| 107 |
device: torch.device | None = None,
|
| 108 |
) -> Tensor:
|
| 109 |
+
"""Run DDIM sampling loop with path-drop PDG support.
|
| 110 |
|
| 111 |
Args:
|
| 112 |
forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat.
|
|
|
|
| 115 |
latents: Encoder latents [B, bottleneck_dim, h, w].
|
| 116 |
logsnr_min, logsnr_max: VP schedule endpoints.
|
| 117 |
log_change_high, log_change_low: Shifted-cosine schedule parameters.
|
| 118 |
+
pdg_mode: "disabled" or "path_drop".
|
| 119 |
pdg_strength: CFG-like strength for PDG.
|
| 120 |
device: Target device.
|
| 121 |
|
|
|
|
| 182 |
pdg_strength: float = 1.5,
|
| 183 |
device: torch.device | None = None,
|
| 184 |
) -> Tensor:
|
| 185 |
+
"""Run DPM++2M sampling loop with path-drop PDG support.
|
| 186 |
|
| 187 |
Multi-step solver using exponential integrator formulation in half-lambda space.
|
| 188 |
"""
|