Patch VP logSNR stability in exported SemDisDiffAE p32 code
Browse files- fcdm_diffae/encoder.py +6 -5
fcdm_diffae/encoder.py
CHANGED
|
@@ -10,6 +10,7 @@ from __future__ import annotations
|
|
| 10 |
from dataclasses import dataclass
|
| 11 |
|
| 12 |
import torch
|
|
|
|
| 13 |
from torch import Tensor, nn
|
| 14 |
|
| 15 |
from .fcdm_block import FCDMBlock
|
|
@@ -30,15 +31,15 @@ class EncoderPosterior:
|
|
| 30 |
|
| 31 |
@property
|
| 32 |
def alpha(self) -> Tensor:
|
| 33 |
-
"""VP signal coefficient
|
| 34 |
logsnr_fp32 = self.logsnr.to(torch.float32)
|
| 35 |
-
return torch.
|
| 36 |
|
| 37 |
@property
|
| 38 |
def sigma(self) -> Tensor:
|
| 39 |
-
"""VP noise coefficient
|
| 40 |
logsnr_fp32 = self.logsnr.to(torch.float32)
|
| 41 |
-
return torch.
|
| 42 |
|
| 43 |
def mode(self) -> Tensor:
|
| 44 |
"""Posterior mode in token space: alpha * mean, computed in float32."""
|
|
@@ -128,7 +129,7 @@ class Encoder(nn.Module):
|
|
| 128 |
mean, logsnr = projection.chunk(2, dim=1)
|
| 129 |
mean = self.norm_out(mean)
|
| 130 |
logsnr_fp32 = logsnr.to(torch.float32)
|
| 131 |
-
alpha = torch.
|
| 132 |
return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
|
| 133 |
z = self.norm_out(projection)
|
| 134 |
return z
|
|
|
|
| 10 |
from dataclasses import dataclass
|
| 11 |
|
| 12 |
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
from torch import Tensor, nn
|
| 15 |
|
| 16 |
from .fcdm_block import FCDMBlock
|
|
|
|
| 31 |
|
| 32 |
@property
|
| 33 |
def alpha(self) -> Tensor:
|
| 34 |
+
"""VP signal coefficient computed stably in float32."""
|
| 35 |
logsnr_fp32 = self.logsnr.to(torch.float32)
|
| 36 |
+
return torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
|
| 37 |
|
| 38 |
@property
|
| 39 |
def sigma(self) -> Tensor:
|
| 40 |
+
"""VP noise coefficient computed stably in float32."""
|
| 41 |
logsnr_fp32 = self.logsnr.to(torch.float32)
|
| 42 |
+
return torch.exp(0.5 * F.logsigmoid(-logsnr_fp32))
|
| 43 |
|
| 44 |
def mode(self) -> Tensor:
|
| 45 |
"""Posterior mode in token space: alpha * mean, computed in float32."""
|
|
|
|
| 129 |
mean, logsnr = projection.chunk(2, dim=1)
|
| 130 |
mean = self.norm_out(mean)
|
| 131 |
logsnr_fp32 = logsnr.to(torch.float32)
|
| 132 |
+
alpha = torch.exp(0.5 * F.logsigmoid(logsnr_fp32))
|
| 133 |
return (alpha * mean.to(torch.float32)).to(dtype=mean.dtype)
|
| 134 |
z = self.norm_out(projection)
|
| 135 |
return z
|