File size: 2,320 Bytes
9e426da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import subprocess
import lightning.pytorch as pl

import logging


logger = logging.getLogger(__name__)
def class_fn_from_str(class_str):
    class_module, from_class = class_str.rsplit(".", 1)
    class_module = __import__(class_module, fromlist=[from_class])
    return getattr(class_module, from_class)


class BaseVAE(torch.nn.Module):
    def __init__(self, scale=1.0, shift=0.0):
        super().__init__()
        self.model = torch.nn.Identity()
        self.scale = scale
        self.shift = shift

    def encode(self, x):
        return x/self.scale+self.shift

    def decode(self, x):
        return (x-self.shift)*self.scale


# very bad bugs with nearest sampling
class DownSampleVAE(BaseVAE):
    def __init__(self, down_ratio, scale=1.0, shift=0.0):
        super().__init__()
        self.model = torch.nn.Identity()
        self.scale = scale
        self.shift = shift
        self.down_ratio = down_ratio

    def encode(self, x):
        x = torch.nn.functional.interpolate(x, scale_factor=1/self.down_ratio, mode='bicubic', align_corners=False)
        return x/self.scale+self.shift

    def decode(self, x):
         x = (x-self.shift)*self.scale
         x = torch.nn.functional.interpolate(x, scale_factor=self.down_ratio, mode='bicubic', align_corners=False)
         return x



class LatentVAE(BaseVAE):
    def __init__(self, precompute=False, weight_path:str=None):
        super().__init__()
        self.precompute = precompute
        self.model = None
        self.weight_path = weight_path

        from diffusers.models import AutoencoderKL
        setattr(self, "model", AutoencoderKL.from_pretrained(self.weight_path))
        self.scaling_factor = self.model.config.scaling_factor

    @torch.no_grad()
    def encode(self, x):
        assert self.model is not None
        if self.precompute:
            return x.mul_(self.scaling_factor)
        return self.model.encode(x).latent_dist.sample().mul_(self.scaling_factor)

    @torch.no_grad()
    def decode(self, x):
        assert self.model is not None
        return self.model.decode(x.div_(self.scaling_factor)).sample


def uint82fp(x):
    x = x.to(torch.float32)
    x = (x - 127.5) / 127.5
    return x

def fp2uint8(x):
    x = torch.clip_((x + 1) * 127.5 + 0.5, 0, 255).to(torch.uint8)
    return x