BorelTHU commited on
Commit
223d932
1 Parent(s): d3109f4
app.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from io import BytesIO
3
+ import os
4
+ import sys
5
+
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+ import numpy as np
9
+ from PIL import Image
10
+ from omegaconf import OmegaConf
11
+
12
+ import torch
13
+ from torchvision import transforms as T
14
+
15
+ from optvq.models.quantizer import sinkhorn
16
+ from optvq.utils.init import seed_everything
17
+ seed_everything(42)
18
+ from optvq.models.vqgan_hf import VQModelHF
19
+ matplotlib.rcParams['font.family'] = 'Times New Roman'
20
+
21
+ #################
22
+ N_data = 50
23
+ N_code = 20
24
+ dim = 2
25
+ handler = None
26
+ device = torch.device("cpu")
27
+ #################
28
+
29
+ def nearest(src, trg):
30
+ dis_mat = torch.cdist(src, trg)
31
+ min_idx = torch.argmin(dis_mat, dim=-1)
32
+ return min_idx
33
+
34
+ def normalize(A, dim, mode="all"):
35
+ if mode == "all":
36
+ A = (A - A.mean()) / (A.std() + 1e-6)
37
+ A = A - A.min()
38
+ elif mode == "dim":
39
+ A = A / dim
40
+ elif mode == "null":
41
+ pass
42
+ return A
43
+
44
+ def draw_NN(data, code):
45
+ # nearest neighbor method
46
+ indices = nearest(data, code)
47
+ data = data.numpy()
48
+ code = code.numpy()
49
+
50
+ plt.figure(figsize=(3, 2.5), dpi=400)
51
+ # draw arrows in blue color, alpha=0.5
52
+ for i in range(data.shape[0]):
53
+ idx = indices[i].item()
54
+ start = data[i]
55
+ end = code[idx]
56
+ plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
57
+ head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6,
58
+ ls="-", lw=0.5)
59
+ plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
60
+ plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
61
+ plt.legend(loc="lower right")
62
+ plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
63
+ plt.title("Nearest neighbor")
64
+
65
+ buf = BytesIO()
66
+ plt.savefig(buf, format="png")
67
+ buf.seek(0)
68
+ image = Image.open(buf)
69
+ return image
70
+
71
+ def draw_optvq(data, code):
72
+ cost = torch.cdist(data, code, p=2.0)
73
+ cost = normalize(cost, dim, mode="all")
74
+ Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False)
75
+ indices = torch.argmax(Q, dim=-1)
76
+ data = data.numpy()
77
+ code = code.numpy()
78
+
79
+ plt.figure(figsize=(3, 2.5), dpi=400)
80
+ # draw arrows in blue color, alpha=0.5
81
+ for i in range(data.shape[0]):
82
+ idx = indices[i].item()
83
+ start = data[i]
84
+ end = code[idx]
85
+ plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
86
+ head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6,
87
+ ls="-", lw=0.5)
88
+ plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
89
+ plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
90
+ plt.legend(loc="lower right")
91
+ plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
92
+ plt.title("Optimal Transport (OptVQ)")
93
+
94
+ buf = BytesIO()
95
+ plt.savefig(buf, format="png")
96
+ buf.seek(0)
97
+ image = Image.open(buf)
98
+ return image
99
+
100
+ def draw_process(x, y, std):
101
+ data = torch.randn(N_data, dim)
102
+ code = torch.randn(N_code, dim) * std
103
+ code[:, 0] += x
104
+ code[:, 1] += y
105
+
106
+ image_NN = draw_NN(data, code)
107
+ image_optvq = draw_optvq(data, code)
108
+
109
+ return image_NN, image_optvq
110
+
111
+ class Handler:
112
+ def __init__(self, device):
113
+ self.transform = T.Compose([
114
+ T.Resize(256),
115
+ T.CenterCrop(256),
116
+ T.ToTensor()
117
+ ])
118
+ self.device = device
119
+
120
+
121
+ self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4")
122
+ self.basevq.to(self.device)
123
+ self.basevq.eval()
124
+
125
+ self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16")
126
+ self.vqgan.to(self.device)
127
+ self.vqgan.eval()
128
+
129
+ self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
130
+ self.optvq.to(self.device)
131
+ self.optvq.eval()
132
+
133
+ def tensor_to_image(self, tensor):
134
+ img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy()
135
+ img = (img + 1) / 2 * 255
136
+ img = img.astype("uint8")
137
+ return img
138
+
139
+ def process_image(self, img: np.ndarray):
140
+ img = Image.fromarray(img.astype("uint8"))
141
+ img = self.transform(img)
142
+ img = img.unsqueeze(0).to(self.device)
143
+ with torch.no_grad():
144
+ img = 2 * img - 1
145
+ # basevq
146
+ quant, *_ = self.basevq.encode(img)
147
+ basevq_rec = self.basevq.decode(quant)
148
+ # vqgan
149
+ quant, *_ = self.vqgan.encode(img)
150
+ vqgan_rec = self.vqgan.decode(quant)
151
+ # optvq
152
+ quant, *_ = self.optvq.encode(img)
153
+ optvq_rec = self.optvq.decode(quant)
154
+
155
+ # tensor to PIL image
156
+ img = self.tensor_to_image(img)
157
+ basevq_rec = self.tensor_to_image(basevq_rec)
158
+ vqgan_rec = self.tensor_to_image(vqgan_rec)
159
+ optvq_rec = self.tensor_to_image(optvq_rec)
160
+ return img, basevq_rec, vqgan_rec, optvq_rec
161
+
162
+ if __name__ == "__main__":
163
+ # create the model handler
164
+ handler = Handler(device=device)
165
+
166
+ # create the interface
167
+ with gr.Blocks() as demo:
168
+ gr.Textbox(value="This demo shows the image reconstruction comparison between OptVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results")
169
+ with gr.Row():
170
+ with gr.Column():
171
+ image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy")
172
+ btn_demo1 = gr.Button(value="Run reconstruction")
173
+ image_basevq = gr.Image(label="BaseVQ rec.")
174
+ image_vqgan = gr.Image(label="VQGAN rec.")
175
+ image_optvq = gr.Image(label="OptVQ rec.")
176
+ btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_input, image_basevq, image_vqgan, image_optvq])
177
+
178
+ gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results")
179
+ with gr.Row():
180
+ with gr.Column():
181
+ input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1)
182
+ input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1)
183
+ input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
184
+ btn_demo2 = gr.Button(value="Run 2D example")
185
+ output_nn = gr.Image(label="NN")
186
+ output_optvq = gr.Image(label="OptVQ")
187
+
188
+ # set the function
189
+ input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
190
+ input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
191
+ input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
192
+ btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
193
+
194
+ demo.launch()
optvq/data/__pycache__/dataloader.cpython-310.pyc ADDED
Binary file (1.63 kB). View file
 
optvq/data/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (2.22 kB). View file
 
optvq/data/__pycache__/preprocessor.cpython-310.pyc ADDED
Binary file (1.71 kB). View file
 
optvq/data/dataloader.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from typing import Union
8
+
9
+ import torch
10
+ from torch.utils.data import Subset
11
+
12
+ def maybe_get_subset(dataset, subset_size: Union[int, float] = None, num_data_repeat: int = None):
13
+ """
14
+ num_data_repeat is aimed at avoiding the consuming loading time for the small dataset
15
+ """
16
+ if subset_size is None:
17
+ return dataset
18
+ else:
19
+ if subset_size < 1.0:
20
+ subset_size = len(dataset) * subset_size
21
+ subset_size = int(subset_size)
22
+ selected_indices = torch.randperm(len(dataset))[:subset_size]
23
+ if num_data_repeat is not None:
24
+ selected_indices = selected_indices.repeat(num_data_repeat)
25
+ return Subset(dataset, selected_indices)
26
+
27
+ class LoaderWrapper:
28
+ """
29
+ write a dataloader class, given total steps, recursively loading data
30
+ """
31
+ def __init__(self, loader, total_iterations: int = None):
32
+ self.loader = loader
33
+ self.total_iterations = total_iterations if total_iterations is not None else len(loader)
34
+
35
+ def __iter__(self):
36
+ self.generator = iter(self.loader)
37
+ self.counter = 0
38
+ return self
39
+
40
+ def __next__(self):
41
+ if self.counter >= self.total_iterations:
42
+ self.counter = 0
43
+ raise StopIteration
44
+ else:
45
+ self.counter += 1
46
+ try:
47
+ return next(self.generator)
48
+ except StopIteration:
49
+ self.generator = iter(self.loader)
50
+ return next(self.generator)
optvq/data/dataset.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ import os
8
+ from PIL import Image
9
+ import numpy as np
10
+
11
+ import torch
12
+ from torch.utils.data import Dataset
13
+ from torchvision import transforms
14
+
15
+ from .preprocessor import normalize_params
16
+
17
+ class ImageNetDataset(Dataset):
18
+ def __init__(self, root, transform=None, convert_to_numpy: bool = True, post_normalize: str = "plain"):
19
+ self.root = root
20
+ self.transform = transform
21
+ self.convert_to_numpy = convert_to_numpy
22
+ self.post_normalize = transforms.Normalize(
23
+ **normalize_params[post_normalize]
24
+ )
25
+
26
+ # find classes
27
+ classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
28
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
29
+
30
+ # make dataset
31
+ self.samples = []
32
+ self.extensions = []
33
+ for target_class in sorted(class_to_idx.keys()):
34
+ class_idx = class_to_idx[target_class]
35
+ target_dir = os.path.join(root, target_class)
36
+ if not os.path.isdir(target_dir):
37
+ continue
38
+ for fname in sorted(os.listdir(target_dir)):
39
+ path = os.path.join(target_dir, fname)
40
+ item = (path, class_idx)
41
+ self.samples.append(item)
42
+ ext = path.split(".")[-1]
43
+ if ext not in self.extensions:
44
+ self.extensions.append(ext)
45
+
46
+ def __len__(self):
47
+ return len(self.samples)
48
+
49
+ def __getitem__(self, index):
50
+ path, label = self.samples[index]
51
+ image = Image.open(path)
52
+ if not image.mode == "RGB":
53
+ image = image.convert("RGB")
54
+ if self.convert_to_numpy:
55
+ image = np.array(image).astype("uint8")
56
+ # image augmentation
57
+ image = self.transform(image=image)["image"]
58
+ # to tensor and normalize
59
+ image = (image / 255).astype(np.float32)
60
+ image = torch.from_numpy(image).permute(2, 0, 1)
61
+ image = self.post_normalize(image)
62
+ return image, label
optvq/data/preprocessor.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from typing import Optional
8
+ import numpy as np
9
+
10
+ from torchvision import transforms
11
+ import albumentations as A
12
+
13
+ BICUBIC = transforms.InterpolationMode.BICUBIC
14
+
15
+ normalize_params = {
16
+ "plain": {"mean": (0.5,), "std": (0.5,)},
17
+ "cnn": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
18
+ "clip": {"mean": (0.48145466, 0.4578275, 0.40821073), "std": (0.26862954, 0.26130258, 0.27577711)}
19
+ }
20
+
21
+ recover_map_dict = {
22
+ "plain": transforms.Normalize(
23
+ mean=(-1,), std=(2,)
24
+ ),
25
+ "cnn": transforms.Normalize(
26
+ mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225),
27
+ std=(1/0.229, 1/0.224, 1/0.225)
28
+ ),
29
+ "clip": transforms.Normalize(
30
+ mean=(-0.48145466/0.26862954, -0.4578275/0.26130258, -0.40821073/0.27577711),
31
+ std=(1/0.26862954, 1/0.26130258, 1/0.27577711)
32
+ )
33
+ }
34
+
35
+ def get_recover_map(name: str):
36
+ return recover_map_dict[name]
37
+
38
+ ###########################################
39
+ # Preprocessor
40
+ ###########################################
41
+
42
+ def plain_preprocessor(resize: Optional[int] = 32):
43
+ return transforms.Compose([
44
+ transforms.ToTensor(),
45
+ transforms.Normalize((0.5,), (0.5,)),
46
+ transforms.Resize(resize),
47
+ ])
48
+
49
+ def imagenet_preprocessor(resize: Optional[int] = 256, is_train: bool = True):
50
+ if is_train:
51
+ # augmentation v1
52
+ # transform = A.Compose([
53
+ # A.SmallestMaxSize(max_size=resize),
54
+ # A.RandomCrop(height=resize, width=resize),
55
+ # A.HorizontalFlip(p=0.5),
56
+ # ])
57
+
58
+ # augmentation v2
59
+ transform = A.Compose([
60
+ A.SmallestMaxSize(max_size=resize),
61
+ A.RandomResizedCrop(width=resize, height=resize, scale=(0.2, 1.0)),
62
+ A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
63
+ A.GaussianBlur(blur_limit=7, p=0.5),
64
+ A.HorizontalFlip(p=0.5),
65
+ ])
66
+ else:
67
+ transform = A.Compose([
68
+ A.SmallestMaxSize(max_size=resize),
69
+ A.CenterCrop(height=resize, width=resize),
70
+ ])
71
+ return transform
optvq/losses/__pycache__/aeloss_disc.cpython-310.pyc ADDED
Binary file (4.27 kB). View file
 
