JTriggerFish commited on
Commit
3196863
·
1 Parent(s): f5cfba7

Fix posterior VP interpolation to use float32 precision

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. fcdm_diffae/encoder.py +18 -13
README.md CHANGED
@@ -17,6 +17,7 @@ library_name: fcdm_diffae
17
 
18
  | Date | Change |
19
  |------|--------|
 
20
  | 2026-04-07 | Rename package `capacitor_diffae` → `fcdm_diffae`, class `FCDMDiffAE`; encode() now returns whitened latents, decode() dewhitens internally |
21
  | 2026-04-06 | Initial release |
22
 
 
17
 
18
  | Date | Change |
19
  |------|--------|
20
+ | 2026-04-08 | Fix posterior VP interpolation to use float32 precision (was using model dtype) |
21
  | 2026-04-07 | Rename package `capacitor_diffae` → `fcdm_diffae`, class `FCDMDiffAE`; encode() now returns whitened latents, decode() dewhitens internally |
22
  | 2026-04-06 | Initial release |
23
 
fcdm_diffae/encoder.py CHANGED
@@ -30,24 +30,28 @@ class EncoderPosterior:
30
 
31
  @property
32
  def alpha(self) -> Tensor:
33
- """VP signal coefficient: sqrt(sigmoid(logsnr))."""
34
- return torch.sigmoid(self.logsnr).sqrt()
 
35
 
36
  @property
37
  def sigma(self) -> Tensor:
38
- """VP noise coefficient: sqrt(sigmoid(-logsnr))."""
39
- return torch.sigmoid(-self.logsnr).sqrt()
 
40
 
41
  def mode(self) -> Tensor:
42
- """Posterior mode in token space: alpha * mean."""
43
- return self.alpha.to(dtype=self.mean.dtype) * self.mean
44
 
45
  def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
46
- """Sample from posterior: alpha * mean + sigma * eps."""
47
- eps = torch.randn_like(self.mean, generator=generator) # type: ignore[call-overload]
48
- alpha = self.alpha.to(dtype=self.mean.dtype)
49
- sigma = self.sigma.to(dtype=self.mean.dtype)
50
- return alpha * self.mean + sigma * eps
 
 
51
 
52
 
53
  class Encoder(nn.Module):
@@ -123,7 +127,8 @@ class Encoder(nn.Module):
123
  if self.bottleneck_posterior_kind == "diagonal_gaussian":
124
  mean, logsnr = projection.chunk(2, dim=1)
125
  mean = self.norm_out(mean)
126
- alpha = torch.sigmoid(logsnr).sqrt().to(dtype=mean.dtype)
127
- return alpha * mean
 
128
  z = self.norm_out(projection)
129
  return z
 
30
 
31
  @property
32
  def alpha(self) -> Tensor:
33
+ """VP signal coefficient: sqrt(sigmoid(logsnr)), computed in float32."""
34
+ logsnr_fp32 = self.logsnr.to(torch.float32)
35
+ return torch.sigmoid(logsnr_fp32).sqrt()
36
 
37
  @property
38
  def sigma(self) -> Tensor:
39
+ """VP noise coefficient: sqrt(sigmoid(-logsnr)), computed in float32."""
40
+ logsnr_fp32 = self.logsnr.to(torch.float32)
41
+ return torch.sigmoid(-logsnr_fp32).sqrt()
42
 
43
  def mode(self) -> Tensor:
44
+ """Posterior mode in token space: alpha * mean, computed in float32."""
45
+ return (self.alpha * self.mean.to(torch.float32)).to(dtype=self.mean.dtype)
46
 
47
  def sample(self, *, generator: torch.Generator | None = None) -> Tensor:
48
+ """Sample from posterior: alpha * mean + sigma * eps, computed in float32."""
49
+ mean_fp32 = self.mean.to(torch.float32)
50
+ eps = torch.randn(
51
+ mean_fp32.shape, device=mean_fp32.device, dtype=torch.float32,
52
+ generator=generator,
53
+ )
54
+ return (self.alpha * mean_fp32 + self.sigma * eps).to(dtype=self.mean.dtype)
55
 
56
 
57
  class Encoder(nn.Module):
 
127
  if self.bottleneck_posterior_kind == "diagonal_gaussian":
128
  mean, logsnr = projection.chunk(2, dim=1)
129
  mean = self.norm_out(mean)
130
+ logsnr_fp32 = logsnr.to(torch.float32)
131
+ alpha = torch.sigmoid(logsnr_fp32).sqrt()
132
+ return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
133
  z = self.norm_out(projection)
134
  return z