data-archetype commited on
Commit
bf1a98e
·
verified ·
1 Parent(s): 6a12ad8

Patch VP logSNR stability in exported SemDisDiffAE p32 code

Browse files
Files changed (1) hide show
  1. fcdm_diffae/encoder.py +6 -5
fcdm_diffae/encoder.py CHANGED
@@ -10,6 +10,7 @@ from __future__ import annotations
10
  from dataclasses import dataclass
11
 
12
  import torch
 
13
  from torch import Tensor, nn
14
 
15
  from .fcdm_block import FCDMBlock
@@ -30,15 +31,15 @@ class EncoderPosterior:
30
 
31
  @property
32
  def alpha(self) -> Tensor:
33
- """VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32."""
34
  logsnr_fp32 = self.logsnr.to(torch.float32)
35
- return torch.sigmoid(logsnr_fp32).sqrt()
36
 
37
  @property
38
  def sigma(self) -> Tensor:
39
- """VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32."""
40
  logsnr_fp32 = self.logsnr.to(torch.float32)
41
- return torch.sigmoid(-logsnr_fp32).sqrt()
42
 
43
  def mode(self) -> Tensor:
44
  """Posterior mode in token space: alpha * mean, computed in float32."""
@@ -128,7 +129,7 @@ class Encoder(nn.Module):
128
  mean, logsnr = projection.chunk(2, dim=1)
129
  mean = self.norm_out(mean)
130
  logsnr_fp32 = logsnr.to(torch.float32)
131
- alpha = torch.sigmoid(logsnr_fp32).sqrt()
132
  return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
133
  z = self.norm_out(projection)
134
  return z
 
10
  from dataclasses import dataclass
11
 
12
  import torch
13
+ import torch.nn.functional as F
14
  from torch import Tensor, nn
15
 
16
  from .fcdm_block import FCDMBlock
 
31
 
32
  @property
33
  def alpha(self) -> Tensor:
34
+ """VP signal coefficient computed stably in float32."""
35
  logsnr_fp32 = self.logsnr.to(torch.float32)
36
+ return torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
37
 
38
  @property
39
  def sigma(self) -> Tensor:
40
+ """VP noise coefficient computed stably in float32."""
41
  logsnr_fp32 = self.logsnr.to(torch.float32)
42
+ return torch.exp(0.5 * F.logsigmoid(-logsnr_fp32))
43
 
44
  def mode(self) -> Tensor:
45
  """Posterior mode in token space: alpha * mean, computed in float32."""
 
129
  mean, logsnr = projection.chunk(2, dim=1)
130
  mean = self.norm_out(mean)
131
  logsnr_fp32 = logsnr.to(torch.float32)
132
+ alpha = torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
133
  return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
134
  z = self.norm_out(projection)
135
  return z