optvq/losses/aeloss.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+ # Modified from [thuanz123/enhancing-transformers](https://github.com/thuanz123/enhancing-transformers)
7
+ # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------
9
+ # Modified from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
10
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
11
+ # ------------------------------------------------------------------------------
12
+
13
+ from typing import Tuple
14
+
15
+ import lpips
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ class AELoss(nn.Module):
21
+ """
22
+ Args:
23
+ loss_q_weight (float): weight for quantization loss
24
+ loss_l1_weight (float): weight for L1 loss (loglaplace)
25
+ loss_l2_weight (float): weight for L2 loss (loggaussian)
26
+ loss_p_weight (float): weight for perceptual loss
27
+ """
28
+ def __init__(self, loss_q_weight: float = 1.0,
29
+ loss_l1_weight: float = 1.0,
30
+ loss_l2_weight: float = 1.0,
31
+ loss_p_weight: float = 1.0) -> None:
32
+ super().__init__()
33
+ self.loss_type = ["aeloss"]
34
+
35
+ self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False)
36
+ # freeze the perceptual loss
37
+ for param in self.perceptual_loss.parameters():
38
+ param.requires_grad = False
39
+
40
+ self.loss_q_weight = loss_q_weight
41
+ self.loss_l1_weight = loss_l1_weight
42
+ self.loss_l2_weight = loss_l2_weight
43
+ self.loss_p_weight = loss_p_weight
44
+
45
+ @torch.autocast(device_type="cuda", enabled=False)
46
+ def forward(self, q_loss: torch.FloatTensor,
47
+ x: torch.FloatTensor,
48
+ x_rec: torch.FloatTensor, *args, **kwargs) -> Tuple:
49
+ x = x.float()
50
+ x_rec = x_rec.float()
51
+
52
+ # compute l1 loss
53
+ loss_l1 = (x_rec - x).abs().mean() if self.loss_l1_weight > 0.0 else torch.tensor(0.0, device=x.device)
54
+
55
+ # compute l2 loss
56
+ loss_l2 = (x_rec - x).pow(2).mean() if self.loss_l2_weight > 0.0 else torch.tensor(0.0, device=x.device)
57
+
58
+ # compute perceptual loss
59
+ loss_p = self.perceptual_loss(x, x_rec).mean() if self.loss_p_weight > 0.0 else torch.tensor(0.0, device=x.device)
60
+
61
+ # compute total loss
62
+ loss = self.loss_p_weight * loss_p + \
63
+ self.loss_l1_weight * loss_l1 + \
64
+ self.loss_l2_weight * loss_l2
65
+ loss += self.loss_q_weight * q_loss
66
+
67
+ # get the log
68
+ log = {
69
+ "loss": loss.detach(),
70
+ "loss_p": loss_p.detach(),
71
+ "loss_l1": loss_l1.detach(),
72
+ "loss_l2": loss_l2.detach(),
73
+ "loss_q": q_loss.detach()
74
+ }
75
+
76
+ return loss, log
optvq/losses/aeloss_disc.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+ # Modified from [thuanz123/enhancing-transformers](https://github.com/thuanz123/enhancing-transformers)
7
+ # Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------
9
+ # Modified from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
10
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
11
+ # ------------------------------------------------------------------------------
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import lpips
17
+
18
+ from optvq.models.discriminator import NLayerDiscriminator, weights_init
19
+
20
+ class DummyLoss(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+
24
+
25
+ def hinge_d_loss(logits_real, logits_fake):
26
+ loss_real = torch.mean(F.relu(1. - logits_real))
27
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
28
+ d_loss = 0.5 * (loss_real + loss_fake)
29
+ return d_loss
30
+
31
+
32
+ def vanilla_d_loss(logits_real, logits_fake):
33
+ d_loss = 0.5 * (
34
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
35
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
36
+ return d_loss
37
+
38
+
39
+ class AELossWithDisc(nn.Module):
40
+ def __init__(self,
41
+ disc_start,
42
+ pixelloss_weight=1.0,
43
+ disc_in_channels=3,
44
+ disc_num_layers=3,
45
+ use_actnorm=False,
46
+ disc_ndf=64,
47
+ disc_conditional=False,
48
+ disc_loss="hinge",
49
+ loss_l1_weight: float = 1.0,
50
+ loss_l2_weight: float = 1.0,
51
+ loss_p_weight: float = 1.0,
52
+ loss_q_weight: float = 1.0,
53
+ loss_g_weight: float = 1.0,
54
+ loss_d_weight: float = 1.0
55
+ ):
56
+ super(AELossWithDisc, self).__init__()
57
+ assert disc_loss in ["hinge", "vanilla"]
58
+
59
+ self.pixel_weight = pixelloss_weight
60
+ self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False).eval()
61
+
62
+ self.loss_l1_weight = loss_l1_weight
63
+ self.loss_l2_weight = loss_l2_weight
64
+ self.loss_p_weight = loss_p_weight
65
+ self.loss_q_weight = loss_q_weight
66
+ self.loss_g_weight = loss_g_weight
67
+ self.loss_d_weight = loss_d_weight
68
+
69
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
70
+ n_layers=disc_num_layers,
71
+ use_actnorm=use_actnorm,
72
+ ndf=disc_ndf
73
+ ).apply(weights_init)
74
+ self.discriminator_iter_start = disc_start
75
+ if disc_loss == "hinge":
76
+ self.disc_loss = hinge_d_loss
77
+ elif disc_loss == "vanilla":
78
+ self.disc_loss = vanilla_d_loss
79
+ else:
80
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
81
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
82
+
83
+ self.disc_conditional = disc_conditional
84
+
85
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
87
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
88
+
89
+ g_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
90
+ g_weight = torch.clamp(g_weight, 0.0, 1e4).detach()
91
+ g_weight = g_weight * self.loss_g_weight
92
+
93
+ # detection nan
94
+ if torch.isnan(g_weight).any():
95
+ g_weight = torch.tensor(0.0, device=g_weight.device)
96
+ return g_weight
97
+
98
+ @torch.autocast(device_type="cuda", enabled=False)
99
+ def forward(self, codebook_loss, inputs, reconstructions, mode, last_layer=None, cond=None, global_step=0):
100
+ x = inputs.contiguous().float()
101
+ x_rec = reconstructions.contiguous().float()
102
+
103
+ # compute q loss
104
+ loss_q = codebook_loss.mean()
105
+
106
+ # compute l1 loss
107
+ loss_l1 = (x_rec - x).abs().mean() if self.loss_l1_weight > 0.0 else torch.tensor(0.0, device=x.device)
108
+
109
+ # compute l2 loss
110
+ loss_l2 = (x_rec - x).pow(2).mean() if self.loss_l2_weight > 0.0 else torch.tensor(0.0, device=x.device)
111
+
112
+ # compute perceptual loss
113
+ loss_p = self.perceptual_loss(x, x_rec).mean() if self.loss_p_weight > 0.0 else torch.tensor(0.0, device=x.device)
114
+
115
+ # intigrate reconstruction loss
116
+ loss_rec = loss_l1 * self.loss_l1_weight + \
117
+ loss_l2 * self.loss_l2_weight + \
118
+ loss_p * self.loss_p_weight
119
+
120
+ # setup the factor_disc
121
+ if global_step < self.discriminator_iter_start:
122
+ factor_disc = 0.0
123
+ else:
124
+ factor_disc = 1.0
125
+
126
+ # now the GAN part
127
+ if mode == 0:
128
+ # generator update
129
+ if cond is None:
130
+ assert not self.disc_conditional
131
+ logits_fake = self.discriminator(x_rec)
132
+ else:
133
+ assert self.disc_conditional
134
+ logits_fake = self.discriminator(torch.cat((x_rec, cond), dim=1))
135
+
136
+ # compute g loss
137
+ loss_g = - logits_fake.mean()
138
+
139
+ try:
140
+ loss_g_weight = self.calculate_adaptive_weight(loss_rec, loss_g, last_layer=last_layer)
141
+ except RuntimeError:
142
+ # assert not self.training
143
+ loss_g_weight = torch.tensor(0.0)
144
+
145
+ loss = loss_g * loss_g_weight * factor_disc + \
146
+ loss_q * self.loss_q_weight + \
147
+ loss_rec
148
+
149
+ log = {"total_loss": loss.item(),
150
+ "loss_q": loss_q.item(),
151
+ "loss_rec": loss_rec.item(),
152
+ "loss_l1": loss_l1.item(),
153
+ "loss_l2": loss_l2.item(),
154
+ "loss_p": loss_p.item(),
155
+ "loss_g": loss_g.item(),
156
+ "loss_g_weight": loss_g_weight.item(),
157
+ "factor_disc": factor_disc,
158
+ }
159
+ return loss, log
160
+
161
+ if mode == 1:
162
+ # second pass for discriminator update
163
+ if cond is None:
164
+ logits_real = self.discriminator(x.detach())
165
+ logits_fake = self.discriminator(x_rec.detach())
166
+ else:
167
+ logits_real = self.discriminator(torch.cat((x.detach(), cond), dim=1))
168
+ logits_fake = self.discriminator(torch.cat((x_rec.detach(), cond), dim=1))
169
+
170
+ loss_d = self.disc_loss(logits_real, logits_fake).mean()
171
+ loss = loss_d * self.loss_d_weight
172
+
173
+ log = {"loss_d": loss_d.item(),
174
+ "logits_real": logits_real.mean().item(),
175
+ "logits_fake": logits_fake.mean().item()
176
+ }
177
+ return loss, log
optvq/models/__pycache__/discriminator.cpython-310.pyc ADDED
Binary file (4.4 kB). View file
 
optvq/models/__pycache__/quantizer.cpython-310.pyc ADDED
Binary file (8.15 kB). View file
 
optvq/models/__pycache__/vqgan.cpython-310.pyc ADDED
Binary file (5.28 kB). View file
 
optvq/models/__pycache__/vqgan_hf.cpython-310.pyc ADDED
Binary file (2.11 kB). View file
 
