| |
| |
| |
| |
| |
| |
|
|
| """Converting between pixel and latent representations of image data.""" |
|
|
| import os |
| import warnings |
| import numpy as np |
| import torch |
| from torch_utils import persistence |
| from torch_utils import misc |
|
|
| warnings.filterwarnings('ignore', 'torch.utils._pytree._register_pytree_node is deprecated.') |
| warnings.filterwarnings('ignore', '`resume_download` is deprecated') |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| @persistence.persistent_class |
| class Encoder: |
| def __init__(self): |
| pass |
|
|
| def init(self, device): |
| pass |
|
|
| def __getstate__(self): |
| return self.__dict__ |
|
|
| def encode_pixels(self, x): |
| raise NotImplementedError |
| |
| |
|
|
| @persistence.persistent_class |
| class StabilityVAEEncoder(Encoder): |
| def __init__(self, |
| vae_name = 'stabilityai/sd-vae-ft-mse', |
| batch_size = 8, |
| ): |
| super().__init__() |
| self.vae_name = vae_name |
| self.batch_size = int(batch_size) |
| self._vae = None |
|
|
| def init(self, device): |
| super().init(device) |
| if self._vae is None: |
| self._vae = load_stability_vae(self.vae_name, device=device) |
| else: |
| self._vae.to(device) |
|
|
| def __getstate__(self): |
| return dict(super().__getstate__(), _vae=None) |
|
|
| def _run_vae_encoder(self, x): |
| d = self._vae.encode(x)['latent_dist'] |
| return torch.cat([d.mean, d.std], dim=1) |
|
|
| def encode_pixels(self, x): |
| self.init(x.device) |
| x = x.to(torch.float32) / 127.5 - 1 |
| x = torch.cat([self._run_vae_encoder(batch) for batch in x.split(self.batch_size)]) |
| return x |
|
|
| |
|
|
| def load_stability_vae(vae_name='stabilityai/sd-vae-ft-mse', device=torch.device('cpu')): |
| import dnnlib |
| cache_dir = dnnlib.make_cache_dir_path('diffusers') |
| os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1' |
| os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1' |
| os.environ['HF_HOME'] = cache_dir |
|
|
|
|
| import diffusers |
| try: |
| |
| vae = diffusers.models.AutoencoderKL.from_pretrained( |
| vae_name, cache_dir=cache_dir, local_files_only=True |
| ) |
| except: |
| |
| vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, cache_dir=cache_dir) |
| return vae.eval().requires_grad_(False).to(device) |
|
|
| |