Patch VP logSNR stability in exported SemDisDiffAE code
Browse files
fcdm_diffae/vp_diffusion.py
CHANGED
|
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
| 5 |
import math
|
| 6 |
|
| 7 |
import torch
|
|
|
|
| 8 |
from torch import Tensor
|
| 9 |
|
| 10 |
|
|
@@ -14,8 +15,8 @@ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
|
|
| 14 |
VP constraint: alpha^2 + sigma^2 = 1.
|
| 15 |
"""
|
| 16 |
lmb32 = lmb.to(dtype=torch.float32)
|
| 17 |
-
alpha = torch.
|
| 18 |
-
sigma = torch.
|
| 19 |
return alpha, sigma
|
| 20 |
|
| 21 |
|
|
|
|
| 5 |
import math
|
| 6 |
|
| 7 |
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
from torch import Tensor
|
| 10 |
|
| 11 |
|
|
|
|
| 15 |
VP constraint: alpha^2 + sigma^2 = 1.
|
| 16 |
"""
|
| 17 |
lmb32 = lmb.to(dtype=torch.float32)
|
| 18 |
+
alpha = torch.exp(0.5 * F.logsigmoid(lmb32))
|
| 19 |
+
sigma = torch.exp(0.5 * F.logsigmoid(-lmb32))
|
| 20 |
return alpha, sigma
|
| 21 |
|
| 22 |
|