manbeast3b commited on
Commit
9ba7979
·
verified ·
1 Parent(s): 0440422

Update src/model.py

Browse files
Files changed (1) hide show
  1. src/model.py +42 -74
src/model.py CHANGED
@@ -1,92 +1,60 @@
1
- import torch
2
- import torch as th
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
 
6
- def conv(n_in, n_out, **kwargs):
7
- return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
8
 
9
- class Clamp(nn.Module):
10
- def forward(self, x):
11
- return torch.tanh(x / 3) * 3
12
 
13
- class Block(nn.Module):
14
- def __init__(self, n_in, n_out):
15
  super().__init__()
16
- self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
17
- self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
18
- self.fuse = nn.ReLU()
19
- def forward(self, x):
20
- return self.fuse(self.conv(x) + self.skip(x))
21
 
22
- def Encoder(latent_channels=4):
23
  return nn.Sequential(
24
- conv(3, 64), Block(64, 64),
25
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
26
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
27
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
28
- conv(64, latent_channels),
29
  )
30
 
31
- def Decoder(latent_channels=16): # Adjusted to match expected input channels
32
  return nn.Sequential(
33
- Clamp(),
34
- conv(latent_channels, 48), # Reduced from 64 to 48 channels
35
- nn.ReLU(),
36
- Block(48, 48), Block(48, 48), # Reduced number of blocks
37
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
38
- Block(48, 48), Block(48, 48), # Reduced number of blocks
39
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
40
- Block(48, 48), # Further reduction in blocks
41
- nn.Upsample(scale_factor=2), conv(48, 48, bias=False),
42
- Block(48, 48),
43
- conv(48, 3), # Final convolution to output channels
44
  )
45
 
46
-
 
 
 
 
 
47
 
48
- class Model(nn.Module):
49
- latent_magnitude = 3
50
- latent_shift = 0.5
 
51
 
52
- def __init__(self, encoder_path="encoder.pth", decoder_path="decoder.pth", latent_channels=None):
53
- super().__init__()
54
- if latent_channels is None:
55
- latent_channels = self.guess_latent_channels(str(encoder_path))
56
- self.encoder = Encoder(latent_channels)
57
- self.decoder = Decoder(latent_channels)
58
-
59
- if encoder_path is not None:
60
- encoder_state_dict = torch.load(encoder_path, map_location="cpu", weights_only=True)
61
- filtered_state_dict = {k.strip('encoder.'): v for k, v in encoder_state_dict.items() if k.strip('encoder.') in self.encoder.state_dict() and v.size() == self.encoder.state_dict()[k.strip('encoder.')].size()}
62
- print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(self.encoder.state_dict())}")
63
- self.encoder.load_state_dict(filtered_state_dict, strict=False)
64
-
65
- if decoder_path is not None:
66
- decoder_state_dict = torch.load(decoder_path, map_location="cpu", weights_only=True)
67
- filtered_state_dict = {k.strip('decoder.'): v for k, v in decoder_state_dict.items() if k.strip('decoder.') in self.decoder.state_dict() and v.size() == self.decoder.state_dict()[k.strip('decoder.')].size()}
68
- print(f" num of keys in filtered: {len(filtered_state_dict)} and in decoder: {len(self.decoder.state_dict())}")
69
- self.decoder.load_state_dict(filtered_state_dict, strict=False)
70
-
71
- self.encoder.requires_grad_(False)
72
- self.decoder.requires_grad_(False)
73
 
74
- def guess_latent_channels(self, encoder_path):
75
- if "taef1" in encoder_path:return 16
76
- if "taesd3" in encoder_path:return 16
77
- return 4
78
-
79
  @staticmethod
80
- def scale_latents(x):
81
- return x.div(2 * Model.latent_magnitude).add(Model.latent_shift).clamp(0, 1)
82
 
83
  @staticmethod
84
- def unscale_latents(x):
85
- return x.sub(Model.latent_shift).mul(2 * Model.latent_magnitude)
86
 
87
- def forward(self, x, return_latent=False):
88
- latent = self.encoder(x)
89
- out = self.decoder(latent)
90
- if return_latent:
91
- return out.clamp(0, 1), latent
92
- return out.clamp(0, 1)
 
1
+ import torch as t, torch.nn as nn, torch.nn.functional as F
 
 
 
2
 
3
+ def cv(n_i, n_o, **kw): return nn.Conv2d(n_i, n_o, 3, padding=1, **kw)
 
4
 
5
+ class C(nn.Module):
6
+ def forward(self, x): return t.tanh(x / 3) * 3
 
7
 
8
+ class B(nn.Module):
9
+ def __init__(s, n_i, n_o):
10
  super().__init__()
11
+ s.c = nn.Sequential(cv(n_i, n_o), nn.ReLU(), cv(n_o, n_o), nn.ReLU(), cv(n_o, n_o))
12
+ s.s = nn.Conv2d(n_i, n_o, 1, bias=False) if n_i != n_o else nn.Identity()
13
+ s.f = nn.ReLU()
14
+ def forward(s, x): return s.f(s.c(x) + s.s(x))
 
15
 
16
+ def E(lc=4):
17
  return nn.Sequential(
18
+ cv(3, 64), B(64, 64), cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
19
+ cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
20
+ cv(64, 64, stride=2, bias=False), B(64, 64), B(64, 64), B(64, 64),
21
+ cv(64, lc),
 
22
  )
23
 
24
+ def D(lc=16):
25
  return nn.Sequential(
26
+ C(), cv(lc, 48), nn.ReLU(), B(48, 48), B(48, 48), nn.Upsample(scale_factor=2),
27
+ cv(48, 48, bias=False), B(48, 48), B(48, 48), nn.Upsample(scale_factor=2),
28
+ cv(48, 48, bias=False), B(48, 48), nn.Upsample(scale_factor=2),
29
+ cv(48, 48, bias=False), B(48, 48), cv(48, 3),
 
 
 
 
 
 
 
30
  )
31
 
32
+ class M(nn.Module):
33
+ lm, ls = 3, 0.5
34
+ def __init__(s, ep="encoder.pth", dp="decoder.pth", lc=None):
35
+ super().__init__()
36
+ if lc is None: lc = s.glc(str(ep))
37
+ s.e, s.d = E(lc), D(lc)
38
 
39
+ def f(sd, mod, pfx):
40
+ f_sd = {k.strip(pfx): v for k, v in sd.items() if k.strip(pfx) in mod.state_dict() and v.size() == mod.state_dict()[k.strip(pfx)].size()}
41
+ print(f"num keys: {len(f_sd)} of {len(mod.state_dict())}")
42
+ mod.load_state_dict(f_sd, strict=False)
43
 
44
+ if ep: f(t.load(ep, map_location="cpu", weights_only=True), s.e, "encoder.")
45
+ if dp: f(t.load(dp, map_location="cpu", weights_only=True), s.d, "decoder.")
46
+
47
+ s.e.requires_grad_(False)
48
+ s.d.requires_grad_(False)
49
+
50
+ def glc(s, ep): return 16 if "taef1" in ep or "taesd3" in ep else 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
52
  @staticmethod
53
+ def sl(x): return x.div(2 * M.lm).add(M.ls).clamp(0, 1)
 
54
 
55
  @staticmethod
56
+ def ul(x): return x.sub(M.ls).mul(2 * M.lm)
 
57
 
58
+ def forward(s, x, rl=False):
59
+ l, o = s.e(x), s.d(s.e(x))
60
+ return (o.clamp(0, 1), l) if rl else o.clamp(0, 1)