optvq/models/backbone/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
optvq/models/backbone/diffusion.py ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copy from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------
5
+
6
+ # pytorch_diffusion + derived encoder decoder
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ import numpy as np
11
+
12
+
13
+ def get_timestep_embedding(timesteps, embedding_dim):
14
+ """
15
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
16
+ From Fairseq.
17
+ Build sinusoidal embeddings.
18
+ This matches the implementation in tensor2tensor, but differs slightly
19
+ from the description in Section 3.5 of "Attention Is All You Need".
20
+ """
21
+ assert len(timesteps.shape) == 1
22
+
23
+ half_dim = embedding_dim // 2
24
+ emb = math.log(10000) / (half_dim - 1)
25
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
26
+ emb = emb.to(device=timesteps.device)
27
+ emb = timesteps.float()[:, None] * emb[None, :]
28
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
29
+ if embedding_dim % 2 == 1: # zero pad
30
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
31
+ return emb
32
+
33
+
34
+ def nonlinearity(x):
35
+ # swish
36
+ return x*torch.sigmoid(x)
37
+
38
+
39
+ def Normalize(in_channels):
40
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
41
+
42
+
43
+ class Upsample(nn.Module):
44
+ def __init__(self, in_channels, with_conv):
45
+ super().__init__()
46
+ self.with_conv = with_conv
47
+ if self.with_conv:
48
+ self.conv = torch.nn.Conv2d(in_channels,
49
+ in_channels,
50
+ kernel_size=3,
51
+ stride=1,
52
+ padding=1)
53
+
54
+ def forward(self, x):
55
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
56
+ if self.with_conv:
57
+ x = self.conv(x)
58
+ return x
59
+
60
+
61
+ class Downsample(nn.Module):
62
+ def __init__(self, in_channels, with_conv):
63
+ super().__init__()
64
+ self.with_conv = with_conv
65
+ if self.with_conv:
66
+ # no asymmetric padding in torch conv, must do it ourselves
67
+ self.conv = torch.nn.Conv2d(in_channels,
68
+ in_channels,
69
+ kernel_size=3,
70
+ stride=2,
71
+ padding=0)
72
+
73
+ def forward(self, x):
74
+ if self.with_conv:
75
+ pad = (0,1,0,1)
76
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
77
+ x = self.conv(x)
78
+ else:
79
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
80
+ return x
81
+
82
+
83
+ class ResnetBlock(nn.Module):
84
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
85
+ dropout, temb_channels=512):
86
+ super().__init__()
87
+ self.in_channels = in_channels
88
+ out_channels = in_channels if out_channels is None else out_channels
89
+ self.out_channels = out_channels
90
+ self.use_conv_shortcut = conv_shortcut
91
+
92
+ self.norm1 = Normalize(in_channels)
93
+ self.conv1 = torch.nn.Conv2d(in_channels,
94
+ out_channels,
95
+ kernel_size=3,
96
+ stride=1,
97
+ padding=1)
98
+ if temb_channels > 0:
99
+ self.temb_proj = torch.nn.Linear(temb_channels,
100
+ out_channels)
101
+ self.norm2 = Normalize(out_channels)
102
+ self.dropout = torch.nn.Dropout(dropout)
103
+ self.conv2 = torch.nn.Conv2d(out_channels,
104
+ out_channels,
105
+ kernel_size=3,
106
+ stride=1,
107
+ padding=1)
108
+ if self.in_channels != self.out_channels:
109
+ if self.use_conv_shortcut:
110
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
111
+ out_channels,
112
+ kernel_size=3,
113
+ stride=1,
114
+ padding=1)
115
+ else:
116
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
117
+ out_channels,
118
+ kernel_size=1,
119
+ stride=1,
120
+ padding=0)
121
+
122
+ def forward(self, x, temb):
123
+ h = x
124
+ h = self.norm1(h)
125
+ h = nonlinearity(h)
126
+ h = self.conv1(h)
127
+
128
+ if temb is not None:
129
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
130
+
131
+ h = self.norm2(h)
132
+ h = nonlinearity(h)
133
+ h = self.dropout(h)
134
+ h = self.conv2(h)
135
+
136
+ if self.in_channels != self.out_channels:
137
+ if self.use_conv_shortcut:
138
+ x = self.conv_shortcut(x)
139
+ else:
140
+ x = self.nin_shortcut(x)
141
+
142
+ return x+h
143
+
144
+
145
+ class AttnBlock(nn.Module):
146
+ def __init__(self, in_channels):
147
+ super().__init__()
148
+ self.in_channels = in_channels
149
+
150
+ self.norm = Normalize(in_channels)
151
+ self.q = torch.nn.Conv2d(in_channels,
152
+ in_channels,
153
+ kernel_size=1,
154
+ stride=1,
155
+ padding=0)
156
+ self.k = torch.nn.Conv2d(in_channels,
157
+ in_channels,
158
+ kernel_size=1,
159
+ stride=1,
160
+ padding=0)
161
+ self.v = torch.nn.Conv2d(in_channels,
162
+ in_channels,
163
+ kernel_size=1,
164
+ stride=1,
165
+ padding=0)
166
+ self.proj_out = torch.nn.Conv2d(in_channels,
167
+ in_channels,
168
+ kernel_size=1,
169
+ stride=1,
170
+ padding=0)
171
+
172
+
173
+ def forward(self, x):
174
+ h_ = x
175
+ h_ = self.norm(h_)
176
+ q = self.q(h_)
177
+ k = self.k(h_)
178
+ v = self.v(h_)
179
+
180
+ # compute attention
181
+ b,c,h,w = q.shape
182
+ q = q.reshape(b,c,h*w)
183
+ q = q.permute(0,2,1) # b,hw,c
184
+ k = k.reshape(b,c,h*w) # b,c,hw
185
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
186
+ w_ = w_ * (int(c)**(-0.5))
187
+ w_ = torch.nn.functional.softmax(w_, dim=2)
188
+
189
+ # attend to values
190
+ v = v.reshape(b,c,h*w)
191
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
192
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
193
+ h_ = h_.reshape(b,c,h,w)
194
+
195
+ h_ = self.proj_out(h_)
196
+
197
+ return x+h_
198
+
199
+
200
+ class Model(nn.Module):
201
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
202
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
203
+ resolution, use_timestep=True):
204
+ super().__init__()
205
+ self.ch = ch
206
+ self.temb_ch = self.ch*4
207
+ self.num_resolutions = len(ch_mult)
208
+ self.num_res_blocks = num_res_blocks
209
+ self.resolution = resolution
210
+ self.in_channels = in_channels
211
+
212
+ self.use_timestep = use_timestep
213
+ if self.use_timestep:
214
+ # timestep embedding
215
+ self.temb = nn.Module()
216
+ self.temb.dense = nn.ModuleList([
217
+ torch.nn.Linear(self.ch,
218
+ self.temb_ch),
219
+ torch.nn.Linear(self.temb_ch,
220
+ self.temb_ch),
221
+ ])
222
+
223
+ # downsampling
224
+ self.conv_in = torch.nn.Conv2d(in_channels,
225
+ self.ch,
226
+ kernel_size=3,
227
+ stride=1,
228
+ padding=1)
229
+
230
+ curr_res = resolution
231
+ in_ch_mult = (1,)+tuple(ch_mult)
232
+ self.down = nn.ModuleList()
233
+ for i_level in range(self.num_resolutions):
234
+ block = nn.ModuleList()
235
+ attn = nn.ModuleList()
236
+ block_in = ch*in_ch_mult[i_level]
237
+ block_out = ch*ch_mult[i_level]
238
+ for i_block in range(self.num_res_blocks):
239
+ block.append(ResnetBlock(in_channels=block_in,
240
+ out_channels=block_out,
241
+ temb_channels=self.temb_ch,
242
+ dropout=dropout))
243
+ block_in = block_out
244
+ if curr_res in attn_resolutions:
245
+ attn.append(AttnBlock(block_in))
246
+ down = nn.Module()
247
+ down.block = block
248
+ down.attn = attn
249
+ if i_level != self.num_resolutions-1:
250
+ down.downsample = Downsample(block_in, resamp_with_conv)
251
+ curr_res = curr_res // 2
252
+ self.down.append(down)
253
+
254
+ # middle
255
+ self.mid = nn.Module()
256
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
257
+ out_channels=block_in,
258
+ temb_channels=self.temb_ch,
259
+ dropout=dropout)
260
+ self.mid.attn_1 = AttnBlock(block_in)
261
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
262
+ out_channels=block_in,
263
+ temb_channels=self.temb_ch,
264
+ dropout=dropout)
265
+
266
+ # upsampling
267
+ self.up = nn.ModuleList()
268
+ for i_level in reversed(range(self.num_resolutions)):
269
+ block = nn.ModuleList()
270
+ attn = nn.ModuleList()
271
+ block_out = ch*ch_mult[i_level]
272
+ skip_in = ch*ch_mult[i_level]
273
+ for i_block in range(self.num_res_blocks+1):
274
+ if i_block == self.num_res_blocks:
275
+ skip_in = ch*in_ch_mult[i_level]
276
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
277
+ out_channels=block_out,
278
+ temb_channels=self.temb_ch,
279
+ dropout=dropout))
280
+ block_in = block_out
281
+ if curr_res in attn_resolutions:
282
+ attn.append(AttnBlock(block_in))
283
+ up = nn.Module()
284
+ up.block = block
285
+ up.attn = attn
286
+ if i_level != 0:
287
+ up.upsample = Upsample(block_in, resamp_with_conv)
288
+ curr_res = curr_res * 2
289
+ self.up.insert(0, up) # prepend to get consistent order
290
+
291
+ # end
292
+ self.norm_out = Normalize(block_in)
293
+ self.conv_out = torch.nn.Conv2d(block_in,
294
+ out_ch,
295
+ kernel_size=3,
296
+ stride=1,
297
+ padding=1)
298
+
299
+
300
+ def forward(self, x, t=None):
301
+ #assert x.shape[2] == x.shape[3] == self.resolution
302
+
303
+ if self.use_timestep:
304
+ # timestep embedding
305
+ assert t is not None
306
+ temb = get_timestep_embedding(t, self.ch)
307
+ temb = self.temb.dense[0](temb)
308
+ temb = nonlinearity(temb)
309
+ temb = self.temb.dense[1](temb)
310
+ else:
311
+ temb = None
312
+
313
+ # downsampling
314
+ hs = [self.conv_in(x)]
315
+ for i_level in range(self.num_resolutions):
316
+ for i_block in range(self.num_res_blocks):
317
+ h = self.down[i_level].block[i_block](hs[-1], temb)
318
+ if len(self.down[i_level].attn) > 0:
319
+ h = self.down[i_level].attn[i_block](h)
320
+ hs.append(h)
321
+ if i_level != self.num_resolutions-1:
322
+ hs.append(self.down[i_level].downsample(hs[-1]))
323
+
324
+ # middle
325
+ h = hs[-1]
326
+ h = self.mid.block_1(h, temb)
327
+ h = self.mid.attn_1(h)
328
+ h = self.mid.block_2(h, temb)
329
+
330
+ # upsampling
331
+ for i_level in reversed(range(self.num_resolutions)):
332
+ for i_block in range(self.num_res_blocks+1):
333
+ h = self.up[i_level].block[i_block](
334
+ torch.cat([h, hs.pop()], dim=1), temb)
335
+ if len(self.up[i_level].attn) > 0:
336
+ h = self.up[i_level].attn[i_block](h)
337
+ if i_level != 0:
338
+ h = self.up[i_level].upsample(h)
339
+
340
+ # end
341
+ h = self.norm_out(h)
342
+ h = nonlinearity(h)
343
+ h = self.conv_out(h)
344
+ return h
345
+
346
+
347
+ class Encoder(nn.Module):
348
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
349
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
350
+ resolution, z_channels, double_z=True, **ignore_kwargs):
351
+ super().__init__()
352
+ self.ch = ch
353
+ self.temb_ch = 0
354
+ self.num_resolutions = len(ch_mult)
355
+ self.num_res_blocks = num_res_blocks
356
+ self.resolution = resolution
357
+ self.in_channels = in_channels
358
+
359
+ # downsampling
360
+ self.conv_in = torch.nn.Conv2d(in_channels,
361
+ self.ch,
362
+ kernel_size=3,
363
+ stride=1,
364
+ padding=1)
365
+
366
+ curr_res = resolution
367
+ in_ch_mult = (1,)+tuple(ch_mult)
368
+ self.down = nn.ModuleList()
369
+ for i_level in range(self.num_resolutions):
370
+ block = nn.ModuleList()
371
+ attn = nn.ModuleList()
372
+ block_in = ch*in_ch_mult[i_level]
373
+ block_out = ch*ch_mult[i_level]
374
+ for i_block in range(self.num_res_blocks):
375
+ block.append(ResnetBlock(in_channels=block_in,
376
+ out_channels=block_out,
377
+ temb_channels=self.temb_ch,
378
+ dropout=dropout))
379
+ block_in = block_out
380
+ if curr_res in attn_resolutions:
381
+ attn.append(AttnBlock(block_in))
382
+ down = nn.Module()
383
+ down.block = block
384
+ down.attn = attn
385
+ if i_level != self.num_resolutions-1:
386
+ down.downsample = Downsample(block_in, resamp_with_conv)
387
+ curr_res = curr_res // 2
388
+ self.down.append(down)
389
+
390
+ # middle
391
+ self.mid = nn.Module()
392
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
393
+ out_channels=block_in,
394
+ temb_channels=self.temb_ch,
395
+ dropout=dropout)
396
+ self.mid.attn_1 = AttnBlock(block_in)
397
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
398
+ out_channels=block_in,
399
+ temb_channels=self.temb_ch,
400
+ dropout=dropout)
401
+
402
+ # end
403
+ self.norm_out = Normalize(block_in)
404
+ self.conv_out = torch.nn.Conv2d(block_in,
405
+ 2*z_channels if double_z else z_channels,
406
+ kernel_size=3,
407
+ stride=1,
408
+ padding=1)
409
+
410
+ @property
411
+ def hidden_dim(self):
412
+ return self.conv_out.out_channels
413
+
414
+ def forward(self, x):
415
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
416
+
417
+ # timestep embedding
418
+ temb = None
419
+
420
+ # downsampling
421
+ hs = [self.conv_in(x)]
422
+ for i_level in range(self.num_resolutions):
423
+ for i_block in range(self.num_res_blocks):
424
+ h = self.down[i_level].block[i_block](hs[-1], temb)
425
+ if len(self.down[i_level].attn) > 0:
426
+ h = self.down[i_level].attn[i_block](h)
427
+ hs.append(h)
428
+ if i_level != self.num_resolutions-1:
429
+ hs.append(self.down[i_level].downsample(hs[-1]))
430
+
431
+ # middle
432
+ h = hs[-1]
433
+ h = self.mid.block_1(h, temb)
434
+ h = self.mid.attn_1(h)
435
+ h = self.mid.block_2(h, temb)
436
+
437
+ # end
438
+ h = self.norm_out(h)
439
+ h = nonlinearity(h)
440
+ h = self.conv_out(h)
441
+ return h
442
+
443
+
444
+ class Decoder(nn.Module):
445
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
446
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
447
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
448
+ super().__init__()
449
+ self.ch = ch
450
+ self.temb_ch = 0
451
+ self.num_resolutions = len(ch_mult)
452
+ self.num_res_blocks = num_res_blocks
453
+ self.resolution = resolution
454
+ self.in_channels = in_channels
455
+ self.give_pre_end = give_pre_end
456
+
457
+ # compute in_ch_mult, block_in and curr_res at lowest res
458
+ in_ch_mult = (1,)+tuple(ch_mult)
459
+ block_in = ch*ch_mult[self.num_resolutions-1]
460
+ curr_res = resolution // 2**(self.num_resolutions-1)
461
+ self.z_shape = (1,z_channels,curr_res,curr_res)
462
+ print("Working with z of shape {} = {} dimensions.".format(
463
+ self.z_shape, np.prod(self.z_shape)))
464
+
465
+ # z to block_in
466
+ self.conv_in = torch.nn.Conv2d(z_channels,
467
+ block_in,
468
+ kernel_size=3,
469
+ stride=1,
470
+ padding=1)
471
+
472
+ # middle
473
+ self.mid = nn.Module()
474
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
475
+ out_channels=block_in,
476
+ temb_channels=self.temb_ch,
477
+ dropout=dropout)
478
+ self.mid.attn_1 = AttnBlock(block_in)
479
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
480
+ out_channels=block_in,
481
+ temb_channels=self.temb_ch,
482
+ dropout=dropout)
483
+
484
+ # upsampling
485
+ self.up = nn.ModuleList()
486
+ for i_level in reversed(range(self.num_resolutions)):
487
+ block = nn.ModuleList()
488
+ attn = nn.ModuleList()
489
+ block_out = ch*ch_mult[i_level]
490
+ for i_block in range(self.num_res_blocks+1):
491
+ block.append(ResnetBlock(in_channels=block_in,
492
+ out_channels=block_out,
493
+ temb_channels=self.temb_ch,
494
+ dropout=dropout))
495
+ block_in = block_out
496
+ if curr_res in attn_resolutions:
497
+ attn.append(AttnBlock(block_in))
498
+ up = nn.Module()
499
+ up.block = block
500
+ up.attn = attn
501
+ if i_level != 0:
502
+ up.upsample = Upsample(block_in, resamp_with_conv)
503
+ curr_res = curr_res * 2
504
+ self.up.insert(0, up) # prepend to get consistent order
505
+
506
+ # end
507
+ self.norm_out = Normalize(block_in)
508
+ self.conv_out = torch.nn.Conv2d(block_in,
509
+ out_ch,
510
+ kernel_size=3,
511
+ stride=1,
512
+ padding=1)
513
+
514
+ @property
515
+ def hidden_dim(self):
516
+ return self.conv_in.in_channels
517
+
518
+ def forward(self, z):
519
+ #assert z.shape[1:] == self.z_shape[1:]
520
+ self.last_z_shape = z.shape
521
+
522
+ # timestep embedding
523
+ temb = None
524
+
525
+ # z to block_in
526
+ h = self.conv_in(z)
527
+
528
+ # middle
529
+ h = self.mid.block_1(h, temb)
530
+ h = self.mid.attn_1(h)
531
+ h = self.mid.block_2(h, temb)
532
+
533
+ # upsampling
534
+ for i_level in reversed(range(self.num_resolutions)):
535
+ for i_block in range(self.num_res_blocks+1):
536
+ h = self.up[i_level].block[i_block](h, temb)
537
+ if len(self.up[i_level].attn) > 0:
538
+ h = self.up[i_level].attn[i_block](h)
539
+ if i_level != 0:
540
+ h = self.up[i_level].upsample(h)
541
+
542
+ # end
543
+ if self.give_pre_end:
544
+ return h
545
+
546
+ h = self.norm_out(h)
547
+ h = nonlinearity(h)
548
+ h = self.conv_out(h)
549
+ return h
550
+
551
+
552
+ class VUNet(nn.Module):
553
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
554
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
555
+ in_channels, c_channels,
556
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
557
+ super().__init__()
558
+ self.ch = ch
559
+ self.temb_ch = self.ch*4
560
+ self.num_resolutions = len(ch_mult)
561
+ self.num_res_blocks = num_res_blocks
562
+ self.resolution = resolution
563
+
564
+ self.use_timestep = use_timestep
565
+ if self.use_timestep:
566
+ # timestep embedding
567
+ self.temb = nn.Module()
568
+ self.temb.dense = nn.ModuleList([
569
+ torch.nn.Linear(self.ch,
570
+ self.temb_ch),
571
+ torch.nn.Linear(self.temb_ch,
572
+ self.temb_ch),
573
+ ])
574
+
575
+ # downsampling
576
+ self.conv_in = torch.nn.Conv2d(c_channels,
577
+ self.ch,
578
+ kernel_size=3,
579
+ stride=1,
580
+ padding=1)
581
+
582
+ curr_res = resolution
583
+ in_ch_mult = (1,)+tuple(ch_mult)
584
+ self.down = nn.ModuleList()
585
+ for i_level in range(self.num_resolutions):
586
+ block = nn.ModuleList()
587
+ attn = nn.ModuleList()
588
+ block_in = ch*in_ch_mult[i_level]
589
+ block_out = ch*ch_mult[i_level]
590
+ for i_block in range(self.num_res_blocks):
591
+ block.append(ResnetBlock(in_channels=block_in,
592
+ out_channels=block_out,
593
+ temb_channels=self.temb_ch,
594
+ dropout=dropout))
595
+ block_in = block_out
596
+ if curr_res in attn_resolutions:
597
+ attn.append(AttnBlock(block_in))
598
+ down = nn.Module()
599
+ down.block = block
600
+ down.attn = attn
601
+ if i_level != self.num_resolutions-1:
602
+ down.downsample = Downsample(block_in, resamp_with_conv)
603
+ curr_res = curr_res // 2
604
+ self.down.append(down)
605
+
606
+ self.z_in = torch.nn.Conv2d(z_channels,
607
+ block_in,
608
+ kernel_size=1,
609
+ stride=1,
610
+ padding=0)
611
+ # middle
612
+ self.mid = nn.Module()
613
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
614
+ out_channels=block_in,
615
+ temb_channels=self.temb_ch,
616
+ dropout=dropout)
617
+ self.mid.attn_1 = AttnBlock(block_in)
618
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
619
+ out_channels=block_in,
620
+ temb_channels=self.temb_ch,
621
+ dropout=dropout)
622
+
623
+ # upsampling
624
+ self.up = nn.ModuleList()
625
+ for i_level in reversed(range(self.num_resolutions)):
626
+ block = nn.ModuleList()
627
+ attn = nn.ModuleList()
628
+ block_out = ch*ch_mult[i_level]
629
+ skip_in = ch*ch_mult[i_level]
630
+ for i_block in range(self.num_res_blocks+1):
631
+ if i_block == self.num_res_blocks:
632
+ skip_in = ch*in_ch_mult[i_level]
633
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
634
+ out_channels=block_out,
635
+ temb_channels=self.temb_ch,
636
+ dropout=dropout))
637
+ block_in = block_out
638
+ if curr_res in attn_resolutions:
639
+ attn.append(AttnBlock(block_in))
640
+ up = nn.Module()
641
+ up.block = block
642
+ up.attn = attn
643
+ if i_level != 0:
644
+ up.upsample = Upsample(block_in, resamp_with_conv)
645
+ curr_res = curr_res * 2
646
+ self.up.insert(0, up) # prepend to get consistent order
647
+
648
+ # end
649
+ self.norm_out = Normalize(block_in)
650
+ self.conv_out = torch.nn.Conv2d(block_in,
651
+ out_ch,
652
+ kernel_size=3,
653
+ stride=1,
654
+ padding=1)
655
+
656
+
657
+ def forward(self, x, z):
658
+ #assert x.shape[2] == x.shape[3] == self.resolution
659
+
660
+ if self.use_timestep:
661
+ # timestep embedding
662
+ assert t is not None
663
+ temb = get_timestep_embedding(t, self.ch)
664
+ temb = self.temb.dense[0](temb)
665
+ temb = nonlinearity(temb)
666
+ temb = self.temb.dense[1](temb)
667
+ else:
668
+ temb = None
669
+
670
+ # downsampling
671
+ hs = [self.conv_in(x)]
672
+ for i_level in range(self.num_resolutions):
673
+ for i_block in range(self.num_res_blocks):
674
+ h = self.down[i_level].block[i_block](hs[-1], temb)
675
+ if len(self.down[i_level].attn) > 0:
676
+ h = self.down[i_level].attn[i_block](h)
677
+ hs.append(h)
678
+ if i_level != self.num_resolutions-1:
679
+ hs.append(self.down[i_level].downsample(hs[-1]))
680
+
681
+ # middle
682
+ h = hs[-1]
683
+ z = self.z_in(z)
684
+ h = torch.cat((h,z),dim=1)
685
+ h = self.mid.block_1(h, temb)
686
+ h = self.mid.attn_1(h)
687
+ h = self.mid.block_2(h, temb)
688
+
689
+ # upsampling
690
+ for i_level in reversed(range(self.num_resolutions)):
691
+ for i_block in range(self.num_res_blocks+1):
692
+ h = self.up[i_level].block[i_block](
693
+ torch.cat([h, hs.pop()], dim=1), temb)
694
+ if len(self.up[i_level].attn) > 0:
695
+ h = self.up[i_level].attn[i_block](h)
696
+ if i_level != 0:
697
+ h = self.up[i_level].upsample(h)
698
+
699
+ # end
700
+ h = self.norm_out(h)
701
+ h = nonlinearity(h)
702
+ h = self.conv_out(h)
703
+ return h
704
+
705
+
706
+ class SimpleDecoder(nn.Module):
707
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
708
+ super().__init__()
709
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
710
+ ResnetBlock(in_channels=in_channels,
711
+ out_channels=2 * in_channels,
712
+ temb_channels=0, dropout=0.0),
713
+ ResnetBlock(in_channels=2 * in_channels,
714
+ out_channels=4 * in_channels,
715
+ temb_channels=0, dropout=0.0),
716
+ ResnetBlock(in_channels=4 * in_channels,
717
+ out_channels=2 * in_channels,
718
+ temb_channels=0, dropout=0.0),
719
+ nn.Conv2d(2*in_channels, in_channels, 1),
720
+ Upsample(in_channels, with_conv=True)])
721
+ # end
722
+ self.norm_out = Normalize(in_channels)
723
+ self.conv_out = torch.nn.Conv2d(in_channels,
724
+ out_channels,
725
+ kernel_size=3,
726
+ stride=1,
727
+ padding=1)
728
+
729
+ def forward(self, x):
730
+ for i, layer in enumerate(self.model):
731
+ if i in [1,2,3]:
732
+ x = layer(x, None)
733
+ else:
734
+ x = layer(x)
735
+
736
+ h = self.norm_out(x)
737
+ h = nonlinearity(h)
738
+ x = self.conv_out(h)
739
+ return x
740
+
741
+
742
+ class UpsampleDecoder(nn.Module):
743
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
744
+ ch_mult=(2,2), dropout=0.0):
745
+ super().__init__()
746
+ # upsampling
747
+ self.temb_ch = 0
748
+ self.num_resolutions = len(ch_mult)
749
+ self.num_res_blocks = num_res_blocks
750
+ block_in = in_channels
751
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
752
+ self.res_blocks = nn.ModuleList()
753
+ self.upsample_blocks = nn.ModuleList()
754
+ for i_level in range(self.num_resolutions):
755
+ res_block = []
756
+ block_out = ch * ch_mult[i_level]
757
+ for i_block in range(self.num_res_blocks + 1):
758
+ res_block.append(ResnetBlock(in_channels=block_in,
759
+ out_channels=block_out,
760
+ temb_channels=self.temb_ch,
761
+ dropout=dropout))
762
+ block_in = block_out
763
+ self.res_blocks.append(nn.ModuleList(res_block))
764
+ if i_level != self.num_resolutions - 1:
765
+ self.upsample_blocks.append(Upsample(block_in, True))
766
+ curr_res = curr_res * 2
767
+
768
+ # end
769
+ self.norm_out = Normalize(block_in)
770
+ self.conv_out = torch.nn.Conv2d(block_in,
771
+ out_channels,
772
+ kernel_size=3,
773
+ stride=1,
774
+ padding=1)
775
+
776
+ def forward(self, x):
777
+ # upsampling
778
+ h = x
779
+ for k, i_level in enumerate(range(self.num_resolutions)):
780
+ for i_block in range(self.num_res_blocks + 1):
781
+ h = self.res_blocks[i_level][i_block](h, None)
782
+ if i_level != self.num_resolutions - 1:
783
+ h = self.upsample_blocks[k](h)
784
+ h = self.norm_out(h)
785
+ h = nonlinearity(h)
786
+ h = self.conv_out(h)
787
+ return h
optvq/models/backbone/simple_cnn.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ import torch.nn as nn
8
+
9
+ class PlainCNNEncoder(nn.Module):
10
+ def __init__(self, in_dim: int = 3):
11
+ super(PlainCNNEncoder, self).__init__()
12
+
13
+ self.in_dim = in_dim
14
+
15
+ self.in_fc = nn.Conv2d(in_channels=in_dim, out_channels=16,
16
+ kernel_size=3, stride=1, padding=1, bias=True)
17
+ self.act0 = nn.ReLU(inplace=True)
18
+
19
+ self.down1 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)
20
+ self.conv1 = nn.Conv2d(in_channels=16, out_channels=16,
21
+ kernel_size=3, stride=1, padding=1, bias=True)
22
+ self.act1 = nn.ReLU(inplace=True)
23
+
24
+ self.down2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)
25
+ self.conv2 = nn.Conv2d(in_channels=16, out_channels=32,
26
+ kernel_size=3, stride=1, padding=1, bias=True)
27
+ self.act2 = nn.ReLU(inplace=True)
28
+
29
+ self.out_fc = nn.Conv2d(in_channels=32, out_channels=32,
30
+ kernel_size=3, stride=1, padding=1, bias=True)
31
+
32
+ @property
33
+ def hidden_dim(self):
34
+ return 32
35
+
36
+ def forward(self, x):
37
+ x = self.in_fc(x)
38
+ x = self.act0(x)
39
+
40
+ x = self.down1(x)
41
+ x = self.conv1(x)
42
+ x = self.act1(x)
43
+
44
+ x = self.down2(x)
45
+ x = self.conv2(x)
46
+ x = self.act2(x)
47
+
48
+ x = self.out_fc(x)
49
+ return x
50
+
51
+ class PlainCNNDecoder(nn.Module):
52
+ def __init__(self, out_dim: int = 3):
53
+ super(PlainCNNDecoder, self).__init__()
54
+ self.out_dim = out_dim
55
+
56
+ self.in_fc = nn.Conv2d(in_channels=32, out_channels=32,
57
+ kernel_size=3, stride=1, padding=1, bias=True)
58
+
59
+ self.act1 = nn.ReLU(inplace=True)
60
+ self.up1 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2)
61
+ self.conv1 = nn.Conv2d(in_channels=16, out_channels=16,
62
+ kernel_size=3, stride=1, padding=1, bias=True)
63
+
64
+ self.act2 = nn.ReLU(inplace=True)
65
+ self.up2 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)
66
+ self.conv2 = nn.Conv2d(in_channels=16, out_channels=16,
67
+ kernel_size=3, stride=1, padding=1, bias=True)
68
+
69
+ self.act3 = nn.ReLU(inplace=True)
70
+ self.out_fc = nn.Conv2d(in_channels=16, out_channels=out_dim,
71
+ kernel_size=3, stride=1, padding=1, bias=True)
72
+
73
+ @property
74
+ def hidden_dim(self):
75
+ return 32
76
+
77
+ def forward(self, x):
78
+ x = self.in_fc(x)
79
+
80
+ x = self.act1(x)
81
+ x = self.up1(x)
82
+ x = self.conv1(x)
83
+
84
+ x = self.act2(x)
85
+ x = self.up2(x)
86
+ x = self.conv2(x)
87
+
88
+ x = self.act3(x)
89
+ x = self.out_fc(x)
90
+ return x
optvq/models/discriminator.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # Copy from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------
5
+
6
+ import functools
7
+ import torch.nn as nn
8
+ import torch
9
+
10
+ class ActNorm(nn.Module):
11
+ def __init__(self, num_features, logdet=False, affine=True,
12
+ allow_reverse_init=False):
13
+ assert affine
14
+ super().__init__()
15
+ self.logdet = logdet
16
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18
+ self.allow_reverse_init = allow_reverse_init
19
+
20
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21
+
22
+ def initialize(self, input):
23
+ with torch.no_grad():
24
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25
+ mean = (
26
+ flatten.mean(1)
27
+ .unsqueeze(1)
28
+ .unsqueeze(2)
29
+ .unsqueeze(3)
30
+ .permute(1, 0, 2, 3)
31
+ )
32
+ std = (
33
+ flatten.std(1)
34
+ .unsqueeze(1)
35
+ .unsqueeze(2)
36
+ .unsqueeze(3)
37
+ .permute(1, 0, 2, 3)
38
+ )
39
+
40
+ self.loc.data.copy_(-mean)
41
+ self.scale.data.copy_(1 / (std + 1e-6))
42
+
43
+ def forward(self, input, reverse=False):
44
+ if reverse:
45
+ return self.reverse(input)
46
+ if len(input.shape) == 2:
47
+ input = input[:,:,None,None]
48
+ squeeze = True
49
+ else:
50
+ squeeze = False
51
+
52
+ _, _, height, width = input.shape
53
+
54
+ if self.training and self.initialized.item() == 0:
55
+ self.initialize(input)
56
+ self.initialized.fill_(1)
57
+
58
+ h = self.scale * (input + self.loc)
59
+
60
+ if squeeze:
61
+ h = h.squeeze(-1).squeeze(-1)
62
+
63
+ if self.logdet:
64
+ log_abs = torch.log(torch.abs(self.scale))
65
+ logdet = height*width*torch.sum(log_abs)
66
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
67
+ return h, logdet
68
+
69
+ return h
70
+
71
+ def reverse(self, output):
72
+ if self.training and self.initialized.item() == 0:
73
+ if not self.allow_reverse_init:
74
+ raise RuntimeError(
75
+ "Initializing ActNorm in reverse direction is "
76
+ "disabled by default. Use allow_reverse_init=True to enable."
77
+ )
78
+ else:
79
+ self.initialize(output)
80
+ self.initialized.fill_(1)
81
+
82
+ if len(output.shape) == 2:
83
+ output = output[:,:,None,None]
84
+ squeeze = True
85
+ else:
86
+ squeeze = False
87
+
88
+ h = output / self.scale - self.loc
89
+
90
+ if squeeze:
91
+ h = h.squeeze(-1).squeeze(-1)
92
+ return h
93
+
94
+
95
+
96
+ def weights_init(m):
97
+ classname = m.__class__.__name__
98
+ if classname.find('Conv') != -1:
99
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
100
+ elif classname.find('BatchNorm') != -1:
101
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
102
+ nn.init.constant_(m.bias.data, 0)
103
+
104
+
105
+ class NLayerDiscriminator(nn.Module):
106
+ """Defines a PatchGAN discriminator as in Pix2Pix
107
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
108
+ """
109
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
110
+ """Construct a PatchGAN discriminator
111
+ Parameters:
112
+ input_nc (int) -- the number of channels in input images
113
+ ndf (int) -- the number of filters in the last conv layer
114
+ n_layers (int) -- the number of conv layers in the discriminator
115
+ norm_layer -- normalization layer
116
+ """
117
+ super(NLayerDiscriminator, self).__init__()
118
+ if not use_actnorm:
119
+ norm_layer = nn.BatchNorm2d
120
+ else:
121
+ norm_layer = ActNorm
122
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
123
+ use_bias = norm_layer.func != nn.BatchNorm2d
124
+ else:
125
+ use_bias = norm_layer != nn.BatchNorm2d
126
+
127
+ kw = 4
128
+ padw = 1
129
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
130
+ nf_mult = 1
131
+ nf_mult_prev = 1
132
+ for n in range(1, n_layers): # gradually increase the number of filters
133
+ nf_mult_prev = nf_mult
134
+ nf_mult = min(2 ** n, 8)
135
+ sequence += [
136
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
137
+ norm_layer(ndf * nf_mult),
138
+ nn.LeakyReLU(0.2, True)
139
+ ]
140
+
141
+ nf_mult_prev = nf_mult
142
+ nf_mult = min(2 ** n_layers, 8)
143
+ sequence += [
144
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
145
+ norm_layer(ndf * nf_mult),
146
+ nn.LeakyReLU(0.2, True)
147
+ ]
148
+
149
+ sequence += [
150
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
151
+ self.main = nn.Sequential(*sequence)
152
+
153
+ def forward(self, input):
154
+ """Standard forward."""
155
+ return self.main(input)
optvq/models/quantizer.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch import Tensor
14
+
15
+ import torch.distributed as dist
16
+
17
+ import optvq.utils.logger as L
18
+
19
+ class VectorQuantizer(nn.Module):
20
+ def __init__(self, n_e: int = 1024, e_dim: int = 128,
21
+ beta: float = 1.0, use_norm: bool = False,
22
+ use_proj: bool = True, fix_codes: bool = False,
23
+ loss_q_type: str = "ce",
24
+ num_head: int = 1,
25
+ start_quantize_steps: int = None):
26
+ super(VectorQuantizer, self).__init__()
27
+ self.n_e = n_e
28
+ self.e_dim = e_dim
29
+ self.beta = beta
30
+ self.loss_q_type = loss_q_type
31
+ self.num_head = num_head
32
+ self.start_quantize_steps = start_quantize_steps
33
+ self.code_dim = self.e_dim // self.num_head
34
+
35
+ self.norm = lambda x: F.normalize(x, p=2.0, dim=-1, eps=1e-6) if use_norm else x
36
+ assert not use_norm, f"use_norm=True is no longer supported! Because the norm operation without theorectical analysis may cause unpredictable unstability."
37
+ self.use_proj = use_proj
38
+
39
+ self.embedding = nn.Embedding(num_embeddings=n_e, embedding_dim=self.code_dim)
40
+ if use_proj:
41
+ self.proj = nn.Linear(self.code_dim, self.code_dim)
42
+ torch.nn.init.normal_(self.proj.weight, std=self.code_dim ** -0.5)
43
+ if fix_codes:
44
+ self.embedding.weight.requires_grad = False
45
+
46
+ def reshape_input(self, x: Tensor):
47
+ """
48
+ (B, C, H, W) / (B, T, C) -> (B, T, C)
49
+ """
50
+ if x.ndim == 4:
51
+ _, C, H, W = x.size()
52
+ x = x.permute(0, 2, 3, 1).contiguous().view(-1, H * W, C)
53
+ return x, {"size": (H, W)}
54
+ elif x.ndim == 3:
55
+ return x, None
56
+ else:
57
+ raise ValueError("Invalid input shape!")
58
+
59
+ def recover_output(self, x: Tensor, info):
60
+ if info is not None:
61
+ H, W = info["size"]
62
+ if x.ndim == 3: # features (B, T, C) -> (B, C, H, W)
63
+ C = x.size(2)
64
+ return x.view(-1, H, W, C).permute(0, 3, 1, 2).contiguous()
65
+ elif x.ndim == 2: # indices (B, T) -> (B, H, W)
66
+ return x.view(-1, H, W)
67
+ else:
68
+ raise ValueError("Invalid input shape!")
69
+ else: # features (B, T, C) or indices (B, T)
70
+ return x
71
+
72
+ def get_codebook(self, return_numpy: bool = True):
73
+ embed = self.proj(self.embedding.weight) if self.use_proj else self.embedding.weight
74
+ if return_numpy:
75
+ return embed.data.cpu().numpy()
76
+ else:
77
+ return embed.data
78
+
79
+ def quantize_input(self, query, reference):
80
+ # compute the distance matrix
81
+ query2ref = torch.cdist(query, reference, p=2.0) # (B1, B2)
82
+
83
+ # find the nearest embedding
84
+ indices = torch.argmin(query2ref, dim=-1) # (B1,)
85
+ nearest_ref = reference[indices] # (B1, C)
86
+
87
+ return indices, nearest_ref, query2ref
88
+
89
+ def compute_codebook_loss(self, query, indices, nearest_ref, beta: float, query2ref):
90
+ # compute the loss
91
+ if self.loss_q_type == "l2":
92
+ loss = torch.mean((query - nearest_ref.detach()).pow(2)) + \
93
+ torch.mean((nearest_ref - query.detach()).pow(2)) * beta
94
+ elif self.loss_q_type == "l1":
95
+ loss = torch.mean((query - nearest_ref.detach()).abs()) + \
96
+ torch.mean((nearest_ref - query.detach()).abs()) * beta
97
+ elif self.loss_q_type == "ce":
98
+ loss = F.cross_entropy(- query2ref, indices)
99
+
100
+ return loss
101
+
102
+ def compute_quantized_output(self, x, x_q):
103
+ if self.start_quantize_steps is not None:
104
+ if self.training and L.log.total_steps < self.start_quantize_steps:
105
+ L.log.add_scalar("params/quantize_ratio", 0.0)
106
+ return x
107
+ else:
108
+ L.log.add_scalar("params/quantize_ratio", 1.0)
109
+ return x + (x_q - x).detach()
110
+ else:
111
+ L.log.add_scalar("params/quantize_ratio", 1.0)
112
+ return x + (x_q - x).detach()
113
+
114
+ @torch.autocast(device_type="cuda", enabled=False)
115
+ def forward(self, x: Tensor):
116
+ """
117
+ Quantize the input tensor x with the embedding table.
118
+
119
+ Args:
120
+ x (Tensor): input tensor with shape (B, C, H, W) or (B, T, C)
121
+ Returns:
122
+ (tuple) containing: (x_q, loss, indices)
123
+ """
124
+ x = x.float()
125
+ x, info = self.reshape_input(x)
126
+ B, T, C = x.size()
127
+ x = x.view(-1, C) # (B * T, C)
128
+ embed = self.proj(self.embedding.weight) if self.use_proj else self.embedding.weight
129
+
130
+ # split the x if multi-head is used
131
+ if self.num_head > 1:
132
+ x = x.view(-1, self.code_dim) # (B * T * nH, dC)
133
+
134
+ # compute the distance between x and each embedding
135
+ x, embed = self.norm(x), self.norm(embed)
136
+
137
+ # compute losses
138
+ indices, x_q, query2ref = self.quantize_input(x, embed)
139
+ loss = self.compute_codebook_loss(
140
+ query=x, indices=indices, nearest_ref=x_q,
141
+ beta=self.beta, query2ref=query2ref
142
+ )
143
+
144
+ # compute statistics
145
+ if self.training and L.GET_STATS:
146
+ with torch.no_grad():
147
+ num_unique = torch.unique(indices).size(0)
148
+ x_norm_mean = torch.mean(x.norm(dim=-1))
149
+ embed_norm_mean = torch.mean(embed.norm(dim=-1))
150
+ diff_norm_mean = torch.mean((x_q - x).norm(dim=-1))
151
+ x2e_mean = query2ref.mean()
152
+ L.log.add_scalar("params/num_unique", num_unique)
153
+ L.log.add_scalar("params/x_norm", x_norm_mean.item())
154
+ L.log.add_scalar("params/embed_norm", embed_norm_mean.item())
155
+ L.log.add_scalar("params/diff_norm", diff_norm_mean.item())
156
+ L.log.add_scalar("params/x2e_mean", x2e_mean.item())
157
+
158
+ # compute the final x_q
159
+ x_q = self.compute_quantized_output(x, x_q).view(B, T, C)
160
+ indices = indices.view(B, T, self.num_head)
161
+
162
+ # for output
163
+ x_q = self.recover_output(x_q, info)
164
+ indices = self.recover_output(indices, info)
165
+
166
+ return x_q, loss, indices
167
+
168
+ def sinkhorn(cost: Tensor, n_iters: int = 3, epsilon: float = 1, is_distributed: bool = False):
169
+ """
170
+ Sinkhorn algorithm.
171
+ Args:
172
+ cost (Tensor): shape with (B, K)
173
+ """
174
+ Q = torch.exp(- cost * epsilon).t() # (K, B)
175
+ if is_distributed:
176
+ B = Q.size(1) * dist.get_world_size()
177
+ else:
178
+ B = Q.size(1)
179
+ K = Q.size(0)
180
+
181
+ # make the matrix sums to 1
182
+ sum_Q = torch.sum(Q)
183
+ if is_distributed:
184
+ dist.all_reduce(sum_Q)
185
+ Q /= (sum_Q + 1e-8)
186
+
187
+ for _ in range(n_iters):
188
+ # normalize each row: total weight per prototype must be 1/K
189
+ sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
190
+ if is_distributed:
191
+ dist.all_reduce(sum_of_rows)
192
+ Q /= (sum_of_rows + 1e-8)
193
+ Q /= K
194
+
195
+ # normalize each column: total weight per sample must be 1/B
196
+ Q /= (torch.sum(Q, dim=0, keepdim=True) + 1e-8)
197
+ Q /= B
198
+
199
+ Q *= B # the columns must sum to 1 so that Q is an assignment
200
+ return Q.t() # (B, K)
201
+
202
+ class VectorQuantizerSinkhorn(VectorQuantizer):
203
+ def __init__(self, epsilon: float = 10.0, n_iters: int = 5,
204
+ normalize_mode: str = "all", use_prob: bool = True,
205
+ *args, **kwargs):
206
+ super(VectorQuantizerSinkhorn, self).__init__(*args, **kwargs)
207
+ self.epsilon = epsilon
208
+ self.n_iters = n_iters
209
+ self.normalize_mode = normalize_mode
210
+ self.use_prob = use_prob
211
+
212
+ def normalize(self, A, dim, mode="all"):
213
+ if mode == "all":
214
+ A = (A - A.mean()) / (A.std() + 1e-6)
215
+ A = A - A.min()
216
+ elif mode == "dim":
217
+ A = A / math.sqrt(dim)
218
+ elif mode == "null":
219
+ pass
220
+ return A
221
+
222
+ def quantize_input(self, query, reference):
223
+ # compute the distance matrix
224
+ query2ref = torch.cdist(query, reference, p=2.0) # (B1, B2)
225
+
226
+ # compute the assignment matrix
227
+ with torch.no_grad():
228
+ is_distributed = dist.is_initialized() and dist.get_world_size() > 1
229
+ normalized_cost = self.normalize(query2ref, dim=reference.size(1), mode=self.normalize_mode)
230
+ Q = sinkhorn(normalized_cost, n_iters=self.n_iters, epsilon=self.epsilon, is_distributed=is_distributed)
231
+
232
+ if self.use_prob:
233
+ # avoid the zero value problem
234
+ max_q_id = torch.argmax(Q, dim=-1)
235
+ Q[torch.arange(Q.size(0)), max_q_id] += 1e-8
236
+ indices = torch.multinomial(Q, num_samples=1).squeeze()
237
+ else:
238
+ indices = torch.argmax(Q, dim=-1)
239
+ nearest_ref = reference[indices]
240
+
241
+ if self.training and L.GET_STATS:
242
+ if L.log.total_steps % 1000 == 0:
243
+ L.log.add_histogram("params/normalized_cost", normalized_cost)
244
+
245
+ return indices, nearest_ref, query2ref
246
+
247
+ class Identity(VectorQuantizer):
248
+ @torch.autocast(device_type="cuda", enabled=False)
249
+ def forward(self, x: Tensor):
250
+ x = x.float()
251
+ loss_q = torch.tensor(0.0, device=x.device, dtype=x.dtype)
252
+
253
+ # compute statistics
254
+ if self.training and L.GET_STATS:
255
+ with torch.no_grad():
256
+ x_flatten, _ = self.reshape_input(x)
257
+ x_norm_mean = torch.mean(x_flatten.norm(dim=-1))
258
+ L.log.add_scalar("params/x_norm", x_norm_mean.item())
259
+
260
+ return x, loss_q, None
optvq/models/vqgan.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+ # Modified from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
7
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.nn as nn
13
+
14
+ import optvq.utils.logger as L
15
+
16
+ class Identity(nn.Module):
17
+ def forward(self, x):
18
+ return x
19
+
20
+ class VQModel(nn.Module):
21
+ def __init__(self,
22
+ encoder: nn.Module,
23
+ decoder: nn.Module,
24
+ loss: nn.Module,
25
+ quantize: nn.Module,
26
+ ckpt_path: str = None,
27
+ ignore_keys=[],
28
+ image_key="image",
29
+ colorize_nlabels=None,
30
+ monitor=None,
31
+ use_connector: bool = True,
32
+ ):
33
+ super(VQModel, self).__init__()
34
+ self.encoder = encoder
35
+ self.decoder = decoder
36
+ self.loss = loss
37
+ self.quantize = quantize
38
+ self.use_connector = use_connector
39
+
40
+ encoder_dim = self.encoder.hidden_dim
41
+ decoder_dim = self.decoder.hidden_dim
42
+ embed_dim = self.quantize.e_dim
43
+
44
+ if not use_connector:
45
+ self.quant_conv = Identity()
46
+ self.post_quant_conv = Identity()
47
+ assert encoder_dim == embed_dim, f"{encoder_dim} != {embed_dim}"
48
+ assert decoder_dim == embed_dim, f"{decoder_dim} != {embed_dim}"
49
+ else:
50
+ self.quant_conv = torch.nn.Conv2d(encoder_dim, embed_dim, 1)
51
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, decoder_dim, 1)
52
+
53
+ if ckpt_path is not None:
54
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
55
+ self.image_key = image_key
56
+ if colorize_nlabels is not None:
57
+ assert type(colorize_nlabels)==int
58
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
59
+ if monitor is not None:
60
+ self.monitor = monitor
61
+
62
+ def init_from_ckpt(self, path, ignore_keys=list()):
63
+ sd = torch.load(path, map_location="cpu")["state_dict"]
64
+ keys = list(sd.keys())
65
+ for k in keys:
66
+ for ik in ignore_keys:
67
+ if k.startswith(ik):
68
+ print("Deleting key {} from state_dict.".format(k))
69
+ del sd[k]
70
+ self.load_state_dict(sd, strict=False)
71
+ print(f"Restored from {path}")
72
+
73
+ def encode(self, x):
74
+ h = self.encoder(x)
75
+ h = self.quant_conv(h)
76
+ quant, emb_loss, indices = self.quantize(h)
77
+ return quant, emb_loss, indices
78
+
79
+ def decode(self, quant):
80
+ quant = self.post_quant_conv(quant)
81
+ dec = self.decoder(quant)
82
+ return dec
83
+
84
+ def decode_code(self, code_b):
85
+ quant_b = self.quantize.embed_code(code_b)
86
+ dec = self.decode(quant_b)
87
+ return dec
88
+
89
+ def forward(self, x, mode: int = 0, global_step: int = None):
90
+ """
91
+ Args:
92
+ x (torch.Tensor): input tensor
93
+ mode (int): 0 for autoencoder, 1 for discriminator
94
+ global_step (int): global step for adaptive discriminator weight
95
+ """
96
+ global_step = global_step if global_step is not None else L.log.total_steps
97
+ quant, qloss, indices = self.encode(x)
98
+ xrec = self.decode(quant)
99
+ if mode == 0:
100
+ # compute the autoencoder loss
101
+ loss, log_dict = self.loss(qloss, x, xrec, mode, last_layer=self.get_last_layer(), global_step=global_step)
102
+ elif mode == 1:
103
+ # compute the discriminator loss
104
+ loss, log_dict = self.loss(qloss, x, xrec, mode, last_layer=self.get_last_layer(), global_step=global_step)
105
+ elif mode == 2:
106
+ # compute the hidden embedding
107
+ h = self.encoder(x)
108
+ h = self.quant_conv(h)
109
+ return h
110
+ return loss, log_dict, indices
111
+
112
+ def get_input(self, batch, k):
113
+ x = batch[k]
114
+ if len(x.shape) == 3:
115
+ x = x[..., None]
116
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
117
+ return x.float()
118
+
119
+ def get_last_layer(self):
120
+ if hasattr(self.decoder, "conv_out"):
121
+ return self.decoder.conv_out.weight
122
+ elif hasattr(self.decoder, "out_fc"):
123
+ return self.decoder.out_fc.weight
124
+ elif hasattr(self.decoder, "inv_conv"):
125
+ return self.decoder.inv_conv.weight
126
+ else:
127
+ raise NotImplementedError(f"Cannot find last layer in decoder")
128
+
129
+ def log_images(self, batch, **kwargs):
130
+ log = dict()
131
+ x = self.get_input(batch, self.image_key)
132
+ x = x.to(self.device)
133
+ xrec, _ = self(x)
134
+ if x.shape[1] > 3:
135
+ # colorize with random projection
136
+ assert xrec.shape[1] > 3
137
+ x = self.to_rgb(x)
138
+ xrec = self.to_rgb(xrec)
139
+ log["inputs"] = x
140
+ log["reconstructions"] = xrec
141
+ return log
142
+
143
+ def to_rgb(self, x):
144
+ assert self.image_key == "segmentation"
145
+ if not hasattr(self, "colorize"):
146
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
147
+ x = F.conv2d(x, weight=self.colorize)
148
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
149
+ return x
150
+
151
+ # The functions below are deprecated
152
+
153
+ def validation_step(self, batch, batch_idx):
154
+ x = self.get_input(batch, self.image_key)
155
+ xrec, qloss = self(x)
156
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
157
+ last_layer=self.get_last_layer(), split="val")
158
+
159
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
160
+ last_layer=self.get_last_layer(), split="val")
161
+ rec_loss = log_dict_ae["val/rec_loss"]
162
+ self.log("val/rec_loss", rec_loss,
163
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
164
+ self.log("val/aeloss", aeloss,
165
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
166
+ self.log_dict(log_dict_ae)
167
+ self.log_dict(log_dict_disc)
168
+ return self.log_dict
optvq/models/vqgan_hf.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ # Convert a Pytorch model to a Hugging Face model
8
+
9
+ import torch.nn as nn
10
+
11
+ from huggingface_hub import PyTorchModelHubMixin
12
+
13
+ from optvq.models.backbone.diffusion import Encoder, Decoder
14
+ from optvq.models.quantizer import VectorQuantizer, VectorQuantizerSinkhorn
15
+ from optvq.losses.aeloss_disc import AELossWithDisc
16
+ from optvq.models.vqgan import VQModel
17
+
18
+ class VQModelHF(nn.Module, PyTorchModelHubMixin):
19
+ def __init__(self,
20
+ encoder: dict = {},
21
+ decoder: dict = {},
22
+ loss: dict = {},
23
+ quantize: dict = {},
24
+ quantize_type: str = "optvq",
25
+ ckpt_path: str = None,
26
+ ignore_keys=[],
27
+ image_key="image",
28
+ colorize_nlabels=None,
29
+ monitor=None,
30
+ use_connector: bool = True,
31
+ ):
32
+ super(VQModelHF, self).__init__()
33
+ encoder = Encoder(**encoder)
34
+ decoder = Decoder(**decoder)
35
+ quantizer = self.setup_quantizer(quantize, quantize_type)
36
+ loss = AELossWithDisc(**loss)
37
+
38
+ self.model = VQModel(
39
+ encoder=encoder,
40
+ decoder=decoder,
41
+ loss=loss,
42
+ quantize=quantizer,
43
+ ckpt_path=ckpt_path,
44
+ ignore_keys=ignore_keys,
45
+ image_key=image_key,
46
+ colorize_nlabels=colorize_nlabels,
47
+ monitor=monitor,
48
+ use_connector=use_connector,
49
+ )
50
+
51
+ def setup_quantizer(self, quantizer_config, quantize_type):
52
+ if quantize_type == "optvq":
53
+ quantizer = VectorQuantizerSinkhorn(**quantizer_config)
54
+ elif quantize_type == "basevq":
55
+ quantizer = VectorQuantizer(**quantizer_config)
56
+ else:
57
+ raise ValueError(f"Unknown quantizer type: {quantize_type}")
58
+ return quantizer
59
+
60
+ def encode(self, x):
61
+ return self.model.encode(x)
62
+
63
+ def decode(self, x):
64
+ return self.model.decode(x)
65
+
66
+ def forward(self, x):
67
+ quant, *_ = self.encode(x)
68
+ rec = self.decode(quant)
69
+ return quant, rec
optvq/trainer/__pycache__/arguments.cpython-310.pyc ADDED
Binary file (1.58 kB). View file
 
optvq/trainer/__pycache__/pipeline.cpython-310.pyc ADDED
Binary file (9.02 kB). View file
 
optvq/trainer/arguments.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ import argparse
8
+
9
+ def get_parser():
10
+ parser = argparse.ArgumentParser()
11
+ # arguments with high priority
12
+ parser.add_argument("--seed", type=int, default=42,
13
+ help="The random seed.")
14
+ parser.add_argument("--gpu", type=int, nargs="+", default=None,
15
+ help="The GPU ids to use.")
16
+ parser.add_argument("--is_distributed", action="store_true", default=False)
17
+ parser.add_argument("--config", type=str, default=None,
18
+ help="The path to the configuration file.")
19
+ parser.add_argument("--resume", type=str, default=None,
20
+ help="The path to the checkpoint to resume.")
21
+ parser.add_argument("--device_rank", type=int, default=0)
22
+
23
+ # arguments for the training
24
+ parser.add_argument("--log_dir", type=str, default=None,
25
+ help="The path to the log directory.")
26
+ parser.add_argument("--mode", type=str, default="train",
27
+ help="options: train, test")
28
+ parser.add_argument("--use_initiate", type=str, default=None,
29
+ help="Options: random, kmeans")
30
+ parser.add_argument("--epochs", type=int, default=None)
31
+ parser.add_argument("--enterpoint", type=str, default=None)
32
+ parser.add_argument("--code_path", type=str, default=None)
33
+ parser.add_argument("--embed_path", type=str, default=None)
34
+
35
+ # arguments for the data
36
+ parser.add_argument("--use_train_subset", type=float, default=None,
37
+ help="The size of the training subset. None means using the full training set.")
38
+ parser.add_argument("--use_train_repeat", type=int, default=None,
39
+ help="The number of times to repeat the training set.")
40
+ parser.add_argument("--use_val_subset", type=float, default=None,
41
+ help="The size of the validation subset. None means using the full validation set.")
42
+ parser.add_argument("--batch_size", type=int, default=None)
43
+ parser.add_argument("--gradient_accumulate", type=int, default=None)
44
+
45
+ # arguments for the optimizer
46
+ parser.add_argument("--lr", type=float, default=None)
47
+ parser.add_argument("--mul_lr", type=float, default=None)
48
+
49
+ # arguments for the model
50
+ parser.add_argument("--num_codes", type=int, default=None,
51
+ help="The number of codes.")
52
+
53
+ return parser
optvq/trainer/pipeline.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from typing import Callable
8
+ import argparse
9
+ import os
10
+ from omegaconf import OmegaConf
11
+ from functools import partial
12
+ from torchinfo import summary
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.utils.data.dataloader import DataLoader
17
+
18
+ from optvq.utils.init import initiate_from_config_recursively
19
+ from optvq.data.dataloader import maybe_get_subset
20
+ import optvq.utils.logger as L
21
+
22
+ def setup_config(opt: argparse.Namespace):
23
+ L.log.info("\n\n### Setting up the configurations. ###")
24
+
25
+ # load the config files
26
+ config = OmegaConf.load(opt.config)
27
+
28
+ # overwrite the certain arguments according to the config.args mapping
29
+ for key, value in config.args_map.items():
30
+ if hasattr(opt, key) and getattr(opt, key) is not None:
31
+ msg = f"config.{value} = opt.{key}"
32
+ L.log.info(f"Overwrite the config: {msg}")
33
+ exec(msg)
34
+
35
+ return config
36
+
37
+ def setup_dataloader(data, batch_size, is_distributed: bool = True, is_train: bool = True, num_workers: int = 8):
38
+ if is_train:
39
+ if is_distributed:
40
+ # setup the sampler
41
+ sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=True, drop_last=True)
42
+ # setup the dataloader
43
+ loader = DataLoader(
44
+ dataset=data, batch_size=batch_size, num_workers=num_workers,
45
+ drop_last=True, sampler=sampler, persistent_workers=True, pin_memory=True
46
+ )
47
+ else:
48
+ # setup the dataloader
49
+ loader = DataLoader(
50
+ dataset=data, batch_size=batch_size, num_workers=num_workers,
51
+ drop_last=True, shuffle=True, persistent_workers=True, pin_memory=True
52
+ )
53
+ else:
54
+ if is_distributed:
55
+ # setup the sampler
56
+ sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=False, drop_last=False)
57
+ # setup the dataloader
58
+ loader = DataLoader(
59
+ dataset=data, batch_size=batch_size, num_workers=num_workers,
60
+ drop_last=False, sampler=sampler, persistent_workers=True, pin_memory=True
61
+ )
62
+ else:
63
+ # setup the dataloader
64
+ loader = DataLoader(
65
+ dataset=data, batch_size=batch_size, num_workers=num_workers,
66
+ drop_last=False, shuffle=False, persistent_workers=True, pin_memory=True
67
+ )
68
+
69
+ return loader
70
+
71
+ def setup_dataset(config: OmegaConf):
72
+ L.log.info("\n\n### Setting up the datasets. ###")
73
+
74
+ # setup the training dataset
75
+ train_data = initiate_from_config_recursively(config.data.train)
76
+ if config.data.use_train_subset is not None:
77
+ train_data = maybe_get_subset(train_data, subset_size=config.data.use_train_subset, num_data_repeat=config.data.use_train_repeat)
78
+ L.log.info(f"Training dataset size: {len(train_data)}")
79
+
80
+ # setup the validation dataset
81
+ val_data = initiate_from_config_recursively(config.data.val)
82
+ if config.data.use_val_subset is not None:
83
+ val_data = maybe_get_subset(val_data, subset_size=config.data.use_val_subset)
84
+ L.log.info(f"Validation dataset size: {len(val_data)}")
85
+
86
+ return train_data, val_data
87
+
88
+ def setup_model(config: OmegaConf, device):
89
+ L.log.info("\n\n### Setting up the models. ###")
90
+
91
+ # setup the model
92
+ model = initiate_from_config_recursively(config.model.autoencoder)
93
+ if config.is_distributed:
94
+ # apply syncBN
95
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
96
+ # model to devices
97
+ model = model.to(device)
98
+ find_unused_parameters = True
99
+ model = torch.nn.parallel.DistributedDataParallel(
100
+ module=model, device_ids=[config.gpu],
101
+ find_unused_parameters=find_unused_parameters
102
+ )
103
+ model_ori = model.module
104
+ else:
105
+ model = model.to(device)
106
+ model_ori = model
107
+
108
+ input_size = config.data.train.params.transform.params.resize
109
+ in_channels = getattr(model_ori.encoder, "in_dim", 3)
110
+ sout = summary(model_ori, (1, in_channels, input_size, input_size), device="cuda", verbose=0)
111
+ L.log.info(sout)
112
+
113
+ # count the total number of parameters
114
+ for name, module in model_ori.named_children():
115
+ num_params = sum(p.numel() for p in module.parameters())
116
+ num_trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
117
+ L.log.info(f"Module: {name}, Total params: {num_params}, Trainable params: {num_trainable}")
118
+
119
+ return model
120
+
121
+ ### factory functions
122
+
123
+ def get_setup_optimizers(config):
124
+ name = config.train.pipeline
125
+ func_name = "setup_optimizers_" + name
126
+ return globals()[func_name]
127
+
128
+ def get_pipeline(config):
129
+ name = config.train.pipeline
130
+ func_name = "pipeline_" + name
131
+ return globals()[func_name]
132
+
133
+ def _forward_backward(
134
+ config,
135
+ x: torch.Tensor,
136
+ forward: Callable,
137
+ model: nn.Module,
138
+ optimizer: torch.optim.Optimizer,
139
+ scheduler: torch.optim.lr_scheduler._LRScheduler,
140
+ scaler: torch.cuda.amp.GradScaler,
141
+ ):
142
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
143
+ enabled=config.use_amp):
144
+ # forward pass
145
+ loss, *output = forward(x)
146
+ loss_acc = loss / config.data.gradient_accumulate
147
+ scaler.scale(loss_acc).backward()
148
+ # gradient accumulate
149
+ if L.log.total_steps % config.data.gradient_accumulate == 0:
150
+ scaler.unscale_(optimizer)
151
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
152
+ scaler.step(optimizer)
153
+ optimizer.zero_grad()
154
+ scaler.update()
155
+
156
+ if scheduler is not None:
157
+ scheduler.step()
158
+ return loss, output
159
+
160
+ ### autoencoder version
161
+ def _find_weight_decay_id(modules: list, params_ids: list,
162
+ include_class: tuple = (nn.Linear, nn.Conv2d,
163
+ nn.ConvTranspose2d,
164
+ nn.MultiheadAttention),
165
+ include_name: list = ["weight"]):
166
+ for mod in modules:
167
+ for sub_mod in mod.modules():
168
+ if isinstance(sub_mod, include_class):
169
+ for name, param in sub_mod.named_parameters():
170
+ if any([k in name for k in include_name]):
171
+ params_ids.append(id(param))
172
+ params_ids = list(set(params_ids))
173
+ return params_ids
174
+
175
+ def set_weight_decay(modules: list):
176
+ weight_decay_ids = _find_weight_decay_id(modules, [])
177
+ wd_params, wd_names, no_wd_params, no_wd_names = [], [], [], []
178
+ for mod in modules:
179
+ for name, param in mod.named_parameters():
180
+ if id(param) in weight_decay_ids:
181
+ wd_params.append(param)
182
+ wd_names.append(name)
183
+ else:
184
+ no_wd_params.append(param)
185
+ no_wd_names.append(name)
186
+ return wd_params, wd_names, no_wd_params, no_wd_names
187
+
188
+ def setup_optimizers_ae(config: OmegaConf, model: nn.Module, total_steps: int):
189
+ L.log.info("\n\n### Setting up the optimizers and schedulers. ###")
190
+
191
+ # compute the total batch size and the learning rate
192
+ total_batch_size = config.data.batch_size * config.world_size * config.data.gradient_accumulate
193
+ total_learning_rate = config.train.learning_rate * total_batch_size
194
+ multipled_learning_rate = total_learning_rate * config.train.mul_learning_rate
195
+ L.log.info(f"Total batch size: {total_batch_size} = {config.data.batch_size} * {config.world_size} * {config.data.gradient_accumulate}")
196
+ L.log.info(f"Total learning rate: {total_learning_rate} = {config.train.learning_rate} * {total_batch_size}")
197
+ L.log.info(f"Multipled learning rate: {multipled_learning_rate} = {total_learning_rate} * {config.train.mul_learning_rate}")
198
+
199
+ # setup the optimizers
200
+ param_group = []
201
+ ## base learning rate
202
+ wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.encoder, model.decoder, model.quant_conv, model.post_quant_conv])
203
+ param_group.append({
204
+ "params": wd_params, "lr": total_learning_rate, "eps": 1e-7,
205
+ "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
206
+ })
207
+ param_group.append({
208
+ "params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7,
209
+ "weight_decay": 0.0, "beta": (0.9, 0.999),
210
+ })
211
+ ## multipled learning rate
212
+ wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.quantize])
213
+ param_group.append({
214
+ "params": wd_params, "lr": multipled_learning_rate, "eps": 1e-7,
215
+ "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
216
+ })
217
+ param_group.append({
218
+ "params": no_wd_params, "lr": multipled_learning_rate, "eps": 1e-7,
219
+ "weight_decay": 0.0, "beta": (0.9, 0.999),
220
+ })
221
+
222
+ optimizer_ae = torch.optim.AdamW(param_group)
223
+ optimizer_dict = {"optimizer_ae": optimizer_ae}
224
+
225
+ # setup the schedulers
226
+ scheduler_ae = torch.optim.lr_scheduler.OneCycleLR(
227
+ optimizer=optimizer_ae, max_lr=[total_learning_rate, total_learning_rate, multipled_learning_rate, multipled_learning_rate],
228
+ total_steps=total_steps, pct_start=0.01, anneal_strategy="cos"
229
+ )
230
+ scheduler_dict = {"scheduler_ae": scheduler_ae}
231
+
232
+ # setup the scalers
233
+ scaler_dict = {"scaler_ae": torch.GradScaler(enabled=config.use_amp)}
234
+ L.log.info(f"Enable AMP: {config.use_amp}")
235
+ return optimizer_dict, scheduler_dict, scaler_dict
236
+
237
+ def pipeline_ae(
238
+ config,
239
+ x: torch.Tensor,
240
+ model: nn.Module,
241
+ optimizers: dict,
242
+ schedulers: dict,
243
+ scalers: dict,
244
+ ):
245
+ assert "optimizer_ae" in optimizers
246
+ assert "scheduler_ae" in schedulers
247
+ assert "scaler_ae" in scalers
248
+
249
+ optimizer = optimizers["optimizer_ae"]
250
+ scheduler = schedulers["scheduler_ae"]
251
+ scaler = scalers["scaler_ae"]
252
+
253
+ forward = partial(model, mode=0)
254
+ _, (loss_ae_dict, indices) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler)
255
+
256
+ log_per_step = loss_ae_dict
257
+ log_per_epoch = {"indices": indices}
258
+ return log_per_step, log_per_epoch
259
+
260
+ ### autoencoder + disc version
261
+
262
+ def setup_optimizers_ae_disc(config: OmegaConf, model: nn.Module, total_steps: int):
263
+ L.log.info("\n\n### Setting up the optimizers and schedulers. ###")
264
+
265
+ # compute the total batch size and the learning rate
266
+ total_batch_size = config.data.batch_size * config.world_size * config.data.gradient_accumulate
267
+ total_learning_rate = config.train.learning_rate * total_batch_size
268
+ multipled_learning_rate = total_learning_rate * config.train.mul_learning_rate
269
+ L.log.info(f"Total batch size: {total_batch_size} = {config.data.batch_size} * {config.world_size} * {config.data.gradient_accumulate}")
270
+ L.log.info(f"Total learning rate: {total_learning_rate} = {config.train.learning_rate} * {total_batch_size}")
271
+ L.log.info(f"Multipled learning rate: {multipled_learning_rate} = {total_learning_rate} * {config.train.mul_learning_rate}")
272
+
273
+ # setup the optimizers
274
+ param_group = []
275
+ ## base learning rate
276
+ wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.encoder, model.decoder, model.quant_conv, model.post_quant_conv])
277
+ param_group.append({
278
+ "params": wd_params, "lr": total_learning_rate, "eps": 1e-7,
279
+ "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
280
+ })
281
+ param_group.append({
282
+ "params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7,
283
+ "weight_decay": 0.0, "beta": (0.9, 0.999),
284
+ })
285
+ ## multipled learning rate
286
+ wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.quantize])
287
+ param_group.append({
288
+ "params": wd_params, "lr": multipled_learning_rate, "eps": 1e-7,
289
+ "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
290
+ })
291
+ param_group.append({
292
+ "params": no_wd_params, "lr": multipled_learning_rate, "eps": 1e-7,
293
+ "weight_decay": 0.0, "beta": (0.9, 0.999),
294
+ })
295
+ optimizer_ae = torch.optim.AdamW(param_group)
296
+
297
+ param_group = []
298
+ wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.loss.discriminator])
299
+ param_group.append({
300
+ "params": wd_params, "lr": total_learning_rate, "eps": 1e-7,
301
+ "weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
302
+ })
303
+ param_group.append({
304
+ "params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7,
305
+ "weight_decay": 0.0, "beta": (0.9, 0.999),
306
+ })
307
+ optimizer_disc = torch.optim.AdamW(param_group)
308
+ optimizer_dict = {"optimizer_ae": optimizer_ae, "optimizer_disc": optimizer_disc}
309
+
310
+ # setup the schedulers
311
+ scheduler_ae = torch.optim.lr_scheduler.OneCycleLR(
312
+ optimizer=optimizer_ae, max_lr=[total_learning_rate, total_learning_rate, multipled_learning_rate, multipled_learning_rate],
313
+ total_steps=total_steps, pct_start=0.01, anneal_strategy="cos"
314
+ )
315
+ scheduler_disc = torch.optim.lr_scheduler.OneCycleLR(
316
+ optimizer=optimizer_disc, max_lr=[total_learning_rate, total_learning_rate],
317
+ total_steps=total_steps, pct_start=0.01, anneal_strategy="cos"
318
+ )
319
+ scheduler_dict = {"scheduler_ae": scheduler_ae, "scheduler_disc": scheduler_disc}
320
+
321
+ # setup the scalers
322
+ scaler_dict = {"scaler_ae": torch.GradScaler(enabled=config.use_amp),
323
+ "scaler_disc": torch.GradScaler(enabled=config.use_amp)}
324
+ L.log.info(f"Enable AMP: {config.use_amp}")
325
+ return optimizer_dict, scheduler_dict, scaler_dict
326
+
327
+ def pipeline_ae_disc(
328
+ config,
329
+ x: torch.Tensor,
330
+ model: nn.Module,
331
+ optimizers: dict,
332
+ schedulers: dict,
333
+ scalers: dict,
334
+ ):
335
+ # autoencoder step
336
+ assert "optimizer_ae" in optimizers
337
+ assert "scheduler_ae" in schedulers
338
+ assert "scaler_ae" in scalers
339
+
340
+ optimizer = optimizers["optimizer_ae"]
341
+ scheduler = schedulers["scheduler_ae"]
342
+ scaler = scalers["scaler_ae"]
343
+
344
+ forward = partial(model, mode=0)
345
+ _, (loss_ae_dict, indices) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler)
346
+
347
+ log_per_step = loss_ae_dict
348
+ log_per_epoch = {"indices": indices}
349
+
350
+ # discriminator step
351
+ assert "optimizer_disc" in optimizers
352
+ assert "scheduler_disc" in schedulers
353
+ assert "scaler_disc" in scalers
354
+
355
+ optimizer = optimizers["optimizer_disc"]
356
+ scheduler = schedulers["scheduler_disc"]
357
+ scaler = scalers["scaler_disc"]
358
+
359
+ forward = partial(model, mode=1)
360
+ _, (loss_disc_dict, _) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler)
361
+ log_per_step.update(loss_disc_dict)
362
+ return log_per_step, log_per_epoch
optvq/utils/__pycache__/func.cpython-310.pyc ADDED
Binary file (1.08 kB). View file
 
