| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class GaussianFourierProjection(nn.Module): |
| | """ |
| | Gaussian Fourier features for continuous time t in [0, 1]. |
| | Produces 2 * embed_dim features: [sin(W t), cos(W t)]. |
| | """ |
| | def __init__(self, embed_dim, scale): |
| | super().__init__() |
| | assert embed_dim % 2 == 0, "embed_dim must be even." |
| | self.embed_dim = embed_dim |
| | self.register_buffer("W", torch.randn(embed_dim // 2) * scale, persistent=False) |
| |
|
| | def forward(self, t): |
| | |
| | t = t.float().unsqueeze(-1) |
| | angles = t * self.W |
| | return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) |
| |
|
| |
|
| | class TimeEmbedding(nn.Module): |
| | def __init__(self, hidden_dim, fourier_dim, scale): |
| | super().__init__() |
| | assert fourier_dim % 2 == 0, "fourier_dim must be even for sine/cosine pairs." |
| |
|
| | self.fourier = GaussianFourierProjection(fourier_dim, scale) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(fourier_dim, hidden_dim), |
| | nn.SiLU(), |
| | nn.Linear(hidden_dim, hidden_dim), |
| | ) |
| |
|
| | def forward(self, t): |
| | ft = self.fourier(t) |
| | return self.mlp(ft) |