| import torch | |
| # "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1 | |
| # to LATENT B, C, H, W and values on the scale of -1..1. | |
| class PixelspaceConversionVAE(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.pixel_space_vae = torch.nn.Parameter(torch.tensor(1.0)) | |
| def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: | |
| return pixels | |
| def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor: | |
| return samples | |