Fabrice-TIERCELIN commited on
Commit
fce9266
·
verified ·
1 Parent(s): a9ea194

Upload 3 files

Browse files
sgm/modules/autoencoding/lpips/__init__.py ADDED
File without changes
sgm/modules/autoencoding/lpips/util.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+
4
+ import requests
5
+ import torch
6
+ import torch.nn as nn
7
+ from tqdm import tqdm
8
+
9
+ URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
10
+
11
+ CKPT_MAP = {"vgg_lpips": "vgg.pth"}
12
+
13
+ MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
14
+
15
+
16
+ def download(url, local_path, chunk_size=1024):
17
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
18
+ with requests.get(url, stream=True) as r:
19
+ total_size = int(r.headers.get("content-length", 0))
20
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
21
+ with open(local_path, "wb") as f:
22
+ for data in r.iter_content(chunk_size=chunk_size):
23
+ if data:
24
+ f.write(data)
25
+ pbar.update(chunk_size)
26
+
27
+
28
+ def md5_hash(path):
29
+ with open(path, "rb") as f:
30
+ content = f.read()
31
+ return hashlib.md5(content).hexdigest()
32
+
33
+
34
+ def get_ckpt_path(name, root, check=False):
35
+ assert name in URL_MAP
36
+ path = os.path.join(root, CKPT_MAP[name])
37
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
38
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
39
+ download(URL_MAP[name], path)
40
+ md5 = md5_hash(path)
41
+ assert md5 == MD5_MAP[name], md5
42
+ return path
43
+
44
+
45
+ class ActNorm(nn.Module):
46
+ def __init__(
47
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
48
+ ):
49
+ assert affine
50
+ super().__init__()
51
+ self.logdet = logdet
52
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
53
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
54
+ self.allow_reverse_init = allow_reverse_init
55
+
56
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
57
+
58
+ def initialize(self, input):
59
+ with torch.no_grad():
60
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
61
+ mean = (
62
+ flatten.mean(1)
63
+ .unsqueeze(1)
64
+ .unsqueeze(2)
65
+ .unsqueeze(3)
66
+ .permute(1, 0, 2, 3)
67
+ )
68
+ std = (
69
+ flatten.std(1)
70
+ .unsqueeze(1)
71
+ .unsqueeze(2)
72
+ .unsqueeze(3)
73
+ .permute(1, 0, 2, 3)
74
+ )
75
+
76
+ self.loc.data.copy_(-mean)
77
+ self.scale.data.copy_(1 / (std + 1e-6))
78
+
79
+ def forward(self, input, reverse=False):
80
+ if reverse:
81
+ return self.reverse(input)
82
+ if len(input.shape) == 2:
83
+ input = input[:, :, None, None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ _, _, height, width = input.shape
89
+
90
+ if self.training and self.initialized.item() == 0:
91
+ self.initialize(input)
92
+ self.initialized.fill_(1)
93
+
94
+ h = self.scale * (input + self.loc)
95
+
96
+ if squeeze:
97
+ h = h.squeeze(-1).squeeze(-1)
98
+
99
+ if self.logdet:
100
+ log_abs = torch.log(torch.abs(self.scale))
101
+ logdet = height * width * torch.sum(log_abs)
102
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
103
+ return h, logdet
104
+
105
+ return h
106
+
107
+ def reverse(self, output):
108
+ if self.training and self.initialized.item() == 0:
109
+ if not self.allow_reverse_init:
110
+ raise RuntimeError(
111
+ "Initializing ActNorm in reverse direction is "
112
+ "disabled by default. Use allow_reverse_init=True to enable."
113
+ )
114
+ else:
115
+ self.initialize(output)
116
+ self.initialized.fill_(1)
117
+
118
+ if len(output.shape) == 2:
119
+ output = output[:, :, None, None]
120
+ squeeze = True
121
+ else:
122
+ squeeze = False
123
+
124
+ h = output / self.scale - self.loc
125
+
126
+ if squeeze:
127
+ h = h.squeeze(-1).squeeze(-1)
128
+ return h
sgm/modules/autoencoding/lpips/vqperceptual.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def hinge_d_loss(logits_real, logits_fake):
6
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
7
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
8
+ d_loss = 0.5 * (loss_real + loss_fake)
9
+ return d_loss
10
+
11
+
12
+ def vanilla_d_loss(logits_real, logits_fake):
13
+ d_loss = 0.5 * (
14
+ torch.mean(torch.nn.functional.softplus(-logits_real))
15
+ + torch.mean(torch.nn.functional.softplus(logits_fake))
16
+ )
17
+ return d_loss