DDT / src /models /vae.py
wangshuai6
init space
9e426da
import torch
import subprocess
import lightning.pytorch as pl
import logging
logger = logging.getLogger(__name__)
def class_fn_from_str(class_str):
class_module, from_class = class_str.rsplit(".", 1)
class_module = __import__(class_module, fromlist=[from_class])
return getattr(class_module, from_class)
class BaseVAE(torch.nn.Module):
def __init__(self, scale=1.0, shift=0.0):
super().__init__()
self.model = torch.nn.Identity()
self.scale = scale
self.shift = shift
def encode(self, x):
return x/self.scale+self.shift
def decode(self, x):
return (x-self.shift)*self.scale
# very bad bugs with nearest sampling
class DownSampleVAE(BaseVAE):
def __init__(self, down_ratio, scale=1.0, shift=0.0):
super().__init__()
self.model = torch.nn.Identity()
self.scale = scale
self.shift = shift
self.down_ratio = down_ratio
def encode(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False)
return x/self.scale+self.shift
def decode(self, x):
x = (x-self.shift)*self.scale
x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False)
return x
class LatentVAE(BaseVAE):
def __init__(self, precompute=False, weight_path:str=None):
super().__init__()
self.precompute = precompute
self.model = None
self.weight_path = weight_path
from diffusers.models import AutoencoderKL
setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
self.scaling_factor = self.model.config.scaling_factor
@torch.no_grad()
def encode(self, x):
assert self.model is not None
if self.precompute:
return x.mul_(self.scaling_factor)
return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)
@torch.no_grad()
def decode(self, x):
assert self.model is not None
return self.model.decode(x.div_(self.scaling_factor)).sample
def uint82fp(x):
x = x.to(torch.float32)
x = (x - 127.5) / 127.5
return x
def fp2uint8(x):
x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
return x