optvq/utils/__pycache__/init.cpython-310.pyc ADDED
Binary file (1.25 kB). View file
 
optvq/utils/__pycache__/logger.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
optvq/utils/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (2.08 kB). View file
 
optvq/utils/func.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ from typing import Tuple, Union, Iterable
8
+ from omegaconf import OmegaConf
9
+ import os
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ def dist_all_gather(x):
15
+ tensor_list = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
16
+ dist.all_gather(tensor_list, x)
17
+ x = torch.cat(tensor_list, dim=0)
18
+ return x
19
+
20
+ def any_2tuple(data: Union[int, Tuple[int]]) -> Tuple[int]:
21
+ if isinstance(data, int):
22
+ return (data, data)
23
+ elif isinstance(data, Iterable):
24
+ assert len(data) == 2, "target size must be tuple of (w, h)"
25
+ return tuple(data)
26
+ else:
27
+ raise ValueError("target size must be int or tuple of (w, h)")
optvq/utils/init.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ import importlib
8
+ import random
9
+ import numpy as np
10
+ from typing import Mapping, Optional
11
+
12
+ import torch
13
+
14
+ def initiate_from_config(config: Mapping):
15
+ assert "target" in config, f"Expected key `target` to initialize!"
16
+ module, cls = config["target"].rsplit(".", 1)
17
+ meta_class = getattr(importlib.import_module(module, package=None), cls)
18
+ return meta_class(**config.get("params", dict()))
19
+
20
+ def initiate_from_config_recursively(config: Mapping):
21
+ assert "target" in config, f"Expected key `target` to initialize!"
22
+ update_config = {"target": config["target"], "params": {}}
23
+ for k, v in config["params"].items():
24
+ if isinstance(v, Mapping) and "target" in v:
25
+ sub_instance = initiate_from_config_recursively(v)
26
+ update_config["params"][k] = sub_instance
27
+ else:
28
+ update_config["params"][k] = v
29
+ return initiate_from_config(update_config)
30
+
31
+ def seed_everything(seed: Optional[int] = None):
32
+ seed = int(seed)
33
+ random.seed(seed)
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.backends.cudnn.deterministic = True
optvq/utils/logger.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ import time
8
+ import datetime
9
+ from typing import List
10
+ import functools
11
+ import os
12
+ from PIL import Image
13
+ from termcolor import colored
14
+ import sys
15
+ import logging
16
+ from omegaconf import OmegaConf
17
+ import json
18
+
19
+ try:
20
+ from torch.utils.tensorboard import SummaryWriter
21
+ from torch import Tensor
22
+ import torch
23
+ except:
24
+ raise ImportError("Please install torch to use this module!")
25
+
26
+ """
27
+ NOTE: The `log` instance is a global variable, which should be imported by other modules as:
28
+ `import optvq.utils.logger as logger`
29
+ rather than
30
+ `from optvq.utils.logger import log`.
31
+ """
32
+
33
+ def setup_printer(file_log_dir: str, use_console: bool = True):
34
+ printer = logging.getLogger("LOG")
35
+ printer.setLevel(logging.DEBUG)
36
+ printer.propagate = False
37
+
38
+ # create formatter
39
+ fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
40
+ color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
41
+ colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
42
+
43
+ # create the console handler
44
+ if use_console:
45
+ console_handler = logging.StreamHandler(sys.stdout)
46
+ console_handler.setLevel(logging.DEBUG)
47
+ console_handler.setFormatter(
48
+ logging.Formatter(fmt=color_fmt, datefmt="%Y-%m-%d %H:%M:%S")
49
+ )
50
+ printer.addHandler(console_handler)
51
+
52
+ # create the file handler
53
+ file_handler = logging.FileHandler(os.path.join(file_log_dir, "record.txt"), mode="a")
54
+ file_handler.setLevel(logging.DEBUG)
55
+ file_handler.setFormatter(
56
+ logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S")
57
+ )
58
+ printer.addHandler(file_handler)
59
+
60
+ return printer
61
+
62
+ @functools.lru_cache()
63
+ def config_loggers(log_dir: str, local_rank: int = 0, master_rank: int = 0):
64
+ global log
65
+
66
+ if local_rank == master_rank:
67
+ log = LogManager(log_dir=log_dir, main_logger=True)
68
+ else:
69
+ log = LogManager(log_dir=log_dir, main_logger=False)
70
+
71
+ class ProgressWithIndices:
72
+ def __init__(self, total: int, sep_char: str = "| ",
73
+ num_per_row: int = 4):
74
+ self.total = total
75
+ self.sep_char = sep_char
76
+ self.num_per_row = num_per_row
77
+
78
+ self.count = 0
79
+ self.start_time = time.time()
80
+ self.past_time = None
81
+ self.current_time = None
82
+ self.eta = None
83
+ self.speed = None
84
+ self.used_time = 0
85
+
86
+ def update(self):
87
+ self.count += 1
88
+ if self.count <= self.total:
89
+ self.past_time = self.current_time
90
+ self.current_time = time.time()
91
+ # compute eta
92
+ if self.past_time is not None:
93
+ self.eta = (self.total - self.count) * (self.current_time - self.past_time)
94
+ self.eta = str(datetime.timedelta(seconds=int(self.eta)))
95
+ self.speed = 1 / (self.current_time - self.past_time + 1e-8)
96
+ # compute used time
97
+ self.used_time = self.current_time - self.start_time
98
+ self.used_time = str(datetime.timedelta(seconds=int(self.used_time)))
99
+ else:
100
+ self.eta = 0
101
+ self.speed = 0
102
+ self.past_time = None
103
+ self.current_time = None
104
+
105
+ def print(self, prefix: str = "", content: str = "", ):
106
+ global log
107
+ prefix_str = f"{prefix}\t" + f"[{self.count}/{self.total} {self.used_time}/Eta:{self.eta}], Speed:{self.speed}iters/s\n"
108
+ content_list = content.split(self.sep_char)
109
+ content_list = [content.strip() for content in content_list]
110
+ content_list = [
111
+ "\t\t" + self.sep_char.join(content_list[i:i + self.num_per_row])
112
+ for i in range(0, len(content_list), self.num_per_row)
113
+ ]
114
+ content = prefix_str + "\n".join(content_list)
115
+ log.info(content)
116
+
117
+ class LogManager:
118
+ """
119
+ This class encapsulates the tensorboard writer, the statistic meters, the console printer, and the progress counters.
120
+
121
+ Args:
122
+ log_dir (str): the parent directory to save all the logs
123
+ init_meters (List[str]): the initial meters to be shown
124
+ show_avg (bool): whether to show the average value of the meters
125
+ """
126
+ def __init__(self, log_dir: str, init_meters: List[str] = [],
127
+ show_avg: bool = True, main_logger: bool = False):
128
+
129
+ # initiate all the directories
130
+ self.show_avg = show_avg
131
+ self.log_dir = log_dir
132
+ self.main_logger = main_logger
133
+ self.setup_dirs()
134
+
135
+ # initiate the statistic meters
136
+ self.meters = {meter: AverageMeter() for meter in init_meters}
137
+
138
+ # initiate the progress counters
139
+ self.total_steps = 0
140
+ self.total_epochs = 0
141
+
142
+ if self.main_logger:
143
+ # initiate the tensorboard writer
144
+ self.board = SummaryWriter(log_dir=self.tb_log_dir)
145
+
146
+ # initiate the console printer
147
+ self.printer = setup_printer(self.file_log_dir, use_console=True)
148
+
149
+ def state_dict(self):
150
+ return {
151
+ "total_steps": self.total_steps,
152
+ "total_epochs": self.total_epochs,
153
+ "meters": {
154
+ meter_name: meter.state_dict() for meter_name, meter in self.meters.items()
155
+ }
156
+ }
157
+
158
+ def load_state_dict(self, state_dict: dict):
159
+ self.total_steps = state_dict["total_steps"]
160
+ self.total_epochs = state_dict["total_epochs"]
161
+ for meter_name, meter_state_dict in state_dict["meters"].items():
162
+ if meter_name not in self.meters:
163
+ self.meters[meter_name] = AverageMeter()
164
+ self.meters[meter_name].load_state_dict(meter_state_dict)
165
+
166
+ ### About directories
167
+ def setup_dirs(self):
168
+ """
169
+ The structure of the log directory:
170
+ - log_dir: [tb_log, txt_log, img_log, model_log]
171
+ """
172
+ self.tb_log_dir = os.path.join(self.log_dir, "tb_log")
173
+ # NOTE: For now, we save the txt records in the parent directory
174
+ # self.file_log_dir = os.path.join(self.log_dir, "txt_log")
175
+ self.file_log_dir = self.log_dir
176
+ self.img_log_dir = os.path.join(self.log_dir, "img_log")
177
+
178
+ self.config_path = os.path.join(self.log_dir, "config.yaml")
179
+ self.checkpoint_path = os.path.join(self.log_dir, "checkpoint.pth")
180
+ self.backup_checkpoint_path = os.path.join(self.log_dir, "checkpoint.pth")
181
+ self.save_logger_path = os.path.join(self.log_dir, "logger.json")
182
+
183
+ if self.main_logger:
184
+ os.makedirs(self.tb_log_dir, exist_ok=True)
185
+ os.makedirs(self.file_log_dir, exist_ok=True)
186
+ os.makedirs(self.img_log_dir, exist_ok=True)
187
+
188
+ ### About printer
189
+
190
+ def info(self, msg, *args, **kwargs):
191
+ if self.main_logger:
192
+ self.printer.info(msg, *args, **kwargs)
193
+
194
+ def show(self, include_key: str = ""):
195
+ if isinstance(include_key, str):
196
+ include_key = [include_key]
197
+ if self.show_avg:
198
+ return "| ".join([f"{meter_name}: {meter.val:.4f}/{meter.avg:.4f}" for meter_name, meter in self.meters.items() if any([k in meter_name for k in include_key])])
199
+ else:
200
+ return "| ".join([f"{meter_name}: {meter.val:.4f}" for meter_name, meter in self.meters.items() if any([k in meter_name for k in include_key])])
201
+
202
+ ### About counter
203
+
204
+ def update_steps(self):
205
+ self.total_steps += 1
206
+ return self.total_steps
207
+
208
+ def update_epochs(self):
209
+ self.total_epochs += 1
210
+ return self.total_epochs
211
+
212
+ ### About tensorboard
213
+ def add_histogram(self, tag: str, values: Tensor, global_step: int = None):
214
+ if self.main_logger:
215
+ global_step = self.total_steps if global_step is None else global_step
216
+ self.board.add_histogram(tag, values, global_step)
217
+
218
+ def add_scalar(self, tag: str, scalar_value: float, global_step: int = None):
219
+ if isinstance(scalar_value, Tensor):
220
+ scalar_value = scalar_value.item()
221
+ if tag in self.meters:
222
+ cur_step = self.meters[tag].update(scalar_value)
223
+ cur_step = cur_step if global_step is None else global_step
224
+ if self.main_logger:
225
+ self.board.add_scalar(tag, scalar_value, cur_step)
226
+ else:
227
+ self.meters[tag] = AverageMeter()
228
+ cur_step = self.meters[tag].update(scalar_value)
229
+ cur_step = cur_step if global_step is None else global_step
230
+ if self.main_logger:
231
+ print(f"Create new meter: {tag}!")
232
+ self.board.add_scalar(tag, scalar_value, cur_step)
233
+
234
+ def add_scalar_dict(self, scalar_dict: dict, global_step: int = None):
235
+ for tag, scalar_value in scalar_dict.items():
236
+ self.add_scalar(tag, scalar_value, global_step)
237
+
238
+ def add_images(self, tag: str, images: Tensor, global_step: int = None):
239
+ if self.main_logger:
240
+ global_step = self.total_steps if global_step is None else global_step
241
+ self.board.add_images(tag, images, global_step, dataformats="NCHW")
242
+
243
+ ### About saving and resuming
244
+ def save_configs(self, config):
245
+ if self.main_logger:
246
+ # save config as yaml file
247
+ OmegaConf.save(config, self.config_path)
248
+ self.info(f"Save config to {self.config_path}.")
249
+
250
+ # save logger
251
+ state_dict = self.state_dict()
252
+ with open(self.save_logger_path, "w") as f:
253
+ json.dump(state_dict, f)
254
+
255
+ def load_configs(self):
256
+ # load config
257
+ assert os.path.exists(self.config_path), f"Config {self.config_path} does not exist!"
258
+ config = OmegaConf.load(self.config_path)
259
+
260
+ # load logger
261
+ assert os.path.exists(self.save_logger_path), f"Logger {self.save_logger_path} does not exist!"
262
+ state_dict = json.load(open(self.save_logger_path, "r"))
263
+ self.load_state_dict(state_dict)
264
+
265
+ return config
266
+
267
+ def save_checkpoint(self, model, optimizers, schedulers, scalers, suffix: str = ""):
268
+ """
269
+ checkpoint_dict: model, optimizer, scheduler, scalers
270
+ """
271
+ if self.main_logger:
272
+
273
+ # save checkpoint_dict
274
+ checkpoint_dict = {
275
+ "model": model.state_dict(),
276
+ "epoch": self.total_epochs,
277
+ "step": self.total_steps
278
+ }
279
+ checkpoint_dict.update({k: v.state_dict() for k, v in optimizers.items()})
280
+ checkpoint_dict.update({k: v.state_dict() for k, v in schedulers.items() if v is not None})
281
+ checkpoint_dict.update({k: v.state_dict() for k, v in scalers.items()})
282
+
283
+ checkpoint_path = self.checkpoint_path + suffix
284
+ torch.save(checkpoint_dict, checkpoint_path)
285
+ if os.path.exists(self.backup_checkpoint_path):
286
+ os.remove(self.backup_checkpoint_path)
287
+ self.backup_checkpoint_path = checkpoint_path + f".epoch{self.total_epochs}"
288
+ torch.save(checkpoint_dict, self.backup_checkpoint_path)
289
+
290
+ self.info(f"### Epoch: {self.total_epochs}| Steps: {self.total_steps}| Save checkpoint to {checkpoint_path}.")
291
+
292
+ def load_checkpoint(self, device, model, optimizers, schedulers, scalers, resume: str = None):
293
+ resume_path = self.checkpoint_path if resume is None else resume
294
+ assert os.path.exists(resume_path), f"Resume {resume_path} does not exist!"
295
+
296
+ # load checkpoint_dict
297
+ checkpoint_dict = torch.load(resume_path, map_location=device)
298
+ model.load_state_dict(checkpoint_dict["model"])
299
+ self.total_epochs = checkpoint_dict["epoch"]
300
+ self.total_steps = checkpoint_dict["step"]
301
+ for k, v in optimizers.items():
302
+ v.load_state_dict(checkpoint_dict[k])
303
+ for k, v in schedulers.items():
304
+ v.load_state_dict(checkpoint_dict[k])
305
+ for k, v in scalers.items():
306
+ v.load_state_dict(checkpoint_dict[k])
307
+
308
+ self.info(f"### Epoch: {self.total_epochs}| Steps: {self.total_steps}| Resume checkpoint from {resume_path}.")
309
+
310
+ return self.total_epochs
311
+
312
+ class EmptyManager:
313
+ def __init__(self):
314
+ for func_name in LogManager.__dict__.keys():
315
+ if not func_name.startswith("_"):
316
+ setattr(self, func_name, lambda *args, **kwargs: print(f"Empty Manager! {func_name} is not available!"))
317
+
318
+ class AverageMeter:
319
+ def __init__(self):
320
+ self.reset()
321
+
322
+ def state_dict(self):
323
+ return {
324
+ "val": self.val,
325
+ "avg": self.avg,
326
+ "sum": self.sum,
327
+ "count": self.count,
328
+ }
329
+
330
+ def load_state_dict(self, state_dict: dict):
331
+ self.val = state_dict["val"]
332
+ self.avg = state_dict["avg"]
333
+ self.sum = state_dict["sum"]
334
+ self.count = state_dict["count"]
335
+
336
+ def reset(self):
337
+ self.val = 0
338
+ self.avg = 0
339
+ self.sum = 0
340
+ self.count = 0
341
+ return 0
342
+
343
+ def update(self, val: float, n: int = 1):
344
+ self.val = val
345
+ self.sum += val * n
346
+ self.count += n
347
+ self.avg = self.sum / self.count
348
+ return self.count
349
+
350
+ def __str__(self):
351
+ return f"{self.avg:.4f}"
352
+
353
+ def save_image(x: Tensor, save_path: str, scale_to_256: bool = True):
354
+ """
355
+ Args:
356
+ x (tensor): default data range is [0, 1]
357
+ """
358
+ if scale_to_256:
359
+ x = x.mul(255).clamp(0, 255)
360
+ x = x.permute(1, 2, 0).detach().cpu().numpy().astype("uint8")
361
+ img = Image.fromarray(x)
362
+ img.save(save_path)
363
+
364
+ def save_images(images_list, ids_list, meta_path):
365
+ for i, (image, id) in enumerate(zip(images_list, ids_list)):
366
+ save_path = os.path.join(meta_path, f"{id}.png")
367
+ save_image(image, save_path)
368
+
369
+ def save_images_multithread(images_list, ids_list, meta_path):
370
+ n_workers = 32
371
+ from concurrent.futures import ThreadPoolExecutor
372
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
373
+ for i in range(0, len(images_list), n_workers):
374
+ cur_images = images_list[i:(i + n_workers)]
375
+ cur_ids = ids_list[i:(i + n_workers)]
376
+ executor.submit(save_images, cur_images, cur_ids, meta_path)
377
+
378
+ def add_prefix(log_dict: dict, prefix: str):
379
+ return {
380
+ f"{prefix}/{key}": val for key, val in log_dict.items()
381
+ }
382
+
383
+ ##################### GLOBAL VARIABLES #####################
384
+ log = EmptyManager()
385
+ GET_STATS: bool = (os.environ.get("ENABLE_STATS", "1") == "1")
386
+ ###########################################################
optvq/utils/metrics.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------
2
+ # OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
3
+ # Copyright (c) 2024 Borui Zhang. All Rights Reserved.
4
+ # Licensed under the MIT License [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------
6
+
7
+ import numpy as np
8
+
9
+ import torch
10
+
11
+ from pytorch_fid.inception import InceptionV3
12
+ from pytorch_fid.fid_score import calculate_frechet_distance
13
+
14
+ class FIDMetric:
15
+ def __init__(self, device, dims=2048):
16
+ self.device = device
17
+ self.num_workers = 32
18
+
19
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
20
+ self.model = InceptionV3([block_idx]).to(device)
21
+ self.model.eval()
22
+
23
+ self.reset_metrics()
24
+
25
+ def reset_metrics(self):
26
+ self.x_pred = []
27
+ self.x_rec_pred = []
28
+
29
+ @torch.no_grad()
30
+ def get_activates(self, x: torch.Tensor):
31
+ pred = self.model(x)[0]
32
+ # If model output is not scalar, apply global spatial average pooling.
33
+ # This happens if you choose a dimensionality not equal 2048.
34
+ if pred.size(2) != 1 or pred.size(3) != 1:
35
+ pred = torch.nn.functional.adaptive_avg_pool2d(pred, output_size=(1, 1))
36
+ return pred.squeeze().cpu().numpy()
37
+
38
+ def update(self, x: torch.Tensor, x_rec: torch.Tensor):
39
+ """
40
+ Args:
41
+ x (torch.Tensor): input tensor range from 0 to 1
42
+ x_rec (torch.Tensor): reconstructed tensor range from 0 to 1
43
+ """
44
+ self.x_pred.append(self.get_activates(x))
45
+ self.x_rec_pred.append(self.get_activates(x_rec))
46
+
47
+ def result(self):
48
+ assert len(self.x_pred) != 0, "No data to compute FID"
49
+ x = np.concatenate(self.x_pred, axis=0)
50
+ x_rec = np.concatenate(self.x_rec_pred, axis=0)
51
+
52
+ x_mu = np.mean(x, axis=0)
53
+ x_sigma = np.cov(x, rowvar=False)
54
+
55
+ x_rec_mu = np.mean(x_rec, axis=0)
56
+ x_rec_sigma = np.cov(x_rec, rowvar=False)
57
+
58
+ fid_score = calculate_frechet_distance(x_mu, x_sigma, x_rec_mu, x_rec_sigma)
59
+ self.reset_metrics()
60
+ return fid_score
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.3.0
2
+ torchvision
3
+ torchaudio
4
+ numpy
5
+ matplotlib
6
+ scikit-learn
7
+ opencv-python
8
+ einops>=0.7.0
9
+ omegaconf>=2.3.0
10
+ lightning>=2.1.3
11
+ transformers>=4.36.2
12
+ pudb
13
+ wandb
14
+ tensorboard
15
+ termcolor
16
+ lpips==0.1.4
17
+ ftfy
18
+ imageio
19
+ imageio-ffmpeg
20
+ pyiqa
21
+ clean-fid
22
+ albumentations
23
+ pytorch-fid
24
+ torchinfo
25
+ gradio