manbeast3b commited on
Commit
7e1e9b6
1 Parent(s): 71d2928

Create output.py

Browse files
Files changed (1) hide show
  1. src/output.py +70 -0
src/output.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ def conv(n_in, n_out, **kwargs):
5
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
6
+
7
+ class Clamp(nn.Module):
8
+ def forward(self, x):
9
+ return torch.tanh(x / 3) * 3
10
+
11
+ class Block(nn.Module):
12
+ def __init__(self, n_in, n_out):
13
+ super().__init__()
14
+ self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
15
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
16
+ self.fuse = nn.ReLU()
17
+ def forward(self, x):
18
+ return self.fuse(self.conv(x) + self.skip(x))
19
+
20
+ def Encoder(latent_channels=4):
21
+ return nn.Sequential(
22
+ conv(3, 64), Block(64, 64),
23
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
24
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
25
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
26
+ conv(64, latent_channels),
27
+ )
28
+
29
+ def Decoder(latent_channels=4):
30
+ return nn.Sequential(
31
+ Clamp(), conv(latent_channels, 64), nn.ReLU(),
32
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
33
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
34
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
35
+ Block(64, 64), conv(64, 3),
36
+ )
37
+
38
+ class TAESD(nn.Module):
39
+ latent_magnitude = 3
40
+ latent_shift = 0.5
41
+
42
+ def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth", latent_channels=None):
43
+ """Initialize pretrained TAESD on the given device from the given checkpoints."""
44
+ super().__init__()
45
+ if latent_channels is None:
46
+ latent_channels = self.guess_latent_channels(str(encoder_path))
47
+ self.encoder = Encoder(latent_channels)
48
+ self.decoder = Decoder(latent_channels)
49
+ if encoder_path is not None:
50
+ self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True))
51
+ if decoder_path is not None:
52
+ self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True))
53
+
54
+ def guess_latent_channels(self, encoder_path):
55
+ """guess latent channel count based on encoder filename"""
56
+ if "taef1" in encoder_path:
57
+ return 16
58
+ if "taesd3" in encoder_path:
59
+ return 16
60
+ return 4
61
+
62
+ @staticmethod
63
+ def scale_latents(x):
64
+ """raw latents -> [0, 1]"""
65
+ return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
66
+
67
+ @staticmethod
68
+ def unscale_latents(x):
69
+ """[0, 1] -> raw latents"""
70
+ return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)