data-archetype commited on
Commit
df52b4d
·
verified ·
1 Parent(s): c9685d2

Patch VP logSNR stability in exported SemDisDiffAE code

Browse files
Files changed (1) hide show
  1. fcdm_diffae/vp_diffusion.py +3 -2
fcdm_diffae/vp_diffusion.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
  import math
6
 
7
  import torch
 
8
  from torch import Tensor
9
 
10
 
@@ -14,8 +15,8 @@ def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
14
  VP constraint: alpha^2 + sigma^2 = 1.
15
  """
16
  lmb32 = lmb.to(dtype=torch.float32)
17
- alpha = torch.sqrt(torch.sigmoid(lmb32))
18
- sigma = torch.sqrt(torch.sigmoid(-lmb32))
19
  return alpha, sigma
20
 
21
 
 
5
  import math
6
 
7
  import torch
8
+ import torch.nn.functional as F
9
  from torch import Tensor
10
 
11
 
 
15
  VP constraint: alpha^2 + sigma^2 = 1.
16
  """
17
  lmb32 = lmb.to(dtype=torch.float32)
18
+ alpha = torch.exp(0.5 * F.logsigmoid(lmb32))
19
+ sigma = torch.exp(0.5 * F.logsigmoid(-lmb32))
20
  return alpha, sigma
21
 
22