| """ |
| Wrapper to match VAE interface to that of SD VAE. |
| """ |
|
|
| from types import SimpleNamespace |
|
|
| import numpy as np |
| import torch |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_utils import ModelMixin |
| from medvae.models import AutoencoderKL_2D |
| from medvae.utils.factory import ( |
| FILE_DICT_ASSOCIATIONS, |
| create_model, |
| download_model_weights, |
| ) |
| from omegaconf import OmegaConf |
|
|
|
|
| class LatentDist: |
| def __init__(self, dist): |
| self.latent_dist = dist |
|
|
| def sample(self): |
| return self.latent_dist.mode() |
|
|
| |
| def mode(self): |
| return self.latent_dist.mode() |
|
|
|
|
| class MedVAEWrapper(ModelMixin, ConfigMixin): |
| config_name = "config.json" |
| ignore_for_config = ["vae"] |
|
|
| @register_to_config |
| def __init__(self, vae=None, scaling_factor=1.0, downsampling_factor=4): |
| super().__init__() |
| assert downsampling_factor in [ |
| 4, |
| 8, |
| ], "Only 4x and 8x downsampling are currently supported" |
| if vae is None: |
| model_name = ( |
| "medvae_4_4_2d_c" if downsampling_factor == 4 else "medvae_8_4_2d_c" |
| ) |
| config_fpath = download_model_weights( |
| FILE_DICT_ASSOCIATIONS[model_name]["config"] |
| ) |
| if model_name == "medvae_8_4_2d_c": |
| config_fpath = "/data/yurman/repos/fast-mri-ldm/submodules/medvae/configs/ours-8x1-new.yaml" |
|
|
| conf = OmegaConf.load(config_fpath) |
| conf.embed_dim = 4 |
| conf.ddconfig.z_channels = 4 |
| conf["ddconfig"]["in_channels"] = 2 |
| conf["ddconfig"]["out_ch"] = 2 |
|
|
| vae = AutoencoderKL_2D( |
| ddconfig=conf.ddconfig, |
| embed_dim=conf.embed_dim, |
| ) |
|
|
| self.vae = vae |
| |
| |
| n_blocks = int(np.log2(downsampling_factor)) + 1 |
| self.register_to_config( |
| block_out_channels=[ |
| 1, |
| ] |
| * n_blocks, |
| in_channels=2, |
| scaling_factor=scaling_factor, |
| downsampling_factor=downsampling_factor, |
| ) |
|
|
| def encode(self, x): |
| dist = self.vae.encode(x) |
|
|
| return SimpleNamespace(latent_dist=LatentDist(dist)) |
|
|
| def decode(self, x, return_dict=False, generator=None): |
| with torch.amp.autocast(device_type="cuda", enabled=False): |
| x = self.vae.decode(x) |
| return (x,) |
|
|